diff options
Diffstat (limited to 'atuin-client/src/record/sync.rs')
-rw-r--r-- | atuin-client/src/record/sync.rs | 421 |
1 files changed, 421 insertions, 0 deletions
diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs new file mode 100644 index 00000000..ebdb8eb2 --- /dev/null +++ b/atuin-client/src/record/sync.rs @@ -0,0 +1,421 @@ +// do a sync :O +use eyre::Result; + +use super::store::Store; +use crate::{api_client::Client, settings::Settings}; + +use atuin_common::record::{Diff, HostId, RecordId, RecordIndex}; + +#[derive(Debug, Eq, PartialEq)] +pub enum Operation { + // Either upload or download until the tail matches the below + Upload { + tail: RecordId, + host: HostId, + tag: String, + }, + Download { + tail: RecordId, + host: HostId, + tag: String, + }, +} + +pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> { + let client = Client::new(&settings.sync_address, &settings.session_token)?; + + let local_index = store.tail_records().await?; + let remote_index = client.record_index().await?; + + let diff = local_index.diff(&remote_index); + + Ok((diff, remote_index)) +} + +// Take a diff, along with a local store, and resolve it into a set of operations. +// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. +// In theory this could be done as a part of the diffing stage, but it's easier to reason +// about and test this way +pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> { + let mut operations = Vec::with_capacity(diffs.len()); + + for diff in diffs { + // First, try to fetch the tail + // If it exists locally, then that means we need to update the remote + // host until it has the same tail. Ie, upload. + // If it does not exist locally, that means remote is ahead of us. + // Therefore, we need to download until our local tail matches + let record = store.get(diff.tail).await; + + let op = if record.is_ok() { + // if local has the ID, then we should find the actual tail of this + // store, so we know what we need to update the remote to. + let tail = store + .tail(diff.host, diff.tag.as_str()) + .await? + .expect("failed to fetch last record, expected tag/host to exist"); + + // TODO(ellie) update the diffing so that it stores the context of the current tail + // that way, we can determine how much we need to upload. + // For now just keep uploading until tails match + + Operation::Upload { + tail: tail.id, + host: diff.host, + tag: diff.tag, + } + } else { + Operation::Download { + tail: diff.tail, + host: diff.host, + tag: diff.tag, + } + }; + + operations.push(op); + } + + // sort them - purely so we have a stable testing order, and can rely on + // same input = same output + // We can sort by ID so long as we continue to use UUIDv7 or something + // with the same properties + + operations.sort_by_key(|op| match op { + Operation::Upload { tail, host, .. } => ("upload", *host, *tail), + Operation::Download { tail, host, .. } => ("download", *host, *tail), + }); + + Ok(operations) +} + +async fn sync_upload( + store: &mut impl Store, + remote_index: &RecordIndex, + client: &Client<'_>, + op: (HostId, String, RecordId), +) -> Result<i64> { + let upload_page_size = 100; + let mut total = 0; + + // so. we have an upload operation, with the tail representing the state + // we want to get the remote to + let current_tail = remote_index.get(op.0, op.1.clone()); + + println!( + "Syncing local {:?}/{}/{:?}, remote has {:?}", + op.0, op.1, op.2, current_tail + ); + + let start = if let Some(current_tail) = current_tail { + current_tail + } else { + store + .head(op.0, op.1.as_str()) + .await + .expect("failed to fetch host/tag head") + .expect("host/tag not in current index") + .id + }; + + debug!("starting push to remote from: {:?}", start); + + // we have the start point for sync. it is either the head of the store if + // the remote has no data for it, or the tail that the remote has + // we need to iterate from the remote tail, and keep going until + // remote tail = current local tail + + let mut record = Some(store.get(start).await.unwrap()); + + let mut buf = Vec::with_capacity(upload_page_size); + + while let Some(r) = record { + if buf.len() < upload_page_size { + buf.push(r.clone()); + } else { + client.post_records(&buf).await?; + + // can we reset what we have? len = 0 but keep capacity + buf = Vec::with_capacity(upload_page_size); + } + record = store.next(&r).await?; + + total += 1; + } + + if !buf.is_empty() { + client.post_records(&buf).await?; + } + + Ok(total) +} + +async fn sync_download( + store: &mut impl Store, + remote_index: &RecordIndex, + client: &Client<'_>, + op: (HostId, String, RecordId), +) -> Result<i64> { + // TODO(ellie): implement variable page sizing like on history sync + let download_page_size = 1000; + + let mut total = 0; + + // We know that the remote is ahead of us, so let's keep downloading until both + // 1) The remote stops returning full pages + // 2) The tail equals what we expect + // + // If (1) occurs without (2), then something is wrong with our index calculation + // and we should bail. + let remote_tail = remote_index + .get(op.0, op.1.clone()) + .expect("remote index does not contain expected tail during download"); + let local_tail = store.tail(op.0, op.1.as_str()).await?; + // + // We expect that the operations diff will represent the desired state + // In this case, that contains the remote tail. + assert_eq!(remote_tail, op.2); + + println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2); + + let mut records = client + .next_records( + op.0, + op.1.clone(), + local_tail.map(|r| r.id), + download_page_size, + ) + .await?; + + while !records.is_empty() { + total += std::cmp::min(download_page_size, records.len() as u64); + store.push_batch(records.iter()).await?; + + if records.last().unwrap().id == remote_tail { + break; + } + + records = client + .next_records( + op.0, + op.1.clone(), + records.last().map(|r| r.id), + download_page_size, + ) + .await?; + } + + Ok(total as i64) +} + +pub async fn sync_remote( + operations: Vec<Operation>, + remote_index: &RecordIndex, + local_store: &mut impl Store, + settings: &Settings, +) -> Result<(i64, i64)> { + let client = Client::new(&settings.sync_address, &settings.session_token)?; + + let mut uploaded = 0; + let mut downloaded = 0; + + // this can totally run in parallel, but lets get it working first + for i in operations { + match i { + Operation::Upload { tail, host, tag } => { + uploaded += + sync_upload(local_store, remote_index, &client, (host, tag, tail)).await? + } + Operation::Download { tail, host, tag } => { + downloaded += + sync_download(local_store, remote_index, &client, (host, tag, tail)).await? + } + } + } + + Ok((uploaded, downloaded)) +} + +#[cfg(test)] +mod tests { + use atuin_common::record::{Diff, EncryptedData, HostId, Record}; + use pretty_assertions::assert_eq; + + use crate::record::{ + encryption::PASETO_V4, + sqlite_store::SqliteStore, + store::Store, + sync::{self, Operation}, + }; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(HostId(atuin_common::utils::uuid_v7())) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: String::new(), + content_encryption_key: String::new(), + }) + .build() + } + + // Take a list of local records, and a list of remote records. + // Return the local database, and a diff of local/remote, ready to build + // ops + async fn build_test_diff( + local_records: Vec<Record<EncryptedData>>, + remote_records: Vec<Record<EncryptedData>>, + ) -> (SqliteStore, Vec<Diff>) { + let local_store = SqliteStore::new(":memory:") + .await + .expect("failed to open in memory sqlite"); + let remote_store = SqliteStore::new(":memory:") + .await + .expect("failed to open in memory sqlite"); // "remote" + + for i in local_records { + local_store.push(&i).await.unwrap(); + } + + for i in remote_records { + remote_store.push(&i).await.unwrap(); + } + + let local_index = local_store.tail_records().await.unwrap(); + let remote_index = remote_store.tail_records().await.unwrap(); + + let diff = local_index.diff(&remote_index); + + (local_store, diff) + } + + #[tokio::test] + async fn test_basic_diff() { + // a diff where local is ahead of remote. nothing else. + + let record = test_record(); + let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; + + assert_eq!(diff.len(), 1); + + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 1); + + assert_eq!( + operations[0], + Operation::Upload { + host: record.host, + tag: record.tag, + tail: record.id + } + ); + } + + #[tokio::test] + async fn build_two_way_diff() { + // a diff where local is ahead of remote for one, and remote for + // another. One upload, one download + + let shared_record = test_record(); + + let remote_ahead = test_record(); + let local_ahead = shared_record + .new_child(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store + let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 2); + + assert_eq!( + operations, + vec![ + Operation::Download { + tail: remote_ahead.id, + host: remote_ahead.host, + tag: remote_ahead.tag, + }, + Operation::Upload { + tail: local_ahead.id, + host: local_ahead.host, + tag: local_ahead.tag, + }, + ] + ); + } + + #[tokio::test] + async fn build_complex_diff() { + // One shared, ahead but known only by remote + // One known only by local + // One known only by remote + + let shared_record = test_record(); + + let remote_known = test_record(); + let local_known = test_record(); + + let second_shared = test_record(); + let second_shared_remote_ahead = second_shared + .new_child(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let local_ahead = shared_record + .new_child(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let local = vec![ + shared_record.clone(), + second_shared.clone(), + local_known.clone(), + local_ahead.clone(), + ]; + + let remote = vec![ + shared_record.clone(), + second_shared.clone(), + second_shared_remote_ahead.clone(), + remote_known.clone(), + ]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 4); + + let mut result_ops = vec![ + Operation::Download { + tail: remote_known.id, + host: remote_known.host, + tag: remote_known.tag, + }, + Operation::Download { + tail: second_shared_remote_ahead.id, + host: second_shared.host, + tag: second_shared.tag, + }, + Operation::Upload { + tail: local_ahead.id, + host: local_ahead.host, + tag: local_ahead.tag, + }, + Operation::Upload { + tail: local_known.id, + host: local_known.host, + tag: local_known.tag, + }, + ]; + + result_ops.sort_by_key(|op| match op { + Operation::Upload { tail, host, .. } => ("upload", *host, *tail), + Operation::Download { tail, host, .. } => ("download", *host, *tail), + }); + + assert_eq!(operations, result_ops); + } +} |