diff options
Diffstat (limited to 'atuin-client/src/record/sqlite_store.rs')
-rw-r--r-- | atuin-client/src/record/sqlite_store.rs | 99 |
1 files changed, 65 insertions, 34 deletions
diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs index f692c0c2..14a7e277 100644 --- a/atuin-client/src/record/sqlite_store.rs +++ b/atuin-client/src/record/sqlite_store.rs @@ -8,12 +8,14 @@ use std::str::FromStr; use async_trait::async_trait; use eyre::{eyre, Result}; use fs_err as fs; +use futures::TryStreamExt; use sqlx::{ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, Row, }; -use atuin_common::record::{EncryptedData, Record}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; +use uuid::Uuid; use super::store::Store; @@ -62,11 +64,11 @@ impl SqliteStore { "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek) values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", ) - .bind(r.id.as_str()) - .bind(r.host.as_str()) + .bind(r.id.0.as_simple().to_string()) + .bind(r.host.0.as_simple().to_string()) .bind(r.tag.as_str()) .bind(r.timestamp as i64) - .bind(r.parent.as_ref()) + .bind(r.parent.map(|p| p.0.as_simple().to_string())) .bind(r.version.as_str()) .bind(r.data.data.as_str()) .bind(r.data.content_encryption_key.as_str()) @@ -79,10 +81,18 @@ impl SqliteStore { fn query_row(row: SqliteRow) -> Record<EncryptedData> { let timestamp: i64 = row.get("timestamp"); + // tbh at this point things are pretty fucked so just panic + let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); + let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); + let parent: Option<&str> = row.get("parent"); + + let parent = parent + .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB")); + Record { - id: row.get("id"), - host: row.get("host"), - parent: row.get("parent"), + id: RecordId(id), + host: HostId(host), + parent: parent.map(RecordId), timestamp: timestamp as u64, tag: row.get("tag"), version: row.get("version"), @@ -111,9 +121,9 @@ impl Store for SqliteStore { Ok(()) } - async fn get(&self, id: &str) -> Result<Record<EncryptedData>> { + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> { let res = sqlx::query("select * from records where id = ?1") - .bind(id) + .bind(id.0.as_simple().to_string()) .map(Self::query_row) .fetch_one(&self.pool) .await?; @@ -121,10 +131,10 @@ impl Store for SqliteStore { Ok(res) } - async fn len(&self, host: &str, tag: &str) -> Result<u64> { + async fn len(&self, host: HostId, tag: &str) -> Result<u64> { let res: (i64,) = sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") - .bind(host) + .bind(host.0.as_simple().to_string()) .bind(tag) .fetch_one(&self.pool) .await?; @@ -134,7 +144,7 @@ impl Store for SqliteStore { async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> { let res = sqlx::query("select * from records where parent = ?1") - .bind(record.id.clone()) + .bind(record.id.0.as_simple().to_string()) .map(Self::query_row) .fetch_one(&self.pool) .await; @@ -146,11 +156,11 @@ impl Store for SqliteStore { } } - async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> { + async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { let res = sqlx::query( "select * from records where host = ?1 and tag = ?2 and parent is null limit 1", ) - .bind(host) + .bind(host.0.as_simple().to_string()) .bind(tag) .map(Self::query_row) .fetch_optional(&self.pool) @@ -159,23 +169,53 @@ impl Store for SqliteStore { Ok(res) } - async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> { + async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { let res = sqlx::query( "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", ) .bind(tag) - .bind(host) + .bind(host.0.as_simple().to_string()) .map(Self::query_row) .fetch_optional(&self.pool) .await?; Ok(res) } + + async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query( + "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;", + ) + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn tail_records(&self) -> Result<RecordIndex> { + let res = sqlx::query( + "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;", + ) + .map(|row: SqliteRow| { + let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host"); + let tag: String= row.get("tag"); + let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id"); + + (HostId(host), tag, RecordId(id)) + }) + .fetch(&self.pool) + .try_collect() + .await?; + + Ok(res) + } } #[cfg(test)] mod tests { - use atuin_common::record::{EncryptedData, Record}; + use atuin_common::record::{EncryptedData, HostId, Record}; use crate::record::{encryption::PASETO_V4, store::Store}; @@ -183,7 +223,7 @@ mod tests { fn test_record() -> Record<EncryptedData> { Record::builder() - .host(atuin_common::utils::uuid_v7().simple().to_string()) + .host(HostId(atuin_common::utils::uuid_v7())) .version("v1".into()) .tag(atuin_common::utils::uuid_v7().simple().to_string()) .data(EncryptedData { @@ -218,10 +258,7 @@ mod tests { let record = test_record(); db.push(&record).await.unwrap(); - let new_record = db - .get(record.id.as_str()) - .await - .expect("failed to fetch record"); + let new_record = db.get(record.id).await.expect("failed to fetch record"); assert_eq!(record, new_record, "records are not equal"); } @@ -233,7 +270,7 @@ mod tests { db.push(&record).await.unwrap(); let len = db - .len(record.host.as_str(), record.tag.as_str()) + .len(record.host, record.tag.as_str()) .await .expect("failed to get store len"); @@ -253,14 +290,8 @@ mod tests { db.push(&first).await.unwrap(); db.push(&second).await.unwrap(); - let first_len = db - .len(first.host.as_str(), first.tag.as_str()) - .await - .unwrap(); - let second_len = db - .len(second.host.as_str(), second.tag.as_str()) - .await - .unwrap(); + let first_len = db.len(first.host, first.tag.as_str()).await.unwrap(); + let second_len = db.len(second.host, second.tag.as_str()).await.unwrap(); assert_eq!(first_len, 1, "expected length of 1 after insert"); assert_eq!(second_len, 1, "expected length of 1 after insert"); @@ -281,7 +312,7 @@ mod tests { } assert_eq!( - db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + db.len(tail.host, tail.tag.as_str()).await.unwrap(), 100, "failed to insert 100 records" ); @@ -304,7 +335,7 @@ mod tests { db.push_batch(records.iter()).await.unwrap(); assert_eq!( - db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), + db.len(tail.host, tail.tag.as_str()).await.unwrap(), 10000, "failed to insert 10k records" ); @@ -327,7 +358,7 @@ mod tests { db.push_batch(records.iter()).await.unwrap(); let mut record = db - .first(tail.host.as_str(), tail.tag.as_str()) + .head(tail.host, tail.tag.as_str()) .await .expect("in memory sqlite should not fail") .expect("entry exists"); |