diff options
Diffstat (limited to 'atuin-client/src/database.rs')
-rw-r--r-- | atuin-client/src/database.rs | 355 |
1 files changed, 140 insertions, 215 deletions
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index 0855359b..754a0ecf 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -1,44 +1,48 @@ -use chrono::prelude::*; -use chrono::Utc; use std::path::Path; +use std::str::FromStr; + +use async_trait::async_trait; +use chrono::Utc; -use eyre::{eyre, Result}; +use eyre::Result; -use rusqlite::{params, Connection}; -use rusqlite::{Params, Transaction}; +use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; use super::history::History; +#[async_trait] pub trait Database { - fn save(&mut self, h: &History) -> Result<()>; - fn save_bulk(&mut self, h: &[History]) -> Result<()>; + async fn save(&mut self, h: &History) -> Result<()>; + async fn save_bulk(&mut self, h: &[History]) -> Result<()>; - fn load(&self, id: &str) -> Result<History>; - fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>; - fn range(&self, from: chrono::DateTime<Utc>, to: chrono::DateTime<Utc>) - -> Result<Vec<History>>; + async fn load(&self, id: &str) -> Result<History>; + async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>; + async fn range( + &self, + from: chrono::DateTime<Utc>, + to: chrono::DateTime<Utc>, + ) -> Result<Vec<History>>; - fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>>; - fn update(&self, h: &History) -> Result<()>; - fn history_count(&self) -> Result<i64>; + async fn update(&self, h: &History) -> Result<()>; + async fn history_count(&self) -> Result<i64>; - fn first(&self) -> Result<History>; - fn last(&self) -> Result<History>; - fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>; + async fn first(&self) -> Result<History>; + async fn last(&self) -> Result<History>; + async fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>; - fn prefix_search(&self, query: &str) -> Result<Vec<History>>; + async fn search(&self, limit: Option<i64>, query: &str) -> Result<Vec<History>>; - fn search(&self, cwd: Option<String>, exit: Option<i64>, query: &str) -> Result<Vec<History>>; + async fn query_history(&self, query: &str) -> Result<Vec<History>>; } // Intended for use on a developer machine and not a sync server. // TODO: implement IntoIterator pub struct Sqlite { - conn: Connection, + pool: SqlitePool, } impl Sqlite { - pub fn new(path: impl AsRef<Path>) -> Result<Self> { + pub async fn new(path: impl AsRef<Path>) -> Result<Self> { let path = path.as_ref(); debug!("opening sqlite database at {:?}", path); @@ -49,137 +53,106 @@ impl Sqlite { } } - let conn = Connection::open(path)?; + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new().connect_with(opts).await?; - Self::setup_db(&conn)?; + Self::setup_db(&pool).await?; - Ok(Self { conn }) + Ok(Self { pool }) } - fn setup_db(conn: &Connection) -> Result<()> { + async fn setup_db(pool: &SqlitePool) -> Result<()> { debug!("running sqlite database setup"); - conn.execute( - "create table if not exists history ( - id text primary key, - timestamp integer not null, - duration integer not null, - exit integer not null, - command text not null, - cwd text not null, - session text not null, - hostname text not null, - - unique(timestamp, cwd, command) - )", - [], - )?; - - conn.execute( - "create table if not exists history_encrypted ( - id text primary key, - data blob not null - )", - [], - )?; - - conn.execute( - "create index if not exists idx_history_timestamp on history(timestamp)", - [], - )?; - - conn.execute( - "create index if not exists idx_history_command on history(command)", - [], - )?; + sqlx::migrate!("./migrations").run(pool).await?; Ok(()) } - fn save_raw(tx: &Transaction, h: &History) -> Result<()> { - tx.execute( - "insert or ignore into history ( - id, - timestamp, - duration, - exit, - command, - cwd, - session, - hostname - ) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", - params![ - h.id, - h.timestamp.timestamp_nanos(), - h.duration, - h.exit, - h.command, - h.cwd, - h.session, - h.hostname - ], - )?; + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { + sqlx::query( + "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(h.id.as_str()) + .bind(h.timestamp.to_rfc3339()) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .execute(tx) + .await?; Ok(()) } } +#[async_trait] impl Database for Sqlite { - fn save(&mut self, h: &History) -> Result<()> { + async fn save(&mut self, h: &History) -> Result<()> { debug!("saving history to sqlite"); - let tx = self.conn.transaction()?; - Self::save_raw(&tx, h)?; - tx.commit()?; + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, h).await?; + tx.commit().await?; Ok(()) } - fn save_bulk(&mut self, h: &[History]) -> Result<()> { + async fn save_bulk(&mut self, h: &[History]) -> Result<()> { debug!("saving history to sqlite"); - let tx = self.conn.transaction()?; + let mut tx = self.pool.begin().await?; + for i in h { - Self::save_raw(&tx, i)? + Self::save_raw(&mut tx, i).await? } - tx.commit()?; + + tx.commit().await?; Ok(()) } - fn load(&self, id: &str) -> Result<History> { + async fn load(&self, id: &str) -> Result<History> { debug!("loading history item {}", id); - let history = self.query( - "select id, timestamp, duration, exit, command, cwd, session, hostname from history - where id = ?1 limit 1", - &[id], - )?; + let res = sqlx::query_as::<_, History>("select * from history where id = ?1") + .bind(id) + .fetch_one(&self.pool) + .await?; - if history.is_empty() { - return Err(eyre!("could not find history with id {}", id)); - } - - let history = history[0].clone(); - - Ok(history) + Ok(res) } - fn update(&self, h: &History) -> Result<()> { + async fn update(&self, h: &History) -> Result<()> { debug!("updating sqlite history"); - self.conn.execute( + sqlx::query( "update history set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8 where id = ?1", - params![h.id, h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, h.cwd, h.session, h.hostname], - )?; + ) + .bind(h.id.as_str()) + .bind(h.timestamp.to_rfc3339()) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .execute(&self.pool) + .await?; Ok(()) } // make a unique list, that only shows the *newest* version of things - fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> { + async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> { debug!("listing history"); // very likely vulnerable to SQL injection @@ -208,144 +181,96 @@ impl Database for Sqlite { } ); - let history = self.query(query.as_str(), params![])?; + let res = sqlx::query_as::<_, History>(query.as_str()) + .fetch_all(&self.pool) + .await?; - Ok(history) + Ok(res) } - fn range( + async fn range( &self, from: chrono::DateTime<Utc>, to: chrono::DateTime<Utc>, ) -> Result<Vec<History>> { debug!("listing history from {:?} to {:?}", from, to); - let mut stmt = self.conn.prepare( - "SELECT * FROM history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", - )?; - - let history_iter = stmt.query_map( - params![from.timestamp_nanos(), to.timestamp_nanos()], - |row| history_from_sqlite_row(None, row), - )?; + let res = sqlx::query_as::<_, History>( + "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", + ) + .bind(from) + .bind(to) + .fetch_all(&self.pool) + .await?; - Ok(history_iter.filter_map(Result::ok).collect()) + Ok(res) } - fn first(&self) -> Result<History> { - let mut stmt = self - .conn - .prepare("SELECT * FROM history order by timestamp asc limit 1")?; - - let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?; + async fn first(&self) -> Result<History> { + let res = sqlx::query_as::<_, History>( + "select * from history where duration >= 0 order by timestamp asc limit 1", + ) + .fetch_one(&self.pool) + .await?; - Ok(history) + Ok(res) } - fn last(&self) -> Result<History> { - let mut stmt = self - .conn - .prepare("SELECT * FROM history where duration >= 0 order by timestamp desc limit 1")?; - - let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?; + async fn last(&self) -> Result<History> { + let res = sqlx::query_as::<_, History>( + "select * from history where duration >= 0 order by timestamp desc limit 1", + ) + .fetch_one(&self.pool) + .await?; - Ok(history) + Ok(res) } - fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> { - let mut stmt = self - .conn - .prepare("SELECT * FROM history where timestamp < ? order by timestamp desc limit ?")?; - - let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| { - history_from_sqlite_row(None, row) - })?; + async fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> { + let res = sqlx::query_as::<_, History>( + "select * from history where timestamp < ?1 order by timestamp desc limit ?2", + ) + .bind(timestamp) + .bind(count) + .fetch_all(&self.pool) + .await?; - Ok(history_iter.filter_map(Result::ok).collect()) + Ok(res) } - fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>> { - let mut stmt = self.conn.prepare(query)?; + async fn history_count(&self) -> Result<i64> { + let res: (i64,) = sqlx::query_as("select count(1) from history") + .fetch_one(&self.pool) + .await?; - let history_iter = stmt.query_map(params, |row| history_from_sqlite_row(None, row))?; - - Ok(history_iter.filter_map(Result::ok).collect()) + Ok(res.0) } - fn prefix_search(&self, query: &str) -> Result<Vec<History>> { + async fn search(&self, limit: Option<i64>, query: &str) -> Result<Vec<History>> { let query = query.to_string().replace("*", "%"); // allow wildcard char + let limit = limit.map_or("".to_owned(), |l| format!("limit {}", l)); - self.query( - "select * from history h - where command like ?1 || '%' - and timestamp = ( - select max(timestamp) from history - where h.command = history.command - ) - order by timestamp desc limit 200", - &[query.as_str()], + let res = sqlx::query_as::<_, History>( + format!( + "select * from history + where command like ?1 || '%' + order by timestamp desc {}", + limit.clone() + ) + .as_str(), ) - } - - fn history_count(&self) -> Result<i64> { - let res: i64 = - self.conn - .query_row_and_then("select count(1) from history;", params![], |row| row.get(0))?; + .bind(query) + .fetch_all(&self.pool) + .await?; Ok(res) } - fn search(&self, cwd: Option<String>, exit: Option<i64>, query: &str) -> Result<Vec<History>> { - match (cwd, exit) { - (Some(cwd), Some(exit)) => self.query( - "select * from history - where command like ?1 || '%' - and cwd = ?2 - and exit = ?3 - order by timestamp asc limit 1000", - &[query, cwd.as_str(), exit.to_string().as_str()], - ), - (Some(cwd), None) => self.query( - "select * from history - where command like ?1 || '%' - and cwd = ?2 - order by timestamp asc limit 1000", - &[query, cwd.as_str()], - ), - (None, Some(exit)) => self.query( - "select * from history - where command like ?1 || '%' - and exit = ?2 - order by timestamp asc limit 1000", - &[query, exit.to_string().as_str()], - ), - (None, None) => self.query( - "select * from history - where command like ?1 || '%' - order by timestamp asc limit 1000", - &[query], - ), - } - } -} + async fn query_history(&self, query: &str) -> Result<Vec<History>> { + let res = sqlx::query_as::<_, History>(query) + .fetch_all(&self.pool) + .await?; -fn history_from_sqlite_row( - id: Option<String>, - row: &rusqlite::Row, -) -> Result<History, rusqlite::Error> { - let id = match id { - Some(id) => id, - None => row.get(0)?, - }; - - Ok(History { - id, - timestamp: Utc.timestamp_nanos(row.get(1)?), - duration: row.get(2)?, - exit: row.get(3)?, - command: row.get(4)?, - cwd: row.get(5)?, - session: row.get(6)?, - hostname: row.get(7)?, - }) + Ok(res) + } } |