summaryrefslogtreecommitdiffstats
path: root/atuin-client/src/record/sqlite_store.rs
diff options
context:
space:
mode:
Diffstat (limited to 'atuin-client/src/record/sqlite_store.rs')
-rw-r--r--atuin-client/src/record/sqlite_store.rs99
1 files changed, 65 insertions, 34 deletions
diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs
index f692c0c2..14a7e277 100644
--- a/atuin-client/src/record/sqlite_store.rs
+++ b/atuin-client/src/record/sqlite_store.rs
@@ -8,12 +8,14 @@ use std::str::FromStr;
use async_trait::async_trait;
use eyre::{eyre, Result};
use fs_err as fs;
+use futures::TryStreamExt;
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Row,
};
-use atuin_common::record::{EncryptedData, Record};
+use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
+use uuid::Uuid;
use super::store::Store;
@@ -62,11 +64,11 @@ impl SqliteStore {
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
- .bind(r.id.as_str())
- .bind(r.host.as_str())
+ .bind(r.id.0.as_simple().to_string())
+ .bind(r.host.0.as_simple().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
- .bind(r.parent.as_ref())
+ .bind(r.parent.map(|p| p.0.as_simple().to_string()))
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
@@ -79,10 +81,18 @@ impl SqliteStore {
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
let timestamp: i64 = row.get("timestamp");
+ // tbh at this point things are pretty fucked so just panic
+ let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
+ let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
+ let parent: Option<&str> = row.get("parent");
+
+ let parent = parent
+ .map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
+
Record {
- id: row.get("id"),
- host: row.get("host"),
- parent: row.get("parent"),
+ id: RecordId(id),
+ host: HostId(host),
+ parent: parent.map(RecordId),
timestamp: timestamp as u64,
tag: row.get("tag"),
version: row.get("version"),
@@ -111,9 +121,9 @@ impl Store for SqliteStore {
Ok(())
}
- async fn get(&self, id: &str) -> Result<Record<EncryptedData>> {
+ async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
let res = sqlx::query("select * from records where id = ?1")
- .bind(id)
+ .bind(id.0.as_simple().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
.await?;
@@ -121,10 +131,10 @@ impl Store for SqliteStore {
Ok(res)
}
- async fn len(&self, host: &str, tag: &str) -> Result<u64> {
+ async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
let res: (i64,) =
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
- .bind(host)
+ .bind(host.0.as_simple().to_string())
.bind(tag)
.fetch_one(&self.pool)
.await?;
@@ -134,7 +144,7 @@ impl Store for SqliteStore {
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query("select * from records where parent = ?1")
- .bind(record.id.clone())
+ .bind(record.id.0.as_simple().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
.await;
@@ -146,11 +156,11 @@ impl Store for SqliteStore {
}
}
- async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
)
- .bind(host)
+ .bind(host.0.as_simple().to_string())
.bind(tag)
.map(Self::query_row)
.fetch_optional(&self.pool)
@@ -159,23 +169,53 @@ impl Store for SqliteStore {
Ok(res)
}
- async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> {
+ async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
)
.bind(tag)
- .bind(host)
+ .bind(host.0.as_simple().to_string())
.map(Self::query_row)
.fetch_optional(&self.pool)
.await?;
Ok(res)
}
+
+ async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
+ let res = sqlx::query(
+ "select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
+ )
+ .bind(tag)
+ .map(Self::query_row)
+ .fetch_all(&self.pool)
+ .await?;
+
+ Ok(res)
+ }
+
+ async fn tail_records(&self) -> Result<RecordIndex> {
+ let res = sqlx::query(
+ "select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
+ )
+ .map(|row: SqliteRow| {
+ let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
+ let tag: String= row.get("tag");
+ let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
+
+ (HostId(host), tag, RecordId(id))
+ })
+ .fetch(&self.pool)
+ .try_collect()
+ .await?;
+
+ Ok(res)
+ }
}
#[cfg(test)]
mod tests {
- use atuin_common::record::{EncryptedData, Record};
+ use atuin_common::record::{EncryptedData, HostId, Record};
use crate::record::{encryption::PASETO_V4, store::Store};
@@ -183,7 +223,7 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
- .host(atuin_common::utils::uuid_v7().simple().to_string())
+ .host(HostId(atuin_common::utils::uuid_v7()))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
@@ -218,10 +258,7 @@ mod tests {
let record = test_record();
db.push(&record).await.unwrap();
- let new_record = db
- .get(record.id.as_str())
- .await
- .expect("failed to fetch record");
+ let new_record = db.get(record.id).await.expect("failed to fetch record");
assert_eq!(record, new_record, "records are not equal");
}
@@ -233,7 +270,7 @@ mod tests {
db.push(&record).await.unwrap();
let len = db
- .len(record.host.as_str(), record.tag.as_str())
+ .len(record.host, record.tag.as_str())
.await
.expect("failed to get store len");
@@ -253,14 +290,8 @@ mod tests {
db.push(&first).await.unwrap();
db.push(&second).await.unwrap();
- let first_len = db
- .len(first.host.as_str(), first.tag.as_str())
- .await
- .unwrap();
- let second_len = db
- .len(second.host.as_str(), second.tag.as_str())
- .await
- .unwrap();
+ let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
+ let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert");
@@ -281,7 +312,7 @@ mod tests {
}
assert_eq!(
- db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host, tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
@@ -304,7 +335,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap();
assert_eq!(
- db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
+ db.len(tail.host, tail.tag.as_str()).await.unwrap(),
10000,
"failed to insert 10k records"
);
@@ -327,7 +358,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap();
let mut record = db
- .first(tail.host.as_str(), tail.tag.as_str())
+ .head(tail.host, tail.tag.as_str())
.await
.expect("in memory sqlite should not fail")
.expect("entry exists");