diff options
-rw-r--r-- | beets/dbcore/database.py | 20 | ||||
-rwxr-xr-x | beets/dbcore/db.py | 19 | ||||
-rw-r--r-- | beets/dbcore/query.py | 9 | ||||
-rw-r--r-- | beetsplug/bareasc.py | 7 |
4 files changed, 42 insertions, 13 deletions
diff --git a/beets/dbcore/database.py b/beets/dbcore/database.py new file mode 100644 index 000000000..511f9ab99 --- /dev/null +++ b/beets/dbcore/database.py @@ -0,0 +1,20 @@ +import inspect +from typing import Callable, NamedTuple, Optional, Set + + +class FunctionDef(NamedTuple): + name: str + arg_count: int + func: Callable + + +FUNCTIONS_TO_REGISTER: Set[FunctionDef] = set() + + +def add_db_function(name: Optional[str] = None) -> Callable: + def add_db_function_to_register(func: Callable) -> Callable: + num_args = len(inspect.signature(func).parameters) + FUNCTIONS_TO_REGISTER.add(FunctionDef(name or func.__name__, num_args, func)) + return func + + return add_db_function_to_register diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 94396f81b..c01be9669 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -23,15 +23,15 @@ import threading import sqlite3 import contextlib -from unidecode import unidecode - import beets from beets.util import functemplate from beets.util import py3_path from beets.dbcore import types -from .query import MatchQuery, NullSort, TrueQuery, AndQuery from collections.abc import Mapping +from .database import FUNCTIONS_TO_REGISTER +from .query import MatchQuery, NullSort, TrueQuery, AndQuery + class DBAccessError(Exception): """The SQLite database became inaccessible. @@ -977,7 +977,9 @@ class Database: conn = sqlite3.connect( py3_path(self.path), timeout=self.timeout ) - self.add_functions(conn) + + for function_def in FUNCTIONS_TO_REGISTER: + conn.create_function(*function_def) if self.supports_extensions: conn.enable_load_extension(True) @@ -990,15 +992,6 @@ class Database: conn.row_factory = sqlite3.Row return conn - def add_functions(self, conn): - def regexp(value, pattern): - if isinstance(value, bytes): - value = value.decode() - return re.search(pattern, str(value)) is not None - - conn.create_function("regexp", 2, regexp) - conn.create_function("unidecode", 1, unidecode) - def _close(self): """Close the all connections to the underlying SQLite database from all threads. This does not render the database object diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 016fe2c1a..166b385cf 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -22,6 +22,8 @@ from datetime import datetime, timedelta import unicodedata from functools import reduce +from .database import add_db_function + class ParsingError(ValueError): """Abstract class for any unparseable user-requested album/query @@ -231,6 +233,13 @@ class RegexpQuery(StringFieldQuery): "a regular expression", format(exc)) + @staticmethod + @add_db_function() + def regexp(value, pattern) -> bool: + if isinstance(value, bytes): + value = value.decode() + return re.search(pattern, str(value)) is not None + def col_clause(self): return f" regexp({self.field}, ?)", [self.pattern.pattern] diff --git a/beetsplug/bareasc.py b/beetsplug/bareasc.py index 3343786f9..669644894 100644 --- a/beetsplug/bareasc.py +++ b/beetsplug/bareasc.py @@ -22,7 +22,9 @@ from beets import ui from beets.ui import print_, decargs from beets.plugins import BeetsPlugin +from beets.dbcore.database import add_db_function from beets.dbcore.query import StringFieldQuery + from unidecode import unidecode @@ -42,6 +44,11 @@ class BareascQuery(StringFieldQuery): val = unidecode(val) return pattern in val + @staticmethod + @add_db_function(name="unidecode") + def _unidecode(value): + return unidecode(value) + def col_clause(self): """Compare ascii version of the pattern.""" clause = f"unidecode({self.field})" |