From 97e24d0d41bb743833e457de5ba49c5c233eb3b3 Mon Sep 17 00:00:00 2001 From: Ellie Huxtable Date: Fri, 14 Jul 2023 20:44:08 +0100 Subject: Add new sync (#1093) * Add record migration * Add database functions for inserting history No real tests yet :( I would like to avoid running postgres lol * Add index handler, use UUIDs not strings * Fix a bunch of tests, remove Option * Add tests, all passing * Working upload sync * Record downloading works * Sync download works * Don't waste requests * Use a page size for uploads, make it variable later * Aaaaaand they're encrypted now too * Add cek * Allow reading tail across hosts * Revert "Allow reading tail across hosts" Not like that This reverts commit 7b0c72e7e050c358172f9b53cbd21b9e44cf4931. * Handle multiple shards properly * format * Format and make clippy happy * use some fancy types (#1098) * use some fancy types * fmt * Goodbye horrible tuple * Update atuin-server-postgres/migrations/20230623070418_records.sql Co-authored-by: Conrad Ludgate * fmt * Sort tests too because time sucks * fix features --------- Co-authored-by: Conrad Ludgate --- Cargo.lock | 30 ++ Cargo.toml | 17 +- atuin-client/Cargo.toml | 7 +- .../20230531212437_create-records.sql | 3 +- .../20230619235421_add_content_encrytion_key.sql | 3 - atuin-client/src/api_client.rs | 59 ++- atuin-client/src/database.rs | 2 +- atuin-client/src/kv.rs | 22 +- atuin-client/src/record/encryption.rs | 65 ++-- atuin-client/src/record/mod.rs | 2 + atuin-client/src/record/sqlite_store.rs | 99 +++-- atuin-client/src/record/store.rs | 18 +- atuin-client/src/record/sync.rs | 421 +++++++++++++++++++++ atuin-client/src/settings.rs | 13 +- atuin-common/Cargo.toml | 3 +- atuin-common/src/lib.rs | 52 +++ atuin-common/src/record.rs | 106 ++++-- atuin-server-database/src/lib.rs | 18 +- atuin-server-postgres/Cargo.toml | 1 + atuin-server-postgres/build.rs | 5 + .../migrations/20230623070418_records.sql | 15 + atuin-server-postgres/src/lib.rs | 102 ++++- atuin-server-postgres/src/wrappers.rs | 29 ++ atuin-server/src/handlers/mod.rs | 1 + atuin-server/src/handlers/record.rs | 104 +++++ atuin-server/src/router.rs | 3 + atuin-server/src/settings.rs | 2 + atuin/src/command/client.rs | 2 +- atuin/src/command/client/sync.rs | 33 +- 29 files changed, 1094 insertions(+), 143 deletions(-) delete mode 100644 atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql create mode 100644 atuin-client/src/record/sync.rs create mode 100644 atuin-server-postgres/build.rs create mode 100644 atuin-server-postgres/migrations/20230623070418_records.sql create mode 100644 atuin-server/src/handlers/record.rs diff --git a/Cargo.lock b/Cargo.lock index 12f699bc..a0f2f5ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,7 @@ dependencies = [ "directories", "eyre", "fs-err", + "futures", "generic-array", "hex", "interim", @@ -151,6 +152,7 @@ dependencies = [ "memchr", "minspan", "parse_duration", + "pretty_assertions", "rand 0.8.5", "regex", "reqwest", @@ -182,6 +184,7 @@ dependencies = [ "pretty_assertions", "rand 0.8.5", "serde", + "sqlx", "typed-builder", "uuid", ] @@ -240,6 +243,7 @@ dependencies = [ "serde", "sqlx", "tracing", + "uuid", ] [[package]] @@ -950,6 +954,21 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541" +[[package]] +name = "futures" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.24" @@ -988,6 +1007,12 @@ dependencies = [ "parking_lot 0.11.2", ] +[[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + [[package]] name = "futures-macro" version = "0.3.24" @@ -1017,10 +1042,13 @@ version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -2567,6 +2595,7 @@ dependencies = [ "thiserror", "tokio-stream", "url", + "uuid", "webpki-roots", "whoami", ] @@ -3037,6 +3066,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" dependencies = [ "getrandom 0.2.7", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index bde7ed67..017aa7e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,11 @@ [workspace] members = [ - "atuin", - "atuin-client", - "atuin-server", - "atuin-server-postgres", - "atuin-server-database", - "atuin-common", + "atuin", + "atuin-client", + "atuin-server", + "atuin-server-postgres", + "atuin-server-database", + "atuin-common", ] [workspace.package] @@ -35,9 +35,10 @@ semver = "1.0.14" serde = { version = "1.0.145", features = ["derive"] } serde_json = "1.0.86" tokio = { version = "1", features = ["full"] } -uuid = { version = "1.3", features = ["v4"] } +uuid = { version = "1.3", features = ["v4", "serde"] } whoami = "1.1.2" typed-builder = "0.14.0" +pretty_assertions = "1.3.0" [workspace.dependencies.reqwest] version = "0.11" @@ -46,4 +47,4 @@ default-features = false [workspace.dependencies.sqlx] version = "0.6" -features = ["runtime-tokio-rustls", "chrono", "postgres"] +features = ["runtime-tokio-rustls", "chrono", "postgres", "uuid"] diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index fa539662..ca620677 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -54,10 +54,14 @@ rmp = { version = "0.8.11" } typed-builder = "0.14.0" tokio = { workspace = true } semver = { workspace = true } +futures = "0.3" # encryption rusty_paseto = { version = "0.5.0", default-features = false } -rusty_paserk = { version = "0.2.0", default-features = false, features = ["v4", "serde"] } +rusty_paserk = { version = "0.2.0", default-features = false, features = [ + "v4", + "serde", +] } # sync urlencoding = { version = "2.1.0", optional = true } @@ -69,3 +73,4 @@ generic-array = { version = "0.14", optional = true, features = ["serde"] } [dev-dependencies] tokio = { version = "1", features = ["full"] } +pretty_assertions = { workspace = true } diff --git a/atuin-client/record-migrations/20230531212437_create-records.sql b/atuin-client/record-migrations/20230531212437_create-records.sql index 46963358..4f4b304a 100644 --- a/atuin-client/record-migrations/20230531212437_create-records.sql +++ b/atuin-client/record-migrations/20230531212437_create-records.sql @@ -7,7 +7,8 @@ create table if not exists records ( timestamp integer not null, tag text not null, version text not null, - data blob not null + data blob not null, + cek blob not null ); create index host_idx on records (host); diff --git a/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql b/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql deleted file mode 100644 index 86bf6844..00000000 --- a/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql +++ /dev/null @@ -1,3 +0,0 @@ --- store content encryption keys in the record -alter table records - add column cek text; diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs index 350c419d..5ae1ed0a 100644 --- a/atuin-client/src/api_client.rs +++ b/atuin-client/src/api_client.rs @@ -8,9 +8,13 @@ use reqwest::{ StatusCode, Url, }; -use atuin_common::api::{ - AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, - LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse, +use atuin_common::record::{EncryptedData, HostId, Record, RecordId}; +use atuin_common::{ + api::{ + AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, + LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse, + }, + record::RecordIndex, }; use semver::Version; @@ -195,6 +199,55 @@ impl<'a> Client<'a> { Ok(()) } + pub async fn post_records(&self, records: &[Record]) -> Result<()> { + let url = format!("{}/record", self.sync_addr); + let url = Url::parse(url.as_str())?; + + self.client.post(url).json(records).send().await?; + + Ok(()) + } + + pub async fn next_records( + &self, + host: HostId, + tag: String, + start: Option, + count: u64, + ) -> Result>> { + let url = format!( + "{}/record/next?host={}&tag={}&count={}", + self.sync_addr, host.0, tag, count + ); + let mut url = Url::parse(url.as_str())?; + + if let Some(start) = start { + url.set_query(Some( + format!( + "host={}&tag={}&count={}&start={}", + host.0, tag, count, start.0 + ) + .as_str(), + )); + } + + let resp = self.client.get(url).send().await?; + + let records = resp.json::>>().await?; + + Ok(records) + } + + pub async fn record_index(&self) -> Result { + let url = format!("{}/record", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let index = resp.json().await?; + + Ok(index) + } + pub async fn delete(&self) -> Result<()> { let url = format!("{}/account", self.sync_addr); let url = Url::parse(url.as_str())?; diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index b7b44409..218c1d6e 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -57,7 +57,7 @@ pub fn current_context() -> Context { session, hostname, cwd, - host_id, + host_id: host_id.0.as_simple().to_string(), } } diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs index c365a385..30018d63 100644 --- a/atuin-client/src/kv.rs +++ b/atuin-client/src/kv.rs @@ -101,10 +101,7 @@ impl KvStore { let bytes = record.serialize()?; - let parent = store - .last(host_id.as_str(), KV_TAG) - .await? - .map(|entry| entry.id); + let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id); let record = atuin_common::record::Record::builder() .host(host_id) @@ -130,17 +127,22 @@ impl KvStore { namespace: &str, key: &str, ) -> Result> { - // TODO: don't load this from disk so much - let host_id = Settings::host_id().expect("failed to get host_id"); - // Currently, this is O(n). When we have an actual KV store, it can be better // Just a poc for now! // iterate records to find the value we want // start at the end, so we get the most recent version - let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else { + let tails = store.tag_tails(KV_TAG).await?; + + if tails.is_empty() { return Ok(None); - }; + } + + // first, decide on a record. + // try getting the newest first + // we always need a way of deciding the "winner" of a write + // TODO(ellie): something better than last-write-wins, what if two write at the same time? + let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone(); loop { let decrypted = match record.version.as_str() { @@ -154,7 +156,7 @@ impl KvStore { } if let Some(parent) = decrypted.parent { - record = store.get(parent.as_str()).await?; + record = store.get(parent).await?; } else { break; } diff --git a/atuin-client/src/record/encryption.rs b/atuin-client/src/record/encryption.rs index f14bf027..6760d97b 100644 --- a/atuin-client/src/record/encryption.rs +++ b/atuin-client/src/record/encryption.rs @@ -1,4 +1,6 @@ -use atuin_common::record::{AdditionalData, DecryptedData, EncryptedData, Encryption}; +use atuin_common::record::{ + AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, +}; use base64::{engine::general_purpose, Engine}; use eyre::{ensure, Context, Result}; use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; @@ -158,10 +160,11 @@ struct AtuinFooter { // This cannot be changed, otherwise it breaks the authenticated encryption. #[derive(Debug, Copy, Clone, Serialize)] struct Assertions<'a> { - id: &'a str, + id: &'a RecordId, version: &'a str, tag: &'a str, - host: &'a str, + host: &'a HostId, + parent: Option<&'a RecordId>, } impl<'a> From> for Assertions<'a> { @@ -171,6 +174,7 @@ impl<'a> From> for Assertions<'a> { version: ad.version, tag: ad.tag, host: ad.host, + parent: ad.parent, } } } @@ -183,7 +187,7 @@ impl Assertions<'_> { #[cfg(test)] mod tests { - use atuin_common::record::Record; + use atuin_common::{record::Record, utils::uuid_v7}; use super::*; @@ -192,10 +196,11 @@ mod tests { let key = Key::::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -210,10 +215,11 @@ mod tests { let key = Key::::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -233,10 +239,11 @@ mod tests { let fake_key = Key::::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -250,10 +257,11 @@ mod tests { let key = Key::::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -261,10 +269,8 @@ mod tests { let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); let ad = AdditionalData { - id: "foo1", - version: "v0", - tag: "kv", - host: "1234", + id: &RecordId(uuid_v7()), + ..ad }; let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); } @@ -275,10 +281,11 @@ mod tests { let key2 = Key::::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -304,10 +311,10 @@ mod tests { fn full_record_round_trip() { let key = [0x55; 32]; let record = Record::builder() - .id("1".to_owned()) + .id(RecordId(uuid_v7())) .version("v0".to_owned()) .tag("kv".to_owned()) - .host("host1".to_owned()) + .host(HostId(uuid_v7())) .timestamp(1687244806000000) .data(DecryptedData(vec![1, 2, 3, 4])) .build(); @@ -316,30 +323,20 @@ mod tests { assert!(!encrypted.data.data.is_empty()); assert!(!encrypted.data.content_encryption_key.is_empty()); - assert_eq!(encrypted.id, "1"); - assert_eq!(encrypted.host, "host1"); - assert_eq!(encrypted.version, "v0"); - assert_eq!(encrypted.tag, "kv"); - assert_eq!(encrypted.timestamp, 1687244806000000); let decrypted = encrypted.decrypt::(&key).unwrap(); assert_eq!(decrypted.data.0, [1, 2, 3, 4]); - assert_eq!(decrypted.id, "1"); - assert_eq!(decrypted.host, "host1"); - assert_eq!(decrypted.version, "v0"); - assert_eq!(decrypted.tag, "kv"); - assert_eq!(decrypted.timestamp, 1687244806000000); } #[test] fn full_record_round_trip_fail() { let key = [0x55; 32]; let record = Record::builder() - .id("1".to_owned()) + .id(RecordId(uuid_v7())) .version("v0".to_owned()) .tag("kv".to_owned()) - .host("host1".to_owned()) + .host(HostId(uuid_v7())) .timestamp(1687244806000000) .data(DecryptedData(vec![1, 2, 3, 4])) .build(); @@ -347,13 +344,13 @@ mod tests { let encrypted = record.encrypt::(&key); let mut enc1 = encrypted.clone(); - enc1.host = "host2".to_owned(); + enc1.host = HostId(uuid_v7()); let _ = enc1 .decrypt::(&key) .expect_err("tampering with the host should result in auth failure"); let mut enc2 = encrypted; - enc2.id = "2".to_owned(); + enc2.id = RecordId(uuid_v7()); let _ = enc2 .decrypt::(&key) .expect_err("tampering with the id should result in auth failure"); diff --git a/atuin-client/src/record/mod.rs b/atuin-client/src/record/mod.rs index 9ac2c541..8bc816ae 100644 --- a/atuin-client/src/record/mod.rs +++ b/atuin-client/src/record/mod.rs @@ -1,3 +1,5 @@ pub mod encryption; pub mod sqlite_store; pub mod store; +#[cfg(feature = "sync")] +pub mod sync; diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs index f692c0c2..14a7e277 100644 --- a/atuin-client/src/record/sqlite_store.rs +++ b/atuin-client/src/record/sqlite_store.rs @@ -8,12 +8,14 @@ use std::str::FromStr; use async_trait::async_trait; use eyre::{eyre, Result}; use fs_err as fs; +use futures::TryStreamExt; use sqlx::{ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, Row, }; -use atuin_common::record::{EncryptedData, Record}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; +use uuid::Uuid; use super::store::Store; @@ -62,11 +64,11 @@ impl SqliteStore { "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek) values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", ) - .bind(r.id.as_str()) - .bind(r.host.as_str()) + .bind(r.id.0.as_simple().to_string()) + .bind(r.host.0.as_simple().to_string()) .bind(r.tag.as_str()) .bind(r.timestamp as i64) - .bind(r.parent.as_ref()) + .bind(r.parent.map(|p| p.0.as_simple().to_string())) .bind(r.version.as_str()) .bind(r.data.data.as_str()) .bind(r.data.content_encryption_key.as_str()) @@ -79,10 +81,18 @@ impl SqliteStore { fn query_row(row: SqliteRow) -> Record { let timestamp: i64 = row.get("timestamp"); + // tbh at this point things are pretty fucked so just panic + let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); + let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); + let parent: Option<&str> = row.get("parent"); + + let parent = parent + .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB")); + Record { - id: row.get("id"), - host: row.get("host"), - parent: row.get("parent"), + id: RecordId(id), + host: HostId(host), + parent: parent.map(RecordId), timestamp: timestamp as u64, tag: row.get("tag"), version: row.get("version"), @@ -111,9 +121,9 @@ impl Store for SqliteStore { Ok(()) } - async fn get(&self, id: &str) -> Result> { + async fn get(&self, id: RecordId) -> Result> { let res = sqlx::query("select * from records where id = ?1") - .bind(id) + .bind(id.0.as_simple().to_string()) .map(Self::query_row) .fetch_one(&self.pool) .await?; @@ -121,10 +131,10 @@ impl Store for SqliteStore { Ok(res) } - async fn len(&self, host: &str, tag: &str) -> Result { + async fn len(&self, host: HostId, tag: &str) -> Result { let res: (i64,) = sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") - .bind(host) + .bind(host.0.as_simple().to_string()) .bind(tag) .fetch_one(&self.pool) .await?; @@ -134,7 +144,7 @@ impl Store for SqliteStore { async fn next(&self, record: &Record) -> Result>> { let res = sqlx::query("select * from records where parent = ?1") - .bind(record.id.clone()) + .bind(record.id.0.as_simple().to_string()) .map(Self::query_row) .fetch_one(&self.pool) .await; @@ -146,11 +156,11 @@ impl Store for SqliteStore { } } - async fn first(&self, host: &str, tag: &str) -> Result>> { + async fn head(&self, host: HostId, tag: &str) -> Result>> { let res = sqlx::query( "select * from records where host = ?1 and tag = ?2 and parent is null limit 1", ) - .bind(host) + .bind(host.0.as_simple().to_string()) .bind(tag) .map(Self::query_row) .fetch_optional(&self.pool) @@ -159,23 +169,53 @@ impl Store for SqliteStore { Ok(res) } - async fn last(&self, host: &str, tag: &str) -> Result>> { + async fn tail(&self, host: HostId, tag: &str) -> Result>> { let res = sqlx::query( "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", ) .bind(tag) - .bind(host) + .bind(host.0.as_simple().to_string()) .map(Self::query_row) .fetch_optional(&self.pool) .await?; Ok(res) } + + async fn tag_tails(&self, tag: &str) -> Result>> { + let res = sqlx::query( + "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;", + ) + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn tail_records(&self) -> Result { + let res = sqlx::query( + "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;", + ) + .map(|row: SqliteRow| { + let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host"); + let tag: String= row.get("tag"); + let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id"); + + (HostId(host), tag, RecordId(id)) + }) + .fetch(&self.pool) + .try_collect() + .await?; + + Ok(res) + } } #[cfg(test)] mod tests { - use atuin_common::record::{EncryptedData, Record}; + use atuin_common::record::{EncryptedData, HostId, Record}; use crate::record::{encryption::PASETO_V4, store::Store}; @@ -183,7 +223,7 @@ mod tests { fn test_record() -> Record { Record::builder() - .host(atuin_common::utils::uuid_v7().simple().to_string()) + .host(HostId(atuin_common::utils::uuid_v7())) .version("v1".into()) .tag(atuin_common::utils::uuid_v7().simple().to_string()) .data(EncryptedData { @@ -218,10 +258,7 @@ mod tests { let record = test_record(); db.push(&record).await.unwrap(); - let new_record = db - .get(record.id.as_str()) - .await - .expect("failed to fetch record"); + let new_record = db.get(record.id).await.expect("failed to fetch record"); assert_eq!(record, new_record, "records are not equal"); } @@ -233,7 +270,7 @@ mod tests { db.push(&record).await.unwrap(); let len = db - .len(record.host.as_str(), record.tag.as_str()) + .len(record.host, record.tag.as_str()) .await .expect("failed to get store len"); @@ -253,14 +290,8 @@ mod tests { db.push(&first).await.unwrap(); db.push(&second).await.unwrap(); - let first_len = db - .len(first.host.as_str(), first.tag.as_str()) - .await - .unwrap(); - let second_len = db - .len(second.host.as_str(), second.tag.as_str()) - .await - .unwrap(); + let first_len = db.len(first.host, first.tag.as_str()).await.unwrap(); + let second_len = db.len(second.host, second.tag.as_str()).await.unwrap(); assert_eq!(first_len, 1, "expected length of 1 after insert"); assert_eq!(second_len, 1, "expected length of 1 after insert"); @@ -281,7 +312,7 @@ mod tests { } assert_eq!( - db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + db.len(tail.host, tail.tag.as_str()).await.unwrap(), 100, "failed to insert 100 records" ); @@ -304,7 +335,7 @@ mod tests { db.push_batch(records.iter()).await.unwrap(); assert_eq!( - db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + db.len(tail.host, tail.tag.as_str()).await.unwrap(), 10000, "failed to insert 10k records" ); @@ -327,7 +358,7 @@ mod tests { db.push_batch(records.iter()).await.unwrap(); let mut record = db - .first(tail.host.as_str(), tail.tag.as_str()) + .head(tail.host, tail.tag.as_str()) .await .expect("in memory sqlite should not fail") .expect("entry exists"); diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs index 9ea7007a..45d554ef 100644 --- a/atuin-client/src/record/store.rs +++ b/atuin-client/src/record/store.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use eyre::Result; -use atuin_common::record::{EncryptedData, Record}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; /// A record store stores records /// In more detail - we tend to need to process this into _another_ format to actually query it. @@ -20,14 +20,22 @@ pub trait Store { records: impl Iterator> + Send + Sync, ) -> Result<()>; - async fn get(&self, id: &str) -> Result>; - async fn len(&self, host: &str, tag: &str) -> Result; + async fn get(&self, id: RecordId) -> Result>; + async fn len(&self, host: HostId, tag: &str) -> Result; /// Get the record that follows this record async fn next(&self, record: &Record) -> Result>>; /// Get the first record for a given host and tag - async fn first(&self, host: &str, tag: &str) -> Result>>; + async fn head(&self, host: HostId, tag: &str) -> Result>>; + /// Get the last record for a given host and tag - async fn last(&self, host: &str, tag: &str) -> Result>>; + async fn tail(&self, host: HostId, tag: &str) -> Result>>; + + // Get the last record for all hosts for a given tag, useful for the read path of apps. + async fn tag_tails(&self, tag: &str) -> Result>>; + + // Get the latest host/tag/record tuple for every set in the store. useful for building an + // index + async fn tail_records(&self) -> Result; } diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs new file mode 100644 index 00000000..ebdb8eb2 --- /dev/null +++ b/atuin-client/src/record/sync.rs @@ -0,0 +1,421 @@ +// do a sync :O +use eyre::Result; + +use super::store::Store; +use crate::{api_client::Client, settings::Settings}; + +use atuin_common::record::{Diff, HostId, RecordId, RecordIndex}; + +#[derive(Debug, Eq, PartialEq)] +pub enum Operation { + // Either upload or download until the tail matches the below + Upload { + tail: RecordId, + host: HostId, + tag: String, + }, + Download { + tail: RecordId, + host: HostId, + tag: String, + }, +} + +pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec, RecordIndex)> { + let client = Client::new(&settings.sync_address, &settings.session_token)?; + + let local_index = store.tail_records().await?; + let remote_index = client.record_index().await?; + + let diff = local_index.diff(&remote_index); + + Ok((diff, remote_index)) +} + +// Take a diff, along with a local store, and resolve it into a set of operations. +// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. +// In theory this could be done as a part of the diffing stage, but it's easier to reason +// about and test this way +pub async fn operations(diffs: Vec, store: &impl Store) -> Result> { + let mut operations = Vec::with_capacity(diffs.len()); + + for diff in diffs { + // First, try to fetch the tail + // If it exists locally, then that means we need to update the remote + // host until it has the same tail. Ie, upload. + // If it does not exist locally, that means remote is ahead of us. + // Therefore, we need to download until our local tail matches + let record = store.get(diff.tail).await; + + let op = if record.is_ok() { + // if local has the ID, then we should find the actual tail of this + // store, so we know what we need to update the remote to. + let tail = store + .tail(diff.host, diff.tag.as_str()) + .await? + .expect("failed to fetch last record, expected tag/host to exist"); + + // TODO(ellie) update the diffing so that it stores the context of the current tail + // that way, we can determine how much we need to upload. + // For now just keep uploading until tails match + + Operation::Upload { + tail: tail.id, + host: diff.host, + tag: diff.tag, + } + } else { + Operation::Download { + tail: diff.tail, + host: diff.host, + tag: diff.tag, + } + }; + + operations.push(op); + } + + // sort them - purely so we have a stable testing order, and can rely on + // same input = same output + // We can sort by ID so long as we continue to use UUIDv7 or something + // with the same properties + + operations.sort_by_key(|op| match op { + Operation::Upload { tail, host, .. } => ("upload", *host, *tail), + Operation::Download { tail, host, .. } => ("download", *host, *tail), + }); + + Ok(operations) +} + +async fn sync_upload( + store: &mut impl Store, + remote_index: &RecordIndex, + client: &Client<'_>, + op: (HostId, String, RecordId), +) -> Result { + let upload_page_size = 100; + let mut total = 0; + + // so. we have an upload operation, with the tail representing the state + // we want to get the remote to + let current_tail = remote_index.get(op.0, op.1.clone()); + + println!( + "Syncing local {:?}/{}/{:?}, remote has {:?}", + op.0, op.1, op.2, current_tail + ); + + let start = if let Some(current_tail) = current_tail { + current_tail + } else { + store + .head(op.0, op.1.as_str()) + .await + .expect("failed to fetch host/tag head") + .expect("host/tag not in current index") + .id + }; + + debug!("starting push to remote from: {:?}", start); + + // we have the start point for sync. it is either the head of the store if + // the remote has no data for it, or the tail that the remote has + // we need to iterate from the remote tail, and keep going until + // remote tail = current local tail + + let mut record = Some(store.get(start).await.unwrap()); + + let mut buf = Vec::with_capacity(upload_page_size); + + while let Some(r) = record { + if buf.len() < upload_page_size { + buf.push(r.clone()); + } else { + client.post_records(&buf).await?; + + // can we reset what we have? len = 0 but keep capacity + buf = Vec::with_capacity(upload_page_size); + } + record = store.next(&r).await?; + + total += 1; + } + + if !buf.is_empty() { + client.post_records(&buf).await?; + } + + Ok(total) +} + +async fn sync_download( + store: &mut impl Store, + remote_index: &RecordIndex, + client: &Client<'_>, + op: (HostId, String, RecordId), +) -> Result { + // TODO(ellie): implement variable page sizing like on history sync + let download_page_size = 1000; + + let mut total = 0; + + // We know that the remote is ahead of us, so let's keep downloading until both + // 1) The remote stops returning full pages + // 2) The tail equals what we expect + // + // If (1) occurs without (2), then something is wrong with our index calculation + // and we should bail. + let remote_tail = remote_index + .get(op.0, op.1.clone()) + .expect("remote index does not contain expected tail during download"); + let local_tail = store.tail(op.0, op.1.as_str()).await?; + // + // We expect that the operations diff will represent the desired state + // In this case, that contains the remote tail. + assert_eq!(remote_tail, op.2); + + println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2); + + let mut records = client + .next_records( + op.0, + op.1.clone(), + local_tail.map(|r| r.id), + download_page_size, + ) + .await?; + + while !records.is_empty() { + total += std::cmp::min(download_page_size, records.len() as u64); + store.push_batch(records.iter()).await?; + + if records.last().unwrap().id == remote_tail { + break; + } + + records = client + .next_records( + op.0, + op.1.clone(), + records.last().map(|r| r.id), + download_page_size, + ) + .await?; + } + + Ok(total as i64) +} + +pub async fn sync_remote( + operations: Vec, + remote_index: &RecordIndex, + local_store: &mut impl Store, + settings: &Settings, +) -> Result<(i64, i64)> { + let client = Client::new(&settings.sync_address, &settings.session_token)?; + + let mut uploaded = 0; + let mut downloaded = 0; + + // this can totally run in parallel, but lets get it working first + for i in operations { + match i { + Operation::Upload { tail, host, tag } => { + uploaded += + sync_upload(local_store, remote_index, &client, (host, tag, tail)).await? + } + Operation::Download { tail, host, tag } => { + downloaded += + sync_download(local_store, remote_index, &client, (host, tag, tail)).await? + } + } + } + + Ok((uploaded, downloaded)) +} + +#[cfg(test)] +mod tests { + use atuin_common::record::{Diff, EncryptedData, HostId, Record}; + use pretty_assertions::assert_eq; + + use crate::record::{ + encryption::PASETO_V4, + sqlite_store::SqliteStore, + store::Store, + sync::{self, Operation}, + }; + + fn test_record() -> Record { + Record::builder() + .host(HostId(atuin_common::utils::uuid_v7())) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: String::new(), + content_encryption_key: String::new(), + }) + .build() + } + + // Take a list of local records, and a list of remote records. + // Return the local database, and a diff of local/remote, ready to build + // ops + async fn build_test_diff( + local_records: Vec>, + remote_records: Vec>, + ) -> (SqliteStore, Vec) { + let local_store = SqliteStore::new(":memory:") + .await + .expect("failed to open in memory sqlite"); + let remote_store = SqliteStore::new(":memory:") + .await + .expect("failed to open in memory sqlite"); // "remote" + + for i in local_records { + local_store.push(&i).await.unwrap(); + } + + for i in remote_records { + remote_store.push(&i).await.unwrap(); + } + + let local_index = local_store.tail_records().await.unwrap(); + let remote_index = remote_store.tail_records().await.unwrap(); + + let diff = local_index.diff(&remote_index); + + (local_store, diff) + } + + #[tokio::test] + async fn test_basic_diff() { + // a diff where local is ahead of remote. nothing else. + + let record = test_record(); + let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; + + assert_eq!(diff.len(), 1); + + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 1); + + assert_eq!( + operations[0], + Operation::Upload { + host: record.host, + tag: record.tag, + tail: record.id + } + ); + } + + #[tokio::test] + async fn build_two_way_diff() { + // a diff where local is ahead of remote for one, and remote for + // another. One upload, one download + + let shared_record = test_record(); + + let remote_ahead = test_record(); + let local_ahead = shared_record + .new_child(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + + let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store + let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 2); + + assert_eq!( + operations, + vec![ + Operation::Download { + tail: remote_ahead.id, + host: remote_ahead.host, + tag: remote_ahead.tag, + }, + Operation::Upload { + tail: local_ahead.id, + host: local_ahead.host, + tag: local_ahead.tag, + }, + ] + ); + } + + #[tokio::test] + async fn build_complex_diff() { + // One shared, ahead but known only by remote + // One known only by local + // One known only by remote + + let shared_record = test_record(); + + let remote_known = test_record(); + let local_known = test_record(); + + let second_shared = test_record(); + let second_shared_remote_ahead = second_shared + .new_child(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + + let local_ahead = shared_record + .new_child(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + + let local = vec![ + shared_record.clone(), + second_shared.clone(), + local_known.clone(), + local_ahead.clone(), + ]; + + let remote = vec![ + shared_record.clone(), + second_shared.clone(), + second_shared_remote_ahead.clone(), + remote_known.clone(), + ]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 4); + + let mut result_ops = vec![ + Operation::Download { + tail: remote_known.id, + host: remote_known.host, + tag: remote_known.tag, + }, + Operation::Download { + tail: second_shared_remote_ahead.id, + host: second_shared.host, + tag: second_shared.tag, + }, + Operation::Upload { + tail: local_ahead.id, + host: local_ahead.host, + tag: local_ahead.tag, + }, + Operation::Upload { + tail: local_known.id, + host: local_known.host, + tag: local_known.tag, + }, + ]; + + result_ops.sort_by_key(|op| match op { + Operation::Upload { tail, host, .. } => ("upload", *host, *tail), + Operation::Download { tail, host, .. } => ("download", *host, *tail), + }); + + assert_eq!(operations, result_ops); + } +} diff --git a/atuin-client/src/settings.rs b/atuin-client/src/settings.rs index dd072451..bb41a890 100644 --- a/atuin-client/src/settings.rs +++ b/atuin-client/src/settings.rs @@ -1,8 +1,10 @@ use std::{ io::prelude::*, path::{Path, PathBuf}, + str::FromStr, }; +use atuin_common::record::HostId; use chrono::{prelude::*, Utc}; use clap::ValueEnum; use config::{Config, Environment, File as ConfigFile, FileFormat}; @@ -12,6 +14,7 @@ use parse_duration::parse; use regex::RegexSet; use semver::Version; use serde::Deserialize; +use uuid::Uuid; pub const HISTORY_PAGE_SIZE: i64 = 100; pub const LAST_SYNC_FILENAME: &str = "last_sync_time"; @@ -228,11 +231,13 @@ impl Settings { Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME) } - pub fn host_id() -> Option { + pub fn host_id() -> Option { let id = Settings::read_from_data_dir(HOST_ID_FILENAME); - if id.is_some() { - return id; + if let Some(id) = id { + let parsed = + Uuid::from_str(id.as_str()).expect("failed to parse host ID from local directory"); + return Some(HostId(parsed)); } let uuid = atuin_common::utils::uuid_v7(); @@ -240,7 +245,7 @@ impl Settings { Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref()) .expect("Could not write host ID to data dir"); - Some(uuid.as_simple().to_string()) + Some(HostId(uuid)) } pub fn should_sync(&self) -> Result { diff --git a/atuin-common/Cargo.toml b/atuin-common/Cargo.toml index ead3df84..a610584d 100644 --- a/atuin-common/Cargo.toml +++ b/atuin-common/Cargo.toml @@ -18,6 +18,7 @@ uuid = { workspace = true } rand = { workspace = true } typed-builder = { workspace = true } eyre = { workspace = true } +sqlx = { workspace = true } [dev-dependencies] -pretty_assertions = "1.3.0" +pretty_assertions = { workspace = true } diff --git a/atuin-common/src/lib.rs b/atuin-common/src/lib.rs index b332e234..d4513ee0 100644 --- a/atuin-common/src/lib.rs +++ b/atuin-common/src/lib.rs @@ -1,5 +1,57 @@ #![forbid(unsafe_code)] +/// Defines a new UUID type wrapper +macro_rules! new_uuid { + ($name:ident) => { + #[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + serde::Serialize, + serde::Deserialize, + )] + #[serde(transparent)] + pub struct $name(pub Uuid); + + impl sqlx::Type for $name + where + Uuid: sqlx::Type, + { + fn type_info() -> ::TypeInfo { + Uuid::type_info() + } + } + + impl<'r, DB: sqlx::Database> sqlx::Decode<'r, DB> for $name + where + Uuid: sqlx::Decode<'r, DB>, + { + fn decode( + value: >::ValueRef, + ) -> std::result::Result { + Uuid::decode(value).map(Self) + } + } + + impl<'q, DB: sqlx::Database> sqlx::Encode<'q, DB> for $name + where + Uuid: sqlx::Encode<'q, DB>, + { + fn encode_by_ref( + &self, + buf: &mut >::ArgumentBuffer, + ) -> sqlx::encode::IsNull { + self.0.encode_by_ref(buf) + } + } + }; +} + pub mod api; pub mod record; pub mod utils; diff --git a/atuin-common/src/record.rs b/atuin-common/src/record.rs index b46647c3..b00c03c4 100644 --- a/atuin-common/src/record.rs +++ b/atuin-common/src/record.rs @@ -3,35 +3,43 @@ use std::collections::HashMap; use eyre::Result; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; +use uuid::Uuid; #[derive(Clone, Debug, PartialEq)] pub struct DecryptedData(pub Vec); -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct EncryptedData { pub data: String, pub content_encryption_key: String, } +#[derive(Debug, PartialEq)] +pub struct Diff { + pub host: HostId, + pub tag: String, + pub tail: RecordId, +} + /// A single record stored inside of our local database #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] pub struct Record { /// a unique ID - #[builder(default = crate::utils::uuid_v7().as_simple().to_string())] - pub id: String, + #[builder(default = RecordId(crate::utils::uuid_v7()))] + pub id: RecordId, /// The unique ID of the host. // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store // as strings. I would rather avoid normalization, so store as UUID binary instead of // encoding to a string and wasting much more storage. - pub host: String, + pub host: HostId, /// The ID of the parent entry // A store is technically just a double linked list // We can do some cheating with the timestamps, but should not rely upon them. // Clocks are tricksy. #[builder(default)] - pub parent: Option, + pub parent: Option, /// The creation time in nanoseconds since unix epoch #[builder(default = chrono::Utc::now().timestamp_nanos() as u64)] @@ -48,21 +56,25 @@ pub struct Record { pub data: Data, } +new_uuid!(RecordId); +new_uuid!(HostId); + /// Extra data from the record that should be encoded in the data #[derive(Debug, Copy, Clone)] pub struct AdditionalData<'a> { - pub id: &'a str, + pub id: &'a RecordId, pub version: &'a str, pub tag: &'a str, - pub host: &'a str, + pub host: &'a HostId, + pub parent: Option<&'a RecordId>, } impl Record { pub fn new_child(&self, data: Vec) -> Record { Record::builder() - .host(self.host.clone()) + .host(self.host) .version(self.version.clone()) - .parent(Some(self.id.clone())) + .parent(Some(self.id)) .tag(self.tag.clone()) .data(DecryptedData(data)) .build() @@ -71,9 +83,10 @@ impl Record { /// An index representing the current state of the record stores /// This can be both remote, or local, and compared in either direction +#[derive(Debug, Serialize, Deserialize)] pub struct RecordIndex { // A map of host -> tag -> tail - pub hosts: HashMap>, + pub hosts: HashMap>, } impl Default for RecordIndex { @@ -82,6 +95,14 @@ impl Default for RecordIndex { } } +impl Extend<(HostId, String, RecordId)> for RecordIndex { + fn extend>(&mut self, iter: T) { + for (host, tag, tail_id) in iter { + self.set_raw(host, tag, tail_id); + } + } +} + impl RecordIndex { pub fn new() -> RecordIndex { RecordIndex { @@ -91,13 +112,14 @@ impl RecordIndex { /// Insert a new tail record into the store pub fn set(&mut self, tail: Record) { - self.hosts - .entry(tail.host) - .or_default() - .insert(tail.tag, tail.id); + self.set_raw(tail.host, tail.tag, tail.id) } - pub fn get(&self, host: String, tag: String) -> Option { + pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordId) { + self.hosts.entry(host).or_default().insert(tag, tail_id); + } + + pub fn get(&self, host: HostId, tag: String) -> Option { self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned() } @@ -108,21 +130,29 @@ impl RecordIndex { /// other machine has a different tail, it will be the differing tail. This is useful to /// check if the other index is ahead of us, or behind. /// If the other index does not have the (host, tag) pair, then the other value will be None. - pub fn diff(&self, other: &Self) -> Vec<(String, String, Option)> { + pub fn diff(&self, other: &Self) -> Vec { let mut ret = Vec::new(); // First, we check if other has everything that self has for (host, tag_map) in self.hosts.iter() { for (tag, tail) in tag_map.iter() { - match other.get(host.clone(), tag.clone()) { + match other.get(*host, tag.clone()) { // The other store is all up to date! No diff. Some(t) if t.eq(tail) => continue, // The other store does exist, but it is either ahead or behind us. A diff regardless - Some(t) => ret.push((host.clone(), tag.clone(), Some(t))), + Some(t) => ret.push(Diff { + host: *host, + tag: tag.clone(), + tail: t, + }), // The other store does not exist :O - None => ret.push((host.clone(), tag.clone(), None)), + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + tail: *tail, + }), }; } } @@ -133,16 +163,20 @@ impl RecordIndex { // account for that! for (host, tag_map) in other.hosts.iter() { for (tag, tail) in tag_map.iter() { - match self.get(host.clone(), tag.clone()) { + match self.get(*host, tag.clone()) { // If we have this host/tag combo, the comparison and diff will have already happened above Some(_) => continue, - None => ret.push((host.clone(), tag.clone(), Some(tail.clone()))), + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + tail: *tail, + }), }; } } - ret.sort(); + ret.sort_by(|a, b| (a.host, a.tag.clone(), a.tail).cmp(&(b.host, b.tag.clone(), b.tail))); ret } } @@ -168,6 +202,7 @@ impl Record { version: &self.version, tag: &self.tag, host: &self.host, + parent: self.parent.as_ref(), }; Record { data: E::encrypt(self.data, ad, key), @@ -188,6 +223,7 @@ impl Record { version: &self.version, tag: &self.tag, host: &self.host, + parent: self.parent.as_ref(), }; Ok(Record { data: E::decrypt(self.data, ad, key)?, @@ -210,6 +246,7 @@ impl Record { version: &self.version, tag: &self.tag, host: &self.host, + parent: self.parent.as_ref(), }; Ok(Record { data: E::re_encrypt(self.data, ad, old_key, new_key)?, @@ -225,12 +262,14 @@ impl Record { #[cfg(test)] mod tests { - use super::{DecryptedData, Record, RecordIndex}; + use crate::record::HostId; + + use super::{DecryptedData, Diff, Record, RecordIndex}; use pretty_assertions::assert_eq; fn test_record() -> Record { Record::builder() - .host(crate::utils::uuid_v7().simple().to_string()) + .host(HostId(crate::utils::uuid_v7())) .version("v1".into()) .tag(crate::utils::uuid_v7().simple().to_string()) .data(DecryptedData(vec![0, 1, 2, 3])) @@ -304,7 +343,14 @@ mod tests { let diff = index1.diff(&index2); assert_eq!(1, diff.len(), "expected single diff"); - assert_eq!(diff[0], (record2.host, record2.tag, Some(record2.id))); + assert_eq!( + diff[0], + Diff { + host: record2.host, + tag: record2.tag, + tail: record2.id + } + ); } #[test] @@ -342,12 +388,14 @@ mod tests { assert_eq!(4, diff1.len()); assert_eq!(4, diff2.len()); + dbg!(&diff1, &diff2); + // both diffs should be ALMOST the same. They will agree on which hosts and tags // require updating, but the "other" value will not be the same. - let smol_diff_1: Vec<(String, String)> = - diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect(); - let smol_diff_2: Vec<(String, String)> = - diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect(); + let smol_diff_1: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + let smol_diff_2: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); assert_eq!(smol_diff_1, smol_diff_2); diff --git a/atuin-server-database/src/lib.rs b/atuin-server-database/src/lib.rs index de33ba44..cdff90a2 100644 --- a/atuin-server-database/src/lib.rs +++ b/atuin-server-database/src/lib.rs @@ -13,7 +13,10 @@ use self::{ models::{History, NewHistory, NewSession, NewUser, Session, User}, }; use async_trait::async_trait; -use atuin_common::utils::get_days_from_month; +use atuin_common::{ + record::{EncryptedData, HostId, Record, RecordId, RecordIndex}, + utils::get_days_from_month, +}; use chrono::{Datelike, TimeZone}; use chronoutil::RelativeDuration; use serde::{de::DeserializeOwned, Serialize}; @@ -55,6 +58,19 @@ pub trait Database: Sized + Clone + Send + Sync + 'static { async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; async fn deleted_history(&self, user: &User) -> DbResult>; + async fn add_records(&self, user: &User, record: &[Record]) -> DbResult<()>; + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option, + count: u64, + ) -> DbResult>>; + + // Return the tail record ID for each store, so (HostID, Tag, TailRecordID) + async fn tail_records(&self, user: &User) -> DbResult; + async fn count_history_range( &self, user: &User, diff --git a/atuin-server-postgres/Cargo.toml b/atuin-server-postgres/Cargo.toml index 18864f6c..bfec70a2 100644 --- a/atuin-server-postgres/Cargo.toml +++ b/atuin-server-postgres/Cargo.toml @@ -18,4 +18,5 @@ chrono = { workspace = true } serde = { workspace = true } sqlx = { workspace = true } async-trait = { workspace = true } +uuid = { workspace = true } futures-util = "0.3" diff --git a/atuin-server-postgres/build.rs b/atuin-server-postgres/build.rs new file mode 100644 index 00000000..d5068697 --- /dev/null +++ b/atuin-server-postgres/build.rs @@ -0,0 +1,5 @@ +// generated by `sqlx migrate build-script` +fn main() { + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed=migrations"); +} diff --git a/atuin-server-postgres/migrations/20230623070418_records.sql b/atuin-server-postgres/migrations/20230623070418_records.sql new file mode 100644 index 00000000..22437595 --- /dev/null +++ b/atuin-server-postgres/migrations/20230623070418_records.sql @@ -0,0 +1,15 @@ +-- Add migration script here +create table records ( + id uuid primary key, -- remember to use uuidv7 for happy indices <3 + client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key + host uuid not null, -- a unique identifier for the host + parent uuid default null, -- the ID of the parent record, bearing in mind this is a linked list + timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision + version text not null, + tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host + data text not null, -- store the actual history data, encrypted. I don't wanna know! + cek text not null, + + user_id bigint not null, -- allow multiple users + created_at timestamp not null default current_timestamp +); diff --git a/atuin-server-postgres/src/lib.rs b/atuin-server-postgres/src/lib.rs index 0dc51daf..404188b0 100644 --- a/atuin-server-postgres/src/lib.rs +++ b/atuin-server-postgres/src/lib.rs @@ -1,14 +1,14 @@ use async_trait::async_trait; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::{Database, DbError, DbResult}; use futures_util::TryStreamExt; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgPoolOptions; - use sqlx::Row; use tracing::instrument; -use wrappers::{DbHistory, DbSession, DbUser}; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; mod wrappers; @@ -329,4 +329,102 @@ impl Database for Postgres { .map_err(fix_error) .map(|DbHistory(h)| h) } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in records { + let id = atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into records + (id, client_id, host, parent, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host) + .bind(i.parent) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option, + count: u64, + ) -> DbResult>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let mut ret = Vec::with_capacity(count as usize); + let mut parent = start; + + // yeah let's do something better + for _ in 0..count { + // a very much not ideal query. but it's simple at least? + // we are basically using postgres as a kv store here, so... maybe consider using an actual + // kv store? + let record: Result = sqlx::query_as( + "select client_id, host, parent, timestamp, version, tag, data, cek from records + where user_id = $1 + and tag = $2 + and host = $3 + and parent is not distinct from $4", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(parent) + .fetch_one(&self.pool) + .await + .map_err(fix_error); + + match record { + Ok(record) => { + let record: Record = record.into(); + ret.push(record.clone()); + + parent = Some(record.id); + } + Err(DbError::NotFound) => { + tracing::debug!("hit tail of store: {:?}/{}", host, tag); + return Ok(ret); + } + Err(e) => return Err(e), + } + } + + Ok(ret) + } + + async fn tail_records(&self, user: &User) -> DbResult { + const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0;"; + + let res = sqlx::query_as(TAIL_RECORDS_SQL) + .bind(user.id) + .fetch(&self.pool) + .try_collect() + .await + .map_err(fix_error)?; + + Ok(res) + } } diff --git a/atuin-server-postgres/src/wrappers.rs b/atuin-server-postgres/src/wrappers.rs index cb3d5a96..8bd482b1 100644 --- a/atuin-server-postgres/src/wra