diff options
Diffstat (limited to 'crates/atuin-client/src')
29 files changed, 7831 insertions, 0 deletions
diff --git a/crates/atuin-client/src/api_client.rs b/crates/atuin-client/src/api_client.rs new file mode 100644 index 00000000..f31a796e --- /dev/null +++ b/crates/atuin-client/src/api_client.rs @@ -0,0 +1,415 @@ +use std::collections::HashMap; +use std::env; +use std::time::Duration; + +use eyre::{bail, Result}; +use reqwest::{ + header::{HeaderMap, AUTHORIZATION, USER_AGENT}, + Response, StatusCode, Url, +}; + +use atuin_common::{ + api::{ + AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest, + ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse, + SyncHistoryResponse, + }, + record::RecordStatus, +}; +use atuin_common::{ + api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, + record::{EncryptedData, HostId, Record, RecordIdx}, +}; + +use semver::Version; +use time::format_description::well_known::Rfc3339; +use time::OffsetDateTime; + +use crate::{history::History, sync::hash_str, utils::get_host_user}; + +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); + +pub struct Client<'a> { + sync_addr: &'a str, + client: reqwest::Client, +} + +pub async fn register( + address: &str, + username: &str, + email: &str, + password: &str, +) -> Result<RegisterResponse> { + let mut map = HashMap::new(); + map.insert("username", username); + map.insert("email", email); + map.insert("password", password); + + let url = format!("{address}/user/{username}"); + let resp = reqwest::get(url).await?; + + if resp.status().is_success() { + bail!("username already in use"); + } + + let url = format!("{address}/register"); + let client = reqwest::Client::new(); + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION) + .json(&map) + .send() + .await?; + + if !ensure_version(&resp)? { + bail!("could not register user due to version mismatch"); + } + + if !resp.status().is_success() { + let error = resp.json::<ErrorResponse>().await?; + bail!("failed to register user: {}", error.reason); + } + + let session = resp.json::<RegisterResponse>().await?; + Ok(session) +} + +pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> { + let url = format!("{address}/login"); + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .json(&req) + .send() + .await?; + + if !ensure_version(&resp)? { + bail!("could not login due to version mismatch"); + } + + if resp.status() != reqwest::StatusCode::OK { + let error = resp.json::<ErrorResponse>().await?; + bail!("invalid login details: {}", error.reason); + } + + let session = resp.json::<LoginResponse>().await?; + Ok(session) +} + +#[cfg(feature = "check-update")] +pub async fn latest_version() -> Result<Version> { + use atuin_common::api::IndexResponse; + + let url = "https://api.atuin.sh"; + let client = reqwest::Client::new(); + + let resp = client + .get(url) + .header(USER_AGENT, APP_USER_AGENT) + .send() + .await?; + + if resp.status() != reqwest::StatusCode::OK { + let error = resp.json::<ErrorResponse>().await?; + bail!("failed to check latest version: {}", error.reason); + } + + let index = resp.json::<IndexResponse>().await?; + let version = Version::parse(index.version.as_str())?; + + Ok(version) +} + +pub fn ensure_version(response: &Response) -> Result<bool> { + let version = response.headers().get(ATUIN_HEADER_VERSION); + + let version = if let Some(version) = version { + match version.to_str() { + Ok(v) => Version::parse(v), + Err(e) => bail!("failed to parse server version: {:?}", e), + } + } else { + // if there is no version header, then the newest this server can possibly be is 17.1.0 + Version::parse("17.1.0") + }?; + + // If the client is newer than the server + if version.major < ATUIN_VERSION.major { + println!("Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"); + println!("Client: {}", ATUIN_CARGO_VERSION); + println!("Server: {}", version); + + return Ok(false); + } + + Ok(true) +} + +async fn handle_resp_error(resp: Response) -> Result<Response> { + let status = resp.status(); + + if status == StatusCode::SERVICE_UNAVAILABLE { + bail!( + "Service unavailable: check https://status.atuin.sh (or get in touch with your host)" + ); + } + + if !status.is_success() { + if let Ok(error) = resp.json::<ErrorResponse>().await { + let reason = error.reason; + + if status.is_client_error() { + bail!("Could not fetch history, client error {status}: {reason}.") + } + + bail!("There was an error with the atuin sync service, server error {status}: {reason}.\nIf the problem persists, contact the host") + } + + bail!("There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host") + } + + Ok(resp) +} + +impl<'a> Client<'a> { + pub fn new( + sync_addr: &'a str, + session_token: &str, + connect_timeout: u64, + timeout: u64, + ) -> Result<Self> { + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?); + + // used for semver server check + headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); + + Ok(Client { + sync_addr, + client: reqwest::Client::builder() + .user_agent(APP_USER_AGENT) + .default_headers(headers) + .connect_timeout(Duration::new(connect_timeout, 0)) + .timeout(Duration::new(timeout, 0)) + .build()?, + }) + } + + pub async fn count(&self) -> Result<i64> { + let url = format!("{}/sync/count", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + if resp.status() != StatusCode::OK { + bail!("failed to get count (are you logged in?)"); + } + + let count = resp.json::<CountResponse>().await?; + + Ok(count.count) + } + + pub async fn status(&self) -> Result<StatusResponse> { + let url = format!("{}/sync/status", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + let status = resp.json::<StatusResponse>().await?; + + Ok(status) + } + + pub async fn me(&self) -> Result<MeResponse> { + let url = format!("{}/api/v0/me", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let status = resp.json::<MeResponse>().await?; + + Ok(status) + } + + pub async fn get_history( + &self, + sync_ts: OffsetDateTime, + history_ts: OffsetDateTime, + host: Option<String>, + ) -> Result<SyncHistoryResponse> { + let host = host.unwrap_or_else(|| hash_str(&get_host_user())); + + let url = format!( + "{}/sync/history?sync_ts={}&history_ts={}&host={}", + self.sync_addr, + urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()), + urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()), + host, + ); + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let history = resp.json::<SyncHistoryResponse>().await?; + Ok(history) + } + + pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { + let url = format!("{}/history", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.post(url).json(history).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_history(&self, h: History) -> Result<()> { + let url = format!("{}/history", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .delete(url) + .json(&DeleteHistoryRequest { + client_id: h.id.to_string(), + }) + .send() + .await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_store(&self) -> Result<()> { + let url = format!("{}/api/v0/store", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> { + let url = format!("{}/api/v0/record", self.sync_addr); + let url = Url::parse(url.as_str())?; + + debug!("uploading {} records to {url}", records.len()); + + let resp = self.client.post(url).json(records).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn next_records( + &self, + host: HostId, + tag: String, + start: RecordIdx, + count: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + debug!( + "fetching record/s from host {}/{}/{}", + host.0.to_string(), + tag, + start + ); + + let url = format!( + "{}/api/v0/record/next?host={}&tag={}&count={}&start={}", + self.sync_addr, host.0, tag, count, start + ); + + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let records = resp.json::<Vec<Record<EncryptedData>>>().await?; + + Ok(records) + } + + pub async fn record_status(&self) -> Result<RecordStatus> { + let url = format!("{}/api/v0/record", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync records due to version mismatch"); + } + + let index = resp.json().await?; + + debug!("got remote index {:?}", index); + + Ok(index) + } + + pub async fn delete(&self) -> Result<()> { + let url = format!("{}/account", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } + + pub async fn change_password( + &self, + current_password: String, + new_password: String, + ) -> Result<()> { + let url = format!("{}/account/password", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .patch(url) + .json(&ChangePasswordRequest { + current_password, + new_password, + }) + .send() + .await?; + + dbg!(&resp); + + if resp.status() == 401 { + bail!("current password is incorrect") + } else if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } +} diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs new file mode 100644 index 00000000..7faa3802 --- /dev/null +++ b/crates/atuin-client/src/database.rs @@ -0,0 +1,1128 @@ +use std::{ + borrow::Cow, + env, + path::{Path, PathBuf}, + str::FromStr, + time::Duration, +}; + +use async_trait::async_trait; +use atuin_common::utils; +use fs_err as fs; +use itertools::Itertools; +use rand::{distributions::Alphanumeric, Rng}; +use sql_builder::{bind::Bind, esc, quote, SqlBuilder, SqlName}; +use sqlx::{ + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, + Result, Row, +}; +use time::OffsetDateTime; + +use crate::{ + history::{HistoryId, HistoryStats}, + utils::get_host_user, +}; + +use super::{ + history::History, + ordering, + settings::{FilterMode, SearchMode, Settings}, +}; + +pub struct Context { + pub session: String, + pub cwd: String, + pub hostname: String, + pub host_id: String, + pub git_root: Option<PathBuf>, +} + +#[derive(Default, Clone)] +pub struct OptFilters { + pub exit: Option<i64>, + pub exclude_exit: Option<i64>, + pub cwd: Option<String>, + pub exclude_cwd: Option<String>, + pub before: Option<String>, + pub after: Option<String>, + pub limit: Option<i64>, + pub offset: Option<i64>, + pub reverse: bool, +} + +pub fn current_context() -> Context { + let Ok(session) = env::var("ATUIN_SESSION") else { + eprintln!("ERROR: Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell."); + std::process::exit(1); + }; + let hostname = get_host_user(); + let cwd = utils::get_current_dir(); + let host_id = Settings::host_id().expect("failed to load host ID"); + let git_root = utils::in_git_repo(cwd.as_str()); + + Context { + session, + hostname, + cwd, + git_root, + host_id: host_id.0.as_simple().to_string(), + } +} + +#[async_trait] +pub trait Database: Send + Sync + 'static { + async fn save(&self, h: &History) -> Result<()>; + async fn save_bulk(&self, h: &[History]) -> Result<()>; + + async fn load(&self, id: &str) -> Result<Option<History>>; + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>>; + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>; + + async fn update(&self, h: &History) -> Result<()>; + async fn history_count(&self, include_deleted: bool) -> Result<i64>; + + async fn last(&self) -> Result<Option<History>>; + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>; + + async fn delete(&self, h: History) -> Result<()>; + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; + async fn deleted(&self) -> Result<Vec<History>>; + + // Yes I know, it's a lot. + // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. + // Been debating maybe a DSL for search? eg "before:time limit:1 the query" + #[allow(clippy::too_many_arguments)] + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>>; + + async fn query_history(&self, query: &str) -> Result<Vec<History>>; + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; + + async fn stats(&self, h: &History) -> Result<HistoryStats>; +} + +// Intended for use on a developer machine and not a sync server. +// TODO: implement IntoIterator +pub struct Sqlite { + pub pool: SqlitePool, +} + +impl Sqlite { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + debug!("opening sqlite database at {:?}", path); + + let create = !path.exists(); + if create { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir)?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { + sqlx::query( + "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, deleted_at) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + async fn delete_row_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + id: HistoryId, + ) -> Result<()> { + sqlx::query("delete from history where id = ?1") + .bind(id.0.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_history(row: SqliteRow) -> History { + let deleted_at: Option<i64> = row.get("deleted_at"); + + History::from_db() + .id(row.get("id")) + .timestamp( + OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128) + .unwrap(), + ) + .duration(row.get("duration")) + .exit(row.get("exit")) + .command(row.get("command")) + .cwd(row.get("cwd")) + .session(row.get("session")) + .hostname(row.get("hostname")) + .deleted_at( + deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()), + ) + .build() + .into() + } +} + +#[async_trait] +impl Database for Sqlite { + async fn save(&self, h: &History) -> Result<()> { + debug!("saving history to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, h).await?; + tx.commit().await?; + + Ok(()) + } + + async fn save_bulk(&self, h: &[History]) -> Result<()> { + debug!("saving history to sqlite"); + + let mut tx = self.pool.begin().await?; + + for i in h { + Self::save_raw(&mut tx, i).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn load(&self, id: &str) -> Result<Option<History>> { + debug!("loading history item {}", id); + + let res = sqlx::query("select * from history where id = ?1") + .bind(id) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn update(&self, h: &History) -> Result<()> { + debug!("updating sqlite history"); + + sqlx::query( + "update history + set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, deleted_at = ?9 + where id = ?1", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // make a unique list, that only shows the *newest* version of things + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + query.field("*").order_desc("timestamp"); + if !include_deleted { + query.and_where_is_null("deleted_at"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + for filter in filters { + match filter { + FilterMode::Global => &mut query, + FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), + FilterMode::Session => query.and_where_eq("session", quote(&context.session)), + FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => query.and_where_like_left("cwd", &git_root), + }; + } + + if unique { + query.group_by("command").having("max(timestamp)"); + } + + if let Some(max) = max { + query.limit(max); + } + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> { + debug!("listing history from {:?} to {:?}", from, to); + + let res = sqlx::query( + "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", + ) + .bind(from.unix_timestamp_nanos() as i64) + .bind(to.unix_timestamp_nanos() as i64) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self) -> Result<Option<History>> { + let res = sqlx::query( + "select * from history where duration >= 0 order by timestamp desc limit 1", + ) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> { + let res = sqlx::query( + "select * from history where timestamp < ?1 order by timestamp desc limit ?2", + ) + .bind(timestamp.unix_timestamp_nanos() as i64) + .bind(count) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn deleted(&self) -> Result<Vec<History>> { + let res = sqlx::query("select * from history where deleted_at is not null") + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn history_count(&self, include_deleted: bool) -> Result<i64> { + let query = if include_deleted { + "select count(1) from history" + } else { + "select count(1) from history where deleted_at is null" + }; + + let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?; + Ok(res.0) + } + + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>> { + let mut sql = SqlBuilder::select_from("history"); + + sql.group_by("command").having("max(timestamp)"); + + if let Some(limit) = filter_options.limit { + sql.limit(limit); + } + + if let Some(offset) = filter_options.offset { |