diff options
author | Ellie Huxtable <ellie@elliehuxtable.com> | 2023-07-14 20:44:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-14 20:44:08 +0100 |
commit | 97e24d0d41bb743833e457de5ba49c5c233eb3b3 (patch) | |
tree | f0cfefd9048df83d3029cb0b0d21f1f88813fe2e | |
parent | 3d4302ded148c13b302fb317240342a303308c7e (diff) |
Add new sync (#1093)
* Add record migration
* Add database functions for inserting history
No real tests yet :( I would like to avoid running postgres lol
* Add index handler, use UUIDs not strings
* Fix a bunch of tests, remove Option<Uuid>
* Add tests, all passing
* Working upload sync
* Record downloading works
* Sync download works
* Don't waste requests
* Use a page size for uploads, make it variable later
* Aaaaaand they're encrypted now too
* Add cek
* Allow reading tail across hosts
* Revert "Allow reading tail across hosts"
Not like that
This reverts commit 7b0c72e7e050c358172f9b53cbd21b9e44cf4931.
* Handle multiple shards properly
* format
* Format and make clippy happy
* use some fancy types (#1098)
* use some fancy types
* fmt
* Goodbye horrible tuple
* Update atuin-server-postgres/migrations/20230623070418_records.sql
Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
* fmt
* Sort tests too because time sucks
* fix features
---------
Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
29 files changed, 1094 insertions, 143 deletions
@@ -142,6 +142,7 @@ dependencies = [ "directories", "eyre", "fs-err", + "futures", "generic-array", "hex", "interim", @@ -151,6 +152,7 @@ dependencies = [ "memchr", "minspan", "parse_duration", + "pretty_assertions", "rand 0.8.5", "regex", "reqwest", @@ -182,6 +184,7 @@ dependencies = [ "pretty_assertions", "rand 0.8.5", "serde", + "sqlx", "typed-builder", "uuid", ] @@ -240,6 +243,7 @@ dependencies = [ "serde", "sqlx", "tracing", + "uuid", ] [[package]] @@ -951,6 +955,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541" [[package]] +name = "futures" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] name = "futures-channel" version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -989,6 +1008,12 @@ dependencies = [ ] [[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + +[[package]] name = "futures-macro" version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1017,10 +1042,13 @@ version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -2567,6 +2595,7 @@ dependencies = [ "thiserror", "tokio-stream", "url", + "uuid", "webpki-roots", "whoami", ] @@ -3037,6 +3066,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" dependencies = [ "getrandom 0.2.7", + "serde", ] [[package]] @@ -1,11 +1,11 @@ [workspace] members = [ - "atuin", - "atuin-client", - "atuin-server", - "atuin-server-postgres", - "atuin-server-database", - "atuin-common", + "atuin", + "atuin-client", + "atuin-server", + "atuin-server-postgres", + "atuin-server-database", + "atuin-common", ] [workspace.package] @@ -35,9 +35,10 @@ semver = "1.0.14" serde = { version = "1.0.145", features = ["derive"] } serde_json = "1.0.86" tokio = { version = "1", features = ["full"] } -uuid = { version = "1.3", features = ["v4"] } +uuid = { version = "1.3", features = ["v4", "serde"] } whoami = "1.1.2" typed-builder = "0.14.0" +pretty_assertions = "1.3.0" [workspace.dependencies.reqwest] version = "0.11" @@ -46,4 +47,4 @@ default-features = false [workspace.dependencies.sqlx] version = "0.6" -features = ["runtime-tokio-rustls", "chrono", "postgres"] +features = ["runtime-tokio-rustls", "chrono", "postgres", "uuid"] diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index fa539662..ca620677 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -54,10 +54,14 @@ rmp = { version = "0.8.11" } typed-builder = "0.14.0" tokio = { workspace = true } semver = { workspace = true } +futures = "0.3" # encryption rusty_paseto = { version = "0.5.0", default-features = false } -rusty_paserk = { version = "0.2.0", default-features = false, features = ["v4", "serde"] } +rusty_paserk = { version = "0.2.0", default-features = false, features = [ + "v4", + "serde", +] } # sync urlencoding = { version = "2.1.0", optional = true } @@ -69,3 +73,4 @@ generic-array = { version = "0.14", optional = true, features = ["serde"] } [dev-dependencies] tokio = { version = "1", features = ["full"] } +pretty_assertions = { workspace = true } diff --git a/atuin-client/record-migrations/20230531212437_create-records.sql b/atuin-client/record-migrations/20230531212437_create-records.sql index 46963358..4f4b304a 100644 --- a/atuin-client/record-migrations/20230531212437_create-records.sql +++ b/atuin-client/record-migrations/20230531212437_create-records.sql @@ -7,7 +7,8 @@ create table if not exists records ( timestamp integer not null, tag text not null, version text not null, - data blob not null + data blob not null, + cek blob not null ); create index host_idx on records (host); diff --git a/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql b/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql deleted file mode 100644 index 86bf6844..00000000 --- a/atuin-client/record-migrations/20230619235421_add_content_encrytion_key.sql +++ /dev/null @@ -1,3 +0,0 @@ --- store content encryption keys in the record -alter table records - add column cek text; diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs index 350c419d..5ae1ed0a 100644 --- a/atuin-client/src/api_client.rs +++ b/atuin-client/src/api_client.rs @@ -8,9 +8,13 @@ use reqwest::{ StatusCode, Url, }; -use atuin_common::api::{ - AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, - LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse, +use atuin_common::record::{EncryptedData, HostId, Record, RecordId}; +use atuin_common::{ + api::{ + AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, + LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse, + }, + record::RecordIndex, }; use semver::Version; @@ -195,6 +199,55 @@ impl<'a> Client<'a> { Ok(()) } + pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> { + let url = format!("{}/record", self.sync_addr); + let url = Url::parse(url.as_str())?; + + self.client.post(url).json(records).send().await?; + + Ok(()) + } + + pub async fn next_records( + &self, + host: HostId, + tag: String, + start: Option<RecordId>, + count: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + let url = format!( + "{}/record/next?host={}&tag={}&count={}", + self.sync_addr, host.0, tag, count + ); + let mut url = Url::parse(url.as_str())?; + + if let Some(start) = start { + url.set_query(Some( + format!( + "host={}&tag={}&count={}&start={}", + host.0, tag, count, start.0 + ) + .as_str(), + )); + } + + let resp = self.client.get(url).send().await?; + + let records = resp.json::<Vec<Record<EncryptedData>>>().await?; + + Ok(records) + } + + pub async fn record_index(&self) -> Result<RecordIndex> { + let url = format!("{}/record", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let index = resp.json().await?; + + Ok(index) + } + pub async fn delete(&self) -> Result<()> { let url = format!("{}/account", self.sync_addr); let url = Url::parse(url.as_str())?; diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index b7b44409..218c1d6e 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -57,7 +57,7 @@ pub fn current_context() -> Context { session, hostname, cwd, - host_id, + host_id: host_id.0.as_simple().to_string(), } } diff --git a/atuin-client/src/kv.rs b/atuin-client/src/kv.rs index c365a385..30018d63 100644 --- a/atuin-client/src/kv.rs +++ b/atuin-client/src/kv.rs @@ -101,10 +101,7 @@ impl KvStore { let bytes = record.serialize()?; - let parent = store - .last(host_id.as_str(), KV_TAG) - .await? - .map(|entry| entry.id); + let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id); let record = atuin_common::record::Record::builder() .host(host_id) @@ -130,17 +127,22 @@ impl KvStore { namespace: &str, key: &str, ) -> Result<Option<KvRecord>> { - // TODO: don't load this from disk so much - let host_id = Settings::host_id().expect("failed to get host_id"); - // Currently, this is O(n). When we have an actual KV store, it can be better // Just a poc for now! // iterate records to find the value we want // start at the end, so we get the most recent version - let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else { + let tails = store.tag_tails(KV_TAG).await?; + + if tails.is_empty() { return Ok(None); - }; + } + + // first, decide on a record. + // try getting the newest first + // we always need a way of deciding the "winner" of a write + // TODO(ellie): something better than last-write-wins, what if two write at the same time? + let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone(); loop { let decrypted = match record.version.as_str() { @@ -154,7 +156,7 @@ impl KvStore { } if let Some(parent) = decrypted.parent { - record = store.get(parent.as_str()).await?; + record = store.get(parent).await?; } else { break; } diff --git a/atuin-client/src/record/encryption.rs b/atuin-client/src/record/encryption.rs index f14bf027..6760d97b 100644 --- a/atuin-client/src/record/encryption.rs +++ b/atuin-client/src/record/encryption.rs @@ -1,4 +1,6 @@ -use atuin_common::record::{AdditionalData, DecryptedData, EncryptedData, Encryption}; +use atuin_common::record::{ + AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, +}; use base64::{engine::general_purpose, Engine}; use eyre::{ensure, Context, Result}; use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; @@ -158,10 +160,11 @@ struct AtuinFooter { // This cannot be changed, otherwise it breaks the authenticated encryption. #[derive(Debug, Copy, Clone, Serialize)] struct Assertions<'a> { - id: &'a str, + id: &'a RecordId, version: &'a str, tag: &'a str, - host: &'a str, + host: &'a HostId, + parent: Option<&'a RecordId>, } impl<'a> From<AdditionalData<'a>> for Assertions<'a> { @@ -171,6 +174,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> { version: ad.version, tag: ad.tag, host: ad.host, + parent: ad.parent, } } } @@ -183,7 +187,7 @@ impl Assertions<'_> { #[cfg(test)] mod tests { - use atuin_common::record::Record; + use atuin_common::{record::Record, utils::uuid_v7}; use super::*; @@ -192,10 +196,11 @@ mod tests { let key = Key::<V4, Local>::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -210,10 +215,11 @@ mod tests { let key = Key::<V4, Local>::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -233,10 +239,11 @@ mod tests { let fake_key = Key::<V4, Local>::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -250,10 +257,11 @@ mod tests { let key = Key::<V4, Local>::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -261,10 +269,8 @@ mod tests { let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); let ad = AdditionalData { - id: "foo1", - version: "v0", - tag: "kv", - host: "1234", + id: &RecordId(uuid_v7()), + ..ad }; let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); } @@ -275,10 +281,11 @@ mod tests { let key2 = Key::<V4, Local>::new_os_random(); let ad = AdditionalData { - id: "foo", + id: &RecordId(uuid_v7()), version: "v0", tag: "kv", - host: "1234", + host: &HostId(uuid_v7()), + parent: None, }; let data = DecryptedData(vec![1, 2, 3, 4]); @@ -304,10 +311,10 @@ mod tests { fn full_record_round_trip() { let key = [0x55; 32]; let record = Record::builder() - .id("1".to_owned()) + .id(RecordId(uuid_v7())) .version("v0".to_owned()) .tag("kv".to_owned()) - .host("host1".to_owned()) + .host(HostId(uuid_v7())) .timestamp(1687244806000000) .data(DecryptedData(vec![1, 2, 3, 4])) .build(); @@ -316,30 +323,20 @@ mod tests { assert!(!encrypted.data.data.is_empty()); assert!(!encrypted.data.content_encryption_key.is_empty()); - assert_eq!(encrypted.id, "1"); - assert_eq!(encrypted.host, "host1"); - assert_eq!(encrypted.version, "v0"); - assert_eq!(encrypted.tag, "kv"); - assert_eq!(encrypted.timestamp, 1687244806000000); let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap(); assert_eq!(decrypted.data.0, [1, 2, 3, 4]); - assert_eq!(decrypted.id, "1"); - assert_eq!(decrypted.host, "host1"); - assert_eq!(decrypted.version, "v0"); - assert_eq!(decrypted.tag, "kv"); - assert_eq!(decrypted.timestamp, 1687244806000000); } #[test] fn full_record_round_trip_fail() { let key = [0x55; 32]; let record = Record::builder() - .id("1".to_owned()) + .id(RecordId(uuid_v7())) .version("v0".to_owned()) .tag("kv".to_owned()) - .host("host1".to_owned()) + .host(HostId(uuid_v7())) .timestamp(1687244806000000) .data(DecryptedData(vec![1, 2, 3, 4])) .build(); @@ -347,13 +344,13 @@ mod tests { let encrypted = record.encrypt::<PASETO_V4>(&key); let mut enc1 = encrypted.clone(); - enc1.host = "host2".to_owned(); + enc1.host = HostId(uuid_v7()); let _ = enc1 .decrypt::<PASETO_V4>(&key) .expect_err("tampering with the host should result in auth failure"); let mut enc2 = encrypted; - enc2.id = "2".to_owned(); + enc2.id = RecordId(uuid_v7()); let _ = enc2 .decrypt::<PASETO_V4>(&key) .expect_err("tampering with the id should result in auth failure"); diff --git a/atuin-client/src/record/mod.rs b/atuin-client/src/record/mod.rs index 9ac2c541..8bc816ae 100644 --- a/atuin-client/src/record/mod.rs +++ b/atuin-client/src/record/mod.rs @@ -1,3 +1,5 @@ pub mod encryption; pub mod sqlite_store; pub mod store; +#[cfg(feature = "sync")] +pub mod sync; diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs index f692c0c2..14a7e277 100644 --- a/atuin-client/src/record/sqlite_store.rs +++ b/atuin-client/src/record/sqlite_store.rs @@ -8,12 +8,14 @@ use std::str::FromStr; use async_trait::async_trait; use eyre::{eyre, Result}; use fs_err as fs; +use futures::TryStreamExt; use sqlx::{ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, Row, }; -use atuin_common::record::{EncryptedData, Record}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; +use uuid::Uuid; use super::store::Store; @@ -62,11 +64,11 @@ impl SqliteStore { "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek) values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", ) - .bind(r.id.as_str()) - .bind(r.host.as_str()) + .bind(r.id.0.as_simple().to_string()) + .bind(r.host.0.as_simple().to_string()) .bind(r.tag.as_str()) .bind(r.timestamp as i64) - .bind(r.parent.as_ref()) + .bind(r.parent.map(|p| p.0.as_simple().to_string())) .bind(r.version.as_str()) .bind(r.data.data.as_str()) .bind(r.data.content_encryption_key.as_str()) @@ -79,10 +81,18 @@ impl SqliteStore { fn query_row(row: SqliteRow) -> Record<EncryptedData> { let timestamp: i64 = row.get("timestamp"); + // tbh at this point things are pretty fucked so just panic + let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); + let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); + let parent: Option<&str> = row.get("parent"); + + let parent = parent + .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB")); + Record { - id: row.get("id"), - host: row.get("host"), - parent: row.get("parent"), + id: RecordId(id), + host: HostId(host), + parent: parent.map(RecordId), timestamp: timestamp as u64, tag: row.get("tag"), version: row.get("version"), @@ -111,9 +121,9 @@ impl Store for SqliteStore { Ok(()) } - async fn get(&self, id: &str) -> Result<Record<EncryptedData>> { + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> { let res = sqlx::query("select * from records where id = ?1") - .bind(id) + .bind(id.0.as_simple().to_string()) .map(Self::query_row) .fetch_one(&self.pool) .await?; @@ -121,10 +131,10 @@ impl Store for SqliteStore { Ok(res) } - async fn len(&self, host: &str, tag: &str) -> Result<u64> { + async fn len(&self, host: HostId, tag: &str) -> Result<u64> { let res: (i64,) = sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") - .bind(host) + .bind(host.0.as_simple().to_string()) .bind(tag) .fetch_one(&self.pool) .await?; @@ -134,7 +144,7 @@ impl Store for SqliteStore { async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> { let res = sqlx::query("select * from records where parent = ?1") - .bind(record.id.clone()) + .bind(record.id.0.as_simple().to_string()) .map(Self::query_row) .fetch_one(&self.pool) .await; @@ -146,11 +156,11 @@ impl Store for SqliteStore { } } - async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> { + async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { let res = sqlx::query( "select * from records where host = ?1 and tag = ?2 and parent is null limit 1", ) - .bind(host) + .bind(host.0.as_simple().to_string()) .bind(tag) .map(Self::query_row) .fetch_optional(&self.pool) @@ -159,23 +169,53 @@ impl Store for SqliteStore { Ok(res) } - async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> { + async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { let res = sqlx::query( "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", ) .bind(tag) - .bind(host) + .bind(host.0.as_simple().to_string()) .map(Self::query_row) .fetch_optional(&self.pool) .await?; Ok(res) } + + async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query( + "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;", + ) + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn tail_records(&self) -> Result<RecordIndex> { + let res = sqlx::query( + "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;", + ) + .map(|row: SqliteRow| { + let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host"); + let tag: String= row.get("tag"); + let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id"); + + (HostId(host), tag, RecordId(id)) + }) + .fetch(&self.pool) + .try_collect() + .await?; + + Ok(res) + } } #[cfg(test)] mod tests { - use atuin_common::record::{EncryptedData, Record}; + use atuin_common::record::{EncryptedData, HostId, Record}; use crate::record::{encryption::PASETO_V4, store::Store}; @@ -183,7 +223,7 @@ mod tests { fn test_record() -> Record<EncryptedData> { Record::builder() - .host(atuin_common::utils::uuid_v7().simple().to_string()) + .host(HostId(atuin_common::utils::uuid_v7())) .version("v1".into()) .tag(atuin_common::utils::uuid_v7().simple().to_string()) .data(EncryptedData { @@ -218,10 +258,7 @@ mod tests { let record = test_record(); db.push(&record).await.unwrap(); - let new_record = db - .get(record.id.as_str()) - .await - .expect("failed to fetch record"); + let new_record = db.get(record.id).await.expect("failed to fetch record"); |