summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--beets/dbcore/database.py20
-rwxr-xr-xbeets/dbcore/db.py19
-rw-r--r--beets/dbcore/query.py9
-rw-r--r--beetsplug/bareasc.py7
4 files changed, 42 insertions, 13 deletions
diff --git a/beets/dbcore/database.py b/beets/dbcore/database.py
new file mode 100644
index 000000000..511f9ab99
--- /dev/null
+++ b/beets/dbcore/database.py
@@ -0,0 +1,20 @@
+import inspect
+from typing import Callable, NamedTuple, Optional, Set
+
+
+class FunctionDef(NamedTuple):
+ name: str
+ arg_count: int
+ func: Callable
+
+
+FUNCTIONS_TO_REGISTER: Set[FunctionDef] = set()
+
+
+def add_db_function(name: Optional[str] = None) -> Callable:
+ def add_db_function_to_register(func: Callable) -> Callable:
+ num_args = len(inspect.signature(func).parameters)
+ FUNCTIONS_TO_REGISTER.add(FunctionDef(name or func.__name__, num_args, func))
+ return func
+
+ return add_db_function_to_register
diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py
index 94396f81b..c01be9669 100755
--- a/beets/dbcore/db.py
+++ b/beets/dbcore/db.py
@@ -23,15 +23,15 @@ import threading
import sqlite3
import contextlib
-from unidecode import unidecode
-
import beets
from beets.util import functemplate
from beets.util import py3_path
from beets.dbcore import types
-from .query import MatchQuery, NullSort, TrueQuery, AndQuery
from collections.abc import Mapping
+from .database import FUNCTIONS_TO_REGISTER
+from .query import MatchQuery, NullSort, TrueQuery, AndQuery
+
class DBAccessError(Exception):
"""The SQLite database became inaccessible.
@@ -977,7 +977,9 @@ class Database:
conn = sqlite3.connect(
py3_path(self.path), timeout=self.timeout
)
- self.add_functions(conn)
+
+ for function_def in FUNCTIONS_TO_REGISTER:
+ conn.create_function(*function_def)
if self.supports_extensions:
conn.enable_load_extension(True)
@@ -990,15 +992,6 @@ class Database:
conn.row_factory = sqlite3.Row
return conn
- def add_functions(self, conn):
- def regexp(value, pattern):
- if isinstance(value, bytes):
- value = value.decode()
- return re.search(pattern, str(value)) is not None
-
- conn.create_function("regexp", 2, regexp)
- conn.create_function("unidecode", 1, unidecode)
-
def _close(self):
"""Close the all connections to the underlying SQLite database
from all threads. This does not render the database object
diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py
index 016fe2c1a..166b385cf 100644
--- a/beets/dbcore/query.py
+++ b/beets/dbcore/query.py
@@ -22,6 +22,8 @@ from datetime import datetime, timedelta
import unicodedata
from functools import reduce
+from .database import add_db_function
+
class ParsingError(ValueError):
"""Abstract class for any unparseable user-requested album/query
@@ -231,6 +233,13 @@ class RegexpQuery(StringFieldQuery):
"a regular expression",
format(exc))
+ @staticmethod
+ @add_db_function()
+ def regexp(value, pattern) -> bool:
+ if isinstance(value, bytes):
+ value = value.decode()
+ return re.search(pattern, str(value)) is not None
+
def col_clause(self):
return f" regexp({self.field}, ?)", [self.pattern.pattern]
diff --git a/beetsplug/bareasc.py b/beetsplug/bareasc.py
index 3343786f9..669644894 100644
--- a/beetsplug/bareasc.py
+++ b/beetsplug/bareasc.py
@@ -22,7 +22,9 @@
from beets import ui
from beets.ui import print_, decargs
from beets.plugins import BeetsPlugin
+from beets.dbcore.database import add_db_function
from beets.dbcore.query import StringFieldQuery
+
from unidecode import unidecode
@@ -42,6 +44,11 @@ class BareascQuery(StringFieldQuery):
val = unidecode(val)
return pattern in val
+ @staticmethod
+ @add_db_function(name="unidecode")
+ def _unidecode(value):
+ return unidecode(value)
+
def col_clause(self):
"""Compare ascii version of the pattern."""
clause = f"unidecode({self.field})"