summaryrefslogtreecommitdiffstats
path: root/atuin-client
diff options
context:
space:
mode:
authorEllie Huxtable <ellie@elliehuxtable.com>2023-07-14 20:44:08 +0100
committerGitHub <noreply@github.com>2023-07-14 20:44:08 +0100
commit97e24d0d41bb743833e457de5ba49c5c233eb3b3 (patch)
treef0cfefd9048df83d3029cb0b0d21f1f88813fe2e /atuin-client
parent3d4302ded148c13b302fb317240342a303308c7e (diff)
Add new sync (#1093)
* Add record migration * Add database functions for inserting history No real tests yet :( I would like to avoid running postgres lol * Add index handler, use UUIDs not strings * Fix a bunch of tests, remove Option<Uuid> * Add tests, all passing * Working upload sync * Record downloading works * Sync download works * Don't waste requests * Use a page size for uploads, make it variable later * Aaaaaand they're encrypted now too * Add cek * Allow reading tail across hosts * Revert "Allow reading tail across hosts" Not like that This reverts commit 7b0c72e7e050c358172f9b53cbd21b9e44cf4931. * Handle multiple shards properly * format * Format and make clippy happy * use some fancy types (#1098) * use some fancy types * fmt * Goodbye horrible tuple * Update atuin-server-postgres/migrations/20230623070418_records.sql Co-authored-by: Conrad Ludgate <conradludgate@gmail.com> * fmt * Sort tests too because time sucks * fix features --------- Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
Diffstat (limited to 'atuin-client')
-rw-r--r--atuin-client/Cargo.toml7
-rw-r--r--atuin-client/record-migrations/20230531212437_create-records.sql3
-rw-r--r--atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql3
-rw-r--r--atuin-client/src/api_client.rs59
-rw-r--r--atuin-client/src/database.rs2
-rw-r--r--atuin-client/src/kv.rs22
-rw-r--r--atuin-client/src/record/encryption.rs65
-rw-r--r--atuin-client/src/record/mod.rs2
-rw-r--r--atuin-client/src/record/sqlite_store.rs99
-rw-r--r--atuin-client/src/record/store.rs18
-rw-r--r--atuin-client/src/record/sync.rs421
-rw-r--r--atuin-client/src/settings.rs13
12 files changed, 618 insertions, 96 deletions
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml
index fa539662..ca620677 100644
--- a/atuin-client/Cargo.toml
+++ b/atuin-client/Cargo.toml
@@ -54,10 +54,14 @@ rmp = { version = "0.8.11" }
typed-builder = "0.14.0"
tokio = { workspace = true }
semver = { workspace = true }
+futures = "0.3"
# encryption
rusty_paseto = { version = "0.5.0", default-features = false }
-rusty_paserk = { version = "0.2.0", default-features = false, features = ["v4", "serde"] }
+rusty_paserk = { version = "0.2.0", default-features = false, features = [
+ "v4",
+ "serde",
+] }
# sync
urlencoding = { version = "2.1.0", optional = true }
@@ -69,3 +73,4 @@ generic-array = { version = "0.14", optional = true, features = ["serde"] }
[dev-dependencies]
tokio = { version = "1", features = ["full"] }
+pretty_assertions = { workspace = true }
diff --git a/atuin-client/record-migrations/20230531212437_create-records.sql b/atuin-client/record-migrations/20230531212437_create-records.sql
index 46963358..4f4b304a 100644
--- a/atuin-client/record-migrations/20230531212437_create-records.sql
+++ b/atuin-client/record-migrations/20230531212437_create-records.sql
@@ -7,7 +7,8 @@ create table if not exists records (
timestamp integer not null,
tag text not null,
version text not null,
- data blob not null
+ data blob not null,
+ cek blob not null
);
create index host_idx on records (host);
diff --git a/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql b/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql
deleted file mode 100644
index 86bf6844..00000000
--- a/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql
+++ /dev/null
@@ -1,3 +0,0 @@
--- store content encryption keys in the record
-alter table records
- add column cek text;
diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs
index 350c419d..5ae1ed0a 100644
--- a/atuin-client/src/api_client.rs
+++ b/atuin-client/src/api_client.rs
@@ -8,9 +8,13 @@ use reqwest::{
StatusCode, Url,
};
-use atuin_common::api::{
- AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
- LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
+use atuin_common::record::{EncryptedData, HostId, Record, RecordId};
+use atuin_common::{
+ api::{
+ AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
+ LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
+ },
+ record::RecordIndex,
};
use semver::Version;
@@ -195,6 +199,55 @@ impl<'a> Client<'a> {
Ok(())
}
+ pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
+ let url = format!("{}/record", self.sync_addr);
+ let url = Url::parse(url.as_str())?;
+
+ self.client.post(url).json(records).send().await?;
+
+ Ok(())
+ }
+
+ pub async fn next_records(
+ &self,
+ host: HostId,
+ tag: String,
+ start: Option<RecordId>,
+ count: u64,
+ ) -> Result<Vec<Record<EncryptedData>>> {
+ let url = format!(
+ "{}/record/next?host={}&tag={}&count={}",
+ self.sync_addr, host.0, tag, count
+ );
+ let mut url = Url::parse(url.as_str())?;
+
+ if let Some(start) = start {
+ url.set_query(Some(
+ format!(
+ "host={}&tag={}&count={}&start={}",
+ host.0, tag, count, start.0
+ )
+ .as_str(),
+ ));
+ }
+
+ let resp = self.client.get(url).send().await?;
+
+ let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
+
+ Ok(records)
+ }
+
+ pub async fn record_index(&self) -> Result<RecordIndex> {
+ let url = format!("{}/record", self.sync_addr);
+ let url = Url::parse(url.as_str())?;
+
+ let resp = self.client.get(url).send().await?;
+ let index = resp.json().await?;
+
+ Ok(index)
+ }
+
pub async fn delete(&self) -> Result<()> {
let url = format!("{}/account", self.sync_addr);
let url = Url::parse(url.as_str())?;
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs
index b7b44409..218c1d6e 100644
--- a/atuin-client/src/database.rs
+++ b/atuin-client/src/database.rs
@@ -57,7 +57,7 @@ pub fn current_context() -> Context {
session,
hostname,
cwd,
- host_id,
+ host_id: host_id.0.as_simple().to_string(),
}
}
diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs
index c365a385..30018d63 100644
--- a/atuin-client/src/kv.rs
+++ b/atuin-client/src/kv.rs
@@ -101,10 +101,7 @@ impl KvStore {
let bytes = record.serialize()?;
- let parent = store
- .last(host_id.as_str(), KV_TAG)
- .await?
- .map(|entry| entry.id);
+ let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id);
let record = atuin_common::record::Record::builder()
.host(host_id)
@@ -130,17 +127,22 @@ impl KvStore {
namespace: &str,
key: &str,
) -> Result<Option<KvRecord>> {
- // TODO: don't load this from disk so much
- let host_id = Settings::host_id().expect("failed to get host_id");
-
// Currently, this is O(n). When we have an actual KV store, it can be better
// Just a poc for now!
// iterate records to find the value we want
// start at the end, so we get the most recent version
- let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else {
+ let tails = store.tag_tails(KV_TAG).await?;
+
+ if tails.is_empty() {
return Ok(None);
- };
+ }
+
+ // first, decide on a record.
+ // try getting the newest first
+ // we always need a way of deciding the "winner" of a write
+ // TODO(ellie): something better than last-write-wins, what if two write at the same time?
+ let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
loop {
let decrypted = match record.version.as_str() {
@@ -154,7 +156,7 @@ impl KvStore {
}
if let Some(parent) = decrypted.parent {
- record = store.get(parent.as_str()).await?;
+ record = store.get(parent).await?;
} else {
break;
}
diff --git a/atuin-client/src/record/encryption.rs b/atuin-client/src/record/encryption.rs
index f14bf027..6760d97b 100644
--- a/atuin-client/src/record/encryption.rs
+++ b/atuin-client/src/record/encryption.rs
@@ -1,4 +1,6 @@
-use atuin_common::record::{AdditionalData, DecryptedData, EncryptedData, Encryption};
+use atuin_common::record::{
+ AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId,
+};
use base64::{engine::general_purpose, Engine};
use eyre::{ensure, Context, Result};
use rusty_paserk::{Key, KeyId, Local, PieWrappedKey};
@@ -158,10 +160,11 @@ struct AtuinFooter {
// This cannot be changed, otherwise it breaks the authenticated encryption.
#[derive(Debug, Copy, Clone, Serialize)]
struct Assertions<'a> {
- id: &'a str,
+ id: &'a RecordId,
version: &'a str,
tag: &'a str,
- host: &'a str,
+ host: &'a HostId,
+ parent: Option<&'a RecordId>,
}
impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
@@ -171,6 +174,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
version: ad.version,
tag: ad.tag,
host: ad.host,
+ parent: ad.parent,
}
}
}
@@ -183,7 +187,7 @@ impl Assertions<'_> {
#[cfg(test)]
mod tests {
- use atuin_common::record::Record;
+ use atuin_common::{record::Record, utils::uuid_v7};
use super::*;
@@ -192,10 +196,11 @@ mod tests {
let key = Key::<V4, Local>::new_os_random();
let ad = AdditionalData {
- id: "foo",
+ id: &RecordId(uuid_v7()),
version: "v0",
tag: "kv",
- host: "1234",
+ host: &HostId(uuid_v7()),
+ parent: None,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -210,10 +215,11 @@ mod tests {
let key = Key::<V4, Local>::new_os_random();
let ad = AdditionalData {
- id: "foo",
+ id: &RecordId(uuid_v7()),
version: "v0",
tag: "kv",
- host: "1234",
+ host: &HostId(uuid_v7()),
+ parent: None,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -233,10 +239,11 @@ mod tests {
let fake_key = Key::<V4, Local>::new_os_random();
let ad = AdditionalData {
- id: "foo",
+ id: &RecordId(uuid_v7()),
version: "v0",
tag: "kv",
- host: "1234",
+ host: &HostId(uuid_v7()),
+ parent: None,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -250,10 +257,11 @@ mod tests {
let key = Key::<V4, Local>::new_os_random();
let ad = AdditionalData {
- id: "foo",
+ id: &RecordId(uuid_v7()),
version: "v0",
tag: "kv",
- host: "1234",
+ host: &HostId(uuid_v7()),
+ parent: None,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -261,10 +269,8 @@ mod tests {
let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes());
let ad = AdditionalData {
- id: "foo1",
- version: "v0",
- tag: "kv",
- host: "1234",
+ id: &RecordId(uuid_v7()),
+ ..ad
};
let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err();
}
@@ -275,10 +281,11 @@ mod tests {
let key2 = Key::<V4, Local>::new_os_random();
let ad = AdditionalData {
- id: "foo",
+ id: &RecordId(uuid_v7()),
version: "v0",
tag: "kv",
- host: "1234",
+ host: &HostId(uuid_v7()),
+ parent: None,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@@ -304,10 +311,10 @@ mod tests {
fn full_record_round_trip() {
let key = [0x55; 32];
let record = Record::builder()
- .id("1".to_owned())
+ .id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
- .host("host1".to_owned())
+ .host(HostId(uuid_v7()))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
.build();
@@ -316,30 +323,20 @@ mod tests {
assert!(!encrypted.data.data.is_empty());
assert!(!encrypted.data.content_encryption_key.is_empty());
- assert_eq!(encrypted.id, "1");
- assert_eq!(encrypted.host, "host1");
- assert_eq!(encrypted.version, "v0");
- assert_eq!(encrypted.tag, "kv");
- assert_eq!(encrypted.timestamp, 1687244806000000);
let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap();
assert_eq!(decrypted.data.0, [1, 2, 3, 4]);
- assert_eq!(decrypted.id, "1");
- assert_eq!(decrypted.host, "host1");
- assert_eq!(decrypted.version, "v0");
- assert_eq!(decrypted.tag, "kv");
- assert_eq!(decrypted.timestamp, 1687244806000000);
}
#[test]
fn full_record_round_trip_fail() {
let key = [0x55; 32];
let record = Record::builder()
- .id("1".to_owned())
+ .id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
- .host("host1".to_owned())
+ .host(HostId(uuid_v7()))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
.build();
@@ -347,13 +344,13 @@ mod tests {
let encrypted = record.encrypt::<PASETO_V4>(&key);
let mut enc1 = encrypted.clone();
- enc1.host = "host2".to_owned();
+ enc1.host = HostId(uuid_v7());
let _ = enc1
.decrypt::<PASETO_V4>(&key)
.expect_err("tampering with the host should result in auth failure");
let mut enc2 = encrypted;
- enc2.id = "2".to_owned();
+ enc2.id = RecordId(uuid_v7());
let _ = enc2
.decrypt::<PASETO_V4>(&key)
.expect_err("tampering with the id should result in auth failure");
diff --git a/atuin-client/src/record/mod.rs b/atuin-client/src/record/mod.rs
index 9ac2c541..8bc816ae 100644
--- a/atuin-client/src/record/mod.rs
+++ b/atuin-client/src/record/mod.rs
@@ -1,3 +1,5 @@
pub mod encryption;
pub mod sqlite_store;
pub mod store;
+#[cfg(feature = "sync")]
+pub mod sync;
diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs
index f692c0c2..14a7e277 100644
--- a/atuin-client/src/record/sqlite_store.rs
+++ b/atuin-client/src/record/sqlite_store.rs
@@ -8,12 +8,14 @@ use std::str::FromStr;
use async_trait::async_trait;
use eyre::{eyre, Result};
use fs_err as fs;
+use futures::TryStreamExt;
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Row,
};
-use atuin_common::record::{EncryptedData, Record};
+use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
+use uuid::Uuid;
use super::store::Store;
@@ -62,11 +64,11 @@ impl SqliteStore {
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
- .bind(r.id.as_str())
- .bind(r.host.as_str())
+ .bind(r.id.0.as_simple().to_string())
+ .bind(r.host.0.as_simple().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
- .bind(r.parent.as_ref())
+ .bind(r.parent.map(|p| p.0.as_simple().to_string()))
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
@@ -79,10 +81,18 @@ impl SqliteStore {
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
let timestamp: i64 = row.get("timestamp");
+ // tbh at this point things are pretty fucked so just panic
+ let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
+ let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
+ let parent: Option<&str> = row.get("parent");
+
+ let parent = parent
+ .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
+
Record {
- id: row.get("id"),
- host: row.get("host"),
- parent: row.get("parent"),
+ id: RecordId(id),
+ host: HostId(host),
+ parent: parent.map(RecordId),
timestamp: timestamp as u64,
tag: row.get("tag"),
version: row.get("version"),
@@ -111,9 +121,9 @@ impl Store for SqliteStore {
Ok(())
}
- async fn get(&self, id: &str) -> Result<Record<EncryptedData>> {
+ async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
let res = sqlx::query("select * from records where id = ?1")
- .bind(id)
+ .bind(id.0.as_simple().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
.await?;
@@ -121,10 +131,10 @@ impl Store for SqliteStore {
Ok(res)
}
- async fn len(&self, host: &str, tag: &str) -> Result<u64> {
+ async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
let res: (i64,) =
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
- .bind(host)
+ .bind(host.0.as_simple().to_string())
.bind(tag)
.fetch_one(&self.pool)
.await?;
@@ -134,7 +144,7 @@ impl Store for SqliteStore {
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query("select * from records where parent = ?1")
- .bind(record.id.clone())
+ .bind(record.id.0.as_simple().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
.await;
@@ -146,11 +156,11 @@ impl Store for SqliteStore {
}
}
- async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
)
- .bind(host)
+ .bind(host.0.as_simple().to_string())
.bind(tag)
.map(Self::query_row)
.fetch_optional(&self.pool)
@@ -159,23 +169,53 @@ impl Store for SqliteStore {
Ok(res)
}
- async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
)
.bind(tag)
- .bind(host)
+ .bind(host.0.as_simple().to_string())
.map(Self::query_row)
.fetch_optional(&self.pool)
.await?;
Ok(res)
}
+
+ async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
+ let res = sqlx::query(
+ "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
+ )
+ .bind(tag)
+ .map(Self::query_row)
+ .fetch_all(&self.pool)
+ .await?;
+
+ Ok(res)
+ }
+
+ async fn tail_records(&self) -> Result<RecordIndex> {
+ let res = sqlx::query(
+ "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
+ )
+ .map(|row: SqliteRow| {
+ let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
+ let tag: String= row.get("tag");
+ let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
+
+ (HostId(host), tag, RecordId(id))
+ })
+ .fetch(&self.pool)
+ .try_collect()
+ .await?;
+
+ Ok(res)
+ }
}
#[cfg(test)]
mod tests {
- use atuin_common::record::{EncryptedData, Record};
+ use atuin_common::record::{EncryptedData, HostId, Record};
use crate::record::{encryption::PASETO_V4, store::Store};
@@ -183,7 +223,7 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
- .host(atuin_common::utils::uuid_v7().simple().to_string())
+ .host(HostId(atuin_common::utils::uuid_v7()))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
@@ -218,10 +258,7 @@ mod tests {
let record = test_record();
db.push(&record).await.unwrap();
- let new_record = db
- .get(record.id.as_str())
- .await
- .expect("failed to fetch record");
+ let new_record = db.get(record.id).await.expect("failed to fetch record");
assert_eq!(record, new_record, "records are not equal");
}
@@ -233,7 +270,7 @@ mod tests {
db.push(&record).await.unwrap();
let len = db
- .len(record.host.as_str(), record.tag.as_str())
+ .len(record.host, record.tag.as_str())
.await
.expect("failed to get store len");
@@ -253,14 +290,8 @@ mod tests {
db.push(&first).await.unwrap();
db.push(&second).await.unwrap();
- let first_len = db
- .len(first.host.as_str(), first.tag.as_str())
- .await
- .unwrap();
- let second_len = db
- .len(second.host.as_str(), second.tag.as_str())
- .await
- .unwrap();
+ let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
+ let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert");
@@ -281,7 +312,7 @@ mod tests {
}
assert_eq!(
- db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host, tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
@@ -304,7 +335,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap();
assert_eq!(
- db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host, tail.tag.as_str()).await.unwrap(),
10000,
"failed to insert 10k records"
);
@@ -327,7 +358,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap();
let mut record = db
- .first(tail.host.as_str(), tail.tag.as_str())
+ .head(tail.host, tail.tag.as_str())
.await
.expect("in memory sqlite should not fail")
.expect("entry exists");
diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs
index 9ea7007a..45d554ef 100644
--- a/atuin-client/src/record/store.rs
+++ b/atuin-client/src/record/store.rs
@@ -1,7 +1,7 @@
use async_trait::async_trait;
use eyre::Result;
-use atuin_common::record::{EncryptedData, Record};
+use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
/// A record store stores records
/// In more detail - we tend to need to process this into _another_ format to actually query it.
@@ -20,14 +20,22 @@ pub trait Store {
records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
) -> Result<()>;
- async fn get(&self, id: &str) -> Result<Record<EncryptedData>>;
- async fn len(&self, host: &str, tag: &str) -> Result<u64>;
+ async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
+ async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
/// Get the record that follows this record
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
/// Get the first record for a given host and tag
- async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+
/// Get the last record for a given host and tag
- async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+ async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
+
+ // Get the last record for all hosts for a given tag, useful for the read path of apps.
+ async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
+
+ // Get the latest host/tag/record tuple for every set in the store. useful for building an
+ // index
+ async fn tail_records(&self) -> Result<RecordIndex>;
}
diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs
new file mode 100644
index 00000000..ebdb8eb2
--- /dev/null
+++ b/atuin-client/src/record/sync.rs
@@ -0,0 +1,421 @@
+// do a sync :O
+use eyre::Result;
+
+use super::store::Store;
+use crate::{api_client::Client, settings::Settings};
+
+use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
+
+#[derive(Debug, Eq, PartialEq)]
+pub enum Operation {
+ // Either upload or download until the tail matches the below
+ Upload {
+ tail: RecordId,
+ host: HostId,
+ tag: String,
+ },
+ Download {
+ tail: RecordId,
+ host: HostId,
+ tag: String,
+ },
+}
+
+pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
+ let client = Client::new(&settings.sync_address, &settings.session_token)?;
+
+ let local_index = store.tail_records().await?;
+ let remote_index = client.record_index().await?;
+
+ let diff = local_index.diff(&remote_index);
+
+ Ok((diff, remote_index))
+}
+
+// Take a diff, along with a local store, and resolve it into a set of operations.
+// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
+// In theory this could be done as a part of the diffing stage, but it's easier to reason
+// about and test this way
+pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
+ let mut operations = Vec::with_capacity(diffs.len());
+
+ for diff in diffs {
+ // First, try to fetch the tail
+ // If it exists locally, then that means we need to update the remote
+ // host until it has the same tail. Ie, upload.
+ // If it does not exist locally, that means remote is ahead of us.
+ // Therefore, we need to download until our local tail matches
+ let record = store.get(diff.tail).await;
+
+ let op = if record.is_ok() {
+ // if local has the ID, then we should find the actual tail of this
+ // store, so we know what we need to update the remote to.
+ let tail = store
+ .tail(diff.host, diff.tag.as_str())
+ .await?
+ .expect("failed to fetch last record, expected tag/host to exist");
+
+ // TODO(ellie) update the diffing so that it stores the context of the current tail
+ // that way, we can determine how much we need to upload.
+ // For now just keep uploading until tails match
+
+ Operation::Upload {
+ tail: tail.id,
+ host: diff.host,
+ tag: diff.tag,
+ }
+ } else {
+ Operation::Download {
+ tail: diff.tail,
+ host: diff.host,
+ tag: diff.tag,
+ }
+ };
+
+ operations.push(op);
+ }
+
+ // sort them - purely so we have a stable testing order, and can rely on
+ // same input = same output
+ // We can sort by ID so long as we continue to use UUIDv7 or something
+ // with the same properties
+
+ operations.sort_by_key(|op| match op {
+ Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
+ Operation::Download { tail, host, .. } => ("download", *host, *tail),
+ });
+
+ Ok(operations)
+}
+
+async fn sync_upload(
+ store: &mut impl Store,
+ remote_index: &RecordIndex,
+ client: &Client<'_>,
+ op: (HostId, String, RecordId),
+) -> Result<i64> {
+ let upload_page_size = 100;
+ let mut total = 0;
+
+ // so. we have an upload operation, with the tail representing the state
+ // we want to get the remote to
+ let current_tail = remote_index.get(op.0, op.1.clone());
+
+ println!(
+ "Syncing local {:?}/{}/{:?}, remote has {:?}",
+ op.0, op.1, op.2, current_tail
+ );
+
+ let start = if let Some(current_tail) = current_tail {
+ current_tail
+ } else {
+ store
+ .head(op.0, op.1.as_str())
+ .await
+ .expect("failed to fetch host/tag head")
+ .expect("host/tag not in current index")
+ .id
+ };
+
+ debug!("starting push to remote from: {:?}", start);
+
+ // we have the start point for sync. it is either the head of the store if
+ // the remote has no data for it, or the tail that the remote has
+ // we need to iterate from the remote tail, and keep going until
+ // remote tail = current local tail
+
+ let mut record = Some(store.get(start).await.unwrap());
+
+ let mut buf = Vec::with_capacity(upload_page_size);
+
+ while let Some(r) = record {
+ if buf.len() < upload_page_size {
+ buf.push(r.clone());
+ } else {
+ client.post_records(&buf).await?;
+
+ // can we reset what we have? len = 0 but keep capacity
+ buf = Vec::with_capacity(upload_page_size);
+ }
+ record = store.next(&r).await?;
+
+ total += 1;
+ }
+