summaryrefslogtreecommitdiffstats
path: root/atuin-client/src/database.rs
diff options
context:
space:
mode:
authorPatrick <pmarschik@users.noreply.github.com>2022-03-18 12:37:27 +0100
committerGitHub <noreply@github.com>2022-03-18 11:37:27 +0000
commitfae118a46ba23da5aed9f4436e16ba7677ecbb84 (patch)
tree5bb01db0358ac1d6fbba74d79b372b06019c0f2e /atuin-client/src/database.rs
parent7cde55a7514c8bce3379847bbcad84bd83cfa42b (diff)
Improve fuzzy search (#279)
* Add SearchMode fzf. Add a new search mode "fzf" that tries to mimic the search syntax of https://github.com/junegunn/fzf#search-syntax This search mode splits the query into terms where each term is matched individually. Terms can have operators like prefix, suffix, exact match only and can be inverted. Additionally, smart-case matching is performed: if a term contains a non-lowercase letter the match will be case-sensitive. * PR feedback. - Use SearchMode::Fuzzy instead of SearchMode::Fzf - update docs - re-order tests so previous fuzzy tests come first, add more tests for each operator * PR comments: remove named arguments, match expression * PR comments: macro -> async func
Diffstat (limited to 'atuin-client/src/database.rs')
-rw-r--r--atuin-client/src/database.rs260
1 files changed, 197 insertions, 63 deletions
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs
index 3cb89472f..3ca617175 100644
--- a/atuin-client/src/database.rs
+++ b/atuin-client/src/database.rs
@@ -7,6 +7,7 @@ use chrono::Utc;
use eyre::Result;
use itertools::Itertools;
+use regex::Regex;
use sqlx::sqlite::{
SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
@@ -286,27 +287,89 @@ impl Database for Sqlite {
let query = query.to_string().replace('*', "%"); // allow wildcard char
let limit = limit.map_or("".to_owned(), |l| format!("limit {}", l));
- let query = match search_mode {
- SearchMode::Prefix => query,
- SearchMode::FullText => format!("%{}", query),
- SearchMode::Fuzzy => query.split("").join("%"),
+ let (query_sql, query_params) = match search_mode {
+ SearchMode::Prefix => ("command like ?1".to_string(), vec![format!("{}%", query)]),
+ SearchMode::FullText => ("command like ?1".to_string(), vec![format!("%{}%", query)]),
+ SearchMode::Fuzzy => {
+ let split_regex = Regex::new(r" +").unwrap();
+ let terms: Vec<&str> = split_regex.split(query.as_str()).collect();
+ let mut query_sql = std::string::String::new();
+ let mut query_params = Vec::with_capacity(terms.len());
+ let mut was_or = false;
+ for (i, query_part) in terms.into_iter().enumerate() {
+ // TODO smart case mode could be made configurable like in fzf
+ let (operator, glob) = if query_part.contains(char::is_uppercase) {
+ ("glob", '*')
+ } else {
+ ("like", '%')
+ };
+ let (is_inverse, query_part) = match query_part.strip_prefix('!') {
+ Some(stripped) => (true, stripped),
+ None => (false, query_part),
+ };
+ match query_part {
+ "|" => {
+ if !was_or {
+ query_sql.push_str(" OR ");
+ was_or = true;
+ continue;
+ } else {
+ query_params.push(format!("{glob}|{glob}"));
+ }
+ }
+ exact_prefix if query_part.starts_with('^') => query_params.push(format!(
+ "{term}{glob}",
+ term = exact_prefix.strip_prefix('^').unwrap()
+ )),
+ exact_suffix if query_part.ends_with('$') => query_params.push(format!(
+ "{glob}{term}",
+ term = exact_suffix.strip_suffix('$').unwrap()
+ )),
+ exact if query_part.starts_with('\'') => query_params.push(format!(
+ "{glob}{term}{glob}",
+ term = exact.strip_prefix('\'').unwrap()
+ )),
+ exact if is_inverse => {
+ query_params.push(format!("{glob}{term}{glob}", term = exact))
+ }
+ _ => {
+ query_params.push(query_part.split("").join(glob.to_string().as_str()))
+ }
+ }
+ if i > 0 && !was_or {
+ query_sql.push_str(" AND ");
+ }
+ if is_inverse {
+ query_sql.push_str("NOT ");
+ }
+ query_sql
+ .push_str(format!("command {} ?{}", operator, query_params.len()).as_str());
+ was_or = false;
+ }
+ (query_sql, query_params)
+ }
};
- let res = sqlx::query(
- format!(
- "select * from history h
- where command like ?1 || '%'
- group by command
- having max(timestamp)
- order by timestamp desc {}",
- limit.clone()
+ let res = query_params
+ .iter()
+ .fold(
+ sqlx::query(
+ format!(
+ "select * from history h
+ where {}
+ group by command
+ having max(timestamp)
+ order by timestamp desc {}",
+ query_sql.as_str(),
+ limit.clone()
+ )
+ .as_str(),
+ ),
+ |query, query_param| query.bind(query_param),
)
- .as_str(),
- )
- .bind(query)
- .map(Self::query_history)
- .fetch_all(&self.pool)
- .await?;
+ .map(Self::query_history)
+ .fetch_all(&self.pool)
+ .await?;
Ok(ordering::reorder_fuzzy(search_mode, orig_query, res))
}
@@ -326,6 +389,36 @@ mod test {
use super::*;
use std::time::{Duration, Instant};
+ async fn assert_search_eq<'a>(
+ db: &impl Database,
+ mode: SearchMode,
+ query: &str,
+ expected: usize,
+ ) -> Result<Vec<History>> {
+ let results = db.search(None, mode, query).await?;
+ assert_eq!(
+ results.len(),
+ expected,
+ "query \"{}\", commands: {:?}",
+ query,
+ results.iter().map(|a| &a.command).collect::<Vec<&String>>()
+ );
+ Ok(results)
+ }
+
+ async fn assert_search_commands(
+ db: &impl Database,
+ mode: SearchMode,
+ query: &str,
+ expected_commands: Vec<&str>,
+ ) {
+ let results = assert_search_eq(db, mode, query, expected_commands.len())
+ .await
+ .unwrap();
+ let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
+ assert_eq!(commands, expected_commands);
+ }
+
async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> {
let history = History::new(
chrono::Utc::now(),
@@ -344,14 +437,15 @@ mod test {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
new_history_item(&mut db, "ls /home/ellie").await.unwrap();
- let mut results = db.search(None, SearchMode::Prefix, "ls").await.unwrap();
- assert_eq!(results.len(), 1);
-
- results = db.search(None, SearchMode::Prefix, "/home").await.unwrap();
- assert_eq!(results.len(), 0);
-
- results = db.search(None, SearchMode::Prefix, "ls ").await.unwrap();
- assert_eq!(results.len(), 0);
+ assert_search_eq(&db, SearchMode::Prefix, "ls", 1)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Prefix, "/home", 0)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Prefix, "ls ", 0)
+ .await
+ .unwrap();
}
#[tokio::test(flavor = "multi_thread")]
@@ -359,17 +453,15 @@ mod test {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
new_history_item(&mut db, "ls /home/ellie").await.unwrap();
- let mut results = db.search(None, SearchMode::FullText, "ls").await.unwrap();
- assert_eq!(results.len(), 1);
-
- results = db
- .search(None, SearchMode::FullText, "/home")
+ assert_search_eq(&db, SearchMode::FullText, "ls", 1)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::FullText, "/home", 1)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::FullText, "ls ", 0)
.await
.unwrap();
- assert_eq!(results.len(), 1);
-
- results = db.search(None, SearchMode::FullText, "ls ").await.unwrap();
- assert_eq!(results.len(), 0);
}
#[tokio::test(flavor = "multi_thread")]
@@ -377,34 +469,77 @@ mod test {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
new_history_item(&mut db, "ls /home/ellie").await.unwrap();
new_history_item(&mut db, "ls /home/frank").await.unwrap();
- new_history_item(&mut db, "cd /home/ellie").await.unwrap();
+ new_history_item(&mut db, "cd /home/Ellie").await.unwrap();
new_history_item(&mut db, "/home/ellie/.bin/rustup")
.await
.unwrap();
- let mut results = db.search(None, SearchMode::Fuzzy, "ls /").await.unwrap();
- assert_eq!(results.len(), 2);
-
- results = db.search(None, SearchMode::Fuzzy, "l/h/").await.unwrap();
- assert_eq!(results.len(), 2);
-
- results = db.search(None, SearchMode::Fuzzy, "/h/e").await.unwrap();
- assert_eq!(results.len(), 3);
-
- results = db.search(None, SearchMode::Fuzzy, "/hmoe/").await.unwrap();
- assert_eq!(results.len(), 0);
+ assert_search_eq(&db, SearchMode::Fuzzy, "ls /", 3)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "ls/", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "l/h/", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "/h/e", 3)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "/hmoe/", 0)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "ellie/home", 0)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "lsellie", 1)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, " ", 4)
+ .await
+ .unwrap();
- results = db
- .search(None, SearchMode::Fuzzy, "ellie/home")
+ // single term operators
+ assert_search_eq(&db, SearchMode::Fuzzy, "^ls", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "'ls", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "ellie$", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "!^ls", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "!ellie", 1)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "!ellie$", 2)
.await
.unwrap();
- assert_eq!(results.len(), 0);
- results = db.search(None, SearchMode::Fuzzy, "lsellie").await.unwrap();
- assert_eq!(results.len(), 1);
+ // multiple terms
+ assert_search_eq(&db, SearchMode::Fuzzy, "ls !ellie", 1)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "^ls !e$", 1)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "home !^ls", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "'frank | 'rustup", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "'frank | 'rustup 'ls", 1)
+ .await
+ .unwrap();
- results = db.search(None, SearchMode::Fuzzy, " ").await.unwrap();
- assert_eq!(results.len(), 3);
+ // case matching
+ assert_search_eq(&db, SearchMode::Fuzzy, "Ellie", 1)
+ .await
+ .unwrap();
}
#[tokio::test(flavor = "multi_thread")]
@@ -414,17 +549,16 @@ mod test {
new_history_item(&mut db, "curl").await.unwrap();
new_history_item(&mut db, "corburl").await.unwrap();
- // if fuzzy reordering is on, it should come back in a more sensible order
- let mut results = db.search(None, SearchMode::Fuzzy, "curl").await.unwrap();
- assert_eq!(results.len(), 2);
- let commands: Vec<&String> = results.iter().map(|a| &a.command).collect();
- assert_eq!(commands, vec!["curl", "corburl"]);
- results = db.search(None, SearchMode::Fuzzy, "xxxx").await.unwrap();
- assert_eq!(results.len(), 0);
+ // if fuzzy reordering is on, it should come back in a more sensible order
+ assert_search_commands(&db, SearchMode::Fuzzy, "curl", vec!["curl", "corburl"]).await;
- results = db.search(None, SearchMode::Fuzzy, "").await.unwrap();
- assert_eq!(results.len(), 2);
+ assert_search_eq(&db, SearchMode::Fuzzy, "xxxx", 0)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, "", 2)
+ .await
+ .unwrap();
}
#[tokio::test(flavor = "multi_thread")]