diff options
author | Šarūnas Nejus <snejus@protonmail.com> | 2023-04-08 05:12:21 +0100 |
---|---|---|
committer | Šarūnas Nejus <snejus@protonmail.com> | 2023-04-09 20:25:33 +0100 |
commit | 2a6fd33d56895fa52df464e2849db8084c148590 (patch) | |
tree | 4e61666ad67f6a1010f8bf548441ae90693178da | |
parent | d05c34ec4e9d8fa9e34706ac2279b7315b36e1e0 (diff) |
Add the ability to register arbitrary db functionsgeneric-db-functions
Add a decorator which records the function it wraps.
In the DB initialization stage, define each recorded function as an
SQLite function.
This way, the SQLite function definition stays together with the SQL
query that uses it. It also allows flexibility, since 'add_db_function'
can be imported from anywhere.
Its implementation had to be moved away from `beets/dbcore/db.py`
because it would have caused circular import issues between `db.py` and
`query.py`.
See https://github.com/beetbox/beets/pull/4741 for the context.
-rw-r--r-- | beets/dbcore/database.py | 20 | ||||
-rwxr-xr-x | beets/dbcore/db.py | 19 | ||||
-rw-r--r-- | beets/dbcore/query.py | 9 | ||||
-rw-r--r-- | beetsplug/bareasc.py | 7 |
4 files changed, 42 insertions, 13 deletions
diff --git a/beets/dbcore/database.py b/beets/dbcore/database.py new file mode 100644 index 000000000..511f9ab99 --- /dev/null +++ b/beets/dbcore/database.py @@ -0,0 +1,20 @@ +import inspect +from typing import Callable, NamedTuple, Optional, Set + + +class FunctionDef(NamedTuple): + name: str + arg_count: int + func: Callable + + +FUNCTIONS_TO_REGISTER: Set[FunctionDef] = set() + + +def add_db_function(name: Optional[str] = None) -> Callable: + def add_db_function_to_register(func: Callable) -> Callable: + num_args = len(inspect.signature(func).parameters) + FUNCTIONS_TO_REGISTER.add(FunctionDef(name or func.__name__, num_args, func)) + return func + + return add_db_function_to_register diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 94396f81b..c01be9669 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -23,15 +23,15 @@ import threading import sqlite3 import contextlib -from unidecode import unidecode - import beets from beets.util import functemplate from beets.util import py3_path from beets.dbcore import types -from .query import MatchQuery, NullSort, TrueQuery, AndQuery from collections.abc import Mapping +from .database import FUNCTIONS_TO_REGISTER +from .query import MatchQuery, NullSort, TrueQuery, AndQuery + class DBAccessError(Exception): """The SQLite database became inaccessible. @@ -977,7 +977,9 @@ class Database: conn = sqlite3.connect( py3_path(self.path), timeout=self.timeout ) - self.add_functions(conn) + + for function_def in FUNCTIONS_TO_REGISTER: + conn.create_function(*function_def) if self.supports_extensions: conn.enable_load_extension(True) @@ -990,15 +992,6 @@ class Database: conn.row_factory = sqlite3.Row return conn - def add_functions(self, conn): - def regexp(value, pattern): - if isinstance(value, bytes): - value = value.decode() - return re.search(pattern, str(value)) is not None - - conn.create_function("regexp", 2, regexp) - conn.create_function("unidecode", 1, unidecode) - def _close(self): """Close the all connections to the underlying SQLite database from all threads. This does not render the database object diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 016fe2c1a..166b385cf 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -22,6 +22,8 @@ from datetime import datetime, timedelta import unicodedata from functools import reduce +from .database import add_db_function + class ParsingError(ValueError): """Abstract class for any unparseable user-requested album/query @@ -231,6 +233,13 @@ class RegexpQuery(StringFieldQuery): "a regular expression", format(exc)) + @staticmethod + @add_db_function() + def regexp(value, pattern) -> bool: + if isinstance(value, bytes): + value = value.decode() + return re.search(pattern, str(value)) is not None + def col_clause(self): return f" regexp({self.field}, ?)", [self.pattern.pattern] diff --git a/beetsplug/bareasc.py b/beetsplug/bareasc.py index 3343786f9..669644894 100644 --- a/beetsplug/bareasc.py +++ b/beetsplug/bareasc.py @@ -22,7 +22,9 @@ from beets import ui from beets.ui import print_, decargs from beets.plugins import BeetsPlugin +from beets.dbcore.database import add_db_function from beets.dbcore.query import StringFieldQuery + from unidecode import unidecode @@ -42,6 +44,11 @@ class BareascQuery(StringFieldQuery): val = unidecode(val) return pattern in val + @staticmethod + @add_db_function(name="unidecode") + def _unidecode(value): + return unidecode(value) + def col_clause(self): """Compare ascii version of the pattern.""" clause = f"unidecode({self.field})" |