summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorŠarūnas Nejus <snejus@protonmail.com>2024-05-01 05:10:03 +0100
committerŠarūnas Nejus <snejus@protonmail.com>2024-05-04 13:04:17 +0100
commitadac9260b26f9092ac4b4bb8c710e66725dd73d2 (patch)
tree494b3659afcdbe69d4ead0bf15f3c1fb65eff0d6
parent329098a0b4f009b91d0a981112ffe0a9f96484f3 (diff)
Fix types in beets.util.__init__
-rw-r--r--beets/util/__init__.py134
1 files changed, 71 insertions, 63 deletions
diff --git a/beets/util/__init__.py b/beets/util/__init__.py
index 4335e0f3e..87bb96ed4 100644
--- a/beets/util/__init__.py
+++ b/beets/util/__init__.py
@@ -24,12 +24,13 @@ import shlex
import shutil
import subprocess
import sys
-import tempfile
import traceback
from collections import Counter, namedtuple
+from contextlib import suppress
from enum import Enum
from logging import Logger
from multiprocessing.pool import ThreadPool
+from tempfile import NamedTemporaryFile
from typing import (
Any,
AnyStr,
@@ -54,7 +55,7 @@ from beets.util import hidden
MAX_FILENAME_LENGTH = 200
WINDOWS_MAGIC_PREFIX = "\\\\?\\"
T = TypeVar("T")
-Bytes_or_String: TypeAlias = Union[str, bytes]
+BytesOrStr: TypeAlias = Union[str, bytes]
class HumanReadableException(Exception):
@@ -154,7 +155,7 @@ class MoveOperation(Enum):
REFLINK_AUTO = 5
-def normpath(path: bytes) -> bytes:
+def normpath(path: AnyStr) -> bytes:
"""Provide the canonical form of the path suitable for storing in
the database.
"""
@@ -163,7 +164,7 @@ def normpath(path: bytes) -> bytes:
return bytestring_path(path)
-def ancestry(path: bytes) -> List[str]:
+def ancestry(path: AnyStr) -> List[AnyStr]:
"""Return a list consisting of path's parent directory, its
grandparent, and so on. For instance:
@@ -172,7 +173,7 @@ def ancestry(path: bytes) -> List[str]:
The argument should *not* be the result of a call to `syspath`.
"""
- out = []
+ out: List[AnyStr] = []
last_path = None
while path:
path = os.path.dirname(path)
@@ -188,28 +189,28 @@ def ancestry(path: bytes) -> List[str]:
def sorted_walk(
- path: AnyStr,
- ignore: Sequence = (),
+ path: BytesOrStr,
+ ignore: Sequence[BytesOrStr] = (),
ignore_hidden: bool = False,
logger: Optional[Logger] = None,
-) -> Generator[Tuple, None, None]:
+) -> Generator[Tuple[bytes, Sequence[bytes], Sequence[bytes]], None, None]:
"""Like `os.walk`, but yields things in case-insensitive sorted,
breadth-first order. Directory and file names matching any glob
pattern in `ignore` are skipped. If `logger` is provided, then
warning messages are logged there when a directory cannot be listed.
"""
# Make sure the paths aren't Unicode strings.
- path = bytestring_path(path)
- ignore = [bytestring_path(i) for i in ignore]
+ bytes_path = bytestring_path(path)
+ bytes_ignore = [bytestring_path(i) for i in ignore]
# Get all the directories and files at this level.
try:
- contents = os.listdir(syspath(path))
+ contents = os.listdir(syspath(bytes_path))
except OSError as exc:
if logger:
logger.warning(
"could not list directory {}: {}".format(
- displayable_path(path), exc.strerror
+ displayable_path(bytes_path), exc.strerror
)
)
return
@@ -220,11 +221,11 @@ def sorted_walk(
# Skip ignored filenames.
skip = False
- for pat in ignore:
+ for pat in bytes_ignore:
if fnmatch.fnmatch(base, pat):
if logger:
- logger.debug(
- "ignoring {} due to ignore rule {}".format(base, pat)
+ logger.error(
+ "ignoring '{}' due to ignore rule '{}'", base, pat
)
skip = True
break
@@ -232,7 +233,7 @@ def sorted_walk(
continue
# Add to output as either a file or a directory.
- cur = os.path.join(path, base)
+ cur = os.path.join(bytes_path, base)
if (ignore_hidden and not hidden.is_hidden(cur)) or not ignore_hidden:
if os.path.isdir(syspath(cur)):
dirs.append(base)
@@ -242,11 +243,11 @@ def sorted_walk(
# Sort lists (case-insensitive) and yield the current level.
dirs.sort(key=bytes.lower)
files.sort(key=bytes.lower)
- yield (path, dirs, files)
+ yield (bytes_path, dirs, files)
# Recurse into directories.
for base in dirs:
- cur = os.path.join(path, base)
+ cur = os.path.join(bytes_path, base)
# yield from sorted_walk(...)
yield from sorted_walk(cur, ignore, ignore_hidden, logger)
@@ -289,7 +290,7 @@ def fnmatch_all(names: Sequence[bytes], patterns: Sequence[bytes]) -> bool:
def prune_dirs(
path: str,
- root: Optional[Bytes_or_String] = None,
+ root: Optional[AnyStr] = None,
clutter: Sequence[str] = (".DS_Store", "Thumbs.db"),
):
"""If path is an empty directory, then remove it. Recursively remove
@@ -299,33 +300,33 @@ def prune_dirs(
emptiness. If root is not provided, then only path may be removed
(i.e., no recursive removal).
"""
- path = normpath(path)
- if root is not None:
- root = normpath(root)
+ bytes_path = normpath(path)
+ root_bytes = normpath(root) if root else None
+ ancestors = ancestry(bytes_path)
- ancestors = ancestry(path)
if root is None:
# Only remove the top directory.
ancestors = []
- elif root in ancestors:
- # Only remove directories below the root.
- ancestors = ancestors[ancestors.index(root) + 1 :]
+ elif root_bytes in ancestors:
+ # Only remove directories below the root_bytes.
+ ancestors = ancestors[ancestors.index(root_bytes) + 1 :]
else:
# Remove nothing.
return
+ bytes_clutter = [bytestring_path(c) for c in clutter]
+
# Traverse upward from path.
- ancestors.append(path)
+ ancestors.append(bytes_path)
ancestors.reverse()
for directory in ancestors:
directory = syspath(directory)
if not os.path.exists(directory):
# Directory gone already.
continue
- clutter: List[bytes] = [bytestring_path(c) for c in clutter]
match_paths = [bytestring_path(d) for d in os.listdir(directory)]
try:
- if fnmatch_all(match_paths, clutter):
+ if fnmatch_all(match_paths, bytes_clutter):
# Directory contains only clutter (or nothing).
shutil.rmtree(directory)
else:
@@ -380,7 +381,7 @@ def _fsencoding() -> str:
return encoding
-def bytestring_path(path: Bytes_or_String) -> bytes:
+def bytestring_path(path: BytesOrStr) -> bytes:
"""Given a path, which is either a bytes or a unicode, returns a str
path (ensuring that we never deal with Unicode pathnames). Path should be
bytes but has safeguards for strings to be converted.
@@ -427,7 +428,7 @@ def displayable_path(
return path.decode("utf-8", "ignore")
-def syspath(path: Bytes_or_String, prefix: bool = True) -> Bytes_or_String:
+def syspath(path: AnyStr, prefix: bool = True) -> AnyStr:
"""Convert a path for use by the operating system. In particular,
paths on Windows must receive a magic prefix and must be converted
to Unicode before they are sent to the OS. To disable the magic
@@ -438,51 +439,57 @@ def syspath(path: Bytes_or_String, prefix: bool = True) -> Bytes_or_String:
if os.path.__name__ != "ntpath":
return path
- if not isinstance(path, str):
+ if isinstance(path, bytes):
# Beets currently represents Windows paths internally with UTF-8
# arbitrarily. But earlier versions used MBCS because it is
# reported as the FS encoding by Windows. Try both.
try:
- path = path.decode("utf-8")
+ str_path = path.decode("utf-8")
except UnicodeError:
# The encoding should always be MBCS, Windows' broken
# Unicode representation.
assert isinstance(path, bytes)
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
- path = path.decode(encoding, "replace")
+ str_path = path.decode(encoding, "replace")
+ else:
+ str_path = path
# Add the magic prefix if it isn't already there.
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
- if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX):
- if path.startswith("\\\\"):
+ if prefix and not str_path.startswith(WINDOWS_MAGIC_PREFIX):
+ if str_path.startswith("\\\\"):
# UNC path. Final path should look like \\?\UNC\...
- path = "UNC" + path[1:]
- path = WINDOWS_MAGIC_PREFIX + path
+ str_path = "UNC" + str_path[1:]
+ str_path = WINDOWS_MAGIC_PREFIX + str_path
- return path
+ return str_path if isinstance(path, str) else str_path.encode()
-def samefile(p1: bytes, p2: bytes) -> bool:
+def samefile(p1: AnyStr, p2: AnyStr) -> bool:
"""Safer equality for paths."""
if p1 == p2:
return True
- return shutil._samefile(syspath(p1), syspath(p2))
+ with suppress(OSError):
+ return os.path.samefile(syspath(p1), syspath(p2))
+
+ return False
-def remove(path: Optional[bytes], soft: bool = True):
+def remove(path: AnyStr, soft: bool = True):
"""Remove the file. If `soft`, then no error will be raised if the
file does not exist.
"""
- path = syspath(path)
if not path or (soft and not os.path.exists(path)):
return
+
+ path = syspath(path)
try:
os.remove(path)
except OSError as exc:
raise FilesystemError(exc, "delete", (path,), traceback.format_exc())
-def copy(path: bytes, dest: bytes, replace: bool = False):
+def copy(path: AnyStr, dest: AnyStr, replace: bool = False):
"""Copy a plain file. Permissions are not copied. If `dest` already
exists, raises a FilesystemError unless `replace` is True. Has no
effect if `path` is the same as `dest`. Paths are translated to
@@ -524,30 +531,31 @@ def move(path: bytes, dest: bytes, replace: bool = False):
# Copy the file to a temporary destination.
basename = os.path.basename(bytestring_path(dest))
dirname = os.path.dirname(bytestring_path(dest))
- tmp = tempfile.NamedTemporaryFile(
+ tmp = NamedTemporaryFile(
suffix=syspath(b".beets", prefix=False),
prefix=syspath(b"." + basename, prefix=False),
dir=syspath(dirname),
delete=False,
)
try:
- with open(syspath(path), "rb") as f:
+ with open(syspath(path)) as f:
shutil.copyfileobj(f, tmp)
finally:
tmp.close()
# Move the copied file into place.
+ tmp_filename = tmp.name
try:
- os.replace(tmp.name, syspath(dest))
- tmp = None
+ os.replace(tmp_filename, syspath(dest))
+ tmp_filename = ""
os.remove(syspath(path))
except OSError as exc:
raise FilesystemError(
exc, "move", (path, dest), traceback.format_exc()
)
finally:
- if tmp is not None:
- os.remove(tmp)
+ if tmp_filename:
+ os.remove(tmp_filename)
def link(path: bytes, dest: bytes, replace: bool = False):
@@ -669,7 +677,7 @@ def unique_path(path: bytes) -> bytes:
# Unix. They are forbidden here because they cause problems on Samba
# shares, which are sufficiently common as to cause frequent problems.
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
-CHAR_REPLACE: List[Tuple[Pattern, str]] = [
+CHAR_REPLACE: List[Tuple[Pattern[str], str]] = [
(re.compile(r"[\\/]"), "_"), # / and \ -- forbidden everywhere.
(re.compile(r"^\."), "_"), # Leading dot (hidden files on Unix).
(re.compile(r"[\x00-\x1f]"), ""), # Control characters.
@@ -681,7 +689,7 @@ CHAR_REPLACE: List[Tuple[Pattern, str]] = [
def sanitize_path(
path: str,
- replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]] = None,
+ replacements: Optional[Sequence[Tuple[Pattern[str], str]]] = None,
) -> str:
"""Takes a path (as a Unicode string) and makes sure that it is
legal. Returns a new path. Only works with fragments; won't work
@@ -722,11 +730,11 @@ def truncate_path(path: AnyStr, length: int = MAX_FILENAME_LENGTH) -> AnyStr:
def _legalize_stage(
path: str,
- replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
+ replacements: Optional[Sequence[Tuple[Pattern[str], str]]],
length: int,
extension: str,
fragment: bool,
-) -> Tuple[Bytes_or_String, bool]:
+) -> Tuple[BytesOrStr, bool]:
"""Perform a single round of path legalization steps
(sanitation/replacement, encoding from Unicode to bytes,
extension-appending, and truncation). Return the path (Unicode if
@@ -752,11 +760,11 @@ def _legalize_stage(
def legalize_path(
path: str,
- replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
+ replacements: Optional[Sequence[Tuple[Pattern[str], str]]],
length: int,
extension: bytes,
fragment: bool,
-) -> Tuple[Union[Bytes_or_String, bool]]:
+) -> Tuple[Union[BytesOrStr, bool]]:
"""Given a path-like Unicode string, produce a legal path. Return
the path and a flag indicating whether some replacements had to be
ignored (see below).
@@ -838,7 +846,7 @@ def as_string(value: Any) -> str:
return str(value)
-def plurality(objs: Sequence[T]) -> T:
+def plurality(objs: Sequence[T]) -> Tuple[T, int]:
"""Given a sequence of hashble objects, returns the object that
is most common in the set and the its number of appearance. The
sequence must contain at least one object.
@@ -884,7 +892,7 @@ def cpu_count() -> int:
return 1
-def convert_command_args(args: List[bytes]) -> List[str]:
+def convert_command_args(args: List[AnyStr]) -> List[str]:
"""Convert command arguments, which may either be `bytes` or `str`
objects, to uniformly surrogate-escaped strings."""
assert isinstance(args, list)
@@ -902,7 +910,7 @@ CommandOutput = namedtuple("CommandOutput", ("stdout", "stderr"))
def command_output(
- cmd: List[Bytes_or_String],
+ cmd: List[AnyStr],
shell: bool = False,
) -> CommandOutput:
"""Runs the command and returns its output after it has exited.
@@ -922,7 +930,7 @@ def command_output(
This replaces `subprocess.check_output` which can have problems if lots of
output is sent to stderr.
"""
- cmd = convert_command_args(cmd)
+ converted_cmd = convert_command_args(cmd)
devnull = subprocess.DEVNULL
@@ -938,7 +946,7 @@ def command_output(
if proc.returncode:
raise subprocess.CalledProcessError(
returncode=proc.returncode,
- cmd=" ".join(cmd),
+ cmd=" ".join(converted_cmd),
output=stdout + stderr,
)
return CommandOutput(stdout, stderr)
@@ -1084,7 +1092,7 @@ def asciify_path(path: str, sep_replace: str) -> str:
# if this platform has an os.altsep, change it to os.sep.
if os.altsep:
path = path.replace(os.altsep, os.sep)
- path_components: List[Bytes_or_String] = path.split(os.sep)
+ path_components: List[str] = path.split(os.sep)
for index, item in enumerate(path_components):
path_components[index] = unidecode(item).replace(os.sep, sep_replace)
if os.altsep:
@@ -1094,7 +1102,7 @@ def asciify_path(path: str, sep_replace: str) -> str:
return os.sep.join(path_components)
-def par_map(transform: Callable, items: Iterable):
+def par_map(transform: Callable[..., Any], items: Iterable[Any]) -> None:
"""Apply the function `transform` to all the elements in the
iterable `items`, like `map(transform, items)` but with no return
value.