summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/app/router.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/app/router.py')
-rw-r--r--openbb_platform/core/openbb_core/app/router.py95
1 files changed, 68 insertions, 27 deletions
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py
index f9a116e9950..9f240f2b513 100644
--- a/openbb_platform/core/openbb_core/app/router.py
+++ b/openbb_platform/core/openbb_core/app/router.py
@@ -1,3 +1,4 @@
+"""OpenBB Router."""
import traceback
import warnings
from functools import lru_cache, partial
@@ -18,11 +19,11 @@ from typing import (
)
from fastapi import APIRouter, Depends
-from importlib_metadata import entry_points
from pydantic import BaseModel
from pydantic.v1.validators import find_validators
from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias
+from openbb_core.app.extension_loader import ExtensionLoader
from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.obbject import OBBject
@@ -45,12 +46,14 @@ class OpenBBErrorResponse(BaseModel):
class CommandValidator:
+ """Validate Command."""
+
@staticmethod
def is_standard_pydantic_type(value_type: Type) -> bool:
"""Check whether or not a parameter type is a valid Pydantic Standard Type."""
try:
func = next(
- find_validators(value_type, config=dict(arbitrary_types_allowed=True))
+ find_validators(value_type, config=dict(arbitrary_types_allowed=True)) # type: ignore
)
valid_type = func.__name__ != "arbitrary_type_validator"
except Exception:
@@ -60,6 +63,7 @@ class CommandValidator:
@staticmethod
def is_valid_pydantic_model_type(model_type: Type) -> bool:
+ """Check whether or not a parameter type is a valid Pydantic Model Type."""
if not isclass(model_type):
return False
@@ -73,12 +77,14 @@ class CommandValidator:
@classmethod
def is_serializable_value_type(cls, value_type: Type) -> bool:
+ """Check whether or not a parameter type is a valid serializable type."""
return cls.is_standard_pydantic_type(
value_type=value_type
) or cls.is_valid_pydantic_model_type(model_type=value_type)
@staticmethod
def is_annotated_dc(annotation) -> bool:
+ """Check whether or not a parameter type is an annotated dataclass."""
return isinstance(annotation, _AnnotatedAlias) and hasattr(
annotation.__args__[0], "__dataclass_fields__"
)
@@ -91,6 +97,7 @@ class CommandValidator:
func: Callable,
sig: Signature,
):
+ """Check whether or not a parameter is reserved."""
if name in parameter_map:
annotation = getattr(parameter_map[name], "annotation", None)
if annotation is not None and CommandValidator.is_annotated_dc(annotation):
@@ -105,6 +112,7 @@ class CommandValidator:
@classmethod
def check_parameters(cls, func: Callable):
+ """Check whether or not a parameter is a valid."""
sig = signature(func)
parameter_map = sig.parameters
@@ -129,6 +137,7 @@ class CommandValidator:
@classmethod
def check_return(cls, func: Callable):
+ """Check whether or not a return type is a valid."""
sig = signature(func)
return_type = sig.return_annotation
@@ -162,6 +171,7 @@ class CommandValidator:
@classmethod
def check(cls, func: Callable, model: str = ""):
+ """Check whether or not a function is valid."""
if model and not iscoroutinefunction(func):
raise TypeError(
f"Invalid function: {func.__module__}.{func.__name__}\n"
@@ -183,14 +193,18 @@ class CommandValidator:
class Router:
+ """OpenBB Router Class."""
+
@property
def api_router(self) -> APIRouter:
+ """API Router."""
return self._api_router
def __init__(
self,
prefix: str = "",
) -> None:
+ """Initialize Router."""
self._api_router = APIRouter(
prefix=prefix,
responses={404: {"description": "Not found"}},
@@ -209,6 +223,7 @@ class Router:
func: Optional[Callable[P, OBBject]] = None,
**kwargs,
) -> Optional[Callable]:
+ """Command decorator for routes."""
if func is None:
return lambda f: self.command(f, **kwargs)
@@ -262,6 +277,7 @@ class Router:
router: "Router",
prefix: str = "",
):
+ """Include router."""
tags = [prefix[1:]] if prefix else None
self._api_router.include_router(
router=router.api_router, prefix=prefix, tags=tags # type: ignore
@@ -269,12 +285,13 @@ class Router:
class SignatureInspector:
+ """Inspect function signature."""
+
@classmethod
def complete_signature(
cls, func: Callable[P, OBBject], model: str
) -> Optional[Callable[P, OBBject]]:
"""Complete function signature."""
-
if isclass(return_type := func.__annotations__["return"]) and not issubclass(
return_type, OBBject
):
@@ -343,35 +360,41 @@ class SignatureInspector:
def inject_return_type(
func: Callable[P, OBBject], inner_type: Any, outer_type: Any
) -> Callable[P, OBBject]:
- """Inject full return model into the function.
- Also updates __name__ and __doc__ for API schemas."""
+ """
+ Inject full return model into the function.
+
+ Also updates __name__ and __doc__ for API schemas.
+ """
ReturnModel = inner_type
- if get_origin(outer_type) == list:
+ outer_type_origin = get_origin(outer_type)
+
+ if outer_type_origin == list:
ReturnModel = List[inner_type] # type: ignore
- elif get_origin(outer_type) == Union:
+ elif outer_type_origin == Union:
ReturnModel = Union[List[inner_type], inner_type] # type: ignore
return_type = OBBject[ReturnModel] # type: ignore
return_type.__name__ = f"OBBject[{inner_type.__name__}]"
return_type.__doc__ = f"OBBject with results of type '{inner_type.__name__}'."
- return_type.model_rebuild(force=True)
func.__annotations__["return"] = return_type
return func
@staticmethod
def polish_return_schema(func: Callable[P, OBBject]) -> Callable[P, OBBject]:
- """Polish API schemas by filling __doc__ and __name__"""
+ """Polish API schemas by filling `__doc__` and `__name__`."""
return_type = func.__annotations__["return"]
is_list = False
results_type = get_type_hints(return_type)["results"]
+ results_type_args = get_args(results_type)
if not isinstance(results_type, type(None)):
- results_type = get_args(results_type)[0]
+ results_type = results_type_args[0]
is_list = get_origin(results_type) == list
- args = get_args(results_type)
- inner_type = args[0] if is_list and args else results_type
+ inner_type = (
+ results_type_args[0] if is_list and results_type_args else results_type
+ )
inner_type_name = getattr(inner_type, "__name__", inner_type)
func.__annotations__["return"].__doc__ = "OBBject"
@@ -417,7 +440,7 @@ class SignatureInspector:
@staticmethod
def get_operation_id(func: Callable) -> str:
- """Get operation id"""
+ """Get operation id."""
operation_id = [
t.replace("_router", "").replace("openbb_", "")
for t in func.__module__.split(".") + [func.__name__]
@@ -432,38 +455,51 @@ class CommandMap:
def __init__(
self, router: Optional[Router] = None, coverage_sep: Optional[str] = None
) -> None:
+ """Initialize CommandMap."""
self._router = router or RouterLoader.from_extensions()
self._map = self.get_command_map(router=self._router)
- self._provider_coverage = self.get_provider_coverage(
- router=self._router, sep=coverage_sep
- )
- self._command_coverage = self.get_command_coverage(
- router=self._router, sep=coverage_sep
- )
- self._commands_model = self.get_commands_model(
- router=self._router, sep=coverage_sep
- )
+ self._provider_coverage: Dict[str, List[str]] = {}
+ self._command_coverage: Dict[str, List[str]] = {}
+ self._commands_model: Dict[str, str] = {}
+ self._coverage_sep = coverage_sep
@property
def map(self) -> Dict[str, Callable]:
+ """Get command map."""
return self._map
@property
def provider_coverage(self) -> Dict[str, List[str]]:
+ """Get provider coverage."""
+ if not self._provider_coverage:
+ self._provider_coverage = self.get_provider_coverage(
+ router=self._router, sep=self._coverage_sep
+ )
return self._provider_coverage
@property
def command_coverage(self) -> Dict[str, List[str]]:
+ """Get command coverage."""
+ if not self._command_coverage:
+ self._command_coverage = self.get_command_coverage(
+ router=self._router, sep=self._coverage_sep
+ )
return self._command_coverage
@property
def commands_model(self) -> Dict[str, str]:
+ """Get commands model."""
+ if not self._commands_model:
+ self._commands_model = self.get_commands_model(
+ router=self._router, sep=self._coverage_sep
+ )
return self._commands_model
@staticmethod
def get_command_map(
router: Router,
) -> Dict[str, Callable]:
+ """Get command map."""
api_router = router.api_router
command_map = {route.path: route.endpoint for route in api_router.routes} # type: ignore
return command_map
@@ -472,6 +508,7 @@ class CommandMap:
def get_provider_coverage(
router: Router, sep: Optional[str] = None
) -> Dict[str, List[str]]:
+ """Get provider coverage."""
api_router = router.api_router
mapping = ProviderInterface().map
@@ -502,6 +539,7 @@ class CommandMap:
def get_command_coverage(
router: Router, sep: Optional[str] = None
) -> Dict[str, List[str]]:
+ """Get command coverage."""
api_router = router.api_router
mapping = ProviderInterface().map
@@ -527,6 +565,7 @@ class CommandMap:
@staticmethod
def get_commands_model(router: Router, sep: Optional[str] = None) -> Dict[str, str]:
+ """Get commands model."""
api_router = router.api_router
coverage_map: Dict[Any, Any] = {}
@@ -544,6 +583,7 @@ class CommandMap:
return coverage_map
def get_command(self, route: str) -> Optional[Callable]:
+ """Get command from route."""
return self._map.get(route, None)
@@ -552,18 +592,19 @@ class LoadingError(Exception):
class RouterLoader:
+ """Router Loader."""
+
@staticmethod
@lru_cache
def from_extensions() -> Router:
+ """Load routes from extensions."""
router = Router()
- for entry_point in sorted(entry_points(group="openbb_core_extension")):
+ for name, entry in ExtensionLoader().core_objects.items():
try:
- entry = entry_point.load()
- if isinstance(entry, Router):
- router.include_router(router=entry, prefix=f"/{entry_point.name}")
+ router.include_router(router=entry, prefix=f"/{name}")
except Exception as e:
- msg = f"Error loading extension: {entry_point.name}\n"
+ msg = f"Error loading extension: {name}\n"
if Env().DEBUG_MODE:
traceback.print_exception(type(e), e, e.__traceback__)
raise LoadingError(msg + f"\033[91m{e}\033[0m") from e