summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
author依云 <lilydjwg@gmail.com>2024-03-01 21:21:53 +0800
committerGitHub <noreply@github.com>2024-03-01 13:21:53 +0000
commitaec5df4123823fdaf8db3d714d0d826ac04ceca4 (patch)
treefdac93f22c47cc533a1288eabd97525f64f20716
parent897af9a326960a67fedd68ae85ed8aae0b19db97 (diff)
feat: support regex with r/.../ syntax (#1745)
* feat: support regex with r/.../ syntax * cargo fmt * feat(tests): add some tests for regex matching
-rw-r--r--Cargo.lock2
-rw-r--r--atuin-client/Cargo.toml3
-rw-r--r--atuin-client/src/database.rs132
-rw-r--r--atuin/src/command/client/search/engines/db.rs6
4 files changed, 125 insertions, 18 deletions
diff --git a/Cargo.lock b/Cargo.lock
index ab0eaa31..50ffc518 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -245,7 +245,6 @@ dependencies = [
"indicatif",
"interim",
"itertools",
- "lazy_static",
"log",
"memchr",
"minspan",
@@ -3510,6 +3509,7 @@ dependencies = [
"libsqlite3-sys",
"log",
"percent-encoding",
+ "regex",
"serde",
"sqlx-core",
"time",
diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml
index 51227044..e2353daf 100644
--- a/atuin-client/Cargo.toml
+++ b/atuin-client/Cargo.toml
@@ -37,13 +37,12 @@ async-trait = { workspace = true }
itertools = { workspace = true }
rand = { workspace = true }
shellexpand = "3"
-sqlx = { workspace = true, features = ["sqlite"] }
+sqlx = { workspace = true, features = ["sqlite", "regexp"] }
minspan = "0.1.1"
regex = "1.9.1"
serde_regex = "1.1.0"
fs-err = { workspace = true }
sql-builder = "3"
-lazy_static = "1"
memchr = "2.5"
rmp = { version = "0.8.11" }
typed-builder = { workspace = true }
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs
index 572955f8..2be27ac8 100644
--- a/atuin-client/src/database.rs
+++ b/atuin-client/src/database.rs
@@ -1,4 +1,5 @@
use std::{
+ borrow::Cow,
env,
path::{Path, PathBuf},
str::FromStr,
@@ -9,10 +10,8 @@ use async_trait::async_trait;
use atuin_common::utils;
use fs_err as fs;
use itertools::Itertools;
-use lazy_static::lazy_static;
use rand::{distributions::Alphanumeric, Rng};
-use regex::Regex;
-use sql_builder::{esc, quote, SqlBuilder, SqlName};
+use sql_builder::{bind::Bind, esc, quote, SqlBuilder, SqlName};
use sqlx::{
sqlite::{
SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
@@ -142,6 +141,7 @@ impl Sqlite {
.journal_mode(SqliteJournalMode::Wal)
.optimize_on_close(true, None)
.synchronous(SqliteSynchronous::Normal)
+ .with_regexp()
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
@@ -428,18 +428,42 @@ impl Database for Sqlite {
};
let orig_query = query;
- let query = query.replace('*', "%"); // allow wildcard char
+ let mut regexes = Vec::new();
match search_mode {
- SearchMode::Prefix => sql.and_where_like_left("command", query),
+ SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
_ => {
- // don't recompile the regex on successive calls!
- lazy_static! {
- static ref SPLIT_REGEX: Regex = Regex::new(r" +").unwrap();
- }
-
let mut is_or = false;
- for query_part in SPLIT_REGEX.split(query.as_str()) {
+ let mut regex = None;
+ for part in query.split_inclusive(' ') {
+ let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
+ (None, false) => {
+ if part.trim_end().is_empty() {
+ continue;
+ }
+ Cow::Owned(part.trim_end().replace('*', "%")) // allow wildcard char
+ }
+ (None, true) => {
+ if part[2..].trim_end().ends_with('/') {
+ let end_pos = part.trim_end().len() - 1;
+ regexes.push(String::from(&part[2..end_pos]));
+ } else {
+ regex = Some(String::from(&part[2..]));
+ }
+ continue;
+ }
+ (Some(r), _) => {
+ if part.trim_end().ends_with('/') {
+ let end_pos = part.trim_end().len() - 1;
+ r.push_str(&part.trim_end()[..end_pos]);
+ regexes.push(regex.take().unwrap());
+ } else {
+ r.push_str(part);
+ }
+ continue;
+ }
+ };
+
// TODO smart case mode could be made configurable like in fzf
let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
(true, "*")
@@ -448,7 +472,7 @@ impl Database for Sqlite {
};
let (is_inverse, query_part) = match query_part.strip_prefix('!') {
- Some(stripped) => (true, stripped),
+ Some(stripped) => (true, Cow::Borrowed(stripped)),
None => (false, query_part),
};
@@ -477,10 +501,18 @@ impl Database for Sqlite {
sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
is_or = false;
}
+ if let Some(r) = regex {
+ regexes.push(r);
+ }
+
&mut sql
}
};
+ for regex in regexes {
+ sql.and_where("command regexp ?".bind(&regex));
+ }
+
filter_options
.exit
.map(|exit| sql.and_where_eq("exit", exit));
@@ -825,6 +857,71 @@ mod test {
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
.await
.unwrap();
+
+ // regex
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r/ls / ie$",
+ 1,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r/ls / !ie",
+ 0,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "meow r/ls/",
+ 0,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r//home//",
+ 1,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r//home///",
+ 0,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::FullText,
+ FilterMode::Global,
+ "r/home.*e",
+ 1,
+ )
+ .await
+ .unwrap();
}
#[tokio::test(flavor = "multi_thread")]
@@ -915,6 +1012,17 @@ mod test {
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
.await
.unwrap();
+
+ // regex
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
+ .await
+ .unwrap();
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
+ .await
+ .unwrap();
}
#[tokio::test(flavor = "multi_thread")]
diff --git a/atuin/src/command/client/search/engines/db.rs b/atuin/src/command/client/search/engines/db.rs
index b4f24561..e638f9d9 100644
--- a/atuin/src/command/client/search/engines/db.rs
+++ b/atuin/src/command/client/search/engines/db.rs
@@ -26,8 +26,8 @@ impl SearchEngine for Search {
..Default::default()
},
)
- .await?
- .into_iter()
- .collect::<Vec<_>>())
+ .await
+ // ignore errors as it may be caused by incomplete regex
+ .map_or(Vec::new(), |r| r.into_iter().collect()))
}
}