summaryrefslogtreecommitdiffstats
path: root/atuin-client
diff options
context:
space:
mode:
authorEllie Huxtable <e@elm.sh>2021-04-21 18:13:51 +0100
committerEllie Huxtable <e@elm.sh>2021-04-21 21:26:44 +0100
commit4a50ce366639ca9dac7324d6a47d6a0e6c7fccdf (patch)
tree7ffd8848f675e1377f750cc0757768d074a5ac05 /atuin-client
parenta9b117aad7e6bd09c7ea188258924dc02855db05 (diff)
Bugfixes, show time ago, perf improvements
Also allow unique listing and more ergonomic cwd usage
Diffstat (limited to 'atuin-client')
-rw-r--r--atuin-client/Cargo.toml2
-rw-r--r--atuin-client/src/api_client.rs120
-rw-r--r--atuin-client/src/database.rs113
-rw-r--r--atuin-client/src/encryption.rs39
-rw-r--r--atuin-client/src/history.rs2
-rw-r--r--atuin-client/src/settings.rs5
-rw-r--r--atuin-client/src/sync.rs20
7 files changed, 259 insertions, 42 deletions
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml
index 9d639d18..09cf9c47 100644
--- a/atuin-client/Cargo.toml
+++ b/atuin-client/Cargo.toml
@@ -26,7 +26,7 @@ serde = "1.0.125"
serde_json = "1.0.64"
rmp-serde = "0.15.4"
sodiumoxide = "0.2.6"
-reqwest = { version = "0.11", features = ["blocking", "json"] }
+reqwest = { version = "0.11", features = ["blocking", "json", "rustls-tls"], default-features = false }
base64 = "0.13.0"
parse_duration = "2.1.1"
rand = "0.8.3"
diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs
index db2802c3..a8ce7b27 100644
--- a/atuin-client/src/api_client.rs
+++ b/atuin-client/src/api_client.rs
@@ -1,15 +1,24 @@
+use std::collections::HashMap;
+
use chrono::Utc;
-use eyre::Result;
-use reqwest::header::{HeaderMap, AUTHORIZATION};
-use reqwest::Url;
+use eyre::{eyre, Result};
+use reqwest::header::{HeaderMap, AUTHORIZATION, USER_AGENT};
+use reqwest::{StatusCode, Url};
use sodiumoxide::crypto::secretbox;
-use atuin_common::api::{AddHistoryRequest, CountResponse, SyncHistoryResponse};
+use atuin_common::api::{
+ AddHistoryRequest, CountResponse, LoginResponse, RegisterResponse, SyncHistoryResponse,
+};
use atuin_common::utils::hash_str;
-use crate::encryption::decrypt;
+use crate::encryption::{decode_key, decrypt};
use crate::history::History;
+const VERSION: &str = env!("CARGO_PKG_VERSION");
+
+// TODO: remove all references to the encryption key from this
+// It should be handled *elsewhere*
+
pub struct Client<'a> {
sync_addr: &'a str,
token: &'a str,
@@ -17,14 +26,70 @@ pub struct Client<'a> {
client: reqwest::Client,
}
+pub fn register(
+ address: &str,
+ username: &str,
+ email: &str,
+ password: &str,
+) -> Result<RegisterResponse> {
+ let mut map = HashMap::new();
+ map.insert("username", username);
+ map.insert("email", email);
+ map.insert("password", password);
+
+ let url = format!("{}/user/{}", address, username);
+ let resp = reqwest::blocking::get(url)?;
+
+ if resp.status().is_success() {
+ return Err(eyre!("username already in use"));
+ }
+
+ let url = format!("{}/register", address);
+ let client = reqwest::blocking::Client::new();
+ let resp = client
+ .post(url)
+ .header(USER_AGENT, format!("atuin/{}", VERSION))
+ .json(&map)
+ .send()?;
+
+ if !resp.status().is_success() {
+ return Err(eyre!("failed to register user"));
+ }
+
+ let session = resp.json::<RegisterResponse>()?;
+ Ok(session)
+}
+
+pub fn login(address: &str, username: &str, password: &str) -> Result<LoginResponse> {
+ let mut map = HashMap::new();
+ map.insert("username", username);
+ map.insert("password", password);
+
+ let url = format!("{}/login", address);
+ let client = reqwest::blocking::Client::new();
+
+ let resp = client
+ .post(url)
+ .header(USER_AGENT, format!("atuin/{}", VERSION))
+ .json(&map)
+ .send()?;
+
+ if resp.status() != reqwest::StatusCode::OK {
+ return Err(eyre!("invalid login details"));
+ }
+
+ let session = resp.json::<LoginResponse>()?;
+ Ok(session)
+}
+
impl<'a> Client<'a> {
- pub fn new(sync_addr: &'a str, token: &'a str, key: secretbox::Key) -> Self {
- Client {
+ pub fn new(sync_addr: &'a str, token: &'a str, key: String) -> Result<Self> {
+ Ok(Client {
sync_addr,
token,
- key,
+ key: decode_key(key)?,
client: reqwest::Client::new(),
- }
+ })
}
pub async fn count(&self) -> Result<i64> {
@@ -36,7 +101,17 @@ impl<'a> Client<'a> {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, token);
- let resp = self.client.get(url).headers(headers).send().await?;
+ let resp = self
+ .client
+ .get(url)
+ .header(USER_AGENT, format!("atuin/{}", VERSION))
+ .headers(headers)
+ .send()
+ .await?;
+
+ if resp.status() != StatusCode::OK {
+ return Err(eyre!("failed to get count (are you logged in?)"));
+ }
let count = resp.json::<CountResponse>().await?;
@@ -66,6 +141,7 @@ impl<'a> Client<'a> {
.client
.get(url)
.header(AUTHORIZATION, format!("Token {}", self.token))
+ .header(USER_AGENT, format!("atuin/{}", VERSION))
.send()
.await?;
@@ -88,9 +164,33 @@ impl<'a> Client<'a> {
.post(url)
.json(history)
.header(AUTHORIZATION, format!("Token {}", self.token))
+ .header(USER_AGENT, format!("atuin/{}", VERSION))
.send()
.await?;
Ok(())
}
+
+ pub async fn login(&self, username: &str, password: &str) -> Result<LoginResponse> {
+ let mut map = HashMap::new();
+ map.insert("username", username);
+ map.insert("password", password);
+
+ let url = format!("{}/login", self.sync_addr);
+ let resp = self
+ .client
+ .post(url)
+ .json(&map)
+ .header(USER_AGENT, format!("atuin/{}", VERSION))
+ .send()
+ .await?;
+
+ if resp.status() != reqwest::StatusCode::OK {
+ return Err(eyre!("invalid login details"));
+ }
+
+ let session = resp.json::<LoginResponse>().await?;
+
+ Ok(session)
+ }
}
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs
index abc22bb8..0855359b 100644
--- a/atuin-client/src/database.rs
+++ b/atuin-client/src/database.rs
@@ -2,7 +2,7 @@ use chrono::prelude::*;
use chrono::Utc;
use std::path::Path;
-use eyre::Result;
+use eyre::{eyre, Result};
use rusqlite::{params, Connection};
use rusqlite::{Params, Transaction};
@@ -14,7 +14,7 @@ pub trait Database {
fn save_bulk(&mut self, h: &[History]) -> Result<()>;
fn load(&self, id: &str) -> Result<History>;
- fn list(&self) -> Result<Vec<History>>;
+ fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>;
fn range(&self, from: chrono::DateTime<Utc>, to: chrono::DateTime<Utc>)
-> Result<Vec<History>>;
@@ -27,6 +27,8 @@ pub trait Database {
fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>;
fn prefix_search(&self, query: &str) -> Result<Vec<History>>;
+
+ fn search(&self, cwd: Option<String>, exit: Option<i64>, query: &str) -> Result<Vec<History>>;
}
// Intended for use on a developer machine and not a sync server.
@@ -81,6 +83,16 @@ impl Sqlite {
[],
)?;
+ conn.execute(
+ "create index if not exists idx_history_timestamp on history(timestamp)",
+ [],
+ )?;
+
+ conn.execute(
+ "create index if not exists idx_history_command on history(command)",
+ [],
+ )?;
+
Ok(())
}
@@ -136,16 +148,19 @@ impl Database for Sqlite {
}
fn load(&self, id: &str) -> Result<History> {
- debug!("loading history item");
+ debug!("loading history item {}", id);
- let mut stmt = self.conn.prepare(
+ let history = self.query(
"select id, timestamp, duration, exit, command, cwd, session, hostname from history
- where id = ?1",
+ where id = ?1 limit 1",
+ &[id],
)?;
- let history = stmt.query_row(params![id], |row| {
- history_from_sqlite_row(Some(id.to_string()), row)
- })?;
+ if history.is_empty() {
+ return Err(eyre!("could not find history with id {}", id));
+ }
+
+ let history = history[0].clone();
Ok(history)
}
@@ -163,16 +178,39 @@ impl Database for Sqlite {
Ok(())
}
- fn list(&self) -> Result<Vec<History>> {
+ // make a unique list, that only shows the *newest* version of things
+ fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> {
debug!("listing history");
- let mut stmt = self
- .conn
- .prepare("SELECT * FROM history order by timestamp asc")?;
+ // very likely vulnerable to SQL injection
+ // however, this is client side, and only used by the client, on their
+ // own data. They can just open the db file...
+ // otherwise building the query is awkward
+ let query = format!(
+ "select * from history h
+ {}
+ order by timestamp desc
+ {}",
+ // inject the unique check
+ if unique {
+ "where timestamp = (
+ select max(timestamp) from history
+ where h.command = history.command
+ )"
+ } else {
+ ""
+ },
+ // inject the limit
+ if let Some(max) = max {
+ format!("limit {}", max)
+ } else {
+ "".to_string()
+ }
+ );
- let history_iter = stmt.query_map(params![], |row| history_from_sqlite_row(None, row))?;
+ let history = self.query(query.as_str(), params![])?;
- Ok(history_iter.filter_map(Result::ok).collect())
+ Ok(history)
}
fn range(
@@ -207,7 +245,7 @@ impl Database for Sqlite {
fn last(&self) -> Result<History> {
let mut stmt = self
.conn
- .prepare("SELECT * FROM history order by timestamp desc limit 1")?;
+ .prepare("SELECT * FROM history where duration >= 0 order by timestamp desc limit 1")?;
let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?;
@@ -235,9 +273,17 @@ impl Database for Sqlite {
}
fn prefix_search(&self, query: &str) -> Result<Vec<History>> {
+ let query = query.to_string().replace("*", "%"); // allow wildcard char
+
self.query(
- "select * from history where command like ?1 || '%' order by timestamp asc limit 1000",
- &[query],
+ "select * from history h
+ where command like ?1 || '%'
+ and timestamp = (
+ select max(timestamp) from history
+ where h.command = history.command
+ )
+ order by timestamp desc limit 200",
+ &[query.as_str()],
)
}
@@ -248,6 +294,39 @@ impl Database for Sqlite {
Ok(res)
}
+
+ fn search(&self, cwd: Option<String>, exit: Option<i64>, query: &str) -> Result<Vec<History>> {
+ match (cwd, exit) {
+ (Some(cwd), Some(exit)) => self.query(
+ "select * from history
+ where command like ?1 || '%'
+ and cwd = ?2
+ and exit = ?3
+ order by timestamp asc limit 1000",
+ &[query, cwd.as_str(), exit.to_string().as_str()],
+ ),
+ (Some(cwd), None) => self.query(
+ "select * from history
+ where command like ?1 || '%'
+ and cwd = ?2
+ order by timestamp asc limit 1000",
+ &[query, cwd.as_str()],
+ ),
+ (None, Some(exit)) => self.query(
+ "select * from history
+ where command like ?1 || '%'
+ and exit = ?2
+ order by timestamp asc limit 1000",
+ &[query, exit.to_string().as_str()],
+ ),
+ (None, None) => self.query(
+ "select * from history
+ where command like ?1 || '%'
+ order by timestamp asc limit 1000",
+ &[query],
+ ),
+ }
+ }
}
fn history_from_sqlite_row(
diff --git a/atuin-client/src/encryption.rs b/atuin-client/src/encryption.rs
index 37153f94..19b773ab 100644
--- a/atuin-client/src/encryption.rs
+++ b/atuin-client/src/encryption.rs
@@ -29,20 +29,51 @@ pub fn load_key(settings: &Settings) -> Result<secretbox::Key> {
let path = settings.key_path.as_str();
if PathBuf::from(path).exists() {
- let bytes = std::fs::read(path)?;
- let key: secretbox::Key = rmp_serde::from_read_ref(&bytes)?;
+ let key = std::fs::read_to_string(path)?;
+ let key = decode_key(key)?;
Ok(key)
} else {
let key = secretbox::gen_key();
- let buf = rmp_serde::to_vec(&key)?;
+ let encoded = encode_key(key.clone())?;
let mut file = File::create(path)?;
- file.write_all(&buf)?;
+ file.write_all(encoded.as_bytes())?;
Ok(key)
}
}
+pub fn load_encoded_key(settings: &Settings) -> Result<String> {
+ let path = settings.key_path.as_str();
+
+ if PathBuf::from(path).exists() {
+ let key = std::fs::read_to_string(path)?;
+ Ok(key)
+ } else {
+ let key = secretbox::gen_key();
+ let encoded = encode_key(key)?;
+
+ let mut file = File::create(path)?;
+ file.write_all(encoded.as_bytes())?;
+
+ Ok(encoded)
+ }
+}
+
+pub fn encode_key(key: secretbox::Key) -> Result<String> {
+ let buf = rmp_serde::to_vec(&key)?;
+ let buf = base64::encode(buf);
+
+ Ok(buf)
+}
+
+pub fn decode_key(key: String) -> Result<secretbox::Key> {
+ let buf = base64::decode(key)?;
+ let buf: secretbox::Key = rmp_serde::from_read_ref(&buf)?;
+
+ Ok(buf)
+}
+
pub fn encrypt(history: &History, key: &secretbox::Key) -> Result<EncryptedHistory> {
// serialize with msgpack
let buf = rmp_serde::to_vec(history)?;
diff --git a/atuin-client/src/history.rs b/atuin-client/src/history.rs
index 7f607784..8dd161db 100644
--- a/atuin-client/src/history.rs
+++ b/atuin-client/src/history.rs
@@ -6,7 +6,7 @@ use chrono::Utc;
use atuin_common::utils::uuid_v4;
// Any new fields MUST be Optional<>!
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, Ord, PartialOrd)]
pub struct History {
pub id: String,
pub timestamp: chrono::DateTime<Utc>,
diff --git a/atuin-client/src/settings.rs b/atuin-client/src/settings.rs
index e28963c0..254bca6d 100644
--- a/atuin-client/src/settings.rs
+++ b/atuin-client/src/settings.rs
@@ -78,15 +78,16 @@ impl Settings {
create_dir_all(config_dir)?;
- let config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG") {
+ let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") {
PathBuf::from(p)
} else {
let mut config_file = PathBuf::new();
config_file.push(config_dir);
- config_file.push("config.toml");
config_file
};
+ config_file.push("config.toml");
+
let mut s = Config::new();
let db_path = ProjectDirs::from("com", "elliehuxtable", "atuin")
diff --git a/atuin-client/src/sync.rs b/atuin-client/src/sync.rs
index 0ca8d3a6..5d81a5e6 100644
--- a/atuin-client/src/sync.rs
+++ b/atuin-client/src/sync.rs
@@ -7,7 +7,7 @@ use atuin_common::{api::AddHistoryRequest, utils::hash_str};
use crate::api_client;
use crate::database::Database;
-use crate::encryption::{encrypt, load_key};
+use crate::encryption::{encrypt, load_encoded_key, load_key};
use crate::settings::{Settings, HISTORY_PAGE_SIZE};
// Currently sync is kinda naive, and basically just pages backwards through
@@ -26,6 +26,8 @@ async fn sync_download(
client: &api_client::Client<'_>,
db: &mut (impl Database + Send),
) -> Result<(i64, i64)> {
+ debug!("starting sync download");
+
let remote_count = client.count().await?;
let initial_local = db.history_count()?;
@@ -46,14 +48,14 @@ async fn sync_download(
.get_history(last_sync, last_timestamp, host.clone())
.await?;
- if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() {
- break;
- }
-
db.save_bulk(&page)?;
local_count = db.history_count()?;
+ if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() {
+ break;
+ }
+
let page_last = page
.last()
.expect("could not get last element of page")
@@ -80,11 +82,15 @@ async fn sync_upload(
client: &api_client::Client<'_>,
db: &mut (impl Database + Send),
) -> Result<()> {
+ debug!("starting sync upload");
+
let initial_remote_count = client.count().await?;
let mut remote_count = initial_remote_count;
let local_count = db.history_count()?;
+ debug!("remote has {}, we have {}", remote_count, local_count);
+
let key = load_key(settings)?; // encryption key
// first just try the most recent set
@@ -127,8 +133,8 @@ pub async fn sync(settings: &Settings, force: bool, db: &mut (impl Database + Se
let client = api_client::Client::new(
settings.sync_address.as_str(),
settings.session_token.as_str(),
- load_key(settings)?,
- );
+ load_encoded_key(settings)?,
+ )?;
sync_upload(settings, force, &client, db).await?;