diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/app/static/package_builder.py')
-rw-r--r-- | openbb_platform/core/openbb_core/app/static/package_builder.py | 304 |
1 files changed, 181 insertions, 123 deletions
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index 793f82731c6..4559d0762b8 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -45,6 +45,7 @@ from openbb_core.app.provider_interface import ProviderInterface from openbb_core.app.router import RouterLoader from openbb_core.app.static.utils.console import Console from openbb_core.app.static.utils.linters import Linters +from openbb_core.app.version import CORE_VERSION, VERSION from openbb_core.env import Env from openbb_core.provider.abstract.data import Data @@ -86,13 +87,17 @@ class PackageBuilder: self.lint = lint self.verbose = verbose self.console = Console(verbose) + self.route_map = PathHandler.build_route_map() + self.path_list = PathHandler.build_path_list(route_map=self.route_map) def auto_build(self) -> None: """Trigger build if there are differences between built and installed extensions.""" if Env().AUTO_BUILD: - add, remove = PackageBuilder._diff( - self.directory / "assets" / "extension_map.json" + reference = PackageBuilder._read( + self.directory / "assets" / "reference.json" ) + ext_map = reference.get("info", {}).get("extensions", {}) + add, remove = PackageBuilder._diff(ext_map) if add: a = ", ".join(sorted(add)) print(f"Extensions to add: {a}") # noqa: T201 @@ -113,10 +118,9 @@ class PackageBuilder: self.console.log("\nBuilding extensions package...\n") self._clean(modules) ext_map = self._get_extension_map() - self._save_extension_map(ext_map) self._save_modules(modules, ext_map) self._save_package() - self._save_reference_file() + self._save_reference_file(ext_map) if self.lint: self._run_linters() @@ -143,12 +147,6 @@ class PackageBuilder: ] return ext_map - def _save_extension_map(self, ext_map: Dict[str, List[str]]) -> None: - """Save the map of extensions available at build time.""" - code = dumps(obj=dict(sorted(ext_map.items())), indent=4) - self.console.log("Writing extension map...") - self._write(code=code, name="extension_map", extension="json", folder="assets") - def _save_modules( self, modules: Optional[Union[str, List[str]]] = None, @@ -156,28 +154,26 @@ class PackageBuilder: ): """Save the modules.""" self.console.log("\nWriting modules...") - route_map = PathHandler.build_route_map() - path_list = PathHandler.build_path_list(route_map=route_map) - if not path_list: + if not self.path_list: self.console.log("\nThere is nothing to write.") return - MAX_LEN = max([len(path) for path in path_list if path != "/"]) + MAX_LEN = max([len(path) for path in self.path_list if path != "/"]) - if modules: - path_list = [path for path in path_list if path in modules] + _path_list = ( + [path for path in self.path_list if path in modules] + if modules + else self.path_list + ) - for path in path_list: - route = PathHandler.get_route(path=path, route_map=route_map) + for path in _path_list: + route = PathHandler.get_route(path, self.route_map) if route is None: - module_code = ModuleBuilder.build( - path=path, - ext_map=ext_map, - ) - module_name = PathHandler.build_module_name(path=path) + code = ModuleBuilder.build(path, ext_map) + name = PathHandler.build_module_name(path) self.console.log(f"({path})", end=" " * (MAX_LEN - len(path))) - self._write(code=module_code, name=module_name) + self._write(code, name) def _save_package(self): """Save the package.""" @@ -185,11 +181,23 @@ class PackageBuilder: code = "### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ###\n" self._write(code=code, name="__init__") - def _save_reference_file(self): + def _save_reference_file(self, ext_map: Optional[Dict[str, List[str]]] = None): """Save the reference.json file.""" self.console.log("\nWriting reference file...") - data = ReferenceGenerator.get_reference_data() - code = dumps(obj=data, indent=4) + code = dumps( + obj={ + "openbb": VERSION.replace("dev", ""), + "info": { + "title": "OpenBB Platform (Python)", + "description": "This is the OpenBB Platform (Python).", + "core": CORE_VERSION.replace("dev", ""), + "extensions": ext_map, + }, + "paths": ReferenceGenerator.get_paths(self.route_map), + "routers": ReferenceGenerator.get_routers(self.route_map), + }, + indent=4, + ) self._write(code=code, name="reference", extension="json", folder="assets") def _run_linters(self): @@ -224,13 +232,28 @@ class PackageBuilder: return content @staticmethod - def _diff(path: Path) -> Tuple[Set[str], Set[str]]: + def _diff(ext_map: Dict[str, List[str]]) -> Tuple[Set[str], Set[str]]: """Check differences between built and installed extensions. Parameters ---------- - path: Path - The path to the folder where the extension map is stored. + ext_map: Dict[str, List[str]] + Dictionary containing the extensions. + Example: + { + "openbb_core_extension": [ + "commodity@1.0.1", + ... + ], + "openbb_provider_extension": [ + "benzinga@1.1.3", + ... + ], + "openbb_obbject_extension": [ + "openbb_charting@1.0.0", + ... + ] + } Returns ------- @@ -238,8 +261,6 @@ class PackageBuilder: First element: set of installed extensions that are not in the package. Second element: set of extensions in the package that are not installed. """ - ext_map = PackageBuilder._read(path) - add: Set[str] = set() remove: Set[str] = set() groups = OpenBBGroups.groups() @@ -263,7 +284,7 @@ class ModuleBuilder: def build(path: str, ext_map: Optional[Dict[str, List[str]]] = None) -> str: """Build the module.""" code = "### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ###\n\n" - code += ImportDefinition.build(path=path) + code += ImportDefinition.build(path) code += ClassDefinition.build(path, ext_map) return code @@ -300,7 +321,7 @@ class ImportDefinition: hint_type = get_args(get_type_hints(return_type)["results"])[0] hint_type_list.append(hint_type) - hint_type_list = cls.filter_hint_type_list(hint_type_list=hint_type_list) + hint_type_list = cls.filter_hint_type_list(hint_type_list) return hint_type_list @@ -380,7 +401,7 @@ class ClassDefinition: code = f"class {class_name}(Container):\n" route_map = PathHandler.build_route_map() - path_list = PathHandler.build_path_list(route_map=route_map) + path_list = PathHandler.build_path_list(route_map) child_path_list = sorted( PathHandler.get_child_path_list( path=path, @@ -391,7 +412,7 @@ class ClassDefinition: doc = f' """{path}\n' if path else ' # fmt: off\n """\nRouters:\n' methods = "" for c in child_path_list: - route = PathHandler.get_route(path=c, route_map=route_map) + route = PathHandler.get_route(c, route_map) if route: doc += f" {route.name}\n" methods += MethodDefinition.build_command_method( @@ -528,10 +549,12 @@ class MethodDefinition: return getattr(PathHandler.build_route_map()[path], "summary", "") @staticmethod - def reorder_params(params: Dict[str, Parameter]) -> "OrderedDict[str, Parameter]": - """Reorder the params.""" + def reorder_params( + params: Dict[str, Parameter], var_kw: Optional[List[str]] = None + ) -> "OrderedDict[str, Parameter]": + """Reorder the params and make sure VAR_KEYWORD come after 'provider.""" formatted_keys = list(params.keys()) - for k in ["provider", "extra_params"]: + for k in ["provider"] + (var_kw or []): if k in formatted_keys: formatted_keys.remove(k) formatted_keys.append(k) @@ -563,14 +586,11 @@ class MethodDefinition: ) formatted: Dict[str, Parameter] = {} - + var_kw = [] for name, param in parameter_map.items(): if name == "extra_params": formatted[name] = Parameter(name="kwargs", kind=Parameter.VAR_KEYWORD) - elif name == "kwargs": - formatted["**" + name] = Parameter( - name="kwargs", kind=Parameter.VAR_KEYWORD, annotation=Any - ) + var_kw.append(name) elif name == "provider_choices": fields = param.annotation.__args__[0].__dataclass_fields__ field = fields["provider"] @@ -624,12 +644,14 @@ class MethodDefinition: formatted[name] = Parameter( name=name, - kind=Parameter.POSITIONAL_OR_KEYWORD, + kind=param.kind, annotation=updated_type, default=param.default, ) + if param.kind == Parameter.VAR_KEYWORD: + var_kw.append(name) - return MethodDefinition.reorder_params(params=formatted) + return MethodDefinition.reorder_params(params=formatted, var_kw=var_kw) @staticmethod def add_field_custom_annotations( @@ -787,13 +809,18 @@ class MethodDefinition: code += " simplefilter('always', DeprecationWarning)\n" code += f""" warn("{deprecation_message}", category=DeprecationWarning, stacklevel=2)\n\n""" - extra_info = {} + info = {} code += " return self._run(\n" code += f""" "{path}",\n""" code += " **filter_inputs(\n" for name, param in parameter_map.items(): if name == "extra_params": + fields = param.annotation.__args__[0].__dataclass_fields__ + values = {k: k for k in fields} + for k in values: + if extra := MethodDefinition.get_extra(fields[k]): + info[k] = extra code += f" {name}=kwargs,\n" elif name == "provider_choices": field = param.annotation.__args__[0].__dataclass_fields__["provider"] @@ -807,19 +834,18 @@ class MethodDefinition: code += " },\n" elif MethodDefinition.is_annotated_dc(param.annotation): fields = param.annotation.__args__[0].__dataclass_fields__ - value = {k: k for k in fields} + values = {k: k for k in fields} code += f" {name}={{\n" - for k, v in value.items(): + for k, v in values.items(): code += f' "{k}": {v},\n' - # TODO: Extend this to extra_params if extra := MethodDefinition.get_extra(fields[k]): - extra_info[k] = extra + info[k] = extra code += " },\n" else: code += f" {name}={name},\n" - if extra_info: - code += f" extra_info={extra_info},\n" + if info: + code += f" info={info},\n" if MethodDefinition.is_data_processing_function(path): code += " data_processing=True,\n" @@ -896,17 +922,17 @@ class DocstringGenerator: Parameters ---------- - field_type (Any): - Typing object containing the field type. - is_required (bool): - Flag to indicate if the field is required. - target (Literal["docstring", "website"], optional): - Target to return type for. Defaults to "docstring". + field_type : Any + Typing object containing the field type. + is_required : bool + Flag to indicate if the field is required. + target : Literal["docstring", "website"] + Target to return type for. Defaults to "docstring". Returns ------- - str: - String representation of the field type. + str + String representation of the field type. """ is_optional = not is_required @@ -1230,7 +1256,7 @@ class ReferenceGenerator: pi = DocstringGenerator.provider_interface @classmethod - def get_endpoint_examples( + def _get_endpoint_examples( cls, path: str, func: Callable, @@ -1243,18 +1269,18 @@ class ReferenceGenerator: Parameters ---------- - path (str): - Path of the router. - func (Callable): - Router endpoint function. - examples (Optional[List[Example]]): - List of Examples (APIEx or PythonEx type) - for the endpoint. + path : str + Path of the router. + func : Callable + Router endpoint function. + examples : Optional[List[Example]] + List of Examples (APIEx or PythonEx type) + for the endpoint. Returns ------- - str: - Formatted string containing the examples for the endpoint. + str: + Formatted string containing the examples for the endpoint. """ sig = signature(func) parameter_map = dict(sig.parameters) @@ -1273,18 +1299,18 @@ class ReferenceGenerator: ) @classmethod - def get_provider_parameter_info(cls, model: str) -> Dict[str, str]: + def _get_provider_parameter_info(cls, model: str) -> Dict[str, str]: """Get the name, type, description, default value and optionality information for the provider parameter. Parameters ---------- - model (str): - Standard model to access the model providers. + model : str + Standard model to access the model providers. Returns ------- - Dict[str, str]: - Dictionary of the provider parameter information + Dict[str, str] + Dictionary of the provider parameter information """ pi_model_provider = cls.pi.model_providers[model] provider_params_field = pi_model_provider.__dataclass_fields__["provider"] @@ -1311,7 +1337,7 @@ class ReferenceGenerator: return provider_parameter_info @classmethod - def get_provider_field_params( + def _get_provider_field_params( cls, model: str, params_type: str, @@ -1321,18 +1347,18 @@ class ReferenceGenerator: Parameters ---------- - model (str): - Model name to access the provider interface - params_type (str): - Parameters to fetch data for (QueryParams or Data) - provider (str, optional): - Provider name. Defaults to "openbb". + model : str + Model name to access the provider interface + params_type : str + Parameters to fetch data for (QueryParams or Data) + provider : str + Provider name. Defaults to "openbb". Returns ------- - List[Dict[str, str]]: - List of dictionaries containing the field name, type, description, default, - optional flag and standard flag for each provider. + List[Dict[str, str]] + List of dictionaries containing the field name, type, description, default, + optional flag and standard flag for each provider. """ provider_field_params = [] expanded_types = MethodDefinition.TYPE_EXPANSION @@ -1386,24 +1412,24 @@ class ReferenceGenerator: return provider_field_params @staticmethod - def get_obbject_returns_fields( + def _get_obbject_returns_fields( model: str, providers: str, ) -> List[Dict[str, str]]: """Get the fields of the OBBject returns object for the given standard_model. - Args - ---- - model (str): - Standard model of the returned object. - providers (str): - Available providers for the model. + Parameters + ---------- + model : str + Standard model of the returned object. + providers : str + Available providers for the model. Returns ------- - List[Dict[str, str]]: - List of dictionaries containing the field name, type, description, default - and optionality of each field. + List[Dict[str, str]] + List of dictionaries containing the field name, type, description, default + and optionality of each field. """ obbject_list = [ { @@ -1436,21 +1462,21 @@ class ReferenceGenerator: return obbject_list @staticmethod - def get_post_method_parameters_info( + def _get_post_method_parameters_info( docstring: str, ) -> List[Dict[str, Union[bool, str]]]: """Get the parameters for the POST method endpoints. Parameters ---------- - docstring (str): - Router endpoint function's docstring + docstring : str + Router endpoint function's docstring Returns ------- - List[Dict[str, str]]: - List of dictionaries containing the name,type, description, default - and optionality of each parameter. + List[Dict[str, str]] + List of dictionaries containing the name,type, description, default + and optionality of each parameter. """ parameters_list = [] @@ -1492,19 +1518,19 @@ class ReferenceGenerator: return parameters_list @staticmethod - def get_post_method_returns_info(docstring: str) -> List[Dict[str, str]]: + def _get_post_method_returns_info(docstring: str) -> List[Dict[str, str]]: """Get the returns information for the POST method endpoints. Parameters ---------- - docstring (str): - Router endpoint function's docstring + docstring: str + Router endpoint function's docstring Returns ------- - List[Dict[str, str]]: - Single element list having a dictionary containing the name, type, - description of the return value + List[Dict[str, str]] + Single element list having a dictionary containing the name, type, + description of the return value """ returns_list = [] @@ -1533,8 +1559,8 @@ class ReferenceGenerator: return returns_list @classmethod - def get_reference_data(cls) -> Dict[str, Dict[str, Any]]: - """Get the reference data for the Platform. + def get_paths(cls, route_map: Dict[str, BaseRoute]) -> Dict[str, Dict[str, Any]]: + """Get path reference data. The reference data is a dictionary containing the description, parameters, returns and examples for each endpoint. This is currently useful for @@ -1542,12 +1568,11 @@ class ReferenceGenerator: Returns ------- - Dict[str, Dict[str, Any]]: - Dictionary containing the description, parameters, returns and - examples for each endpoint. + Dict[str, Dict[str, Any]] + Dictionary containing the description, parameters, returns and + examples for each endpoint. """ reference: Dict[str, Dict] = {} - route_map = PathHandler.build_route_map() for path, route in route_map.items(): # Initialize the reference fields as empty dictionaries @@ -1569,7 +1594,7 @@ class ReferenceGenerator: } # Add endpoint examples examples = openapi_extra.get("examples", []) - reference[path]["examples"] = cls.get_endpoint_examples( + reference[path]["examples"] = cls._get_endpoint_examples( path, route_func, examples, # type: ignore @@ -1588,10 +1613,12 @@ class ReferenceGenerator: if provider == "openbb": # openbb provider is always present hence its the standard field reference[path]["parameters"]["standard"] = ( - cls.get_provider_field_params(standard_model, "QueryParams") + cls._get_provider_field_params( + standard_model, "QueryParams" + ) ) # Add `provider` parameter fields to the openbb provider - provider_parameter_fields = cls.get_provider_parameter_info( + provider_parameter_fields = cls._get_provider_parameter_info( standard_model ) reference[path]["parameters"]["standard"].append( @@ -1600,23 +1627,23 @@ class ReferenceGenerator: # Add endpoint data fields for standard provider reference[path]["data"]["standard"] = ( - cls.get_provider_field_params(standard_model, "Data") + cls._get_provider_field_params(standard_model, "Data") ) continue # Adds provider specific parameter fields to the reference reference[path]["parameters"][provider] = ( - cls.get_provider_field_params( + cls._get_provider_field_params( standard_model, "QueryParams", provider ) ) # Adds provider specific data fields to the reference - reference[path]["data"][provider] = cls.get_provider_field_params( + reference[path]["data"][provider] = cls._get_provider_field_params( standard_model, "Data", provider ) # Add endpoint returns data # Currently only OBBject object is returned providers = provider_parameter_fields["type"] - reference[path]["returns"]["OBBject"] = cls.get_obbject_returns_fields( + reference[path]["returns"]["OBBject"] = cls._get_obbject_returns_fields( standard_model, providers ) # Add data for the endpoints without a standard model (data processing endpoints) @@ -1630,12 +1657,43 @@ class ReferenceGenerator: reference[path]["description"] = re.sub(" +", " ", description) # Add endpoint parameters fields for POST methods reference[path]["parameters"]["standard"] = ( - ReferenceGenerator.get_post_method_parameters_info(docstring) + cls._get_post_method_parameters_info(docstring) ) # Add endpoint returns data # Currently only OBBject object is returned reference[path]["returns"]["OBBject"] = ( - cls.get_post_method_returns_info(docstring) + cls._get_post_method_returns_info(docstring) ) return reference + + @classmethod + def get_routers(cls, route_map: Dict[str, BaseRoute]) -> Dict[str, Dict[str, Any]]: + """Get router reference data. + + Parameters + ---------- + route_map : Dict[str, BaseRoute] + Dictionary containing the path and route object for the router. + + Returns + ------- + Dict[str, Dict[str, Any]] + Dictionary containing the description for each router. + """ + main_router = RouterLoader.from_extensions() + routers = {} + for path in route_map: + path_parts = path.split("/") + # We start at 2: ["/", "some_router"] "/some_router" + i = 2 + p = "/".join(path_parts[:i]) + while p != path: + if p not in routers: + description = main_router.get_attr(p, "description") + if description is not None: + routers[p] = {"description": description} + # We go down the path to include sub-routers + i += 1 + p = "/".join(path_parts[:i]) + return routers |