use std::path::Path; use std::str::FromStr; use async_trait::async_trait; use chrono::prelude::*; use chrono::Utc; use eyre::Result; use sqlx::sqlite::{ SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, }; use sqlx::Row; use super::history::History; use super::settings::SearchMode; #[async_trait] pub trait Database { async fn save(&mut self, h: &History) -> Result<()>; async fn save_bulk(&mut self, h: &[History]) -> Result<()>; async fn load(&self, id: &str) -> Result; async fn list(&self, max: Option, unique: bool) -> Result>; async fn range( &self, from: chrono::DateTime, to: chrono::DateTime, ) -> Result>; async fn update(&self, h: &History) -> Result<()>; async fn history_count(&self) -> Result; async fn first(&self) -> Result; async fn last(&self) -> Result; async fn before(&self, timestamp: chrono::DateTime, count: i64) -> Result>; async fn search( &self, limit: Option, search_mode: SearchMode, query: &str, ) -> Result>; async fn query_history(&self, query: &str) -> Result>; } // Intended for use on a developer machine and not a sync server. // TODO: implement IntoIterator pub struct Sqlite { pool: SqlitePool, } impl Sqlite { pub async fn new(path: impl AsRef) -> Result { let path = path.as_ref(); debug!("opening sqlite database at {:?}", path); let create = !path.exists(); if create { if let Some(dir) = path.parent() { std::fs::create_dir_all(dir)?; } } let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? .journal_mode(SqliteJournalMode::Wal) .create_if_missing(true); let pool = SqlitePoolOptions::new().connect_with(opts).await?; Self::setup_db(&pool).await?; Ok(Self { pool }) } async fn setup_db(pool: &SqlitePool) -> Result<()> { debug!("running sqlite database setup"); sqlx::migrate!("./migrations").run(pool).await?; Ok(()) } async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { sqlx::query( "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname) values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", ) .bind(h.id.as_str()) .bind(h.timestamp.timestamp_nanos()) .bind(h.duration) .bind(h.exit) .bind(h.command.as_str()) .bind(h.cwd.as_str()) .bind(h.session.as_str()) .bind(h.hostname.as_str()) .execute(tx) .await?; Ok(()) } fn query_history(row: SqliteRow) -> History { History { id: row.get("id"), timestamp: Utc.timestamp_nanos(row.get("timestamp")), duration: row.get("duration"), exit: row.get("exit"), command: row.get("command"), cwd: row.get("cwd"), session: row.get("session"), hostname: row.get("hostname"), } } } #[async_trait] impl Database for Sqlite { async fn save(&mut self, h: &History) -> Result<()> { debug!("saving history to sqlite"); let mut tx = self.pool.begin().await?; Self::save_raw(&mut tx, h).await?; tx.commit().await?; Ok(()) } async fn save_bulk(&mut self, h: &[History]) -> Result<()> { debug!("saving history to sqlite"); let mut tx = self.pool.begin().await?; for i in h { Self::save_raw(&mut tx, i).await? } tx.commit().await?; Ok(()) } async fn load(&self, id: &str) -> Result { debug!("loading history item {}", id); let res = sqlx::query("select * from history where id = ?1") .bind(id) .map(Self::query_history) .fetch_one(&self.pool) .await?; Ok(res) } async fn update(&self, h: &History) -> Result<()> { debug!("updating sqlite history"); sqlx::query( "update history set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8 where id = ?1", ) .bind(h.id.as_str()) .bind(h.timestamp.timestamp_nanos()) .bind(h.duration) .bind(h.exit) .bind(h.command.as_str()) .bind(h.cwd.as_str()) .bind(h.session.as_str()) .bind(h.hostname.as_str()) .execute(&self.pool) .await?; Ok(()) } // make a unique list, that only shows the *newest* version of things async fn list(&self, max: Option, unique: bool) -> Result> { debug!("listing history"); // very likely vulnerable to SQL injection // however, this is client side, and only used by the client, on their // own data. They can just open the db file... // otherwise building the query is awkward let query = format!( "select * from history h {} order by timestamp desc {}", // inject the unique check if unique { "where timestamp = ( select max(timestamp) from history where h.command = history.command )" } else { "" }, // inject the limit if let Some(max) = max { format!("limit {}", max) } else { "".to_string() } ); let res = sqlx::query(query.as_str()) .map(Self::query_history) .fetch_all(&self.pool) .await?; Ok(res) } async fn range( &self, from: chrono::DateTime, to: chrono::DateTime, ) -> Result> { debug!("listing history from {:?} to {:?}", from, to); let res = sqlx::query( "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", ) .bind(from) .bind(to) .map(Self::query_history) .fetch_all(&self.pool) .await?; Ok(res) } async fn first(&self) -> Result { let res = sqlx::query("select * from history where duration >= 0 order by timestamp asc limit 1") .map(Self::query_history) .fetch_one(&self.pool) .await?; Ok(res) } async fn last(&self) -> Result { let res = sqlx::query( "select * from history where duration >= 0 order by timestamp desc limit 1", ) .map(Self::query_history) .fetch_one(&self.pool) .await?; Ok(res) } async fn before(&self, timestamp: chrono::DateTime, count: i64) -> Result> { let res = sqlx::query( "select * from history where timestamp < ?1 order by timestamp desc limit ?2", ) .bind(timestamp.timestamp_nanos()) .bind(count) .map(Self::query_history) .fetch_all(&self.pool) .await?; Ok(res) } async fn history_count(&self) -> Result { let res: (i64,) = sqlx::query_as("select count(1) from history") .fetch_one(&self.pool) .await?; Ok(res.0) } async fn search( &self, limit: Option, search_mode: SearchMode, query: &str, ) -> Result> { let query = query.to_string().replace("*", "%"); // allow wildcard char let limit = limit.map_or("".to_owned(), |l| format!("limit {}", l)); let query = match search_mode { SearchMode::Prefix => query, SearchMode::FullText => format!("%{}", query), }; let res = sqlx::query( format!( "select * from history h where command like ?1 || '%' and timestamp = ( select max(timestamp) from history where h.command = history.command ) order by timestamp desc {}", limit.clone() ) .as_str(), ) .bind(query) .map(Self::query_history) .fetch_all(&self.pool) .await?; Ok(res) } async fn query_history(&self, query: &str) -> Result> { let res = sqlx::query(query) .map(Self::query_history) .fetch_all(&self.pool) .await?; Ok(res) } }