summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorŠarūnas Nejus <snejus@protonmail.com>2023-04-08 05:12:21 +0100
committerŠarūnas Nejus <snejus@protonmail.com>2023-04-09 20:25:33 +0100
commit2a6fd33d56895fa52df464e2849db8084c148590 (patch)
tree4e61666ad67f6a1010f8bf548441ae90693178da
parentd05c34ec4e9d8fa9e34706ac2279b7315b36e1e0 (diff)
Add the ability to register arbitrary db functionsgeneric-db-functions
Add a decorator which records the function it wraps. In the DB initialization stage, define each recorded function as an SQLite function. This way, the SQLite function definition stays together with the SQL query that uses it. It also allows flexibility, since 'add_db_function' can be imported from anywhere. Its implementation had to be moved away from `beets/dbcore/db.py` because it would have caused circular import issues between `db.py` and `query.py`. See https://github.com/beetbox/beets/pull/4741 for the context.
-rw-r--r--beets/dbcore/database.py20
-rwxr-xr-xbeets/dbcore/db.py19
-rw-r--r--beets/dbcore/query.py9
-rw-r--r--beetsplug/bareasc.py7
4 files changed, 42 insertions, 13 deletions
diff --git a/beets/dbcore/database.py b/beets/dbcore/database.py
new file mode 100644
index 000000000..511f9ab99
--- /dev/null
+++ b/beets/dbcore/database.py
@@ -0,0 +1,20 @@
+import inspect
+from typing import Callable, NamedTuple, Optional, Set
+
+
+class FunctionDef(NamedTuple):
+ name: str
+ arg_count: int
+ func: Callable
+
+
+FUNCTIONS_TO_REGISTER: Set[FunctionDef] = set()
+
+
+def add_db_function(name: Optional[str] = None) -> Callable:
+ def add_db_function_to_register(func: Callable) -> Callable:
+ num_args = len(inspect.signature(func).parameters)
+ FUNCTIONS_TO_REGISTER.add(FunctionDef(name or func.__name__, num_args, func))
+ return func
+
+ return add_db_function_to_register
diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py
index 94396f81b..c01be9669 100755
--- a/beets/dbcore/db.py
+++ b/beets/dbcore/db.py
@@ -23,15 +23,15 @@ import threading
import sqlite3
import contextlib
-from unidecode import unidecode
-
import beets
from beets.util import functemplate
from beets.util import py3_path
from beets.dbcore import types
-from .query import MatchQuery, NullSort, TrueQuery, AndQuery
from collections.abc import Mapping
+from .database import FUNCTIONS_TO_REGISTER
+from .query import MatchQuery, NullSort, TrueQuery, AndQuery
+
class DBAccessError(Exception):
"""The SQLite database became inaccessible.
@@ -977,7 +977,9 @@ class Database:
conn = sqlite3.connect(
py3_path(self.path), timeout=self.timeout
)
- self.add_functions(conn)
+
+ for function_def in FUNCTIONS_TO_REGISTER:
+ conn.create_function(*function_def)
if self.supports_extensions:
conn.enable_load_extension(True)
@@ -990,15 +992,6 @@ class Database:
conn.row_factory = sqlite3.Row
return conn
- def add_functions(self, conn):
- def regexp(value, pattern):
- if isinstance(value, bytes):
- value = value.decode()
- return re.search(pattern, str(value)) is not None
-
- conn.create_function("regexp", 2, regexp)
- conn.create_function("unidecode", 1, unidecode)
-
def _close(self):
"""Close the all connections to the underlying SQLite database
from all threads. This does not render the database object
diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py
index 016fe2c1a..166b385cf 100644
--- a/beets/dbcore/query.py
+++ b/beets/dbcore/query.py
@@ -22,6 +22,8 @@ from datetime import datetime, timedelta
import unicodedata
from functools import reduce
+from .database import add_db_function
+
class ParsingError(ValueError):
"""Abstract class for any unparseable user-requested album/query
@@ -231,6 +233,13 @@ class RegexpQuery(StringFieldQuery):
"a regular expression",
format(exc))
+ @staticmethod
+ @add_db_function()
+ def regexp(value, pattern) -> bool:
+ if isinstance(value, bytes):
+ value = value.decode()
+ return re.search(pattern, str(value)) is not None
+
def col_clause(self):
return f" regexp({self.field}, ?)", [self.pattern.pattern]
diff --git a/beetsplug/bareasc.py b/beetsplug/bareasc.py
index 3343786f9..669644894 100644
--- a/beetsplug/bareasc.py
+++ b/beetsplug/bareasc.py
@@ -22,7 +22,9 @@
from beets import ui
from beets.ui import print_, decargs
from beets.plugins import BeetsPlugin
+from beets.dbcore.database import add_db_function
from beets.dbcore.query import StringFieldQuery
+
from unidecode import unidecode
@@ -42,6 +44,11 @@ class BareascQuery(StringFieldQuery):
val = unidecode(val)
return pattern in val
+ @staticmethod
+ @add_db_function(name="unidecode")
+ def _unidecode(value):
+ return unidecode(value)
+
def col_clause(self):
"""Compare ascii version of the pattern."""
clause = f"unidecode({self.field})"