summaryrefslogtreecommitdiffstats
path: root/atuin-client/src/database.rs
diff options
context:
space:
mode:
Diffstat (limited to 'atuin-client/src/database.rs')
-rw-r--r--atuin-client/src/database.rs132
1 files changed, 120 insertions, 12 deletions
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs
index 572955f8..2be27ac8 100644
--- a/atuin-client/src/database.rs
+++ b/atuin-client/src/database.rs
@@ -1,4 +1,5 @@
use std::{
+ borrow::Cow,
env,
path::{Path, PathBuf},
str::FromStr,
@@ -9,10 +10,8 @@ use async_trait::async_trait;
use atuin_common::utils;
use fs_err as fs;
use itertools::Itertools;
-use lazy_static::lazy_static;
use rand::{distributions::Alphanumeric, Rng};
-use regex::Regex;
-use sql_builder::{esc, quote, SqlBuilder, SqlName};
+use sql_builder::{bind::Bind, esc, quote, SqlBuilder, SqlName};
use sqlx::{
sqlite::{
SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
@@ -142,6 +141,7 @@ impl Sqlite {
.journal_mode(SqliteJournalMode::Wal)
.optimize_on_close(true, None)
.synchronous(SqliteSynchronous::Normal)
+ .with_regexp()
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
@@ -428,18 +428,42 @@ impl Database for Sqlite {
};
let orig_query = query;
- let query = query.replace('*', "%"); // allow wildcard char
+ let mut regexes = Vec::new();
match search_mode {
- SearchMode::Prefix => sql.and_where_like_left("command", query),
+ SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
_ => {
- // don't recompile the regex on successive calls!
- lazy_static! {
- static ref SPLIT_REGEX: Regex = Regex::new(r" +").unwrap();
- }
-
let mut is_or = false;
- for query_part in SPLIT_REGEX.split(query.as_str()) {
+ let mut regex = None;
+ for part in query.split_inclusive(' ') {
+ let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
+ (None, false) => {
+ if part.trim_end().is_empty() {
+ continue;
+ }
+ Cow::Owned(part.trim_end().replace('*', "%")) // allow wildcard char
+ }
+ (None, true) => {
+ if part[2..].trim_end().ends_with('/') {
+ let end_pos = part.trim_end().len() - 1;
+ regexes.push(String::from(&part[2..end_pos]));
+ } else {
+ regex = Some(String::from(&part[2..]));
+ }
+ continue;
+ }
+ (Some(r), _) => {
+ if part.trim_end().ends_with('/') {
+ let end_pos = part.trim_end().len() - 1;
+ r.push_str(&part.trim_end()[..end_pos]);
+ regexes.push(regex.take().unwrap());
+ } else {
+ r.push_str(part);
+ }
+ continue;
+ }
+ };
+
// TODO smart case mode could be made configurable like in fzf
let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
(true, "*")
@@ -448,7 +472,7 @@ impl Database for Sqlite {
};
let (is_inverse, query_part) = match query_part.strip_prefix('!') {
- Some(stripped) => (true, stripped),
+ Some(stripped) => (true, Cow::Borrowed(stripped)),
None => (false, query_part),
};
@@ -477,10 +501,18 @@ impl Database for Sqlite {
sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
is_or = false;
}
+ if let Some(r) = regex {
+ regexes.push(r);
+ }
+
&mut sql
}
};
+ for regex in regexes {
+ sql.and_where("command regexp ?".bind(&regex));
+ }
+
filter_options
.exit
.map(|exit| sql.and_where_eq("exit", exit));
@@ -825,6 +857,71 @@ mod test {
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
.await
.unwrap();
+
+ // regex
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r/ls / ie$",
+ 1,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r/ls / !ie",
+ 0,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "meow r/ls/",
+ 0,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r//home//",
+ 1,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r//home///",
+ 0,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r/home.*e",
+ 1,
+ )
+ .await
+ .unwrap();
}
#[tokio::test(flavor = "multi_thread")]
@@ -915,6 +1012,17 @@ mod test {
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
.await
.unwrap();
+
+ // regex
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
+ .await
+ .unwrap();
}
#[tokio::test(flavor = "multi_thread")]