diff options
author | montezdesousa <79287829+montezdesousa@users.noreply.github.com> | 2024-04-05 12:21:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-05 11:21:35 +0000 |
commit | 33d186f8eaab1e1836db17f2d1c58ef46ba63bbf (patch) | |
tree | 77b31d03df079b35ac90eec430bcdc992581aa5a | |
parent | 3f08e0f3420c07cfe98a2b0e81d629c1d2ca7978 (diff) |
[Feature] - Router descriptions (#6288)
* First try
* add routers to reference.json
* fix apploader
* add some router descriptions
* rebuild
* remove pydocstyle fixes
* minor fix
* docstring
* new descriptions
* undo sub routers
---------
Co-authored-by: Henrique Joaquim <henriquecjoaquim@gmail.com>
19 files changed, 776 insertions, 111 deletions
diff --git a/openbb_platform/core/openbb_core/api/app_loader.py b/openbb_platform/core/openbb_core/api/app_loader.py index 9c8330af10b..d6ff19878ca 100644 --- a/openbb_platform/core/openbb_core/api/app_loader.py +++ b/openbb_platform/core/openbb_core/api/app_loader.py @@ -1,12 +1,30 @@ +"""App loader module.""" + from typing import List, Optional from fastapi import APIRouter, FastAPI +from openbb_core.app.router import RouterLoader class AppLoader: """App loader.""" @staticmethod + def get_openapi_tags() -> List[dict]: + """Get openapi tags.""" + main_router = RouterLoader.from_extensions() + openapi_tags = [] + # Add tag data for each router in the main router + for r in main_router.routers: + openapi_tags.append( + { + "name": r, + "description": main_router.get_attr(r, "description"), + } + ) + return openapi_tags + + @staticmethod def from_routers( app: FastAPI, routers: List[Optional[APIRouter]], prefix: str ) -> FastAPI: diff --git a/openbb_platform/core/openbb_core/api/rest_api.py b/openbb_platform/core/openbb_core/api/rest_api.py index f166a91be4e..0ad92f613b6 100644 --- a/openbb_platform/core/openbb_core/api/rest_api.py +++ b/openbb_platform/core/openbb_core/api/rest_api.py @@ -67,13 +67,13 @@ app = FastAPI( ], lifespan=lifespan, ) - app.add_middleware( CORSMiddleware, allow_origins=system.api_settings.cors.allow_origins, allow_methods=system.api_settings.cors.allow_methods, allow_headers=system.api_settings.cors.allow_headers, ) +app.openapi_tags = AppLoader.get_openapi_tags() AppLoader.from_routers( app=app, routers=( diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index 218629be8ba..68ac8dec3ec 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -203,15 +203,33 @@ class Router: """API Router.""" return self._api_router + @property + def prefix(self) -> str: + """Prefix.""" + return self._api_router.prefix + + @property + def description(self) -> str: + """Description.""" + return self._description + + @property + def routers(self) -> Dict[str, "Router"]: + """Routers nested within the Router, i.e. sub-routers.""" + return self._routers + def __init__( self, prefix: str = "", + description: str = "", ) -> None: """Initialize Router.""" self._api_router = APIRouter( prefix=prefix, responses={404: {"description": "Not found"}}, ) + self._description = description + self._routers: Dict[str, Router] = {} @overload def command(self, func: Optional[Callable[P, OBBject]]) -> Callable[P, OBBject]: @@ -290,10 +308,41 @@ class Router: prefix: str = "", ): """Include router.""" - tags = [prefix[1:]] if prefix else None + tags = [prefix.strip("/")] if prefix else None self._api_router.include_router( router=router.api_router, prefix=prefix, tags=tags # type: ignore ) + name = prefix if prefix else router.prefix + self._routers[name.strip("/")] = router + + def get_attr(self, path: str, attr: str) -> Any: + """Get router attribute from path. + + Parameters + ---------- + path : str + Path to the router or nested router. + E.g. "/equity" or "/equity/price". + attr : str + Attribute to get. + + Returns + ------- + Any + Attribute value. + """ + return self._search_attr(self, path, attr) + + @staticmethod + def _search_attr(router: "Router", path: str, attr: str) -> Any: + """Recursively search router attribute from path.""" + path = path.strip("/") + first = path.split("/")[0] + if first in router.routers: + return Router._search_attr( + router.routers[first], "/".join(path.split("/")[1:]), attr + ) + return getattr(router, attr, None) class SignatureInspector: diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index f0a8e157b55..98a6996b8c5 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -87,6 +87,8 @@ class PackageBuilder: self.lint = lint self.verbose = verbose self.console = Console(verbose) + self.route_map = PathHandler.build_route_map() + self.path_list = PathHandler.build_path_list(route_map=self.route_map) def auto_build(self) -> None: """Trigger build if there are differences between built and installed extensions.""" @@ -152,28 +154,26 @@ class PackageBuilder: ): """Save the modules.""" self.console.log("\nWriting modules...") - route_map = PathHandler.build_route_map() - path_list = PathHandler.build_path_list(route_map=route_map) - if not path_list: + if not self.path_list: self.console.log("\nThere is nothing to write.") return - MAX_LEN = max([len(path) for path in path_list if path != "/"]) + MAX_LEN = max([len(path) for path in self.path_list if path != "/"]) - if modules: - path_list = [path for path in path_list if path in modules] + _path_list = ( + [path for path in self.path_list if path in modules] + if modules + else self.path_list + ) - for path in path_list: - route = PathHandler.get_route(path=path, route_map=route_map) + for path in _path_list: + route = PathHandler.get_route(path, self.route_map) if route is None: - module_code = ModuleBuilder.build( - path=path, - ext_map=ext_map, - ) - module_name = PathHandler.build_module_name(path=path) + code = ModuleBuilder.build(path, ext_map) + name = PathHandler.build_module_name(path) self.console.log(f"({path})", end=" " * (MAX_LEN - len(path))) - self._write(code=module_code, name=module_name) + self._write(code, name) def _save_package(self): """Save the package.""" @@ -184,7 +184,6 @@ class PackageBuilder: def _save_reference_file(self, ext_map: Optional[Dict[str, List[str]]] = None): """Save the reference.json file.""" self.console.log("\nWriting reference file...") - data = ReferenceGenerator.get_reference_data() code = dumps( obj={ "openbb": VERSION.replace("dev", ""), @@ -194,7 +193,8 @@ class PackageBuilder: "core": CORE_VERSION.replace("dev", ""), "extensions": ext_map, }, - "paths": data, + "paths": ReferenceGenerator.get_paths(self.route_map), + "routers": ReferenceGenerator.get_routers(self.route_map), }, indent=4, ) @@ -284,7 +284,7 @@ class ModuleBuilder: def build(path: str, ext_map: Optional[Dict[str, List[str]]] = None) -> str: """Build the module.""" code = "### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ###\n\n" - code += ImportDefinition.build(path=path) + code += ImportDefinition.build(path) code += ClassDefinition.build(path, ext_map) return code @@ -321,7 +321,7 @@ class ImportDefinition: hint_type = get_args(get_type_hints(return_type)["results"])[0] hint_type_list.append(hint_type) - hint_type_list = cls.filter_hint_type_list(hint_type_list=hint_type_list) + hint_type_list = cls.filter_hint_type_list(hint_type_list) return hint_type_list @@ -401,7 +401,7 @@ class ClassDefinition: code = f"class {class_name}(Container):\n" route_map = PathHandler.build_route_map() - path_list = PathHandler.build_path_list(route_map=route_map) + path_list = PathHandler.build_path_list(route_map) child_path_list = sorted( PathHandler.get_child_path_list( path=path, @@ -412,7 +412,7 @@ class ClassDefinition: doc = f' """{path}\n' if path else ' # fmt: off\n """\nRouters:\n' methods = "" for c in child_path_list: - route = PathHandler.get_route(path=c, route_map=route_map) + route = PathHandler.get_route(c, route_map) if route: doc += f" {route.name}\n" methods += MethodDefinition.build_command_method( @@ -922,17 +922,17 @@ class DocstringGenerator: Parameters ---------- - field_type (Any): - Typing object containing the field type. - is_required (bool): - Flag to indicate if the field is required. - target (Literal["docstring", "website"], optional): - Target to return type for. Defaults to "docstring". + field_type : Any + Typing object containing the field type. + is_required : bool + Flag to indicate if the field is required. + target : Literal["docstring", "website"] + Target to return type for. Defaults to "docstring". Returns ------- - str: - String representation of the field type. + str + String representation of the field type. """ is_optional = not is_required @@ -1256,7 +1256,7 @@ class ReferenceGenerator: pi = DocstringGenerator.provider_interface @classmethod - def get_endpoint_examples( + def _get_endpoint_examples( cls, path: str, func: Callable, @@ -1269,18 +1269,18 @@ class ReferenceGenerator: Parameters ---------- - path (str): - Path of the router. - func (Callable): - Router endpoint function. - examples (Optional[List[Example]]): - List of Examples (APIEx or PythonEx type) - for the endpoint. + path : str + Path of the router. + func : Callable + Router endpoint function. + examples : Optional[List[Example]] + List of Examples (APIEx or PythonEx type) + for the endpoint. Returns ------- - str: - Formatted string containing the examples for the endpoint. + str: + Formatted string containing the examples for the endpoint. """ sig = signature(func) parameter_map = dict(sig.parameters) @@ -1299,18 +1299,18 @@ class ReferenceGenerator: ) @classmethod - def get_provider_parameter_info(cls, model: str) -> Dict[str, str]: + def _get_provider_parameter_info(cls, model: str) -> Dict[str, str]: """Get the name, type, description, default value and optionality information for the provider parameter. Parameters ---------- - model (str): - Standard model to access the model providers. + model : str + Standard model to access the model providers. Returns ------- - Dict[str, str]: - Dictionary of the provider parameter information + Dict[str, str] + Dictionary of the provider parameter information """ pi_model_provider = cls.pi.model_providers[model] provider_params_field = pi_model_provider.__dataclass_fields__["provider"] @@ -1337,7 +1337,7 @@ class ReferenceGenerator: return provider_parameter_info @classmethod - def get_provider_field_params( + def _get_provider_field_params( cls, model: str, params_type: str, @@ -1347,18 +1347,18 @@ class ReferenceGenerator: Parameters ---------- - model (str): - Model name to access the provider interface - params_type (str): - Parameters to fetch data for (QueryParams or Data) - provider (str, optional): - Provider name. Defaults to "openbb". + model : str + Model name to access the provider interface + params_type : str + Parameters to fetch data for (QueryParams or Data) + provider : str + Provider name. Defaults to "openbb". Returns ------- - List[Dict[str, str]]: - List of dictionaries containing the field name, type, description, default, - optional flag and standard flag for each provider. + List[Dict[str, str]] + List of dictionaries containing the field name, type, description, default, + optional flag and standard flag for each provider. """ provider_field_params = [] expanded_types = MethodDefinition.TYPE_EXPANSION @@ -1412,24 +1412,24 @@ class ReferenceGenerator: return provider_field_params @staticmethod - def get_obbject_returns_fields( + def _get_obbject_returns_fields( model: str, providers: str, ) -> List[Dict[str, str]]: """Get the fields of the OBBject returns object for the given standard_model. - Args - ---- - model (str): - Standard model of the returned object. - providers (str): - Available providers for the model. + Parameters + ---------- + model : str + Standard model of the returned object. + providers : str + Available providers for the model. Returns ------- - List[Dict[str, str]]: - List of dictionaries containing the field name, type, description, default - and optionality of each field. + List[Dict[str, str]] + List of dictionaries containing the field name, type, description, default + and optionality of each field. """ obbject_list = [ { @@ -1462,21 +1462,21 @@ class ReferenceGenerator: return obbject_list @staticmethod - def get_post_method_parameters_info( + def _get_post_method_parameters_info( docstring: str, ) -> List[Dict[str, Union[bool, str]]]: """Get the parameters for the POST method endpoints. Parameters ---------- - docstring (str): - Router endpoint function's docstring + docstring : str + Router endpoint function's docstring Returns ------- - List[Dict[str, str]]: - List of dictionaries containing the name,type, description, default - and optionality of each parameter. + List[Dict[str, str]] + List of dictionaries containing the name,type, description, default + and optionality of each parameter. """ parameters_list = [] @@ -1518,19 +1518,19 @@ class ReferenceGenerator: return parameters_list @staticmethod - def get_post_method_returns_info(docstring: str) -> List[Dict[str, str]]: + def _get_post_method_returns_info(docstring: str) -> List[Dict[str, str]]: """Get the returns information for the POST method endpoints. Parameters ---------- - docstring (str): - Router endpoint function's docstring + docstring: str + Router endpoint function's docstring Returns ------- - List[Dict[str, str]]: - Single element list having a dictionary containing the name, type, - description of the return value + List[Dict[str, str]] + Single element list having a dictionary containing the name, type, + description of the return value """ returns_list = [] @@ -1559,8 +1559,8 @@ class ReferenceGenerator: return returns_list @classmethod - def get_reference_data(cls) -> Dict[str, Dict[str, Any]]: - """Get the reference data for the Platform. + def get_paths(cls, route_map: Dict[str, BaseRoute]) -> Dict[str, Dict[str, Any]]: + """Get path reference data. The reference data is a dictionary containing the description, parameters, returns and examples for each endpoint. This is currently useful for @@ -1568,12 +1568,11 @@ class ReferenceGenerator: Returns ------- - Dict[str, Dict[str, Any]]: - Dictionary containing the description, parameters, returns and - examples for each endpoint. + Dict[str, Dict[str, Any]] + Dictionary containing the description, parameters, returns and + examples for each endpoint. """ reference: Dict[str, Dict] = {} - route_map = PathHandler.build_route_map() for path, route in route_map.items(): # Initialize the reference fields as empty dictionaries @@ -1595,7 +1594,7 @@ class ReferenceGenerator: } # Add endpoint examples examples = openapi_extra.get("examples", []) - reference[path]["examples"] = cls.get_endpoint_examples( + reference[path]["examples"] = cls._get_endpoint_examples( path, route_func, examples, # type: ignore @@ -1614,10 +1613,12 @@ class ReferenceGenerator: if provider == "openbb": # openbb provider is always present hence its the standard field reference[path]["parameters"]["standard"] = ( - cls.get_provider_field_params(standard_model, "QueryParams") + cls._get_provider_field_params( + standard_model, "QueryParams" + ) ) # Add `provider` parameter fields to the openbb provider - provider_parameter_fields = cls.get_provider_parameter_info( + provider_parameter_fields = cls._get_provider_parameter_info( standard_model ) reference[path]["parameters"]["standard"].append( @@ -1626,23 +1627,23 @@ class ReferenceGenerator: # Add endpoint data fields for standard provider reference[path]["data"]["standard"] = ( - cls.get_provider_field_params(standard_model, "Data") + cls._get_provider_field_params(standard_model, "Data") ) continue # Adds provider specific parameter fields to the reference reference[path]["parameters"][provider] = ( - cls.get_provider_field_params( + cls._get_provider_field_params( standard_model, "QueryParams", provider ) ) # Adds provider specific data fields to the reference - reference[path]["data"][provider] = cls.get_provider_field_params( + reference[path]["data"][provider] = cls._get_provider_field_params( standard_model, "Data", provider ) # Add endpoint returns data # Currently only OBBject object is returned providers = provider_parameter_fields["type"] - reference[path]["returns"]["OBBject"] = cls.get_obbject_returns_fields( + reference[path]["returns"]["OBBject"] = cls._get_obbject_returns_fields( standard_model, providers ) # Add data for the endpoints without a standard model (data processing endpoints) @@ -1656,12 +1657,35 @@ class ReferenceGenerator: reference[path]["description"] = re.sub(" +", " ", description) # Add endpoint parameters fields for POST methods reference[path]["parameters"]["standard"] = ( - ReferenceGenerator.get_post_method_parameters_info(docstring) + cls._get_post_method_parameters_info(docstring) ) # Add endpoint returns data # Currently only OBBject object is returned reference[path]["returns"]["OBBject"] = ( - cls.get_post_method_returns_info(docstring) + cls._get_post_method_returns_info(docstring) ) return reference + + @classmethod + def get_routers(cls, route_map: Dict[str, BaseRoute]) -> Dict[str, Dict[str, Any]]: + """Get router reference data. + + Parameters + ---------- + route_map : Dict[str, BaseRoute] + Dictionary containing the path and route object for the router. + + Returns + ------- + Dict[str, Dict[str, Any]] + Dictionary containing the description for each router. + """ + main_router = RouterLoader.from_extensions() + routers = {} + for path in route_map: + # Strip the command name from the path + _path = "/".join(path.split("/")[:-1]) + if description := main_router.get_attr(_path, "description"): + routers[_path] = {"description": description} + return routers diff --git a/openbb_platform/extensions/commodity/openbb_commodity/commodity_router.py b/openbb_platform/extensions/commodity/openbb_commodity/commodity_router.py index ebf5401c28f..3d8482d6852 100644 --- a/openbb_platform/extensions/commodity/openbb_commodity/commodity_router.py +++ b/openbb_platform/extensions/commodity/openbb_commodity/commodity_router.py @@ -12,7 +12,7 @@ from openbb_core.app.query import Query from openbb_core.app.router import Router from pydantic import BaseModel -router = Router(prefix="") +router = Router(prefix="", description="Commodity market data.") # pylint: disable=unused-argument diff --git a/openbb_platform/extensions/crypto/openbb_crypto/crypto_router.py b/openbb_platform/extensions/crypto/openbb_crypto/crypto_router.py index 1f92af05c97..bc81b7db88c 100644 --- a/openbb_platform/extensions/crypto/openbb_crypto/crypto_router.py +++ b/openbb_platform/extensions/crypto/openbb_crypto/crypto_router.py @@ -13,7 +13,7 @@ from openbb_core.app.router import Router from openbb_crypto.price.price_router import router as price_router -router = Router(prefix="") +router = Router(prefix="", description="Cryptocurrency market data.") router.include_router(price_router) diff --git a/openbb_platform/extensions/currency/openbb_currency/currency_router.py |