summaryrefslogtreecommitdiffstats
path: root/openbb_platform/extensions/tests/utils/helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/extensions/tests/utils/helpers.py')
-rw-r--r--openbb_platform/extensions/tests/utils/helpers.py35
1 files changed, 29 insertions, 6 deletions
diff --git a/openbb_platform/extensions/tests/utils/helpers.py b/openbb_platform/extensions/tests/utils/helpers.py
index 35f73ca8b69..a00ba7c5583 100644
--- a/openbb_platform/extensions/tests/utils/helpers.py
+++ b/openbb_platform/extensions/tests/utils/helpers.py
@@ -9,10 +9,12 @@ import os
import re
from ast import AsyncFunctionDef, Call, FunctionDef, Name, parse, unparse
from dataclasses import dataclass
-from importlib.metadata import entry_points
+from importlib.metadata import EntryPoint, entry_points
from inspect import getmembers, isfunction
-from typing import Any, Dict, List, Optional, Set, Tuple
+from sys import version_info
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
+from importlib_metadata import EntryPoints
from openbb_core.app.provider_interface import ProviderInterface
pi = ProviderInterface()
@@ -68,12 +70,18 @@ def check_docstring_examples() -> List[str]:
return errors
+def filter_eps(eps: Union[EntryPoints, dict], group: str) -> Tuple[EntryPoint, ...]:
+ if version_info[:2] == (3, 12):
+ return eps.select(group=group) or () # type: ignore[union-attr]
+ return eps.get(group, ()) # type: ignore[union-attr]
+
+
def list_openbb_extensions() -> Tuple[Set[str], Set[str], Set[str]]:
"""List installed openbb extensions and providers.
Returns
-------
- Tuple[Set[str], Set[str]]
+ Tuple[Set[str], Set[str], Set[str]]
First element: set of installed core extensions.
Second element: set of installed provider extensions.
Third element: set of installed obbject extensions.
@@ -82,15 +90,30 @@ def list_openbb_extensions() -> Tuple[Set[str], Set[str], Set[str]]:
core_extensions = set()
provider_extensions = set()
obbject_extensions = set()
+
entry_points_dict = entry_points()
- for entry_point in entry_points_dict.get("openbb_core_extension", []):
+ # Compatibility for different Python versions
+ if hasattr(entry_points_dict, "select"): # Python 3.12+
+ core_entry_points = entry_points_dict.select(group="openbb_core_extension")
+ provider_entry_points = entry_points_dict.select(
+ group="openbb_provider_extension"
+ )
+ obbject_entry_points = entry_points_dict.select(
+ group="openbb_obbject_extension"
+ )
+ else:
+ core_entry_points = entry_points_dict.get("openbb_core_extension", [])
+ provider_entry_points = entry_points_dict.get("openbb_provider_extension", [])
+ obbject_entry_points = entry_points_dict.get("openbb_obbject_extension", [])
+
+ for entry_point in core_entry_points:
core_extensions.add(f"{entry_point.name}")
- for entry_point in entry_points_dict.get("openbb_provider_extension", []):
+ for entry_point in provider_entry_points:
provider_extensions.add(f"{entry_point.name}")
- for entry_point in entry_points_dict.get("openbb_obbject_extension", []):
+ for entry_point in obbject_entry_points:
obbject_extensions.add(f"{entry_point.name}")
return core_extensions, provider_extensions, obbject_extensions