diff options
Diffstat (limited to 'openbb_platform/extensions/tests/utils/helpers.py')
-rw-r--r-- | openbb_platform/extensions/tests/utils/helpers.py | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/openbb_platform/extensions/tests/utils/helpers.py b/openbb_platform/extensions/tests/utils/helpers.py index 35f73ca8b69..a00ba7c5583 100644 --- a/openbb_platform/extensions/tests/utils/helpers.py +++ b/openbb_platform/extensions/tests/utils/helpers.py @@ -9,10 +9,12 @@ import os import re from ast import AsyncFunctionDef, Call, FunctionDef, Name, parse, unparse from dataclasses import dataclass -from importlib.metadata import entry_points +from importlib.metadata import EntryPoint, entry_points from inspect import getmembers, isfunction -from typing import Any, Dict, List, Optional, Set, Tuple +from sys import version_info +from typing import Any, Dict, List, Optional, Set, Tuple, Union +from importlib_metadata import EntryPoints from openbb_core.app.provider_interface import ProviderInterface pi = ProviderInterface() @@ -68,12 +70,18 @@ def check_docstring_examples() -> List[str]: return errors +def filter_eps(eps: Union[EntryPoints, dict], group: str) -> Tuple[EntryPoint, ...]: + if version_info[:2] == (3, 12): + return eps.select(group=group) or () # type: ignore[union-attr] + return eps.get(group, ()) # type: ignore[union-attr] + + def list_openbb_extensions() -> Tuple[Set[str], Set[str], Set[str]]: """List installed openbb extensions and providers. Returns ------- - Tuple[Set[str], Set[str]] + Tuple[Set[str], Set[str], Set[str]] First element: set of installed core extensions. Second element: set of installed provider extensions. Third element: set of installed obbject extensions. @@ -82,15 +90,30 @@ def list_openbb_extensions() -> Tuple[Set[str], Set[str], Set[str]]: core_extensions = set() provider_extensions = set() obbject_extensions = set() + entry_points_dict = entry_points() - for entry_point in entry_points_dict.get("openbb_core_extension", []): + # Compatibility for different Python versions + if hasattr(entry_points_dict, "select"): # Python 3.12+ + core_entry_points = entry_points_dict.select(group="openbb_core_extension") + provider_entry_points = entry_points_dict.select( + group="openbb_provider_extension" + ) + obbject_entry_points = entry_points_dict.select( + group="openbb_obbject_extension" + ) + else: + core_entry_points = entry_points_dict.get("openbb_core_extension", []) + provider_entry_points = entry_points_dict.get("openbb_provider_extension", []) + obbject_entry_points = entry_points_dict.get("openbb_obbject_extension", []) + + for entry_point in core_entry_points: core_extensions.add(f"{entry_point.name}") - for entry_point in entry_points_dict.get("openbb_provider_extension", []): + for entry_point in provider_entry_points: provider_extensions.add(f"{entry_point.name}") - for entry_point in entry_points_dict.get("openbb_obbject_extension", []): + for entry_point in obbject_entry_points: obbject_extensions.add(f"{entry_point.name}") return core_extensions, provider_extensions, obbject_extensions |