diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/app/router.py')
-rw-r--r-- | openbb_platform/core/openbb_core/app/router.py | 61 |
1 files changed, 53 insertions, 8 deletions
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index 3968d5892a0..51bfc836d07 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -203,15 +203,33 @@ class Router: """API Router.""" return self._api_router + @property + def prefix(self) -> str: + """Prefix.""" + return self._api_router.prefix + + @property + def description(self) -> Optional[str]: + """Description.""" + return self._description + + @property + def routers(self) -> Dict[str, "Router"]: + """Routers nested within the Router, i.e. sub-routers.""" + return self._routers + def __init__( self, prefix: str = "", + description: Optional[str] = None, ) -> None: """Initialize Router.""" self._api_router = APIRouter( prefix=prefix, responses={404: {"description": "Not found"}}, ) + self._description = description + self._routers: Dict[str, Router] = {} @overload def command(self, func: Optional[Callable[P, OBBject]]) -> Callable[P, OBBject]: @@ -290,10 +308,41 @@ class Router: prefix: str = "", ): """Include router.""" - tags = [prefix[1:]] if prefix else None + tags = [prefix.strip("/")] if prefix else None self._api_router.include_router( router=router.api_router, prefix=prefix, tags=tags # type: ignore ) + name = prefix if prefix else router.prefix + self._routers[name.strip("/")] = router + + def get_attr(self, path: str, attr: str) -> Any: + """Get router attribute from path. + + Parameters + ---------- + path : str + Path to the router or nested router. + E.g. "/equity" or "/equity/price". + attr : str + Attribute to get. + + Returns + ------- + Any + Attribute value. + """ + return self._search_attr(self, path, attr) + + @staticmethod + def _search_attr(router: "Router", path: str, attr: str) -> Any: + """Recursively search router attribute from path.""" + path = path.strip("/") + first = path.split("/")[0] + if first in router.routers: + return Router._search_attr( + router.routers[first], "/".join(path.split("/")[1:]), attr + ) + return getattr(router, attr, None) class SignatureInspector: @@ -350,7 +399,7 @@ class SignatureInspector: func = cls.inject_return_type( func=func, - return_map=provider_interface.return_map.get(model), + return_map=provider_interface.return_map.get(model, {}), model=model, ) @@ -374,11 +423,7 @@ class SignatureInspector: return_map: Dict[str, dict], model: str, ) -> Callable[P, OBBject]: - """ - Inject full return model into the function. - Also updates __name__ and __doc__ for API schemas. - """ - + """Inject full return model into the function. Also updates __name__ and __doc__ for API schemas.""" results: Dict[str, Any] = {"list_type": [], "dict_type": []} for provider, return_data in return_map.items(): @@ -397,7 +442,7 @@ class SignatureInspector: if not v: continue - inner_type = SerializeAsAny[ + inner_type: Any = SerializeAsAny[ # type: ignore[misc,valid-type] Annotated[ Union[tuple(v)], # type: ignore Field(discriminator="provider"), |