diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/app/router.py')
-rw-r--r-- | openbb_platform/core/openbb_core/app/router.py | 95 |
1 files changed, 68 insertions, 27 deletions
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index f9a116e9950..9f240f2b513 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -1,3 +1,4 @@ +"""OpenBB Router.""" import traceback import warnings from functools import lru_cache, partial @@ -18,11 +19,11 @@ from typing import ( ) from fastapi import APIRouter, Depends -from importlib_metadata import entry_points from pydantic import BaseModel from pydantic.v1.validators import find_validators from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias +from openbb_core.app.extension_loader import ExtensionLoader from openbb_core.app.model.abstract.warning import OpenBBWarning from openbb_core.app.model.command_context import CommandContext from openbb_core.app.model.obbject import OBBject @@ -45,12 +46,14 @@ class OpenBBErrorResponse(BaseModel): class CommandValidator: + """Validate Command.""" + @staticmethod def is_standard_pydantic_type(value_type: Type) -> bool: """Check whether or not a parameter type is a valid Pydantic Standard Type.""" try: func = next( - find_validators(value_type, config=dict(arbitrary_types_allowed=True)) + find_validators(value_type, config=dict(arbitrary_types_allowed=True)) # type: ignore ) valid_type = func.__name__ != "arbitrary_type_validator" except Exception: @@ -60,6 +63,7 @@ class CommandValidator: @staticmethod def is_valid_pydantic_model_type(model_type: Type) -> bool: + """Check whether or not a parameter type is a valid Pydantic Model Type.""" if not isclass(model_type): return False @@ -73,12 +77,14 @@ class CommandValidator: @classmethod def is_serializable_value_type(cls, value_type: Type) -> bool: + """Check whether or not a parameter type is a valid serializable type.""" return cls.is_standard_pydantic_type( value_type=value_type ) or cls.is_valid_pydantic_model_type(model_type=value_type) @staticmethod def is_annotated_dc(annotation) -> bool: + """Check whether or not a parameter type is an annotated dataclass.""" return isinstance(annotation, _AnnotatedAlias) and hasattr( annotation.__args__[0], "__dataclass_fields__" ) @@ -91,6 +97,7 @@ class CommandValidator: func: Callable, sig: Signature, ): + """Check whether or not a parameter is reserved.""" if name in parameter_map: annotation = getattr(parameter_map[name], "annotation", None) if annotation is not None and CommandValidator.is_annotated_dc(annotation): @@ -105,6 +112,7 @@ class CommandValidator: @classmethod def check_parameters(cls, func: Callable): + """Check whether or not a parameter is a valid.""" sig = signature(func) parameter_map = sig.parameters @@ -129,6 +137,7 @@ class CommandValidator: @classmethod def check_return(cls, func: Callable): + """Check whether or not a return type is a valid.""" sig = signature(func) return_type = sig.return_annotation @@ -162,6 +171,7 @@ class CommandValidator: @classmethod def check(cls, func: Callable, model: str = ""): + """Check whether or not a function is valid.""" if model and not iscoroutinefunction(func): raise TypeError( f"Invalid function: {func.__module__}.{func.__name__}\n" @@ -183,14 +193,18 @@ class CommandValidator: class Router: + """OpenBB Router Class.""" + @property def api_router(self) -> APIRouter: + """API Router.""" return self._api_router def __init__( self, prefix: str = "", ) -> None: + """Initialize Router.""" self._api_router = APIRouter( prefix=prefix, responses={404: {"description": "Not found"}}, @@ -209,6 +223,7 @@ class Router: func: Optional[Callable[P, OBBject]] = None, **kwargs, ) -> Optional[Callable]: + """Command decorator for routes.""" if func is None: return lambda f: self.command(f, **kwargs) @@ -262,6 +277,7 @@ class Router: router: "Router", prefix: str = "", ): + """Include router.""" tags = [prefix[1:]] if prefix else None self._api_router.include_router( router=router.api_router, prefix=prefix, tags=tags # type: ignore @@ -269,12 +285,13 @@ class Router: class SignatureInspector: + """Inspect function signature.""" + @classmethod def complete_signature( cls, func: Callable[P, OBBject], model: str ) -> Optional[Callable[P, OBBject]]: """Complete function signature.""" - if isclass(return_type := func.__annotations__["return"]) and not issubclass( return_type, OBBject ): @@ -343,35 +360,41 @@ class SignatureInspector: def inject_return_type( func: Callable[P, OBBject], inner_type: Any, outer_type: Any ) -> Callable[P, OBBject]: - """Inject full return model into the function. - Also updates __name__ and __doc__ for API schemas.""" + """ + Inject full return model into the function. + + Also updates __name__ and __doc__ for API schemas. + """ ReturnModel = inner_type - if get_origin(outer_type) == list: + outer_type_origin = get_origin(outer_type) + + if outer_type_origin == list: ReturnModel = List[inner_type] # type: ignore - elif get_origin(outer_type) == Union: + elif outer_type_origin == Union: ReturnModel = Union[List[inner_type], inner_type] # type: ignore return_type = OBBject[ReturnModel] # type: ignore return_type.__name__ = f"OBBject[{inner_type.__name__}]" return_type.__doc__ = f"OBBject with results of type '{inner_type.__name__}'." - return_type.model_rebuild(force=True) func.__annotations__["return"] = return_type return func @staticmethod def polish_return_schema(func: Callable[P, OBBject]) -> Callable[P, OBBject]: - """Polish API schemas by filling __doc__ and __name__""" + """Polish API schemas by filling `__doc__` and `__name__`.""" return_type = func.__annotations__["return"] is_list = False results_type = get_type_hints(return_type)["results"] + results_type_args = get_args(results_type) if not isinstance(results_type, type(None)): - results_type = get_args(results_type)[0] + results_type = results_type_args[0] is_list = get_origin(results_type) == list - args = get_args(results_type) - inner_type = args[0] if is_list and args else results_type + inner_type = ( + results_type_args[0] if is_list and results_type_args else results_type + ) inner_type_name = getattr(inner_type, "__name__", inner_type) func.__annotations__["return"].__doc__ = "OBBject" @@ -417,7 +440,7 @@ class SignatureInspector: @staticmethod def get_operation_id(func: Callable) -> str: - """Get operation id""" + """Get operation id.""" operation_id = [ t.replace("_router", "").replace("openbb_", "") for t in func.__module__.split(".") + [func.__name__] @@ -432,38 +455,51 @@ class CommandMap: def __init__( self, router: Optional[Router] = None, coverage_sep: Optional[str] = None ) -> None: + """Initialize CommandMap.""" self._router = router or RouterLoader.from_extensions() self._map = self.get_command_map(router=self._router) - self._provider_coverage = self.get_provider_coverage( - router=self._router, sep=coverage_sep - ) - self._command_coverage = self.get_command_coverage( - router=self._router, sep=coverage_sep - ) - self._commands_model = self.get_commands_model( - router=self._router, sep=coverage_sep - ) + self._provider_coverage: Dict[str, List[str]] = {} + self._command_coverage: Dict[str, List[str]] = {} + self._commands_model: Dict[str, str] = {} + self._coverage_sep = coverage_sep @property def map(self) -> Dict[str, Callable]: + """Get command map.""" return self._map @property def provider_coverage(self) -> Dict[str, List[str]]: + """Get provider coverage.""" + if not self._provider_coverage: + self._provider_coverage = self.get_provider_coverage( + router=self._router, sep=self._coverage_sep + ) return self._provider_coverage @property def command_coverage(self) -> Dict[str, List[str]]: + """Get command coverage.""" + if not self._command_coverage: + self._command_coverage = self.get_command_coverage( + router=self._router, sep=self._coverage_sep + ) return self._command_coverage @property def commands_model(self) -> Dict[str, str]: + """Get commands model.""" + if not self._commands_model: + self._commands_model = self.get_commands_model( + router=self._router, sep=self._coverage_sep + ) return self._commands_model @staticmethod def get_command_map( router: Router, ) -> Dict[str, Callable]: + """Get command map.""" api_router = router.api_router command_map = {route.path: route.endpoint for route in api_router.routes} # type: ignore return command_map @@ -472,6 +508,7 @@ class CommandMap: def get_provider_coverage( router: Router, sep: Optional[str] = None ) -> Dict[str, List[str]]: + """Get provider coverage.""" api_router = router.api_router mapping = ProviderInterface().map @@ -502,6 +539,7 @@ class CommandMap: def get_command_coverage( router: Router, sep: Optional[str] = None ) -> Dict[str, List[str]]: + """Get command coverage.""" api_router = router.api_router mapping = ProviderInterface().map @@ -527,6 +565,7 @@ class CommandMap: @staticmethod def get_commands_model(router: Router, sep: Optional[str] = None) -> Dict[str, str]: + """Get commands model.""" api_router = router.api_router coverage_map: Dict[Any, Any] = {} @@ -544,6 +583,7 @@ class CommandMap: return coverage_map def get_command(self, route: str) -> Optional[Callable]: + """Get command from route.""" return self._map.get(route, None) @@ -552,18 +592,19 @@ class LoadingError(Exception): class RouterLoader: + """Router Loader.""" + @staticmethod @lru_cache def from_extensions() -> Router: + """Load routes from extensions.""" router = Router() - for entry_point in sorted(entry_points(group="openbb_core_extension")): + for name, entry in ExtensionLoader().core_objects.items(): try: - entry = entry_point.load() - if isinstance(entry, Router): - router.include_router(router=entry, prefix=f"/{entry_point.name}") + router.include_router(router=entry, prefix=f"/{name}") except Exception as e: - msg = f"Error loading extension: {entry_point.name}\n" + msg = f"Error loading extension: {name}\n" if Env().DEBUG_MODE: traceback.print_exception(type(e), e, e.__traceback__) raise LoadingError(msg + f"\033[91m{e}\033[0m") from e |