summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEllie Huxtable <ellie@elliehuxtable.com>2024-01-29 16:38:24 +0000
committerGitHub <noreply@github.com>2024-01-29 16:38:24 +0000
commit366b8ea97bbe36ad5e3dd8d45f1e787ee2a7f223 (patch)
treed889d76c73176def805c45fd10f72d4cd3a5a93a
parent15bad15f48fa714a7d99ac782cec9b5d11cf336b (diff)
feat: automatically init history store when record sync is enabled (#1634)
* add support for getting the total length of a store * tidy up sync * auto call init if history is ahead * fix import order, key regen * fix import order, key regen * do not delete key when user deletes account * message output * remote init store command; this is now automatic * should probs make that function return u64 at some point
-rw-r--r--atuin-client/src/history/store.rs30
-rw-r--r--atuin-client/src/record/sqlite_store.rs32
-rw-r--r--atuin-client/src/record/store.rs2
-rw-r--r--atuin-client/src/record/sync.rs49
-rw-r--r--atuin/src/command/client/account/delete.rs5
-rw-r--r--atuin/src/command/client/account/register.rs3
-rw-r--r--atuin/src/command/client/history.rs43
-rw-r--r--atuin/src/command/client/sync.rs24
8 files changed, 120 insertions, 68 deletions
diff --git a/atuin-client/src/history/store.rs b/atuin-client/src/history/store.rs
index 442da45d..0a2a2312 100644
--- a/atuin-client/src/history/store.rs
+++ b/atuin-client/src/history/store.rs
@@ -4,7 +4,7 @@ use eyre::{bail, eyre, Result};
use rmp::decode::Bytes;
use crate::{
- database::Database,
+ database::{self, Database},
record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store},
};
use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx};
@@ -255,6 +255,34 @@ impl HistoryStore {
Ok(ret)
}
+
+ pub async fn init_store(&self, context: database::Context, db: &impl Database) -> Result<()> {
+ println!("Importing all history.db data into records.db");
+
+ println!("Fetching history from old database");
+ let history = db.list(&[], &context, None, false, true).await?;
+
+ println!("Fetching history already in store");
+ let store_ids = self.history_ids().await?;
+
+ for i in history {
+ println!("loaded {}", i.id);
+
+ if store_ids.contains(&i.id) {
+ println!("skipping {} - already exists", i.id);
+ continue;
+ }
+
+ if i.deleted_at.is_some() {
+ self.push(i.clone()).await?;
+ self.delete(i.id).await?;
+ } else {
+ self.push(i).await?;
+ }
+ }
+
+ Ok(())
+ }
}
#[cfg(test)]
diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs
index 50f30d76..e9d7ff59 100644
--- a/atuin-client/src/record/sqlite_store.rs
+++ b/atuin-client/src/record/sqlite_store.rs
@@ -155,6 +155,18 @@ impl Store for SqliteStore {
self.idx(host, tag, 0).await
}
+ async fn len_tag(&self, tag: &str) -> Result<u64> {
+ let res: Result<(i64,), sqlx::Error> =
+ sqlx::query_as("select count(*) from store where tag=?1")
+ .bind(tag)
+ .fetch_one(&self.pool)
+ .await;
+ match res {
+ Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
+ Ok(v) => Ok(v.0 as u64),
+ }
+ }
+
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
let last = self.last(host, tag).await?;
@@ -343,6 +355,20 @@ mod tests {
}
#[tokio::test]
+ async fn len_tag() {
+ let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
+ let record = test_record();
+ db.push(&record).await.unwrap();
+
+ let len = db
+ .len_tag(record.tag.as_str())
+ .await
+ .expect("failed to get store len");
+
+ assert_eq!(len, 1, "expected length of 1 after insert");
+ }
+
+ #[tokio::test]
async fn len_different_tags() {
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
@@ -379,6 +405,12 @@ mod tests {
100,
"failed to insert 100 records"
);
+
+ assert_eq!(
+ db.len_tag(tail.tag.as_str()).await.unwrap(),
+ 100,
+ "failed to insert 100 records"
+ );
}
#[tokio::test]
diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs
index efe2eb4a..40c1224b 100644
--- a/atuin-client/src/record/store.rs
+++ b/atuin-client/src/record/store.rs
@@ -21,7 +21,9 @@ pub trait Store {
) -> Result<()>;
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
+
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
+ async fn len_tag(&self, tag: &str) -> Result<u64>;
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs
index 97152f79..eca0c930 100644
--- a/atuin-client/src/record/sync.rs
+++ b/atuin-client/src/record/sync.rs
@@ -14,14 +14,17 @@ pub enum SyncError {
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
LocalAheadOtherHost,
- #[error("an issue with the local database occured")]
- LocalStoreError,
+ #[error("an issue with the local database occured: {msg:?}")]
+ LocalStoreError { msg: String },
#[error("something has gone wrong with the sync logic: {msg:?}")]
SyncLogicError { msg: String },
- #[error("a request to the sync server failed")]
- RemoteRequestError,
+ #[error("operational error: {msg:?}")]
+ OperationalError { msg: String },
+
+ #[error("a request to the sync server failed: {msg:?}")]
+ RemoteRequestError { msg: String },
}
#[derive(Debug, Eq, PartialEq)]
@@ -45,16 +48,27 @@ pub enum Operation {
},
}
-pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
+pub async fn diff(
+ settings: &Settings,
+ store: &impl Store,
+) -> Result<(Vec<Diff>, RecordStatus), SyncError> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
settings.network_connect_timeout,
settings.network_timeout,
- )?;
+ )
+ .map_err(|e| SyncError::OperationalError { msg: e.to_string() })?;
+
+ let local_index = store
+ .status()
+ .await
+ .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
- let local_index = store.status().await?;
- let remote_index = client.record_status().await?;
+ let remote_index = client
+ .record_status()
+ .await
+ .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
let diff = local_index.diff(&remote_index);
@@ -166,13 +180,13 @@ async fn sync_upload(
.map_err(|e| {
error!("failed to read upload page: {e:?}");
- SyncError::LocalStoreError
+ SyncError::LocalStoreError { msg: e.to_string() }
})?;
client.post_records(&page).await.map_err(|e| {
error!("failed to post records: {e:?}");
- SyncError::RemoteRequestError
+ SyncError::RemoteRequestError { msg: e.to_string() }
})?;
println!(
@@ -217,12 +231,12 @@ async fn sync_download(
let page = client
.next_records(host, tag.clone(), local + progress, download_page_size)
.await
- .map_err(|_| SyncError::RemoteRequestError)?;
+ .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
store
.push_batch(page.iter())
.await
- .map_err(|_| SyncError::LocalStoreError)?;
+ .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
println!(
"downloaded {} records from remote, progress {}/{}",
@@ -283,6 +297,17 @@ pub async fn sync_remote(
Ok((uploaded, downloaded))
}
+pub async fn sync(
+ settings: &Settings,
+ store: &impl Store,
+) -> Result<(i64, Vec<RecordId>), SyncError> {
+ let (diff, _) = diff(settings, store).await?;
+ let operations = operations(diff, store).await?;
+ let (uploaded, downloaded) = sync_remote(operations, store, settings).await?;
+
+ Ok((uploaded, downloaded))
+}
+
#[cfg(test)]
mod tests {
use atuin_common::record::{Diff, EncryptedData, HostId, Record};
diff --git a/atuin/src/command/client/account/delete.rs b/atuin/src/command/client/account/delete.rs
index 6a4b1406..3591c6f3 100644
--- a/atuin/src/command/client/account/delete.rs
+++ b/atuin/src/command/client/account/delete.rs
@@ -5,7 +5,6 @@ use std::path::PathBuf;
pub async fn run(settings: &Settings) -> Result<()> {
let session_path = settings.session_path.as_str();
- let key_path = settings.key_path.as_str();
if !PathBuf::from(session_path).exists() {
bail!("You are not logged in");
@@ -25,10 +24,6 @@ pub async fn run(settings: &Settings) -> Result<()> {
remove_file(PathBuf::from(session_path))?;
}
- if PathBuf::from(key_path).exists() {
- remove_file(PathBuf::from(key_path))?;
- }
-
println!("Your account is deleted");
Ok(())
diff --git a/atuin/src/command/client/account/register.rs b/atuin/src/command/client/account/register.rs
index 0523dced..96b7d7d6 100644
--- a/atuin/src/command/client/account/register.rs
+++ b/atuin/src/command/client/account/register.rs
@@ -49,8 +49,7 @@ pub async fn run(
let mut file = File::create(path).await?;
file.write_all(session.session.as_bytes()).await?;
- // Create a new key, and save it to disk
- let _key = atuin_client::encryption::new_key(settings)?;
+ let _key = atuin_client::encryption::load_key(settings)?;
Ok(())
}
diff --git a/atuin/src/command/client/history.rs b/atuin/src/command/client/history.rs
index e983cc7b..18ae17cf 100644
--- a/atuin/src/command/client/history.rs
+++ b/atuin/src/command/client/history.rs
@@ -88,10 +88,6 @@ pub enum Cmd {
#[arg(long, short)]
format: Option<String>,
},
-
- /// Import all old history.db data into the record store. Do not run more than once, and do not
- /// run unless you know what you're doing (or the docs ask you to)
- InitStore,
}
#[derive(Clone, Copy, Debug)]
@@ -321,10 +317,7 @@ impl Cmd {
#[cfg(feature = "sync")]
{
if settings.sync.records {
- let (diff, _) = record::sync::diff(settings, &store).await?;
- let operations = record::sync::operations(diff, &store).await?;
- let (_, downloaded) =
- record::sync::sync_remote(operations, &store, settings).await?;
+ let (_, downloaded) = record::sync::sync(settings, &store).await?;
history_store.incremental_build(db, &downloaded).await?;
} else {
@@ -380,38 +373,6 @@ impl Cmd {
Ok(())
}
- async fn init_store(
- context: atuin_client::database::Context,
- db: &impl Database,
- store: HistoryStore,
- ) -> Result<()> {
- println!("Importing all history.db data into records.db");
-
- println!("Fetching history from old database");
- let history = db.list(&[], &context, None, false, true).await?;
-
- println!("Fetching history already in store");
- let store_ids = store.history_ids().await?;
-
- for i in history {
- println!("loaded {}", i.id);
-
- if store_ids.contains(&i.id) {
- println!("skipping {} - already exists", i.id);
- continue;
- }
-
- if i.deleted_at.is_some() {
- store.push(i.clone()).await?;
- store.delete(i.id).await?;
- } else {
- store.push(i).await?;
- }
- }
-
- Ok(())
- }
-
pub async fn run(
self,
settings: &Settings,
@@ -468,8 +429,6 @@ impl Cmd {
Ok(())
}
-
- Self::InitStore => Self::init_store(context, db, history_store).await,
}
}
}
diff --git a/atuin/src/command/client/sync.rs b/atuin/src/command/client/sync.rs
index 2e58f07d..5b438453 100644
--- a/atuin/src/command/client/sync.rs
+++ b/atuin/src/command/client/sync.rs
@@ -2,10 +2,10 @@ use clap::Subcommand;
use eyre::{Result, WrapErr};
use atuin_client::{
- database::Database,
+ database::{current_context, Database},
encryption,
history::store::HistoryStore,
- record::{sqlite_store::SqliteStore, sync},
+ record::{sqlite_store::SqliteStore, store::Store, sync},
settings::Settings,
};
@@ -80,10 +80,6 @@ async fn run(
store: SqliteStore,
) -> Result<()> {
if settings.sync.records {
- let (diff, _) = sync::diff(settings, &store).await?;
- let operations = sync::operations(diff, &store).await?;
- let (uploaded, downloaded) = sync::sync_remote(operations, &store, settings).await?;
-
let encryption_key: [u8; 32] = encryption::load_key(settings)
.context("could not load encryption key")?
.into();
@@ -91,6 +87,22 @@ async fn run(
let host_id = Settings::host_id().expect("failed to get host_id");
let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
+ let history_length = db.history_count(true).await?;
+ let store_history_length = store.len_tag("history").await?;
+
+ #[allow(clippy::cast_sign_loss)]
+ if history_length as u64 > store_history_length {
+ println!("History DB is longer than history record store");
+ println!("This happens when you used Atuin pre-record-store");
+
+ let context = current_context();
+ history_store.init_store(context, db).await?;
+
+ println!("\n");
+ }
+
+ let (uploaded, downloaded) = sync::sync(settings, &store).await?;
+
history_store.incremental_build(db, &downloaded).await?;
println!("{uploaded}/{} up/down to record store", downloaded.len());