diff options
author | Henrique Joaquim <h.joaquim@campus.fct.unl.pt> | 2024-01-10 15:23:14 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-10 15:23:14 +0000 |
commit | c3abc418a80c0d8071503158c5d70c1228915dbe (patch) | |
tree | 8388c34009fa81e42229d6bcd44c04c01e05b3c2 | |
parent | ee5388a81deed0d2ed608b57102501c334055d65 (diff) |
Improving Platform's import time (#5894)
* missing decorator
* missing optional statements
* us-gov to optional and right position
* removing redundant call
* lazy load mappings
* docstrings
* introducing the extension loader class
* misleading docstring
* using the extension loader on the router
* using the extension loader on the credentials
* docstrings
* lazy load of entry points
* checkout dev_install
* using the extension loader to get the entry points instead of recalling the function
* improved auth service and charting service
* using the extension loader on the registry
* typos
* adding properties to extension loader for easy access
* using the extension loader on the package builder
* avoiding circular imports by using forward references
* fix charting service tests
* test for the extension loader
* Update openbb_platform/core/openbb_core/app/model/credentials.py
Co-authored-by: montezdesousa <79287829+montezdesousa@users.noreply.github.com>
* Update openbb_platform/core/openbb_core/app/router.py
Co-authored-by: montezdesousa <79287829+montezdesousa@users.noreply.github.com>
* Update openbb_platform/core/openbb_core/provider/registry.py
Co-authored-by: montezdesousa <79287829+montezdesousa@users.noreply.github.com>
* adjustments
---------
Co-authored-by: montezdesousa <79287829+montezdesousa@users.noreply.github.com>
13 files changed, 525 insertions, 132 deletions
diff --git a/openbb_platform/core/openbb_core/app/charting_service.py b/openbb_platform/core/openbb_core/app/charting_service.py index 13a8d80999e..1c8a3373840 100644 --- a/openbb_platform/core/openbb_core/app/charting_service.py +++ b/openbb_platform/core/openbb_core/app/charting_service.py @@ -1,9 +1,9 @@ +"""Charting service.""" from importlib import import_module from inspect import getmembers, getsource, isfunction from typing import Callable, List, Optional, Tuple, TypeVar -from importlib_metadata import entry_points - +from openbb_core.app.extension_loader import ExtensionLoader from openbb_core.app.model.abstract.singleton import SingletonMeta from openbb_core.app.model.charts.chart import Chart, ChartFormat from openbb_core.app.model.charts.charting_settings import ChartingSettings @@ -13,18 +13,18 @@ from openbb_core.env import Env T = TypeVar("T") -POETRY_PLUGIN = "openbb_core_extension" # this is needed because static assets and api endpoints are built before any user is instantiated EXTENSION_NAME = Env().CHARTING_EXTENSION class ChartingServiceError(Exception): - pass + """Charting service error.""" class ChartingService(metaclass=SingletonMeta): """ - Charting service class. + Charting Service. + It is responsible for retrieving and executing the charting function, corresponding to a given route, from the user's preferred charting extension. @@ -53,6 +53,7 @@ class ChartingService(metaclass=SingletonMeta): user_settings: Optional[UserSettings] = None, system_settings: Optional[SystemSettings] = None, ) -> None: + """Initializes ChartingService.""" # Although the __init__ method states that both the user_settings and the system_settings # are optional, they are actually required for the first initialization of the ChartingService. # This is because the ChartingService is a singleton and it is initialized only once. @@ -73,6 +74,7 @@ class ChartingService(metaclass=SingletonMeta): @property def charting_settings(self) -> ChartingSettings: + """Gets charting settings.""" return self._charting_settings @charting_settings.setter @@ -87,9 +89,7 @@ class ChartingService(metaclass=SingletonMeta): def _check_and_get_charting_extension_name( user_preferences_charting_extension: str, ): - """ - Checks if the charting extension defined on user preferences is the same as the one defined in the env file. - """ + """Checks if the charting extension defined on user preferences is the same as the one defined in the env file.""" if user_preferences_charting_extension != EXTENSION_NAME: raise ChartingServiceError( f"The charting extension defined on user preferences must be the same as the one defined in the env file." @@ -98,49 +98,16 @@ class ChartingService(metaclass=SingletonMeta): return user_preferences_charting_extension @staticmethod - def _check_charting_extension_installed( - charting_extension: str, plugin: str = POETRY_PLUGIN - ) -> bool: - """ - Checks if charting extension is installed. - Given a charting extension name, it checks if it is installed under the given plugin. - - Parameters - ---------- - charting_extension : str - Charting extension name. - plugin : Optional[str] - Plugin name. - Returns - ------- - bool - Either charting extension is installed or not. - """ - extensions = [ext.name for ext in entry_points(group=plugin)] - - return charting_extension in extensions - - @staticmethod - def _get_extension_router( - extension_name: str, plugin: Optional[str] = POETRY_PLUGIN - ): - """ - Get the module of the given extension. - """ - entry_points_ = entry_points(group=plugin) - entry_point = next( - (ep for ep in entry_points_ if ep.name == extension_name), None - ) - if entry_point is None: - raise ChartingServiceError( - f"Extension '{extension_name}' is not installed." - ) - return import_module(entry_point.module) + def _check_charting_extension_installed(ext_name: str) -> bool: + """Checks if a given extension is installed.""" + extension = ExtensionLoader().get_core_entry_point(ext_name) or False + return extension and ext_name == extension.name # type: ignore @staticmethod def _handle_backend(charting_extension: str, charting_settings: ChartingSettings): """ Handles the backend of the given charting extension. + This function that the module expose in its root (__init__.py) the following functions: - `create_backend(charting_settings: ChartingSettings)` - `get_backend()` @@ -162,9 +129,20 @@ class ChartingService(metaclass=SingletonMeta): get_backend_func().start(debug=charting_settings.debug_mode) @classmethod + def _get_extension_router(cls, extension_name: str): + """Get the module of the given extension.""" + extension = ExtensionLoader().get_core_entry_point(extension_name) + if not extension or extension_name != extension.name: + raise ChartingServiceError( + f"Extension '{extension_name}' is not installed." + ) + return import_module(extension.module) + + @classmethod def _get_chart_format(cls, extension_name: str) -> ChartFormat: """ Given an extension name, it returns the chart format. + The module must contain the `CHART_FORMAT` attribute. """ module = cls._get_extension_router(extension_name) @@ -174,6 +152,7 @@ class ChartingService(metaclass=SingletonMeta): def _get_chart_function(cls, extension_name: str, route: str) -> Callable: """ Given an extension name and a route, it returns the chart function. + The module must contain the given route. """ adjusted_route = route.replace("/", "_")[1:] @@ -184,9 +163,7 @@ class ChartingService(metaclass=SingletonMeta): def get_implemented_charting_functions( cls, extension_name: str = EXTENSION_NAME ) -> List[str]: - """ - Given an extension name, it returns the implemented charting functions from its router. - """ + """Given an extension name, it returns the implemented charting functions from its router.""" implemented_functions = [] try: @@ -226,7 +203,6 @@ class ChartingService(metaclass=SingletonMeta): Exception If the charting extension module does not contain the `to_chart` function. """ - if not self._charting_extension_installed: raise ChartingServiceError( f"Charting extension `{self._charting_extension}` is not installed" @@ -258,6 +234,8 @@ class ChartingService(metaclass=SingletonMeta): **kwargs, ) -> Chart: """ + Given a route and an obbject item, it returns the chart object. + If the charting extension is not installed, an error is raised. Otherwise, a charting function will be retrieved and executed from the user's preferred charting extension. This function assumes that, in order to successfully retrieve the charting function, @@ -276,6 +254,7 @@ class ChartingService(metaclass=SingletonMeta): Route name, example: `/stocks/load`. obbject_item Command output item. + Returns ------- Chart diff --git a/openbb_platform/core/openbb_core/app/extension_loader.py b/openbb_platform/core/openbb_core/app/extension_loader.py new file mode 100644 index 00000000000..4381714f9aa --- /dev/null +++ b/openbb_platform/core/openbb_core/app/extension_loader.py @@ -0,0 +1,157 @@ +"""Extension Loader.""" + +from enum import Enum +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Dict, Optional + +from importlib_metadata import EntryPoint, EntryPoints, entry_points + +from openbb_core.app.model.abstract.singleton import SingletonMeta +from openbb_core.app.model.extension import Extension + +if TYPE_CHECKING: + from openbb_core.app.router import Router + from openbb_core.provider.abstract.provider import Provider + + +class OpenBBGroups(Enum): + """OpenBB Extension Groups.""" + + core = "openbb_core_extension" + provider = "openbb_provider_extension" + obbject = "openbb_obbject_extension" + + +class ExtensionLoader(metaclass=SingletonMeta): + """Extension loader class.""" + + def __init__( + self, + ) -> None: + """Initialize the extension loader.""" + self._obbject_entry_points: EntryPoints = self._sorted_entry_points( + group=OpenBBGroups.obbject.value + ) + self._core_entry_points: EntryPoints = self._sorted_entry_points( + group=OpenBBGroups.core.value + ) + self._provider_entry_points: EntryPoints = self._sorted_entry_points( + group=OpenBBGroups.provider.value + ) + self._obbject_objects: Dict[str, Extension] = {} + self._core_objects: Dict[str, Router] = {} + self._provider_objects: Dict[str, Provider] = {} + + @property + def obbject_entry_points(self) -> EntryPoints: + """Return the obbject entry points.""" + return self._obbject_entry_points + + @property + def core_entry_points(self) -> EntryPoints: + """Return the core entry points.""" + return self._core_entry_points + + @property + def provider_entry_points(self) -> EntryPoints: + """Return the provider entry points.""" + return self._provider_entry_points + + @staticmethod + def _get_entry_point( + entry_points_: EntryPoints, ext_name: str + ) -> Optional[EntryPoint]: + """Given an extension name and a list of entry points, return the corresponding entry point.""" + return next((ep for ep in entry_points_ if ep.name == ext_name), None) + + def get_obbject_entry_point(self, ext_name: str) -> Optional[EntryPoint]: + """Given an extension name, return the corresponding entry point.""" + return self._get_entry_point(self._obbject_entry_points, ext_name) + + def get_core_entry_point(self, ext_name: str) -> Optional[EntryPoint]: + """Given an extension name, return the corresponding entry point.""" + return self._get_entry_point(self._core_entry_points, ext_name) + + def get_provider_entry_point(self, ext_name: str) -> Optional[EntryPoint]: + """Given an extension name, return the corresponding entry point.""" + return self._get_entry_point(self._provider_entry_points, ext_name) + + @property + @lru_cache + def obbject_objects(self) -> Dict[str, Extension]: + """Return a dict of obbject extension objects.""" + self._obbject_objects = self._load_entry_points( + self._obbject_entry_points, OpenBBGroups.obbject + ) + return self._obbject_objects + + @property + @lru_cache + def core_objects(self) -> Dict[str, "Router"]: + """Return a dict of core extension objects.""" + self._core_objects = self._load_entry_points( + self._core_entry_points, OpenBBGroups.core + ) + return self._core_objects + + @property + @lru_cache + def provider_objects(self) -> Dict[str, "Provider"]: + """Return a dict of provider extension objects.""" + self._provider_objects = self._load_entry_points( + self._provider_entry_points, OpenBBGroups.provider + ) + return self._provider_objects + + @staticmethod + def _sorted_entry_points(group: str) -> EntryPoints: + """Return a sorted dictionary of entry points.""" + return sorted(entry_points(group=group)) # type: ignore + + def _load_entry_points( + self, entry_points_: EntryPoints, group: OpenBBGroups + ) -> Dict[str, Any]: + """Return a dict of objects matching the entry points.""" + + def load_obbject(eps: EntryPoints) -> Dict[str, Extension]: + """ + Return a dictionary of obbject objects. + + Keys are entry point names and values are instances of the Extension class. + """ + return { + ep.name: entry + for ep in eps + if isinstance((entry := ep.load()), Extension) + } + + def load_core(eps: EntryPoints) -> Dict[str, "Router"]: + """Return a dictionary of core objects.""" + # pylint: disable=import-outside-toplevel + from openbb_core.app.router import Router + + return { + ep.name: entry for ep in eps if isinstance((entry := ep.load()), Router) + } + + def load_provider(eps: EntryPoints) -> Dict[str, "Provider"]: + """ + Return a dictionary of provider objects. + + Keys are entry point names and values are instances of the Provider class. + """ + # pylint: disable=import-outside-toplevel + from openbb_core.provider.abstract.provider import Provider + + return { + ep.name: entry + for ep in eps + if isinstance((entry := ep.load()), Provider) + } + + func = { + OpenBBGroups.obbject: load_obbject, + OpenBBGroups.core: load_core, + OpenBBGroups.provider: load_provider, + } + return func[group](entry_points_) # type: ignore diff --git a/openbb_platform/core/openbb_core/app/model/credentials.py b/openbb_platform/core/openbb_core/app/model/credentials.py index 9e1de94b860..4841a69a21e 100644 --- a/openbb_platform/core/openbb_core/app/model/credentials.py +++ b/openbb_platform/core/openbb_core/app/model/credentials.py @@ -1,8 +1,8 @@ +"""Credentials model and its utilities.""" import traceback import warnings from typing import Dict, Optional, Set, Tuple -from importlib_metadata import entry_points from pydantic import ( BaseModel, ConfigDict, @@ -13,8 +13,8 @@ from pydantic import ( from pydantic.functional_serializers import PlainSerializer from typing_extensions import Annotated +from openbb_core.app.extension_loader import ExtensionLoader from openbb_core.app.model.abstract.warning import OpenBBWarning -from openbb_core.app.model.extension import Extension from openbb_core.app.provider_interface import ProviderInterface from openbb_core.env import Env @@ -33,7 +33,7 @@ OBBSecretStr = Annotated[ class CredentialsLoader: - """Here we create the Credentials model""" + """Here we create the Credentials model.""" credentials: Dict[str, Set[str]] = {} @@ -41,7 +41,7 @@ class CredentialsLoader: def prepare( credentials: Dict[str, Set[str]], ) -> Dict[str, Tuple[object, None]]: - """Prepare credentials map to be used in the Credentials model""" + """Prepare credentials map to be used in the Credentials model.""" formatted: Dict[str, Tuple[object, None]] = {} for origin, creds in credentials.items(): for c in creds: @@ -58,16 +58,14 @@ class CredentialsLoader: return formatted def from_obbject(self) -> None: - """Load credentials from OBBject extensions""" + """Load credentials from OBBject extensions.""" self.credentials["obbject"] = set() - for entry_point in sorted(entry_points(group="openbb_obbject_extension")): + for name, entry in ExtensionLoader().obbject_objects.items(): try: - entry = entry_point.load() - if isinstance(entry, Extension): - for c in entry.credentials: - self.credentials["obbject"].add(c) + for c in entry.credentials: + self.credentials["obbject"].add(c) except Exception as e: - msg = f"Error loading extension: {entry_point.name}\n" + msg = f"Error loading extension: {name}\n" if Env().DEBUG_MODE: traceback.print_exception(type(e), e, e.__traceback__) raise LoadingError(msg + f"\033[91m{e}\033[0m") from e @@ -77,13 +75,13 @@ class CredentialsLoader: ) def from_providers(self) -> None: - """Load credentials from providers""" + """Load credentials from providers.""" self.credentials["providers"] = set() for c in ProviderInterface().credentials: self.credentials["providers"].add(c) def load(self) -> BaseModel: - """Load credentials from providers""" + """Load credentials from providers.""" # We load providers first to give them priority choosing credential names self.from_providers() self.from_obbject() @@ -98,10 +96,10 @@ _Credentials = CredentialsLoader().load() class Credentials(_Credentials): # type: ignore - """Credentials model used to store provider credentials""" + """Credentials model used to store provider credentials.""" def __repr__(self) -> str: - """String representation of the credentials""" + """String representation of the credentials.""" return ( self.__class__.__name__ + "\n\n" @@ -109,7 +107,7 @@ class Credentials(_Credentials): # type: ignore ) def show(self): - """Unmask credentials and print them""" + """Unmask credentials and print them.""" print( # noqa: T201 self.__class__.__name__ + "\n\n" diff --git a/openbb_platform/core/openbb_core/app/model/extension.py b/openbb_platform/core/openbb_core/app/model/extension.py index d5acda39bfe..73f9484ce87 100644 --- a/openbb_platform/core/openbb_core/app/model/extension.py +++ b/openbb_platform/core/openbb_core/app/model/extension.py @@ -1,11 +1,13 @@ +"""Extension class for OBBject extensions.""" import warnings from typing import Callable, List, Optional class Extension: - """Serves as extension entry point and must be created by each extension package. + """ + Serves as OBBject extension entry point and must be created by each extension package. - See README.md for more information on how to create an extension. + See https://docs.openbb.co/platform/development/developer-guidelines/obbject_extensions. """ def __init__( @@ -37,7 +39,7 @@ class Extension: @staticmethod def register_accessor(name, cls) -> Callable: - """Register a custom accessor""" + """Register a custom accessor.""" def decorator(accessor): if hasattr(cls, name): @@ -56,13 +58,15 @@ class Extension: class CachedAccessor: - """CachedAccessor""" + """CachedAccessor.""" def __init__(self, name: str, accessor) -> None: + """Initialize the cached accessor.""" self._name = name self._accessor = accessor def __get__(self, obj, cls): + """Get the cached accessor.""" if obj is None: return self._accessor accessor_obj = self._accessor(obj) diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index f9a116e9950..9f240f2b513 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -1,3 +1,4 @@ +"""OpenBB Router.""" import traceback import warnings from functools import lru_cache, partial @@ -18,11 +19,11 @@ from typing import ( ) from fastapi import APIRouter, Depends -from importlib_metadata import entry_points from pydantic import BaseModel from pydantic.v1.validators import find_validators from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias +from openbb_core.app.extension_loader import ExtensionLoader from openbb_core.app.model.abstract.warning import OpenBBWarning from openbb_core.app.model.command_context import CommandContext from openbb_core.app.model.obbject import OBBject @@ -45,12 +46,14 @@ class OpenBBErrorResponse(BaseModel): class CommandValidator: + """Validate Command.""" + @staticmethod def is_standard_pydantic_type(value_type: Type) -> bool: """Check whether or not a parameter type is a valid Pydantic Standard Type.""" try: func = next( - find_validators(value_type, config=dict(arbitrary_types_allowed=True)) + find_validators(value_type, config=dict(arbitrary_types_allowed=True)) # type: ignore ) valid_type = func.__name__ != "arbitrary_type_validator" except Exception: @@ -60,6 +63,7 @@ class CommandValidator: @staticmethod def is_valid_pydantic_model_type(model_type: Type) -> bool: + """Check whether or not a parameter type is a valid Pydantic Model Type.""" if not isclass(model_type): return False @@ -73,12 +77,14 @@ class CommandValidator: @classmethod def is_serializable_value_type(cls, value_type: Type) -> bool: + """Check whether or not a parameter type is a valid serializable type.""" return cls.is_standard_pydantic_type( value_type=value_type ) or cls.is_valid_pydantic_model_type(model_type=value_type) @staticmethod def is_annotated_dc(annotation) -> bool: + """Check whether or not a parameter type is an annotated dataclass.""" return isinstance(annotation, _AnnotatedAlias) and hasattr( annotation.__args__[0], "__dataclass_fields__" ) @@ -91,6 +97,7 @@ class CommandValidator: func: Callable, sig: Signature, ): + """Check whether or not a parameter is reserved.""" if name in parameter_map: annotation = getattr(parameter_map[name], "annotation", None) if annotation is not None and CommandValidator.is_annotated_dc(annotation): @@ -105,6 +112,7 @@ class CommandValidator: @classmethod def check_parameters(cls, func: Callable): + """Check whether or not a parameter is a valid.""" sig = signature(func) parameter_map = sig.parameters @@ -129,6 +137,7 @@ class CommandValidator: @classmethod def check_return(cls, func: Callable): + """Check whether or not a return type is a valid.""" sig = signature(func) return_type = sig.return_annotation @@ -162,6 +171,7 @@ class CommandValidator: @classmethod def check(cls, func: Callable, model: str = ""): + """Check whether or not a function is valid.""" if model and not iscoroutinefunction(func): raise TypeError( f"Invalid function: {func.__module__}.{func.__name__}\n" @@ -183,14 +193,18 @@ class CommandValidator: class Router: + """OpenBB Router Class.""" + @property def api_router(self) -> APIRouter: + """API Router.""" return self._api_router def __init__( self, prefix: str = "", ) -> None: + """Initialize Router.""" self._api_router = APIRouter( prefix=prefix, responses={404: {"description": "Not found"}}, @@ -209,6 +223,7 @@ class Router: func: Optional[Callable[P, OBBject]] = None, **kwargs, ) -> Optional[Callable]: + """Command decorator for routes.""" if func is None: return lambda f: self.command(f, **kwargs) @@ -262,6 +277,7 @@ class Router: router: "Router", prefix: str = "", ): + """Include router.""" tags = [prefix[1:]] if prefix else None self._api_router.include_router( router=router.api_router, prefix=prefix, tags=tags # type: ignore @@ -269,12 +285,13 @@ class Router: class SignatureInspector: + """Inspect function signature.""" + @classmethod def complete_signature( cls, func: Callable[P, OBBject], model: str ) -> Optional[Callable[P, OBBject]]: """Complete function signature.""" - if isclass(return_type := func.__annotations__["return"]) and not issubclass( return_type, OBBject ): @@ -343,35 +360,41 @@ class SignatureInspector: def inject_return_type( func: Callable[P, OBBject], inner_type: Any, outer_type: Any ) -> Callable[P, OBBject]: - """Inject full return model into the function. - Also updates __name__ and __doc__ for API schemas.""" + """ + Inject full return model into the function. + + Also updates __name__ and __doc__ for API schemas. + """ ReturnModel = inner_type - if get_origin(outer_type) == l |