summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDiogo Sousa <montezdesousa@gmail.com>2023-09-21 10:34:12 +0100
committerDiogo Sousa <montezdesousa@gmail.com>2023-09-21 10:34:12 +0100
commit0fc3185139dd9df9b5717111c15d9616a1fa8d42 (patch)
tree8d0d8c82ad3098a1634d4883854c9177fc3fd8b1
parent4f4e941267d564fa73a6773c132839d6345b944e (diff)
make extension test less verbose
-rw-r--r--openbb_sdk/tests/test_extension_map.py71
1 files changed, 33 insertions, 38 deletions
diff --git a/openbb_sdk/tests/test_extension_map.py b/openbb_sdk/tests/test_extension_map.py
index bf80a89657e..4a135b74a61 100644
--- a/openbb_sdk/tests/test_extension_map.py
+++ b/openbb_sdk/tests/test_extension_map.py
@@ -2,62 +2,57 @@ import json
from pathlib import Path
from typing import Dict
-from poetry.core.constraints.version import Version, parse_constraint
+from poetry.core.constraints.version import Version, VersionConstraint, parse_constraint
from poetry.core.pyproject.toml import PyProjectTOML
-def load_extension_map(file: Path) -> Dict[str, Version]:
+def load_ext_map(file: Path) -> Dict[str, Version]:
"""Load the extension map from extension_map.json."""
- extension_map = {}
+ ext_map = {}
with open(file) as f:
- extension_map_json = json.load(f)
-
- for _, v in extension_map_json.items():
+ ext_map_json = json.load(f)
+ for _, v in ext_map_json.items():
for value in v:
name, version = value.split("@")
- extension_map[name] = Version.parse(version)
-
- return extension_map
+ ext_map[name] = Version.parse(version)
+ return ext_map
-def load_required_extensions(file: Path) -> Dict[str, parse_constraint]:
+def load_req_ext(file: Path) -> Dict[str, VersionConstraint]:
"""Load the required extensions from pyproject.toml."""
pyproject = PyProjectTOML(file)
- dependencies = pyproject.data["tool"]["poetry"]["dependencies"]
- required_extensions = {}
-
- for k, v in dependencies.items():
+ deps = pyproject.data["tool"]["poetry"]["dependencies"]
+ req_ext = {}
+ for k, v in deps.items():
if k.startswith("openbb-") and k not in ("openbb-core", "openbb-provider"):
name = k[7:]
if isinstance(v, str):
- required_extensions[name] = parse_constraint(v)
+ req_ext[name] = parse_constraint(v)
elif isinstance(v, dict) and not v.get("optional", False):
- required_extensions[name] = parse_constraint(v["version"])
-
- return required_extensions
+ req_ext[name] = parse_constraint(v["version"])
+ return req_ext
def test_extension_map():
- """Ensure that only required extensions are built and that the versions are compatible."""
+ """Ensure only required extensions are built and versions respect pyproject.toml"""
this_dir = Path(__file__).parent
- extension_map = load_extension_map(
+ ext_map = load_ext_map(
Path(this_dir, "..", "openbb", "package", "extension_map.json")
)
- required_extensions = load_required_extensions(
- Path(this_dir, "..", "pyproject.toml")
- )
-
- # Check that all required extensions are built
- for name in required_extensions:
- assert (
- name in extension_map
- ), f"Extension '{name}' is required in pyproject.toml but is not built, install it and rebuild or remove it from pyproject.toml" # noqa: E501
-
- # Check that all built extensions are required and that the versions are compatible
- for name, version in extension_map.items():
- assert (
- name in required_extensions
- ), f"'{name}' is not a required extension in pyproject.toml, uninstall it and rebuild or add it to pyproject.toml" # noqa: E501
- assert required_extensions[name].allows(
- version
- ), f"Version '{version}' of extension '{name}' is not compatible with the version '{required_extensions[name]}' required in pyproject.toml" # noqa: E501
+ req_ext = load_req_ext(Path(this_dir, "..", "pyproject.toml"))
+
+ for ext in req_ext:
+ assert ext in ext_map, (
+ f"Extension '{ext}' is required in pyproject.toml but is not built, install"
+ " it and rebuild or remove it mandatory requirements in pyproject.toml"
+ )
+
+ for name, version in ext_map.items():
+ assert name in req_ext, (
+ f"'{name}' is not a required extension in pyproject.toml, uninstall it and"
+ " rebuild or add it to pyproject.toml"
+ )
+ assert req_ext[name].allows(version), (
+ f"Version '{version}' of extension '{name}' is not compatible with the"
+ " version '{req_ext[name]}' required in pyproject.toml"
+ )