diff options
author | Ellie Huxtable <e@elm.sh> | 2021-04-25 18:21:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-25 17:21:52 +0000 |
commit | 156893d774b4da5b541fdbb08428f9ec392949a0 (patch) | |
tree | 9185d94384aa62eb6eb099ddc4ca9408df6f90d1 /atuin-client | |
parent | 4210e8de5a29eb389b753adf8df47d2c449a2eeb (diff) |
Update docs, unify on SQLx, bugfixes (#40)
* Begin moving to sqlx for local too
* Stupid scanners should just have a nice cup of tea
Random internet shit searching for /.env or whatever
* Remove diesel and rusqlite fully
Diffstat (limited to 'atuin-client')
-rw-r--r-- | atuin-client/Cargo.toml | 2 | ||||
-rw-r--r-- | atuin-client/migrations/20210422143411_create_history.sql | 16 | ||||
-rw-r--r-- | atuin-client/src/database.rs | 355 | ||||
-rw-r--r-- | atuin-client/src/encryption.rs | 2 | ||||
-rw-r--r-- | atuin-client/src/history.rs | 2 | ||||
-rw-r--r-- | atuin-client/src/settings.rs | 46 | ||||
-rw-r--r-- | atuin-client/src/sync.rs | 10 |
7 files changed, 181 insertions, 252 deletions
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index 4d3e91301..bd09ca42d 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -37,6 +37,6 @@ tokio = { version = "1", features = ["full"] } async-trait = "0.1.49" urlencoding = "1.1.1" humantime = "2.1.0" -rusqlite= { version = "0.25", features = ["bundled"] } itertools = "0.10.0" shellexpand = "2" +sqlx = { version = "0.5", features = [ "runtime-tokio-rustls", "uuid", "chrono", "sqlite" ] } diff --git a/atuin-client/migrations/20210422143411_create_history.sql b/atuin-client/migrations/20210422143411_create_history.sql new file mode 100644 index 000000000..23c63a4f4 --- /dev/null +++ b/atuin-client/migrations/20210422143411_create_history.sql @@ -0,0 +1,16 @@ +-- Add migration script here +create table if not exists history ( + id text primary key, + timestamp text not null, + duration integer not null, + exit integer not null, + command text not null, + cwd text not null, + session text not null, + hostname text not null, + + unique(timestamp, cwd, command) +); + +create index if not exists idx_history_timestamp on history(timestamp); +create index if not exists idx_history_command on history(command); diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index 0855359b8..754a0ecfc 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -1,44 +1,48 @@ -use chrono::prelude::*; -use chrono::Utc; use std::path::Path; +use std::str::FromStr; + +use async_trait::async_trait; +use chrono::Utc; -use eyre::{eyre, Result}; +use eyre::Result; -use rusqlite::{params, Connection}; -use rusqlite::{Params, Transaction}; +use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; use super::history::History; +#[async_trait] pub trait Database { - fn save(&mut self, h: &History) -> Result<()>; - fn save_bulk(&mut self, h: &[History]) -> Result<()>; + async fn save(&mut self, h: &History) -> Result<()>; + async fn save_bulk(&mut self, h: &[History]) -> Result<()>; - fn load(&self, id: &str) -> Result<History>; - fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>; - fn range(&self, from: chrono::DateTime<Utc>, to: chrono::DateTime<Utc>) - -> Result<Vec<History>>; + async fn load(&self, id: &str) -> Result<History>; + async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>; + async fn range( + &self, + from: chrono::DateTime<Utc>, + to: chrono::DateTime<Utc>, + ) -> Result<Vec<History>>; - fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>>; - fn update(&self, h: &History) -> Result<()>; - fn history_count(&self) -> Result<i64>; + async fn update(&self, h: &History) -> Result<()>; + async fn history_count(&self) -> Result<i64>; - fn first(&self) -> Result<History>; - fn last(&self) -> Result<History>; - fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>; + async fn first(&self) -> Result<History>; + async fn last(&self) -> Result<History>; + async fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>; - fn prefix_search(&self, query: &str) -> Result<Vec<History>>; + async fn search(&self, limit: Option<i64>, query: &str) -> Result<Vec<History>>; - fn search(&self, cwd: Option<String>, exit: Option<i64>, query: &str) -> Result<Vec<History>>; + async fn query_history(&self, query: &str) -> Result<Vec<History>>; } // Intended for use on a developer machine and not a sync server. // TODO: implement IntoIterator pub struct Sqlite { - conn: Connection, + pool: SqlitePool, } impl Sqlite { - pub fn new(path: impl AsRef<Path>) -> Result<Self> { + pub async fn new(path: impl AsRef<Path>) -> Result<Self> { let path = path.as_ref(); debug!("opening sqlite database at {:?}", path); @@ -49,137 +53,106 @@ impl Sqlite { } } - let conn = Connection::open(path)?; + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new().connect_with(opts).await?; - Self::setup_db(&conn)?; + Self::setup_db(&pool).await?; - Ok(Self { conn }) + Ok(Self { pool }) } - fn setup_db(conn: &Connection) -> Result<()> { + async fn setup_db(pool: &SqlitePool) -> Result<()> { debug!("running sqlite database setup"); - conn.execute( - "create table if not exists history ( - id text primary key, - timestamp integer not null, - duration integer not null, - exit integer not null, - command text not null, - cwd text not null, - session text not null, - hostname text not null, - - unique(timestamp, cwd, command) - )", - [], - )?; - - conn.execute( - "create table if not exists history_encrypted ( - id text primary key, - data blob not null - )", - [], - )?; - - conn.execute( - "create index if not exists idx_history_timestamp on history(timestamp)", - [], - )?; - - conn.execute( - "create index if not exists idx_history_command on history(command)", - [], - )?; + sqlx::migrate!("./migrations").run(pool).await?; Ok(()) } - fn save_raw(tx: &Transaction, h: &History) -> Result<()> { - tx.execute( - "insert or ignore into history ( - id, - timestamp, - duration, - exit, - command, - cwd, - session, - hostname - ) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", - params![ - h.id, - h.timestamp.timestamp_nanos(), - h.duration, - h.exit, - h.command, - h.cwd, - h.session, - h.hostname - ], - )?; + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { + sqlx::query( + "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(h.id.as_str()) + .bind(h.timestamp.to_rfc3339()) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .execute(tx) + .await?; Ok(()) } } +#[async_trait] impl Database for Sqlite { - fn save(&mut self, h: &History) -> Result<()> { + async fn save(&mut self, h: &History) -> Result<()> { debug!("saving history to sqlite"); - let tx = self.conn.transaction()?; - Self::save_raw(&tx, h)?; - tx.commit()?; + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, h).await?; + tx.commit().await?; Ok(()) } - fn save_bulk(&mut self, h: &[History]) -> Result<()> { + async fn save_bulk(&mut self, h: &[History]) -> Result<()> { debug!("saving history to sqlite"); - let tx = self.conn.transaction()?; + let mut tx = self.pool.begin().await?; + for i in h { - Self::save_raw(&tx, i)? + Self::save_raw(&mut tx, i).await? } - tx.commit()?; + + tx.commit().await?; Ok(()) } - fn load(&self, id: &str) -> Result<History> { + async fn load(&self, id: &str) -> Result<History> { debug!("loading history item {}", id); - let history = self.query( - "select id, timestamp, duration, exit, command, cwd, session, hostname from history - where id = ?1 limit 1", - &[id], - )?; + let res = sqlx::query_as::<_, History>("select * from history where id = ?1") + .bind(id) + .fetch_one(&self.pool) + .await?; - if history.is_empty() { - return Err(eyre!("could not find history with id {}", id)); - } - - let history = history[0].clone(); - - Ok(history) + Ok(res) } - fn update(&self, h: &History) -> Result<()> { + async fn update(&self, h: &History) -> Result<()> { debug!("updating sqlite history"); - self.conn.execute( + sqlx::query( "update history set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8 where id = ?1", - params![h.id, h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, h.cwd, h.session, h.hostname], - )?; + ) + .bind(h.id.as_str()) + .bind(h.timestamp.to_rfc3339()) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .execute(&self.pool) + .await?; Ok(()) } // make a unique list, that only shows the *newest* version of things - fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> { + async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> { debug!("listing history"); // very likely vulnerable to SQL injection @@ -208,144 +181,96 @@ impl Database for Sqlite { } ); - let history = self.query(query.as_str(), params![])?; + let res = sqlx::query_as::<_, History>(query.as_str()) + .fetch_all(&self.pool) + .await?; - Ok(history) + Ok(res) } - fn range( + async fn range( &self, from: chrono::DateTime<Utc>, to: chrono::DateTime<Utc>, ) -> Result<Vec<History>> { debug!("listing history from {:?} to {:?}", from, to); - let mut stmt = self.conn.prepare( - "SELECT * FROM history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", - )?; - - let history_iter = stmt.query_map( - params![from.timestamp_nanos(), to.timestamp_nanos()], - |row| history_from_sqlite_row(None, row), - )?; + let res = sqlx::query_as::<_, History>( + "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", + ) + .bind(from) + .bind(to) + .fetch_all(&self.pool) + .await?; - Ok(history_iter.filter_map(Result::ok).collect()) + Ok(res) } - fn first(&self) -> Result<History> { - let mut stmt = self - .conn - .prepare("SELECT * FROM history order by timestamp asc limit 1")?; - - let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?; + async fn first(&self) -> Result<History> { + let res = sqlx::query_as::<_, History>( + "select * from history where duration >= 0 order by timestamp asc limit 1", + ) + .fetch_one(&self.pool) + .await?; - Ok(history) + Ok(res) } - fn last(&self) -> Result<History> { - let mut stmt = self - .conn - .prepare("SELECT * FROM history where duration >= 0 order by timestamp desc limit 1")?; - - let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?; + async fn last(&self) -> Result<History> { + let res = sqlx::query_as::<_, History>( + "select * from history where duration >= 0 order by timestamp desc limit 1", + ) + .fetch_one(&self.pool) + .await?; - Ok(history) + Ok(res) } - fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> { - let mut stmt = self - .conn - .prepare("SELECT * FROM history where timestamp < ? order by timestamp desc limit ?")?; - - let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| { - history_from_sqlite_row(None, row) - })?; + async fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> { + let res = sqlx::query_as::<_, History>( + "select * from history where timestamp < ?1 order by timestamp desc limit ?2", + ) + .bind(timestamp) + .bind(count) + .fetch_all(&self.pool) + .await?; - Ok(history_iter.filter_map(Result::ok).collect()) + Ok(res) } - fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>> { - let mut stmt = self.conn.prepare(query)?; + async fn history_count(&self) -> Result<i64> { + let res: (i64,) = sqlx::query_as("select count(1) from history") + .fetch_one(&self.pool) + .await?; - let history_iter = stmt.query_map(params, |row| history_from_sqlite_row(None, row))?; - - Ok(history_iter.filter_map(Result::ok).collect()) + Ok(res.0) } - fn prefix_search(&self, query: &str) -> Result<Vec<History>> { + async fn search(&self, limit: Option<i64>, query: &str) -> Result<Vec<History>> { let query = query.to_string().replace("*", "%"); // allow wildcard char + let limit = limit.map_or("".to_owned(), |l| format!("limit {}", l)); - self.query( - "select * from history h - where command like ?1 || '%' - and timestamp = ( - select max(timestamp) from history - where h.command = history.command - ) - order by timestamp desc limit 200", - &[query.as_str()], + let res = sqlx::query_as::<_, History>( + format!( + "select * from history + where command like ?1 || '%' + order by timestamp desc {}", + limit.clone() + ) + .as_str(), ) - } - - fn history_count(&self) -> Result<i64> { - let res: i64 = - self.conn - .query_row_and_then("select count(1) from history;", params![], |row| row.get(0))?; + .bind(query) + .fetch_all(&self.pool) + .await?; Ok(res) } - fn search(&self, cwd: Option<String>, exit: Option<i64>, query: &str) -> Result<Vec<History>> { - match (cwd, exit) { - (Some(cwd), Some(exit)) => self.query( - "select * from history - where command like ?1 || '%' - and cwd = ?2 - and exit = ?3 - order by timestamp asc limit 1000", - &[query, cwd.as_str(), exit.to_string().as_str()], - ), - (Some(cwd), None) => self.query( - "select * from history - where command like ?1 || '%' - and cwd = ?2 - order by timestamp asc limit 1000", - &[query, cwd.as_str()], - ), - (None, Some(exit)) => self.query( - "select * from history - where command like ?1 || '%' - and exit = ?2 - order by timestamp asc limit 1000", - &[query, exit.to_string().as_str()], - ), - (None, None) => self.query( - "select * from history - where command like ?1 || '%' - order by timestamp asc limit 1000", - &[query], - ), - } - } -} + async fn query_history(&self, query: &str) -> Result<Vec<History>> { + let res = sqlx::query_as::<_, History>(query) + .fetch_all(&self.pool) + .await?; -fn history_from_sqlite_row( - id: Option<String>, - row: &rusqlite::Row, -) -> Result<History, rusqlite::Error> { - let id = match id { - Some(id) => id, - None => row.get(0)?, - }; - - Ok(History { - id, - timestamp: Utc.timestamp_nanos(row.get(1)?), - duration: row.get(2)?, - exit: row.get(3)?, - command: row.get(4)?, - cwd: row.get(5)?, - session: row.get(6)?, - hostname: row.get(7)?, - }) + Ok(res) + } } diff --git a/atuin-client/src/encryption.rs b/atuin-client/src/encryption.rs index 19b773ab9..9cb8d3ea1 100644 --- a/atuin-client/src/encryption.rs +++ b/atuin-client/src/encryption.rs @@ -98,7 +98,7 @@ pub fn decrypt(encrypted_history: &EncryptedHistory, key: &secretbox::Key) -> Re mod test { use sodiumoxide::crypto::secretbox; - use crate::local::history::History; + use crate::history::History; use super::{decrypt, encrypt}; diff --git a/atuin-client/src/history.rs b/atuin-client/src/history.rs index 8dd161dba..92e92ddfe 100644 --- a/atuin-client/src/history.rs +++ b/atuin-client/src/history.rs @@ -6,7 +6,7 @@ use chrono::Utc; use atuin_common::utils::uuid_v4; // Any new fields MUST be Optional<>! -#[derive(Debug, Clone, Serialize, Deserialize, Ord, PartialOrd)] +#[derive(Debug, Clone, Serialize, Deserialize, Ord, PartialOrd, sqlx::FromRow)] pub struct History { pub id: String, pub timestamp: chrono::DateTime<Utc>, diff --git a/atuin-client/src/settings.rs b/atuin-client/src/settings.rs index 254bca6d5..4ea4be841 100644 --- a/atuin-client/src/settings.rs +++ b/atuin-client/src/settings.rs @@ -5,7 +5,6 @@ use std::path::{Path, PathBuf}; use chrono::prelude::*; use chrono::Utc; use config::{Config, Environment, File as ConfigFile}; -use directories::ProjectDirs; use eyre::{eyre, Result}; use parse_duration::parse; @@ -28,9 +27,10 @@ pub struct Settings { impl Settings { pub fn save_sync_time() -> Result<()> { - let sync_time_path = ProjectDirs::from("com", "elliehuxtable", "atuin") - .ok_or_else(|| eyre!("could not determine key file location"))?; - let sync_time_path = sync_time_path.data_dir().join("last_sync_time"); + let data_dir = atuin_common::utils::data_dir(); + let data_dir = data_dir.as_path(); + + let sync_time_path = data_dir.join("last_sync_time"); std::fs::write(sync_time_path, Utc::now().to_rfc3339())?; @@ -38,15 +38,10 @@ impl Settings { } pub fn last_sync() -> Result<chrono::DateTime<Utc>> { - let sync_time_path = ProjectDirs::from("com", "elliehuxtable", "atuin"); - - if sync_time_path.is_none() { - debug!("failed to load projectdirs, not syncing"); - return Err(eyre!("could not load project dirs")); - } + let data_dir = atuin_common::utils::data_dir(); + let data_dir = data_dir.as_path(); - let sync_time_path = sync_time_path.unwrap(); - let sync_time_path = sync_time_path.data_dir().join("last_sync_time"); + let sync_time_path = data_dir.join("last_sync_time"); if !sync_time_path.exists() { return Ok(Utc.ymd(1970, 1, 1).and_hms(0, 0, 0)); @@ -73,10 +68,14 @@ impl Settings { } pub fn new() -> Result<Self> { - let config_dir = ProjectDirs::from("com", "elliehuxtable", "atuin").unwrap(); - let config_dir = config_dir.config_dir(); + let config_dir = atuin_common::utils::config_dir(); + let config_dir = config_dir.as_path(); + + let data_dir = atuin_common::utils::data_dir(); + let data_dir = data_dir.as_path(); create_dir_all(config_dir)?; + create_dir_all(data_dir)?; let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { PathBuf::from(p) @@ -90,27 +89,16 @@ impl Settings { let mut s = Config::new(); - let db_path = ProjectDirs::from("com", "elliehuxtable", "atuin") - .ok_or_else(|| eyre!("could not determine db file location"))? - .data_dir() - .join("history.db"); - - let key_path = ProjectDirs::from("com", "elliehuxtable", "atuin") - .ok_or_else(|| eyre!("could not determine key file location"))? - .data_dir() - .join("key"); - - let session_path = ProjectDirs::from("com", "elliehuxtable", "atuin") - .ok_or_else(|| eyre!("could not determine session file location"))? - .data_dir() - .join("session"); + let db_path = data_dir.join("history.db"); + let key_path = data_dir.join("key"); + let session_path = data_dir.join("session"); s.set_default("db_path", db_path.to_str())?; s.set_default("key_path", key_path.to_str())?; s.set_default("session_path", session_path.to_str())?; s.set_default("dialect", "us")?; s.set_default("auto_sync", true)?; - s.set_default("sync_frequency", "5m")?; + s.set_default("sync_frequency", "1h")?; s.set_default("sync_address", "https://api.atuin.sh")?; if config_file.exists() { diff --git a/atuin-client/src/sync.rs b/atuin-client/src/sync.rs index 5d81a5e6e..944080183 100644 --- a/atuin-client/src/sync.rs +++ b/atuin-client/src/sync.rs @@ -30,7 +30,7 @@ async fn sync_download( let remote_count = client.count().await?; - let initial_local = db.history_count()?; + let initial_local = db.history_count().await?; let mut local_count = initial_local; let mut last_sync = if force { @@ -48,9 +48,9 @@ async fn sync_download( .get_history(last_sync, last_timestamp, host.clone()) .await?; - db.save_bulk(&page)?; + db.save_bulk(&page).await?; - local_count = db.history_count()?; + local_count = db.history_count().await?; if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() { break; @@ -87,7 +87,7 @@ async fn sync_upload( let initial_remote_count = client.count().await?; let mut remote_count = initial_remote_count; - let local_count = db.history_count()?; + let local_count = db.history_count().await?; debug!("remote has {}, we have {}", remote_count, local_count); @@ -98,7 +98,7 @@ async fn sync_upload( let mut cursor = Utc::now(); while local_count > remote_count { - let last = db.before(cursor, HISTORY_PAGE_SIZE)?; + let last = db.before(cursor, HISTORY_PAGE_SIZE).await?; let mut buffer = Vec::<AddHistoryRequest>::new(); if last.is_empty() { |