summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/app/router.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/app/router.py')
-rw-r--r--openbb_platform/core/openbb_core/app/router.py61
1 files changed, 53 insertions, 8 deletions
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py
index 3968d5892a0..51bfc836d07 100644
--- a/openbb_platform/core/openbb_core/app/router.py
+++ b/openbb_platform/core/openbb_core/app/router.py
@@ -203,15 +203,33 @@ class Router:
"""API Router."""
return self._api_router
+ @property
+ def prefix(self) -> str:
+ """Prefix."""
+ return self._api_router.prefix
+
+ @property
+ def description(self) -> Optional[str]:
+ """Description."""
+ return self._description
+
+ @property
+ def routers(self) -> Dict[str, "Router"]:
+ """Routers nested within the Router, i.e. sub-routers."""
+ return self._routers
+
def __init__(
self,
prefix: str = "",
+ description: Optional[str] = None,
) -> None:
"""Initialize Router."""
self._api_router = APIRouter(
prefix=prefix,
responses={404: {"description": "Not found"}},
)
+ self._description = description
+ self._routers: Dict[str, Router] = {}
@overload
def command(self, func: Optional[Callable[P, OBBject]]) -> Callable[P, OBBject]:
@@ -290,10 +308,41 @@ class Router:
prefix: str = "",
):
"""Include router."""
- tags = [prefix[1:]] if prefix else None
+ tags = [prefix.strip("/")] if prefix else None
self._api_router.include_router(
router=router.api_router, prefix=prefix, tags=tags # type: ignore
)
+ name = prefix if prefix else router.prefix
+ self._routers[name.strip("/")] = router
+
+ def get_attr(self, path: str, attr: str) -> Any:
+ """Get router attribute from path.
+
+ Parameters
+ ----------
+ path : str
+ Path to the router or nested router.
+ E.g. "/equity" or "/equity/price".
+ attr : str
+ Attribute to get.
+
+ Returns
+ -------
+ Any
+ Attribute value.
+ """
+ return self._search_attr(self, path, attr)
+
+ @staticmethod
+ def _search_attr(router: "Router", path: str, attr: str) -> Any:
+ """Recursively search router attribute from path."""
+ path = path.strip("/")
+ first = path.split("/")[0]
+ if first in router.routers:
+ return Router._search_attr(
+ router.routers[first], "/".join(path.split("/")[1:]), attr
+ )
+ return getattr(router, attr, None)
class SignatureInspector:
@@ -350,7 +399,7 @@ class SignatureInspector:
func = cls.inject_return_type(
func=func,
- return_map=provider_interface.return_map.get(model),
+ return_map=provider_interface.return_map.get(model, {}),
model=model,
)
@@ -374,11 +423,7 @@ class SignatureInspector:
return_map: Dict[str, dict],
model: str,
) -> Callable[P, OBBject]:
- """
- Inject full return model into the function.
- Also updates __name__ and __doc__ for API schemas.
- """
-
+ """Inject full return model into the function. Also updates __name__ and __doc__ for API schemas."""
results: Dict[str, Any] = {"list_type": [], "dict_type": []}
for provider, return_data in return_map.items():
@@ -397,7 +442,7 @@ class SignatureInspector:
if not v:
continue
- inner_type = SerializeAsAny[
+ inner_type: Any = SerializeAsAny[ # type: ignore[misc,valid-type]
Annotated[
Union[tuple(v)], # type: ignore
Field(discriminator="provider"),