diff options
author | Šarūnas Nejus <snejus@protonmail.com> | 2024-05-01 05:10:03 +0100 |
---|---|---|
committer | Šarūnas Nejus <snejus@protonmail.com> | 2024-05-04 13:04:17 +0100 |
commit | adac9260b26f9092ac4b4bb8c710e66725dd73d2 (patch) | |
tree | 494b3659afcdbe69d4ead0bf15f3c1fb65eff0d6 | |
parent | 329098a0b4f009b91d0a981112ffe0a9f96484f3 (diff) |
Fix types in beets.util.__init__
-rw-r--r-- | beets/util/__init__.py | 134 |
1 files changed, 71 insertions, 63 deletions
diff --git a/beets/util/__init__.py b/beets/util/__init__.py index 4335e0f3e..87bb96ed4 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -24,12 +24,13 @@ import shlex import shutil import subprocess import sys -import tempfile import traceback from collections import Counter, namedtuple +from contextlib import suppress from enum import Enum from logging import Logger from multiprocessing.pool import ThreadPool +from tempfile import NamedTemporaryFile from typing import ( Any, AnyStr, @@ -54,7 +55,7 @@ from beets.util import hidden MAX_FILENAME_LENGTH = 200 WINDOWS_MAGIC_PREFIX = "\\\\?\\" T = TypeVar("T") -Bytes_or_String: TypeAlias = Union[str, bytes] +BytesOrStr: TypeAlias = Union[str, bytes] class HumanReadableException(Exception): @@ -154,7 +155,7 @@ class MoveOperation(Enum): REFLINK_AUTO = 5 -def normpath(path: bytes) -> bytes: +def normpath(path: AnyStr) -> bytes: """Provide the canonical form of the path suitable for storing in the database. """ @@ -163,7 +164,7 @@ def normpath(path: bytes) -> bytes: return bytestring_path(path) -def ancestry(path: bytes) -> List[str]: +def ancestry(path: AnyStr) -> List[AnyStr]: """Return a list consisting of path's parent directory, its grandparent, and so on. For instance: @@ -172,7 +173,7 @@ def ancestry(path: bytes) -> List[str]: The argument should *not* be the result of a call to `syspath`. """ - out = [] + out: List[AnyStr] = [] last_path = None while path: path = os.path.dirname(path) @@ -188,28 +189,28 @@ def ancestry(path: bytes) -> List[str]: def sorted_walk( - path: AnyStr, - ignore: Sequence = (), + path: BytesOrStr, + ignore: Sequence[BytesOrStr] = (), ignore_hidden: bool = False, logger: Optional[Logger] = None, -) -> Generator[Tuple, None, None]: +) -> Generator[Tuple[bytes, Sequence[bytes], Sequence[bytes]], None, None]: """Like `os.walk`, but yields things in case-insensitive sorted, breadth-first order. Directory and file names matching any glob pattern in `ignore` are skipped. If `logger` is provided, then warning messages are logged there when a directory cannot be listed. """ # Make sure the paths aren't Unicode strings. - path = bytestring_path(path) - ignore = [bytestring_path(i) for i in ignore] + bytes_path = bytestring_path(path) + bytes_ignore = [bytestring_path(i) for i in ignore] # Get all the directories and files at this level. try: - contents = os.listdir(syspath(path)) + contents = os.listdir(syspath(bytes_path)) except OSError as exc: if logger: logger.warning( "could not list directory {}: {}".format( - displayable_path(path), exc.strerror + displayable_path(bytes_path), exc.strerror ) ) return @@ -220,11 +221,11 @@ def sorted_walk( # Skip ignored filenames. skip = False - for pat in ignore: + for pat in bytes_ignore: if fnmatch.fnmatch(base, pat): if logger: - logger.debug( - "ignoring {} due to ignore rule {}".format(base, pat) + logger.error( + "ignoring '{}' due to ignore rule '{}'", base, pat ) skip = True break @@ -232,7 +233,7 @@ def sorted_walk( continue # Add to output as either a file or a directory. - cur = os.path.join(path, base) + cur = os.path.join(bytes_path, base) if (ignore_hidden and not hidden.is_hidden(cur)) or not ignore_hidden: if os.path.isdir(syspath(cur)): dirs.append(base) @@ -242,11 +243,11 @@ def sorted_walk( # Sort lists (case-insensitive) and yield the current level. dirs.sort(key=bytes.lower) files.sort(key=bytes.lower) - yield (path, dirs, files) + yield (bytes_path, dirs, files) # Recurse into directories. for base in dirs: - cur = os.path.join(path, base) + cur = os.path.join(bytes_path, base) # yield from sorted_walk(...) yield from sorted_walk(cur, ignore, ignore_hidden, logger) @@ -289,7 +290,7 @@ def fnmatch_all(names: Sequence[bytes], patterns: Sequence[bytes]) -> bool: def prune_dirs( path: str, - root: Optional[Bytes_or_String] = None, + root: Optional[AnyStr] = None, clutter: Sequence[str] = (".DS_Store", "Thumbs.db"), ): """If path is an empty directory, then remove it. Recursively remove @@ -299,33 +300,33 @@ def prune_dirs( emptiness. If root is not provided, then only path may be removed (i.e., no recursive removal). """ - path = normpath(path) - if root is not None: - root = normpath(root) + bytes_path = normpath(path) + root_bytes = normpath(root) if root else None + ancestors = ancestry(bytes_path) - ancestors = ancestry(path) if root is None: # Only remove the top directory. ancestors = [] - elif root in ancestors: - # Only remove directories below the root. - ancestors = ancestors[ancestors.index(root) + 1 :] + elif root_bytes in ancestors: + # Only remove directories below the root_bytes. + ancestors = ancestors[ancestors.index(root_bytes) + 1 :] else: # Remove nothing. return + bytes_clutter = [bytestring_path(c) for c in clutter] + # Traverse upward from path. - ancestors.append(path) + ancestors.append(bytes_path) ancestors.reverse() for directory in ancestors: directory = syspath(directory) if not os.path.exists(directory): # Directory gone already. continue - clutter: List[bytes] = [bytestring_path(c) for c in clutter] match_paths = [bytestring_path(d) for d in os.listdir(directory)] try: - if fnmatch_all(match_paths, clutter): + if fnmatch_all(match_paths, bytes_clutter): # Directory contains only clutter (or nothing). shutil.rmtree(directory) else: @@ -380,7 +381,7 @@ def _fsencoding() -> str: return encoding -def bytestring_path(path: Bytes_or_String) -> bytes: +def bytestring_path(path: BytesOrStr) -> bytes: """Given a path, which is either a bytes or a unicode, returns a str path (ensuring that we never deal with Unicode pathnames). Path should be bytes but has safeguards for strings to be converted. @@ -427,7 +428,7 @@ def displayable_path( return path.decode("utf-8", "ignore") -def syspath(path: Bytes_or_String, prefix: bool = True) -> Bytes_or_String: +def syspath(path: AnyStr, prefix: bool = True) -> AnyStr: """Convert a path for use by the operating system. In particular, paths on Windows must receive a magic prefix and must be converted to Unicode before they are sent to the OS. To disable the magic @@ -438,51 +439,57 @@ def syspath(path: Bytes_or_String, prefix: bool = True) -> Bytes_or_String: if os.path.__name__ != "ntpath": return path - if not isinstance(path, str): + if isinstance(path, bytes): # Beets currently represents Windows paths internally with UTF-8 # arbitrarily. But earlier versions used MBCS because it is # reported as the FS encoding by Windows. Try both. try: - path = path.decode("utf-8") + str_path = path.decode("utf-8") except UnicodeError: # The encoding should always be MBCS, Windows' broken # Unicode representation. assert isinstance(path, bytes) encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() - path = path.decode(encoding, "replace") + str_path = path.decode(encoding, "replace") + else: + str_path = path # Add the magic prefix if it isn't already there. # https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx - if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX): - if path.startswith("\\\\"): + if prefix and not str_path.startswith(WINDOWS_MAGIC_PREFIX): + if str_path.startswith("\\\\"): # UNC path. Final path should look like \\?\UNC\... - path = "UNC" + path[1:] - path = WINDOWS_MAGIC_PREFIX + path + str_path = "UNC" + str_path[1:] + str_path = WINDOWS_MAGIC_PREFIX + str_path - return path + return str_path if isinstance(path, str) else str_path.encode() -def samefile(p1: bytes, p2: bytes) -> bool: +def samefile(p1: AnyStr, p2: AnyStr) -> bool: """Safer equality for paths.""" if p1 == p2: return True - return shutil._samefile(syspath(p1), syspath(p2)) + with suppress(OSError): + return os.path.samefile(syspath(p1), syspath(p2)) + + return False -def remove(path: Optional[bytes], soft: bool = True): +def remove(path: AnyStr, soft: bool = True): """Remove the file. If `soft`, then no error will be raised if the file does not exist. """ - path = syspath(path) if not path or (soft and not os.path.exists(path)): return + + path = syspath(path) try: os.remove(path) except OSError as exc: raise FilesystemError(exc, "delete", (path,), traceback.format_exc()) -def copy(path: bytes, dest: bytes, replace: bool = False): +def copy(path: AnyStr, dest: AnyStr, replace: bool = False): """Copy a plain file. Permissions are not copied. If `dest` already exists, raises a FilesystemError unless `replace` is True. Has no effect if `path` is the same as `dest`. Paths are translated to @@ -524,30 +531,31 @@ def move(path: bytes, dest: bytes, replace: bool = False): # Copy the file to a temporary destination. basename = os.path.basename(bytestring_path(dest)) dirname = os.path.dirname(bytestring_path(dest)) - tmp = tempfile.NamedTemporaryFile( + tmp = NamedTemporaryFile( suffix=syspath(b".beets", prefix=False), prefix=syspath(b"." + basename, prefix=False), dir=syspath(dirname), delete=False, ) try: - with open(syspath(path), "rb") as f: + with open(syspath(path)) as f: shutil.copyfileobj(f, tmp) finally: tmp.close() # Move the copied file into place. + tmp_filename = tmp.name try: - os.replace(tmp.name, syspath(dest)) - tmp = None + os.replace(tmp_filename, syspath(dest)) + tmp_filename = "" os.remove(syspath(path)) except OSError as exc: raise FilesystemError( exc, "move", (path, dest), traceback.format_exc() ) finally: - if tmp is not None: - os.remove(tmp) + if tmp_filename: + os.remove(tmp_filename) def link(path: bytes, dest: bytes, replace: bool = False): @@ -669,7 +677,7 @@ def unique_path(path: bytes) -> bytes: # Unix. They are forbidden here because they cause problems on Samba # shares, which are sufficiently common as to cause frequent problems. # https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx -CHAR_REPLACE: List[Tuple[Pattern, str]] = [ +CHAR_REPLACE: List[Tuple[Pattern[str], str]] = [ (re.compile(r"[\\/]"), "_"), # / and \ -- forbidden everywhere. (re.compile(r"^\."), "_"), # Leading dot (hidden files on Unix). (re.compile(r"[\x00-\x1f]"), ""), # Control characters. @@ -681,7 +689,7 @@ CHAR_REPLACE: List[Tuple[Pattern, str]] = [ def sanitize_path( path: str, - replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]] = None, + replacements: Optional[Sequence[Tuple[Pattern[str], str]]] = None, ) -> str: """Takes a path (as a Unicode string) and makes sure that it is legal. Returns a new path. Only works with fragments; won't work @@ -722,11 +730,11 @@ def truncate_path(path: AnyStr, length: int = MAX_FILENAME_LENGTH) -> AnyStr: def _legalize_stage( path: str, - replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]], + replacements: Optional[Sequence[Tuple[Pattern[str], str]]], length: int, extension: str, fragment: bool, -) -> Tuple[Bytes_or_String, bool]: +) -> Tuple[BytesOrStr, bool]: """Perform a single round of path legalization steps (sanitation/replacement, encoding from Unicode to bytes, extension-appending, and truncation). Return the path (Unicode if @@ -752,11 +760,11 @@ def _legalize_stage( def legalize_path( path: str, - replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]], + replacements: Optional[Sequence[Tuple[Pattern[str], str]]], length: int, extension: bytes, fragment: bool, -) -> Tuple[Union[Bytes_or_String, bool]]: +) -> Tuple[Union[BytesOrStr, bool]]: """Given a path-like Unicode string, produce a legal path. Return the path and a flag indicating whether some replacements had to be ignored (see below). @@ -838,7 +846,7 @@ def as_string(value: Any) -> str: return str(value) -def plurality(objs: Sequence[T]) -> T: +def plurality(objs: Sequence[T]) -> Tuple[T, int]: """Given a sequence of hashble objects, returns the object that is most common in the set and the its number of appearance. The sequence must contain at least one object. @@ -884,7 +892,7 @@ def cpu_count() -> int: return 1 -def convert_command_args(args: List[bytes]) -> List[str]: +def convert_command_args(args: List[AnyStr]) -> List[str]: """Convert command arguments, which may either be `bytes` or `str` objects, to uniformly surrogate-escaped strings.""" assert isinstance(args, list) @@ -902,7 +910,7 @@ CommandOutput = namedtuple("CommandOutput", ("stdout", "stderr")) def command_output( - cmd: List[Bytes_or_String], + cmd: List[AnyStr], shell: bool = False, ) -> CommandOutput: """Runs the command and returns its output after it has exited. @@ -922,7 +930,7 @@ def command_output( This replaces `subprocess.check_output` which can have problems if lots of output is sent to stderr. """ - cmd = convert_command_args(cmd) + converted_cmd = convert_command_args(cmd) devnull = subprocess.DEVNULL @@ -938,7 +946,7 @@ def command_output( if proc.returncode: raise subprocess.CalledProcessError( returncode=proc.returncode, - cmd=" ".join(cmd), + cmd=" ".join(converted_cmd), output=stdout + stderr, ) return CommandOutput(stdout, stderr) @@ -1084,7 +1092,7 @@ def asciify_path(path: str, sep_replace: str) -> str: # if this platform has an os.altsep, change it to os.sep. if os.altsep: path = path.replace(os.altsep, os.sep) - path_components: List[Bytes_or_String] = path.split(os.sep) + path_components: List[str] = path.split(os.sep) for index, item in enumerate(path_components): path_components[index] = unidecode(item).replace(os.sep, sep_replace) if os.altsep: @@ -1094,7 +1102,7 @@ def asciify_path(path: str, sep_replace: str) -> str: return os.sep.join(path_components) -def par_map(transform: Callable, items: Iterable): +def par_map(transform: Callable[..., Any], items: Iterable[Any]) -> None: """Apply the function `transform` to all the elements in the iterable `items`, like `map(transform, items)` but with no return value. |