diff options
author | wisp3rwind <17089248+wisp3rwind@users.noreply.github.com> | 2023-07-22 13:27:53 +0200 |
---|---|---|
committer | wisp3rwind <17089248+wisp3rwind@users.noreply.github.com> | 2023-11-05 08:25:39 +0100 |
commit | 05383a0dab8ce17683226f7adc3574ac1610cb5c (patch) | |
tree | 7b347fb2c5a259d83d221b005f53757ef4fbd3b4 | |
parent | 1fdf075402415fd0b58776a7de760a161b684a60 (diff) |
replaygain: typings
also, minor clean-up (remove unused function after_version, f-string
conversion)
-rw-r--r-- | beets/dbcore/__init__.py | 2 | ||||
-rw-r--r-- | beets/library.py | 6 | ||||
-rw-r--r-- | beets/ui/__init__.py | 3 | ||||
-rw-r--r-- | beets/util/__init__.py | 7 | ||||
-rw-r--r-- | beetsplug/replaygain.py | 304 |
5 files changed, 207 insertions, 115 deletions
diff --git a/beets/dbcore/__init__.py b/beets/dbcore/__init__.py index 985b4eb80..baeb10d26 100644 --- a/beets/dbcore/__init__.py +++ b/beets/dbcore/__init__.py @@ -16,7 +16,7 @@ Library. """ -from .db import Database, Model +from .db import Database, Model, Results from .query import ( AndQuery, FieldQuery, diff --git a/beets/library.py b/beets/library.py index 7507f5d34..5ce59852b 100644 --- a/beets/library.py +++ b/beets/library.py @@ -27,7 +27,7 @@ from mediafile import MediaFile, UnreadableFileError import beets from beets import dbcore, logging, plugins, util -from beets.dbcore import types +from beets.dbcore import Results, types from beets.util import ( MoveOperation, bytestring_path, @@ -1665,11 +1665,11 @@ class Library(dbcore.Database): Item, beets.config["sort_item"].as_str_seq() ) - def albums(self, query=None, sort=None): + def albums(self, query=None, sort=None) -> Results[Album]: """Get :class:`Album` objects matching the query.""" return self._fetch(Album, query, sort or self.get_default_album_sort()) - def items(self, query=None, sort=None): + def items(self, query=None, sort=None) -> Results[Item]: """Get :class:`Item` objects matching the query.""" return self._fetch(Item, query, sort or self.get_default_item_sort()) diff --git a/beets/ui/__init__.py b/beets/ui/__init__.py index ae68e6413..ef96c9c38 100644 --- a/beets/ui/__init__.py +++ b/beets/ui/__init__.py @@ -28,6 +28,7 @@ import sys import textwrap import traceback from difflib import SequenceMatcher +from typing import Any, Callable, List import confuse @@ -1450,6 +1451,8 @@ class Subcommand: invoked by a SubcommandOptionParser. """ + func: Callable[[library.Library, optparse.Values, List[str]], Any] + def __init__(self, name, parser=None, help="", aliases=(), hide=False): """Creates a new subcommand. name is the primary way to invoke the subcommand; aliases are alternate names. parser is an diff --git a/beets/util/__init__.py b/beets/util/__init__.py index fb07d7abc..00558e90a 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -405,7 +405,10 @@ def bytestring_path(path: Bytes_or_String) -> bytes: PATH_SEP: bytes = bytestring_path(os.sep) -def displayable_path(path: bytes, separator: str = "; ") -> str: +def displayable_path( + path: Union[bytes, str, Tuple[Union[bytes, str], ...]], + separator: str = "; ", +) -> str: """Attempts to decode a bytestring path to a unicode object for the purpose of displaying it to the user. If the `path` argument is a list or a tuple, the elements are joined with `separator`. @@ -801,7 +804,7 @@ def legalize_path( return second_stage_path, retruncated -def py3_path(path: AnyStr) -> str: +def py3_path(path: Union[bytes, str]) -> str: """Convert a bytestring path to Unicode. This helps deal with APIs on Python 3 that *only* accept Unicode diff --git a/beetsplug/replaygain.py b/beetsplug/replaygain.py index 639bb3754..b04c40862 100644 --- a/beetsplug/replaygain.py +++ b/beetsplug/replaygain.py @@ -16,16 +16,38 @@ import collections import enum import math +import optparse import os import queue import signal import subprocess import sys import warnings +from abc import ABC, abstractmethod +from dataclasses import dataclass +from logging import Logger from multiprocessing.pool import ThreadPool from threading import Event, Thread +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from confuse import ConfigView from beets import ui +from beets.importer import ImportSession, ImportTask +from beets.library import Album, Item, Library from beets.plugins import BeetsPlugin from beets.util import ( command_output, @@ -53,7 +75,7 @@ class FatalGstreamerPluginReplayGainError(FatalReplayGainError): loading the required plugins.""" -def call(args, log, **kwargs): +def call(args: List[Any], log: Logger, **kwargs: Any): """Execute the command and return its output or raise a ReplayGainError on failure. """ @@ -71,13 +93,7 @@ def call(args, log, **kwargs): raise ReplayGainError("argument encoding failed") -def after_version(version_a, version_b): - return tuple(int(s) for s in version_a.split(".")) >= tuple( - int(s) for s in version_b.split(".") - ) - - -def db_to_lufs(db): +def db_to_lufs(db: float) -> float: """Convert db to LUFS. According to https://wiki.hydrogenaud.io/index.php?title= @@ -86,7 +102,7 @@ def db_to_lufs(db): return db - 107 -def lufs_to_db(db): +def lufs_to_db(db: float) -> float: """Convert LUFS to db. According to https://wiki.hydrogenaud.io/index.php?title= @@ -97,9 +113,13 @@ def lufs_to_db(db): # Backend base and plumbing classes. -# gain: in LU to reference level -# peak: part of full scale (FS is 1.0) -Gain = collections.namedtuple("Gain", "gain peak") + +@dataclass +class Gain: + # gain: in LU to reference level + gain: float + # peak: part of full scale (FS is 1.0) + peak: float class PeakMethod(enum.Enum): @@ -118,7 +138,13 @@ class RgTask: """ def __init__( - self, items, album, target_level, peak_method, backend_name, log + self, + items: Sequence[Item], + album: Optional[Album], + target_level: float, + peak_method: Optional[PeakMethod], + backend_name: str, + log: Logger, ): self.items = items self.album = album @@ -126,10 +152,10 @@ class RgTask: self.peak_method = peak_method self.backend_name = backend_name self._log = log - self.album_gain = None - self.track_gains = None + self.album_gain: Optional[Gain] = None + self.track_gains: Optional[List[Gain]] = None - def _store_track_gain(self, item, track_gain): + def _store_track_gain(self, item: Item, track_gain: Gain): """Store track gain for a single item in the database.""" item.rg_track_gain = track_gain.gain item.rg_track_peak = track_gain.peak @@ -140,13 +166,13 @@ class RgTask: item.rg_track_peak, ) - def _store_album_gain(self, item): + def _store_album_gain(self, item: Item, album_gain: Gain): """Store album gain for a single item in the database. The caller needs to ensure that `self.album_gain is not None`. """ - item.rg_album_gain = self.album_gain.gain - item.rg_album_peak = self.album_gain.peak + item.rg_album_gain = album_gain.gain + item.rg_album_peak = album_gain.peak item.store() self._log.debug( "applied album gain {0} LU, peak {1} of FS", @@ -154,7 +180,7 @@ class RgTask: item.rg_album_peak, ) - def _store_track(self, write): + def _store_track(self, write: bool): """Store track gain for the first track of the task in the database.""" item = self.items[0] if self.track_gains is None or len(self.track_gains) != 1: @@ -172,7 +198,7 @@ class RgTask: item.try_write() self._log.debug("done analyzing {0}", item) - def _store_album(self, write): + def _store_album(self, write: bool): """Store track/album gains for all tracks of the task in the database.""" if ( self.album_gain is None @@ -190,12 +216,12 @@ class RgTask: ) for item, track_gain in zip(self.items, self.track_gains): self._store_track_gain(item, track_gain) - self._store_album_gain(item) + self._store_album_gain(item, self.album_gain) if write: item.try_write() self._log.debug("done analyzing {0}", item) - def store(self, write): + def store(self, write: bool): """Store computed gains for the items of this task in the database.""" if self.album is not None: self._store_album(write) @@ -213,44 +239,56 @@ class R128Task(RgTask): tags. """ - def __init__(self, items, album, target_level, backend_name, log): + def __init__( + self, + items: Sequence[Item], + album: Optional[Album], + target_level: float, + backend_name: str, + log: Logger, + ): # R128_* tags do not store the track/album peak super().__init__(items, album, target_level, None, backend_name, log) - def _store_track_gain(self, item, track_gain): + def _store_track_gain(self, item: Item, track_gain: Gain): item.r128_track_gain = track_gain.gain item.store() self._log.debug("applied r128 track gain {0} LU", item.r128_track_gain) - def _store_album_gain(self, item): + def _store_album_gain(self, item: Item, album_gain: Gain): """ The caller needs to ensure that `self.album_gain is not None`. """ - item.r128_album_gain = self.album_gain.gain + item.r128_album_gain = album_gain.gain item.store() self._log.debug("applied r128 album gain {0} LU", item.r128_album_gain) -class Backend: +AnyRgTask = TypeVar("AnyRgTask", bound=RgTask) + + +class Backend(ABC): """An abstract class representing engine for calculating RG values.""" NAME = "" do_parallel = False - def __init__(self, config, log): + def __init__(self, config: ConfigView, log: Logger): """Initialize the backend with the configuration view for the plugin. """ self._log = log - def compute_track_gain(self, task): + @abstractmethod + def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the track gain for the tracks belonging to `task`, and sets the `track_gains` attribute on the task. Returns `task`. """ raise NotImplementedError() - def compute_album_gain(self, task): + @abstractmethod + def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the album gain for the album belonging to `task`, and sets the `album_gain` attribute on the task. Returns `task`. """ @@ -264,7 +302,7 @@ class FfmpegBackend(Backend): NAME = "ffmpeg" do_parallel = True - def __init__(self, config, log): + def __init__(self, config: ConfigView, log: Logger): super().__init__(config, log) self._ffmpeg_path = "ffmpeg" @@ -292,7 +330,7 @@ class FfmpegBackend(Backend): "the --enable-libebur128 configuration option is required." ) - def compute_track_gain(self, task): + def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the track gain for the tracks belonging to `task`, and sets the `track_gains` attribute on the task. Returns `task`. """ @@ -310,7 +348,7 @@ class FfmpegBackend(Backend): return task - def compute_album_gain(self, task): + def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the album gain for the album belonging to `task`, and sets the `album_gain` attribute on the task. Returns `task`. """ @@ -318,7 +356,7 @@ class FfmpegBackend(Backend): # analyse tracks # Gives a list of tuples (track_gain, track_n_blocks) - track_results = [ + track_results: List[Tuple[Gain, int]] = [ self._analyse_item( item, task.target_level, @@ -328,8 +366,7 @@ class FfmpegBackend(Backend): for item in task.items ] - # list of track Gain objects - track_gains = [tg for tg, _nb in track_results] + track_gains: List[Gain] = [tg for tg, _nb in track_results] # Album peak is maximum track peak album_peak = max(tg.peak for tg in track_gains) @@ -337,7 +374,7 @@ class FfmpegBackend(Backend): # Total number of BS.1770 gating blocks n_blocks = sum(nb for _tg, nb in track_results) - def sum_of_track_powers(track_gain, track_n_blocks): + def sum_of_track_powers(track_gain: Gain, track_n_blocks: int): # convert `LU to target_level` -> LUFS loudness = target_level_lufs - track_gain.gain @@ -363,6 +400,7 @@ class FfmpegBackend(Backend): album_gain = -0.691 + 10 * math.log10(sum_powers / n_blocks) else: album_gain = -70 + # convert LUFS -> `LU to target_level` album_gain = target_level_lufs - album_gain @@ -378,7 +416,9 @@ class FfmpegBackend(Backend): return task - def _construct_cmd(self, item, peak_method): + def _construct_cmd( + self, item: Item, peak_method: Optional[PeakMethod] + ) -> List[Union[str, bytes]]: """Construct the shell command to analyse items.""" return [ self._ffmpeg_path, @@ -397,7 +437,13 @@ class FfmpegBackend(Backend): "-", ] - def _analyse_item(self, item, target_level, peak_method, count_blocks=True): + def _analyse_item( + self, + item: Item, + target_level: float, + peak_method: Optional[PeakMethod], + count_blocks: bool = True, + ) -> Tuple[Gain, int]: """Analyse item. Return a pair of a Gain object and the number of gating blocks above the threshold. @@ -415,7 +461,7 @@ class FfmpegBackend(Backend): # parse output if peak_method is None: - peak = 0 + peak = 0.0 else: line_peak = self._find_line( output, @@ -486,7 +532,13 @@ class FfmpegBackend(Backend): return Gain(gain, peak), n_blocks - def _find_line(self, output, search, start_line=0, step_size=1): + def _find_line( + self, + output: Sequence[bytes], + search: bytes, + start_line: int = 0, + step_size: int = 1, + ) -> int: """Return index of line beginning with `search`. Begins searching at index `start_line` in `output`. @@ -501,19 +553,19 @@ class FfmpegBackend(Backend): ) ) - def _parse_float(self, line): + def _parse_float(self, line: bytes) -> float: """Extract a float from a key value pair in `line`. This format is expected: /[^:]:[[:space:]]*value.*/, where `value` is the float. """ # extract value - value = line.split(b":", 1) - if len(value) < 2: + parts = line.split(b":", 1) + if len(parts) < 2: raise ReplayGainError( - "ffmpeg output: expected key value pair, found {}".format(line) + f"ffmpeg output: expected key value pair, found {line!r}" ) - value = value[1].lstrip() + value = parts[1].lstrip() # strip unit value = value.split(b" ", 1)[0] # cast value to float @@ -521,7 +573,7 @@ class FfmpegBackend(Backend): return float(value) except ValueError: raise ReplayGainError( - "ffmpeg output: expected float value, found {}".format(value) + f"ffmpeg output: expected float value, found {value!r}" ) @@ -530,7 +582,7 @@ class CommandBackend(Backend): NAME = "command" do_parallel = True - def __init__(self, config, log): + def __init__(self, config: ConfigView, log: Logger): super().__init__(config, log) config.add( { @@ -539,7 +591,7 @@ class CommandBackend(Backend): } ) - self.command = config["command"].as_str() + self.command = cast(str, config["command"].as_str()) if self.command: # Explicit executable path. @@ -562,7 +614,7 @@ class CommandBackend(Backend): self.noclip = config["noclip"].get(bool) - def compute_track_gain(self, task): + def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the track gain for the tracks belonging to `task`, and sets the `track_gains` attribute on the task. Returns `task`. """ @@ -571,7 +623,7 @@ class CommandBackend(Backend): task.track_gains = output return task - def compute_album_gain(self, task): + def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the album gain for the album belonging to `task`, and sets the `album_gain` attribute on the task. Returns `task`. """ @@ -590,7 +642,7 @@ class CommandBackend(Backend): task.track_gains = output[:-1] return task - def format_supported(self, item): + def format_supported(self, item: Item) -> bool: """Checks whether the given item is supported by the selected tool.""" if "mp3gain" in self.command and item.format != "MP3": return False @@ -598,7 +650,12 @@ class CommandBackend(Backend): return False return True - def compute_gain(self, items, target_level, is_album): + def compute_gain( + self, + items: Sequence[Item], + target_level: float, + is_album: bool, + ) -> List[Gain]: """Computes the track or album gain of a list of items, returns a list of TrackGain objects. @@ -618,7 +675,7 @@ class CommandBackend(Backend): # tag-writing; this turns the mp3gain/aacgain tool into a gain # calculator rather than a tag manipulator because we take care # of changing tags ourselves. - cmd = [self.command, "-o", "-s", "s"] + cmd: List[Union[bytes, str]] = [self.command, "-o", "-s", "s"] if self.noclip: # Adjust to avoid clipping. cmd = cmd + ["-k"] @@ -636,7 +693,7 @@ class CommandBackend(Backend): output, len(items) + (1 if is_album else 0) ) - def parse_tool_output(self, text, num_lines): + def parse_tool_output(self, text: bytes, num_lines: int) -> List[Gain]: """Given the tab-delimited output from an invocation of mp3gain or aacgain, parse the text and return a list of dictionaries containing information about each analyzed file. @@ -647,15 +704,15 @@ class CommandBackend(Backend): if len(parts) != 6 or parts[0] == b"File": self._log.debug("bad tool output: {0}", text) raise ReplayGainError("mp3gain failed") - d = { - "file": parts[0], - "mp3gain": int(parts[1]), - "gain": float(parts[2]), - "peak": float(parts[3]) / (1 << 15), - "maxgain": int(parts[4]), - "mingain": int(parts[5]), - } - out.append(Gain(d["gain"], d["peak"])) + + # _file = parts[0] + # _mp3gain = int(parts[1]) + gain = float(parts[2]) + peak = float(parts[3]) / (1 << 15) + # _maxgain = int(parts[4]) + # _mingain = int(parts[5]) + + out.append(Gain(gain, peak)) return out @@ -665,7 +722,7 @@ class CommandBackend(Backend): class GStreamerBackend(Backend): NAME = "gstreamer" - def __init__(self, config, log): + def __init__(self, config: ConfigView, log: Logger): super().__init__(config, log) self._import_gst() @@ -722,7 +779,7 @@ class GStreamerBackend(Backend): self._main_loop = self.GLib.MainLoop() - self._files = [] + self._files: List[bytes] = [] def _import_gst(self): """Import the necessary GObject-related modules and assign `Gst` @@ -754,14 +811,17 @@ class GStreamerBackend(Backend): self.GLib = GLib self.Gst = Gst - def compute(self, items, target_level, album): + def compute(self, items: Sequence[Item], target_level: float, album: bool): if len(items) == 0: return self._error = None self._files = [i.path for i in items] - self._file_tags = collections.defaultdict(dict) + # FIXME: Turn this into DefaultDict[bytes, Gain] + self._file_tags: DefaultDict[ + bytes, Dict[str, float] + ] = collections.defaultdict(dict) self._rg.set_property("reference-level", target_level) @@ -773,7 +833,7 @@ class GStreamerBackend(Backend): if self._error is not None: raise self._error - def compute_track_gain(self, task): + def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the track gain for the tracks belonging to `task`, and sets the `track_gains` attribute on the task. Returns `task`. """ @@ -793,7 +853,7 @@ class GStreamerBackend(Backend): task.track_gains = ret return task - def compute_album_gain(self, task): + def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the album gain for the album belonging to `task`, and sets the `album_gain` attribute on the task. Returns `task`. """ @@ -876,7 +936,7 @@ class GStreamerBackend(Backend): tags.foreach(handle_tag, None) - def _set_first_file(self): + def _set_first_file(self) -> bool: if len(self._files) == 0: return False @@ -886,7 +946,7 @@ class GStreamerBackend(Backend): self._pipe.set_state(self.Gst.State.PLAYING) return True - def _set_file(self): + def _set_file(self) -> bool: """Initialize the filesrc element with the next file to be analyzed.""" # No more files, we're done if len(self._files) == 0: @@ -919,7 +979,7 @@ class GStreamerBackend(Backend): return True - def _set_next_file(self): + def _set_next_file(self) -> bool: """Set the next file to be analyzed while keeping the pipeline in the PAUSED state so that the rganalysis element can correctly handle album gain. @@ -960,7 +1020,7 @@ class AudioToolsBackend(Backend): NAME = "audiotools" - def __init__(self, config, log): + def __init__(self, config: ConfigView, log: Logger): super().__init__(config, log) self._import_audiotools() @@ -980,7 +1040,7 @@ class AudioToolsBackend(Backend): self._mod_audiotools = audiotools self._mod_replaygain = audiotools.replaygain - def open_audio_file(self, item): + def open_audio_file(self, item: Item): """Open the file to read the PCM stream from the using ``item.path``. @@ -998,7 +1058,7 @@ class AudioToolsBackend(Backend): return audiofile - def init_replaygain(self, audiofile, item): + def init_replaygain(self, audiofile, item: Item): """Return an initialized :class:`audiotools.replaygain.ReplayGain` instance, which requires the sample rate of the song(s) on which the ReplayGain values will be computed. The item is passed in case @@ -1015,7 +1075,7 @@ class AudioToolsBackend(Backend): return return rg - def compute_track_gain(self, task): + def compute_track_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the track gain for the tracks belonging to `task`, and sets the `track_gains` attribute on the task. Returns `task`. """ @@ -1025,14 +1085,14 @@ class AudioToolsBackend(Backend): task.track_gains = gains return task - def _with_target_level(self, gain, target_level): + def _with_target_level(self, gain: float, target_level: float): """Return `gain` relative to `target_level`. Assumes `gain` is relative to 89 db. """ return gain + (target_level - 89) - def _title_gain(self, rg, audiofile, target_level): + def _title_gain(self, rg, audiofile, target_level: float): """Get the gain result pair from PyAudioTools using the `ReplayGain` instance `rg` for the given `audiofile`. @@ -1050,7 +1110,7 @@ class AudioToolsBackend(Backend): raise ReplayGainError("audiotools audio data error") return self._with_target_level(gain, target_level), peak - def _compute_track_gain(self, item, target_level): + def _compute_track_gain(self, item: Item, target_level: float): """Compute ReplayGain value for the requested item. :rtype: :class:`Gain` @@ -1073,7 +1133,7 @@ class AudioToolsBackend(Backend): ) return Gain(gain=rg_track_gain, peak=rg_track_peak) - def compute_album_gain(self, task): + def compute_album_gain(self, task: AnyRgTask) -> AnyRgTask: """Computes the album gain for the album belonging to `task`, and sets the `album_gain` attribute on the task. Returns `task`. """ @@ -1121,7 +1181,7 @@ class ExceptionWatcher(Thread): Once an exception occurs, raise it and execute a callback. """ - def __init__(self, queue, callback): + def __init__(self, queue: queue.Queue, callback: Callable[[], None]): self._queue = queue self._callback = callback self._stopevent = Event() @@ -1138,20 +1198,20 @@ class ExceptionWatcher(Thread): # whether `_stopevent` is set pass - def join(self, timeout=None): + def join(self, timeout: Optional[float] = None): self._stopevent.set() Thread.join(self, timeout) # Main plugin logic. -BACKEND_CLASSES = [ +BACKEND_CLASSES: List[Type[Backend]] = [ CommandBackend, GStreamerBackend, AudioToolsBackend, FfmpegBackend, ] -BACKENDS = {b.NAME: b for b in BACKEND_CLASSES} +BACKENDS: Dict[str, Type[Backend]] = {b.NAME: b for b in BACKEND_CLASSES} class ReplayGainPlugin(BeetsPlugin): @@ -1178,7 +1238,7 @@ class ReplayGainPlugin(BeetsPlugin): # FIXME: Consider renaming the configuration option and deprecating the # old name 'overwrite'. - self.force_on_import = self.config["overwrite"].get(bool) + self.force_on_import = cast(bool, self.config["overwrite"].get(bool)) # Remember which backend is used for CLI feedback self.backend_name = self.config["backend"].as_str() @@ -1224,21 +1284,21 @@ class ReplayGainPlugin(BeetsPlugin): # Start threadpool lazily. self.pool = None - def should_use_r128(self, item): + def should_use_r128(self, item: Item) -> bool: """Checks the plugin setting to decide whether the calculation should be done using the EBU R128 standard and use R128_ tags instead. """ return item.format in self.r128_whitelist @staticmethod - def has_r128_track_data(item): + def has_r128_track_data(item: Item) -> bool: return item.r128_track_gain is not None @staticmethod - def has_rg_track_data(item): + def has_rg_track_data(item: Item) -> bool: return item.rg_track_gain is not None and item.rg_track_peak is not None - def track_requires_gain(self, item): + def track_requires_gain(self, item: Item) -> bool: if self.should_use_r128(item): if not self.has_r128_track_data(item): return True @@ -1249,17 +1309,17 @@ class ReplayGainPlugin(BeetsPlugin): return False @staticmethod - def has_r128_album_data(item): + def has_r128_album_data(item: Item) -> bool: return ( item.r128_track_gain is not None and item.r128_album_gain is not None ) @staticmethod - def has_rg_album_data(item): + def has_rg_album_data(item: Item) -> bool: return item.rg_album_gain is not None and item.rg_album_peak is not None - def album_requires_gain(self, album): + def album_requires_gain(self, album: Album) -> bool: # Skip calculating gain only when *all* files don't need # recalculation. This way, if any file among an album's tracks # needs recalculation, we still get an accurate album gain @@ -1274,7 +1334,12 @@ class ReplayGainPlugin(BeetsPlugin): return False - def create_task(self, items, use_r128, album=None): + def create_task( + self, + items: Sequence[Item], + use_r128: bool, + album: Optional[Album] = None, + ) -> RgTask: if use_r128: return R128Task( items, @@ -1293,7 +1358,7 @@ class ReplayGainPlugin(BeetsPlugin): self._log, ) - def handle_album(self, album, write, force=False): + def handle_album(self, album: Album, write: bool, force: bool = False): """Compute album and track replay gain store it in all of the album's items. @@ -1316,7 +1381,7 @@ class ReplayGainPlugin(BeetsPlugin): self._log.info("analyzing {0}", album) - discs = {} + discs: Dict[int, List[Item]] = {} if self.config["per_disc"].get(bool): for item in album.items(): if discs.get(item.disc) is None: @@ -1325,6 +1390,9 @@ class ReplayGainPlugin(BeetsPlugin): else: discs[1] = album.items() + def store_cb(task: RgTask): + task.store(write) + for discnumber, items in discs.items(): task = self.create_task(items, use_r128, album=album) try: @@ -1332,14 +1400,14 @@ class ReplayGainPlugin(BeetsPlugin): self.backend_instance.compute_album_gain, args=[task], kwds={}, - callback=lambda task: task.store(write), + callback=store_cb, ) except ReplayGainError as e: self._log.info("ReplayGain error: {0}", e) except FatalReplayGainError as e: raise ui.UserError(f"Fatal replay gain error: {e}") - def handle_track(self, item, write, force=False): + def handle_track(self, item: Item, write: bool, force: bool = False): """Compute track replay gain and store it in the item. If ``write`` is truthy then ``item.write()`` is called to write @@ -1352,24 +1420,27 @@ class ReplayGainPlugin(BeetsPlugin): use_r128 = self.should_use_r128(item) + def store_cb(task: RgTask): + task.store(write) + task = self.create_task([item], use_r128) try: self._apply( self.backend_instance.compute_track_gain, args=[task], kwds={}, - callback=lambda task: task.store(write), + callback=store_cb, ) except ReplayGainError as e: self._log.info("ReplayGain error: {0}", e) except FatalReplayGainError as e: raise ui.UserError(f"Fatal replay gain error: {e}") - def open_pool(self, threads): + def open_pool(self, threads: int): """Open a `ThreadPool` instance in `self.pool`""" if self.pool is None and self.backend_instance.do_parallel: self.pool = ThreadPool(threads) - self.exc_queue = queue.Queue() + self.exc_queue: queue.Queue[Exception] = queue.Queue() signal.signal(signal.SIGINT, self._interrupt) @@ -1379,7 +1450,13 @@ class ReplayGainPlugin(BeetsPlugin): ) self.exc_watcher.start() - def _apply(self, func, args, kwds, callback): + def _apply( + self, + func: Callable[..., AnyRgTask], + args: List[Any], + kwds: Dict[str, Any], + callback: Callable[[AnyRgTask], Any], + ): if self.pool is not None: def handle_exc(exc): @@ -1425,9 +1502,9 @@ class ReplayGainPlugin(BeetsPlugin): self.exc_watcher.join() self.pool = None - def import_begin(self, session): + def import_begin(self, session: ImportSession): """Handle `import_begin` event -> open pool""" - threads = self.config["threads"].get(int) + threads = cast(int, self.config["threads"].get(int)) if ( self.config["parallel_on_import"] @@ -1440,22 +1517,31 @@ class ReplayGainPlugin(BeetsPlugin): """Handle `import` event -> close pool""" self.close_pool() - def imported(self, session, task): + def imported(self, session: ImportSession, task: ImportTask): """Add replay gain info to items or albums of ``task``.""" if self.config["auto"]: if task.is_album: self.handle_album(task.album, False, self.force_on_import) else: + # Should be a SingletonImportTask + assert hasattr(task, "item") self.handle_track(task.item, False, self.force_on_import) - def command_func(self, lib, opts, args): + def command_func( + self, + lib: Library, + opts: optparse.Values, + args: List[str], + ): try: write = ui.should_write(opts.write) force = opts.force # Bypass self.open_pool() if called with `--threads 0` if opts.threads != 0: - threads = opts.thread |