summaryrefslogtreecommitdiffstats
path: root/atuin-client/src/database.rs
diff options
context:
space:
mode:
Diffstat (limited to 'atuin-client/src/database.rs')
-rw-r--r--atuin-client/src/database.rs355
1 files changed, 140 insertions, 215 deletions
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs
index 0855359b..754a0ecf 100644
--- a/atuin-client/src/database.rs
+++ b/atuin-client/src/database.rs
@@ -1,44 +1,48 @@
-use chrono::prelude::*;
-use chrono::Utc;
use std::path::Path;
+use std::str::FromStr;
+
+use async_trait::async_trait;
+use chrono::Utc;
-use eyre::{eyre, Result};
+use eyre::Result;
-use rusqlite::{params, Connection};
-use rusqlite::{Params, Transaction};
+use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions};
use super::history::History;
+#[async_trait]
pub trait Database {
- fn save(&mut self, h: &History) -> Result<()>;
- fn save_bulk(&mut self, h: &[History]) -> Result<()>;
+ async fn save(&mut self, h: &History) -> Result<()>;
+ async fn save_bulk(&mut self, h: &[History]) -> Result<()>;
- fn load(&self, id: &str) -> Result<History>;
- fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>;
- fn range(&self, from: chrono::DateTime<Utc>, to: chrono::DateTime<Utc>)
- -> Result<Vec<History>>;
+ async fn load(&self, id: &str) -> Result<History>;
+ async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>;
+ async fn range(
+ &self,
+ from: chrono::DateTime<Utc>,
+ to: chrono::DateTime<Utc>,
+ ) -> Result<Vec<History>>;
- fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>>;
- fn update(&self, h: &History) -> Result<()>;
- fn history_count(&self) -> Result<i64>;
+ async fn update(&self, h: &History) -> Result<()>;
+ async fn history_count(&self) -> Result<i64>;
- fn first(&self) -> Result<History>;
- fn last(&self) -> Result<History>;
- fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>;
+ async fn first(&self) -> Result<History>;
+ async fn last(&self) -> Result<History>;
+ async fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>;
- fn prefix_search(&self, query: &str) -> Result<Vec<History>>;
+ async fn search(&self, limit: Option<i64>, query: &str) -> Result<Vec<History>>;
- fn search(&self, cwd: Option<String>, exit: Option<i64>, query: &str) -> Result<Vec<History>>;
+ async fn query_history(&self, query: &str) -> Result<Vec<History>>;
}
// Intended for use on a developer machine and not a sync server.
// TODO: implement IntoIterator
pub struct Sqlite {
- conn: Connection,
+ pool: SqlitePool,
}
impl Sqlite {
- pub fn new(path: impl AsRef<Path>) -> Result<Self> {
+ pub async fn new(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
debug!("opening sqlite database at {:?}", path);
@@ -49,137 +53,106 @@ impl Sqlite {
}
}
- let conn = Connection::open(path)?;
+ let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
+ .journal_mode(SqliteJournalMode::Wal)
+ .create_if_missing(true);
+
+ let pool = SqlitePoolOptions::new().connect_with(opts).await?;
- Self::setup_db(&conn)?;
+ Self::setup_db(&pool).await?;
- Ok(Self { conn })
+ Ok(Self { pool })
}
- fn setup_db(conn: &Connection) -> Result<()> {
+ async fn setup_db(pool: &SqlitePool) -> Result<()> {
debug!("running sqlite database setup");
- conn.execute(
- "create table if not exists history (
- id text primary key,
- timestamp integer not null,
- duration integer not null,
- exit integer not null,
- command text not null,
- cwd text not null,
- session text not null,
- hostname text not null,
-
- unique(timestamp, cwd, command)
- )",
- [],
- )?;
-
- conn.execute(
- "create table if not exists history_encrypted (
- id text primary key,
- data blob not null
- )",
- [],
- )?;
-
- conn.execute(
- "create index if not exists idx_history_timestamp on history(timestamp)",
- [],
- )?;
-
- conn.execute(
- "create index if not exists idx_history_command on history(command)",
- [],
- )?;
+ sqlx::migrate!("./migrations").run(pool).await?;
Ok(())
}
- fn save_raw(tx: &Transaction, h: &History) -> Result<()> {
- tx.execute(
- "insert or ignore into history (
- id,
- timestamp,
- duration,
- exit,
- command,
- cwd,
- session,
- hostname
- ) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
- params![
- h.id,
- h.timestamp.timestamp_nanos(),
- h.duration,
- h.exit,
- h.command,
- h.cwd,
- h.session,
- h.hostname
- ],
- )?;
+ async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> {
+ sqlx::query(
+ "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname)
+ values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
+ )
+ .bind(h.id.as_str())
+ .bind(h.timestamp.to_rfc3339())
+ .bind(h.duration)
+ .bind(h.exit)
+ .bind(h.command.as_str())
+ .bind(h.cwd.as_str())
+ .bind(h.session.as_str())
+ .bind(h.hostname.as_str())
+ .execute(tx)
+ .await?;
Ok(())
}
}
+#[async_trait]
impl Database for Sqlite {
- fn save(&mut self, h: &History) -> Result<()> {
+ async fn save(&mut self, h: &History) -> Result<()> {
debug!("saving history to sqlite");
- let tx = self.conn.transaction()?;
- Self::save_raw(&tx, h)?;
- tx.commit()?;
+ let mut tx = self.pool.begin().await?;
+ Self::save_raw(&mut tx, h).await?;
+ tx.commit().await?;
Ok(())
}
- fn save_bulk(&mut self, h: &[History]) -> Result<()> {
+ async fn save_bulk(&mut self, h: &[History]) -> Result<()> {
debug!("saving history to sqlite");
- let tx = self.conn.transaction()?;
+ let mut tx = self.pool.begin().await?;
+
for i in h {
- Self::save_raw(&tx, i)?
+ Self::save_raw(&mut tx, i).await?
}
- tx.commit()?;
+
+ tx.commit().await?;
Ok(())
}
- fn load(&self, id: &str) -> Result<History> {
+ async fn load(&self, id: &str) -> Result<History> {
debug!("loading history item {}", id);
- let history = self.query(
- "select id, timestamp, duration, exit, command, cwd, session, hostname from history
- where id = ?1 limit 1",
- &[id],
- )?;
+ let res = sqlx::query_as::<_, History>("select * from history where id = ?1")
+ .bind(id)
+ .fetch_one(&self.pool)
+ .await?;
- if history.is_empty() {
- return Err(eyre!("could not find history with id {}", id));
- }
-
- let history = history[0].clone();
-
- Ok(history)
+ Ok(res)
}
- fn update(&self, h: &History) -> Result<()> {
+ async fn update(&self, h: &History) -> Result<()> {
debug!("updating sqlite history");
- self.conn.execute(
+ sqlx::query(
"update history
set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8
where id = ?1",
- params![h.id, h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, h.cwd, h.session, h.hostname],
- )?;
+ )
+ .bind(h.id.as_str())
+ .bind(h.timestamp.to_rfc3339())
+ .bind(h.duration)
+ .bind(h.exit)
+ .bind(h.command.as_str())
+ .bind(h.cwd.as_str())
+ .bind(h.session.as_str())
+ .bind(h.hostname.as_str())
+ .execute(&self.pool)
+ .await?;
Ok(())
}
// make a unique list, that only shows the *newest* version of things
- fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> {
+ async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> {
debug!("listing history");
// very likely vulnerable to SQL injection
@@ -208,144 +181,96 @@ impl Database for Sqlite {
}
);
- let history = self.query(query.as_str(), params![])?;
+ let res = sqlx::query_as::<_, History>(query.as_str())
+ .fetch_all(&self.pool)
+ .await?;
- Ok(history)
+ Ok(res)
}
- fn range(
+ async fn range(
&self,
from: chrono::DateTime<Utc>,
to: chrono::DateTime<Utc>,
) -> Result<Vec<History>> {
debug!("listing history from {:?} to {:?}", from, to);
- let mut stmt = self.conn.prepare(
- "SELECT * FROM history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
- )?;
-
- let history_iter = stmt.query_map(
- params![from.timestamp_nanos(), to.timestamp_nanos()],
- |row| history_from_sqlite_row(None, row),
- )?;
+ let res = sqlx::query_as::<_, History>(
+ "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc",
+ )
+ .bind(from)
+ .bind(to)
+ .fetch_all(&self.pool)
+ .await?;
- Ok(history_iter.filter_map(Result::ok).collect())
+ Ok(res)
}
- fn first(&self) -> Result<History> {
- let mut stmt = self
- .conn
- .prepare("SELECT * FROM history order by timestamp asc limit 1")?;
-
- let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?;
+ async fn first(&self) -> Result<History> {
+ let res = sqlx::query_as::<_, History>(
+ "select * from history where duration >= 0 order by timestamp asc limit 1",
+ )
+ .fetch_one(&self.pool)
+ .await?;
- Ok(history)
+ Ok(res)
}
- fn last(&self) -> Result<History> {
- let mut stmt = self
- .conn
- .prepare("SELECT * FROM history where duration >= 0 order by timestamp desc limit 1")?;
-
- let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?;
+ async fn last(&self) -> Result<History> {
+ let res = sqlx::query_as::<_, History>(
+ "select * from history where duration >= 0 order by timestamp desc limit 1",
+ )
+ .fetch_one(&self.pool)
+ .await?;
- Ok(history)
+ Ok(res)
}
- fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> {
- let mut stmt = self
- .conn
- .prepare("SELECT * FROM history where timestamp < ? order by timestamp desc limit ?")?;
-
- let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| {
- history_from_sqlite_row(None, row)
- })?;
+ async fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> {
+ let res = sqlx::query_as::<_, History>(
+ "select * from history where timestamp < ?1 order by timestamp desc limit ?2",
+ )
+ .bind(timestamp)
+ .bind(count)
+ .fetch_all(&self.pool)
+ .await?;
- Ok(history_iter.filter_map(Result::ok).collect())
+ Ok(res)
}
- fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>> {
- let mut stmt = self.conn.prepare(query)?;
+ async fn history_count(&self) -> Result<i64> {
+ let res: (i64,) = sqlx::query_as("select count(1) from history")
+ .fetch_one(&self.pool)
+ .await?;
- let history_iter = stmt.query_map(params, |row| history_from_sqlite_row(None, row))?;
-
- Ok(history_iter.filter_map(Result::ok).collect())
+ Ok(res.0)
}
- fn prefix_search(&self, query: &str) -> Result<Vec<History>> {
+ async fn search(&self, limit: Option<i64>, query: &str) -> Result<Vec<History>> {
let query = query.to_string().replace("*", "%"); // allow wildcard char
+ let limit = limit.map_or("".to_owned(), |l| format!("limit {}", l));
- self.query(
- "select * from history h
- where command like ?1 || '%'
- and timestamp = (
- select max(timestamp) from history
- where h.command = history.command
- )
- order by timestamp desc limit 200",
- &[query.as_str()],
+ let res = sqlx::query_as::<_, History>(
+ format!(
+ "select * from history
+ where command like ?1 || '%'
+ order by timestamp desc {}",
+ limit.clone()
+ )
+ .as_str(),
)
- }
-
- fn history_count(&self) -> Result<i64> {
- let res: i64 =
- self.conn
- .query_row_and_then("select count(1) from history;", params![], |row| row.get(0))?;
+ .bind(query)
+ .fetch_all(&self.pool)
+ .await?;
Ok(res)
}
- fn search(&self, cwd: Option<String>, exit: Option<i64>, query: &str) -> Result<Vec<History>> {
- match (cwd, exit) {
- (Some(cwd), Some(exit)) => self.query(
- "select * from history
- where command like ?1 || '%'
- and cwd = ?2
- and exit = ?3
- order by timestamp asc limit 1000",
- &[query, cwd.as_str(), exit.to_string().as_str()],
- ),
- (Some(cwd), None) => self.query(
- "select * from history
- where command like ?1 || '%'
- and cwd = ?2
- order by timestamp asc limit 1000",
- &[query, cwd.as_str()],
- ),
- (None, Some(exit)) => self.query(
- "select * from history
- where command like ?1 || '%'
- and exit = ?2
- order by timestamp asc limit 1000",
- &[query, exit.to_string().as_str()],
- ),
- (None, None) => self.query(
- "select * from history
- where command like ?1 || '%'
- order by timestamp asc limit 1000",
- &[query],
- ),
- }
- }
-}
+ async fn query_history(&self, query: &str) -> Result<Vec<History>> {
+ let res = sqlx::query_as::<_, History>(query)
+ .fetch_all(&self.pool)
+ .await?;
-fn history_from_sqlite_row(
- id: Option<String>,
- row: &rusqlite::Row,
-) -> Result<History, rusqlite::Error> {
- let id = match id {
- Some(id) => id,
- None => row.get(0)?,
- };
-
- Ok(History {
- id,
- timestamp: Utc.timestamp_nanos(row.get(1)?),
- duration: row.get(2)?,
- exit: row.get(3)?,
- command: row.get(4)?,
- cwd: row.get(5)?,
- session: row.get(6)?,
- hostname: row.get(7)?,
- })
+ Ok(res)
+ }
}