summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPratyush Shukla <ps4534@nyu.edu>2024-03-13 23:00:10 +0530
committerGitHub <noreply@github.com>2024-03-13 17:30:10 +0000
commit540f5788529be2d1f8874af17152f6aafd759b2d (patch)
treeefc5a87b8a9dcfd652836577181fdba4a5dd3382
parentd984637844be34beab090db0b7d3dfd0b4ba419f (diff)
[Enhancement] `ReferenceGenerator` class in `package_builder.py` (#6179)
* fix docstrings in quantitative,technical,econometrics routers * add reference generator class to create reference.json file in openbb/assets folder * modify platform markdown generator script to not generate the reference.json file * get properly formatted examples for the website * modify get_field_type function * remove TODO comment * path change / to . * remove extra '-' below 'Args' in function docstring * get obbject extensions in extension_map.json * skew in stats_router had extra . rip * black * extra space (sigh) * modify get_provider_parameter_info func to use model_provider from provider interface * update lock files * black * Revert "update lock files" This reverts commit 13e7d8b280a910acd19959ae2ffac3cdab5a2818. * add function to check all extensions are installed before the markdow generator runs * better comment * fix docstring fix mypy obbject description function * black * 'static * remove duplicate standard params & data from reference data add standard params & data fields while generating markdown files
-rw-r--r--openbb_platform/core/openbb_core/app/static/package_builder.py451
-rw-r--r--openbb_platform/core/tests/app/static/test_package_builder.py5
-rw-r--r--openbb_platform/extensions/econometrics/openbb_econometrics/econometrics_router.py28
-rw-r--r--openbb_platform/extensions/quantitative/openbb_quantitative/rolling/rolling_router.py24
-rw-r--r--openbb_platform/extensions/quantitative/openbb_quantitative/stats/stats_router.py9
-rw-r--r--openbb_platform/extensions/technical/openbb_technical/technical_router.py2
-rw-r--r--openbb_platform/openbb/assets/extension_map.json1
-rw-r--r--openbb_platform/openbb/assets/module_map.json152
-rw-r--r--openbb_platform/openbb/assets/reference.json41065
-rw-r--r--website/generate_platform_v4_markdown.py817
10 files changed, 41803 insertions, 751 deletions
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py
index 20b2fc433a8..9de60e14837 100644
--- a/openbb_platform/core/openbb_core/app/static/package_builder.py
+++ b/openbb_platform/core/openbb_core/app/static/package_builder.py
@@ -3,13 +3,15 @@
# pylint: disable=too-many-lines
import builtins
import inspect
+import re
import shutil
import sys
from dataclasses import Field
from inspect import Parameter, _empty, isclass, signature
-from json import dumps, load
+from json import dump, dumps, load
from pathlib import Path
from typing import (
+ Any,
Callable,
Dict,
List,
@@ -112,9 +114,9 @@ class PackageBuilder:
self._clean(modules)
ext_map = self._get_extension_map()
self._save_extension_map(ext_map)
- self._save_module_map()
self._save_modules(modules, ext_map)
self._save_package()
+ self._save_reference_file()
if self.lint:
self._run_linters()
@@ -147,17 +149,6 @@ class PackageBuilder:
self.console.log("Writing extension map...")
self._write(code=code, name="extension_map", extension="json", folder="assets")
- def _save_module_map(self):
- """Save the module map."""
- route_map = PathHandler.build_route_map()
- path_list = PathHandler.build_path_list(route_map=route_map)
- module_map = {
- PathHandler.build_module_name(path=path): path for path in path_list
- }
- code = dumps(obj=dict(sorted(module_map.items())), indent=4)
- self.console.log("\nWriting module map...")
- self._write(code=code, name="module_map", extension="json", folder="assets")
-
def _save_modules(
self,
modules: Optional[Union[str, List[str]]] = None,
@@ -194,6 +185,16 @@ class PackageBuilder:
code = "### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ###\n"
self._write(code=code, name="__init__")
+ def _save_reference_file(self):
+ """Save the reference.json file."""
+ self.console.log("\nWriting reference file...")
+ data = ReferenceGenerator.get_reference_data()
+ file_path = self.directory / "assets" / "reference.json"
+ # Dumping the reference dictionary as a JSON file
+ self.console.log(str(file_path))
+ with open(file_path, "w", encoding="utf-8") as f:
+ dump(data, f, indent=4)
+
def _run_linters(self):
"""Run the linters."""
self.console.log("\nRunning linters...")
@@ -886,12 +887,14 @@ class DocstringGenerator:
@staticmethod
def get_field_type(
- field: FieldInfo, target: Literal["docstring", "website"] = "docstring"
+ field_type: Any,
+ is_required: bool,
+ target: Literal["docstring", "website"] = "docstring",
) -> str:
"""Get the implicit data type of a defined Pydantic field.
- Args
- ----
+ Parameters
+ ----------
field (FieldInfo): Pydantic field object containing field information.
target (Literal["docstring", "website"], optional): Target to return type for. Defaults to "docstring".
@@ -899,10 +902,10 @@ class DocstringGenerator:
-------
str: String representation of the field type.
"""
- is_optional = not field.is_required() if target == "docstring" else False
+ is_optional = not is_required
try:
- _type = field.annotation
+ _type = field_type
if "BeforeValidator" in str(_type):
_type = "Optional[int]" if is_optional else "int" # type: ignore
@@ -918,38 +921,47 @@ class DocstringGenerator:
.replace("NoneType", "None")
.replace(", None", "")
)
+
field_type = (
f"Optional[{field_type}]"
if is_optional and "Optional" not in str(_type)
else field_type
)
+
+ if target == "website":
+ field_type = re.sub(r"Optional\[(.*)\]", r"\1", field_type)
+ field_type = re.sub(r"Annotated\[(.*)\]", r"\1", field_type)
+
+ return field_type
+
except TypeError:
# Fallback to the annotation if the repr fails
- field_type = field.annotation # type: ignore
-
- return field_type
+ return field_type # type: ignore
@staticmethod
def get_OBBject_description(
results_type: str,
providers: Optional[str],
+ target: Literal["docstring", "website"] = "docstring",
) -> str:
"""Get the command output description."""
available_providers = providers or "Optional[str]"
+ indent = 2 if target == "docstring" else 0
obbject_description = (
- f"{create_indent(2)}OBBject\n"
- f"{create_indent(3)}results : {results_type}\n"
- f"{create_indent(4)}Serializable results.\n"
- f"{create_indent(3)}provider : {available_providers}\n"
- f"{create_indent(4)}Provider name.\n"
- f"{create_indent(3)}warnings : Optional[List[Warning_]]\n"
- f"{create_indent(4)}List of warnings.\n"
- f"{create_indent(3)}chart : Optional[Chart]\n"
- f"{create_indent(4)}Chart object.\n"
- f"{create_indent(3)}extra : Dict[str, Any]\n"
- f"{create_indent(4)}Extra info.\n"
+ f"{create_indent(indent)}OBBject\n"
+ f"{create_indent(indent+1)}results : {results_type}\n"
+ f"{create_indent(indent+2)}Serializable results.\n"
+ f"{create_indent(indent+1)}provider : {available_providers}\n"
+ f"{create_indent(indent+2)}Provider name.\n"
+ f"{create_indent(indent+1)}warnings : Optional[List[Warning_]]\n"
+ f"{create_indent(indent+2)}List of warnings.\n"
+ f"{create_indent(indent+1)}chart : Optional[Chart]\n"
+ f"{create_indent(indent+2)}Chart object.\n"
+ f"{create_indent(indent+1)}extra : Dict[str, Any]\n"
+ f"{create_indent(indent+2)}Extra info.\n"
)
+
obbject_description = obbject_description.replace("NoneType", "None")
return obbject_description
@@ -1066,7 +1078,7 @@ class DocstringGenerator:
docstring += f"{create_indent(2)}{underline}\n"
for name, field in returns.items():
- field_type = cls.get_field_type(field)
+ field_type = cls.get_field_type(field.annotation, field.is_required())
description = getattr(field, "description", "")
docstring += f"{create_indent(2)}{field.alias or name} : {field_type}\n"
docstring += f"{create_indent(3)}{format_description(description)}\n"
@@ -1190,3 +1202,376 @@ class PathHandler:
if not path:
return "Extensions"
return f"ROUTER_{cls.clean_path(path=path)}"
+
+
+class ReferenceGenerator:
+ """Generate the reference for the Platform."""
+
+ REFERENCE_FIELDS = [
+ "deprecated",
+ "description",
+ "examples",
+ "parameters",
+ "returns",
+ "data",
+ ]
+
+ # pylint: disable=protected-access
+ pi = DocstringGenerator.provider_interface
+
+ @classmethod
+ def get_endpoint_examples(
+ cls,
+ path: str,
+ func: Callable,
+ examples: Optional[List[Example]],
+ ) -> str:
+ """Get the examples for the given standard model or function.
+
+ For a given standard model or function, the examples are fetched from the
+ list of Example objects and formatted into a string.
+
+ Parameters
+ ----------
+ path (str):
+ Path of the router.
+ func (Callable):
+ Router endpoint function.
+ examples (Optional[List[Example]]):
+ List of Examples (APIEx or PythonEx type)
+ for the endpoint.
+
+ Returns
+ -------
+ str:
+ Formatted string containing the examples for the endpoint.
+ """
+ sig = signature(func)
+ parameter_map = dict(sig.parameters)
+ formatted_params = MethodDefinition.format_params(
+ path=path, parameter_map=parameter_map
+ )
+ explicit_params = dict(formatted_params)
+ explicit_params.pop("extra_params", None)
+ param_types = {k: v.annotation for k, v in explicit_params.items()}
+
+ return DocstringGenerator.build_examples(
+ path.replace("/", "."),
+ param_types,
+ examples,
+ "website",
+ )
+
+ @classmethod
+ def get_provider_parameter_info(cls, model: str) -> Dict[str, str]:
+ """Get the name, type, description, default value and optionality information for the provider parameter.
+
+ Parameters
+ ----------
+ model (str):
+ Standard model to access the model providers.
+
+ Returns
+ -------
+ Dict[str, str]:
+ Dictionary of the provider parameter information
+ """
+ pi_model_provider = cls.pi.model_providers[model]
+ provider_params_field = pi_model_provider.__dataclass_fields__["provider"]
+
+ name = provider_params_field.name
+ field_type = DocstringGenerator.get_field_type(
+ provider_params_field.type, False, "website"
+ )
+ default = provider_params_field.type.__args__[0]
+ description = (
+ "The provider to use for the query, by default None. "
+ "If None, the provider specified in defaults is selected "
+ f"or '{default}' if there is no default."
+ )
+
+ provider_parameter_info = {
+ "name": name,
+ "type": field_type,
+ "description": description,
+ "default": default,
+ "optional": True,
+ }
+
+ return provider_parameter_info
+
+ @classmethod
+ def get_provider_field_params(
+ cls,
+ model: str,
+ params_type: str,
+ provider: str = "openbb",
+ ) -> List[Dict[str, Any]]:
+ """Get the fields of the given parameter type for the given provider of the standard_model.
+
+ Parameters
+ ----------
+ model (str):
+ Model name to access the provider interface
+ params_type (str):
+ Parameters to fetch data for (QueryParams or Data)
+ provider (str, optional):
+ Provider name. Defaults to "openbb".
+
+ Returns
+ -------
+ List[Dict[str, str]]:
+ List of dictionaries containing the field name, type, description, default,
+ optional flag and standard flag for each provider.
+ """
+ provider_field_params = []
+ expanded_types = MethodDefinition.TYPE_EXPANSION
+ model_map = cls.pi._map[model] # pylint: disable=protected-access
+
+ for field, field_info in model_map[provider][params_type]["fields"].items():
+ # Determine the field type, expanding it if necessary and if params_type is "Parameters"
+ field_type = field_info.annotation
+ is_required = field_info.is_required()
+ field_type = DocstringGenerator.get_field_type(
+ field_type, is_required, "website"
+ )
+
+ if params_type == "QueryParams" and field in expanded_types:
+ expanded_type = DocstringGenerator.get_field_type(
+ expanded_types[field], is_required, "website"
+ )
+ field_type = f"Union[{field_type}, {expanded_type}]"
+
+ cleaned_description = (
+ str(field_info.description)
+ .strip().replace("\n", " ").replace(" ", " ").replace('"', "'")
+ ) # fmt: skip
+
+ # Add information for the providers supporting multiple symbols
+ if params_type == "QueryParams" and field_info.json_schema_extra:
+ multiple_items_list = field_info.json_schema_extra.get(
+ "multiple_items_allowed", None
+ )
+ if multiple_items_list:
+ multiple_items = ", ".join(multiple_items_list)
+ cleaned_description += (
+ f" Multiple items allowed for provider(s): {multiple_items}."
+ )
+ # Manually setting to List[<field_type>] for multiple items
+ # Should be removed if TYPE_EXPANSION is updated to include this
+ field_type = f"Union[{field_type}, List[{field_type}]]"
+
+ default_value = "" if field_info.default is PydanticUndefined else str(field_info.default) # fmt: skip
+
+ provider_field_params.append(
+ {
+ "name": field,
+ "type": field_type,
+ "description": cleaned_description,
+ "default": default_value,
+ "optional": not is_required,
+ }
+ )
+
+ return provider_field_params
+
+ @staticmethod
+ def get_post_method_parameters_info(
+ docstring: str,
+ ) -> List[Dict[str, Union[bool, str]]]:
+ """Get the parameters for the POST method endpoints.
+
+ Parameters
+ ----------
+ docstring (str):
+ Router endpoint function's docstring
+
+ Returns
+ -------
+ List[Dict[str, str]]:
+ List of dictionaries containing the name,type, description, default
+ and optionality of each parameter.
+ """
+ # Define a regex pattern to match parameter blocks
+ # This pattern looks for a parameter name followed by " : ", then captures the type and description
+ pattern = re.compile(
+ r"\n\s*(?P<name>\w+)\s*:\s*(?P<type>[^\n]+?)(?:\s*=\s*(?P<default>[^\n]+))?\n\s*(?P<description>[^\n]+)"
+ )
+
+ # Find all matches in the docstring
+ matches = pattern.finditer(docstring)
+
+ # Initialize an empty list to store parameter dictionaries
+ parameters_list = []
+
+ # Iterate over the matches to extract details
+ for match in matches:
+ # Extract named groups as a dictionary
+ param_info = match.groupdict()
+
+ # Determine if the parameter is optional
+ is_optional = "Optional" in param_info["type"]
+
+ # If no default value is captured, set it to an empty string
+ default_value = (
+ param_info["default"] if param_info["default"] is not None else ""
+ )
+
+ # Create a new dictionary with fields in the desired order
+ param_dict = {
+ "name": param_info["name"],
+ "type": param_info["type"],
+ "description": param_info["description"],
+ "default": default_value,
+ "optional": is_optional,
+ }
+
+ # Append the dictionary to the list
+ parameters_list.append(param_dict)
+
+ return parameters_list
+
+ @staticmethod
+ def get_post_method_returns_info(docstring: str) -> str:
+ """Get the returns information for the POST method endpoints.
+
+ Parameters
+ ----------
+ docstring (str):
+ Router endpoint function's docstring
+
+ Returns
+ -------
+ Dict[str, str]:
+ Dictionary containing the name, type, description of the return value
+ """
+ # Define a regex pattern to match the Returns section
+ # This pattern captures the model name inside "OBBject[]" and its description
+ match = re.search(r"Returns\n\s*-------\n\s*([^\n]+)\n\s*([^\n]+)", docstring)
+ return_type = match.group(1).strip() # type: ignore
+ # Remove newlines and indentation from the description
+ description = match.group(2).strip().replace("\n", "").replace(" ", "") # type: ignore
+ # Adjust regex to correctly capture content inside brackets, including nested brackets
+ content_inside_brackets = re.search(
+ r"OBBject\[\s*((?:[^\[\]]|\[[^\[\]]*\])*)\s*\]", return_type
+ )
+ return_type_content = content_inside_brackets.group(1) # type: ignore
+
+ return_info = (
+ f"OBBject\n"
+ f"{create_indent(1)}results : {return_type_content}\n"
+ f"{create_indent(2)}{description}"
+ )
+
+ return return_info
+
+ @classmethod
+ def get_reference_data(cls) -> Dict[str, Dict[str, Any]]:
+ """Get the reference data for the Platform.
+
+ The reference data is a dictionary containing the description, parameters,
+ returns and examples for each endpoint. This is currently useful for
+ automating the creation of the website documentation files.
+
+ Returns
+ -------
+ Dict[str, Dict[str, Any]]:
+ Dictionary containing the description, parameters, returns and
+ examples for each endpoint.
+ """
+ reference: Dict[str, Dict] = {}
+ route_map = PathHandler.build_route_map()
+
+ for path, route in route_map.items():
+ # Initialize the reference fields as empty dictionaries
+ reference[path] = {field: {} for field in cls.REFERENCE_FIELDS}
+ # Route method is used to distinguish between GET and POST methods
+ route_method = getattr(route, "methods", None)
+ # Route endpoint is the callable function
+ route_func = getattr(route, "endpoint", None)
+ # Attribute contains the model and examples info for the endpoint
+ openapi_extra = getattr(route, "openapi_extra", {})
+ # Standard model is used as the key for the ProviderInterface Map dictionary
+ standard_model = openapi_extra.get("model", "")
+ # Add endpoint model for GET methods
+ reference[path]["model"] = standard_model
+ # Add endpoint deprecation details
+ reference[path]["deprecated"] = {
+ "flag": MethodDefinition.is_deprecated_function(path),
+ "message": MethodDefinition.get_deprecation_message(path),
+ }
+ # Add endpoint examples
+ examples = openapi_extra.get("examples", [])
+ reference[path]["examples"] = cls.get_endpoint_examples(
+ path, route_func, examples # type: ignore
+ )
+ # Add data for the endpoints having a standard model
+ if route_method == {"GET"}:
+ reference[path]["description"] = getattr(
+ route, "description", "No description available."
+ )
+ # Access model map from the ProviderInterface
+ model_map = cls.pi._map[
+ standard_model
+ ] # pylint: disable=protected-access
+
+ for provider in model_map:
+ if provider == "openbb":
+ # openbb provider is always present hence its the standard field
+ reference[path]["parameters"]["standard"] = (
+ cls.get_provider_field_params(standard_model, "QueryParams")
+ )
+ # Add `provider` parameter fields to the openbb provider
+ provider_parameter_fields = cls.get_provider_parameter_info(
+ standard_model
+ )
+ reference[path]["parameters"]["standard"].append(
+ provider_parameter_fields
+ )
+
+ # Add endpoint data fields for standard provider
+ reference[path]["data"]["standard"] = (
+ cls.get_provider_field_params(standard_model, "Data")
+ )
+ continue
+ # Adds provider specific parameter fields to the reference
+ reference[path]["parameters"][provider] = (
+ cls.get_provider_field_params(
+ standard_model, "QueryParams", provider
+ )
+ )
+ # Adds provider specific data fields to the reference
+ reference[path]["data"][provider] = cls.get_provider_field_params(
+ standard_model, "Data", provider
+ )
+ # Add endpoint returns data
+ # Currently only OBBject object is returned
+ providers = provider_parameter_fields["type"]
+ reference[path]["returns"]["OBBject"] = (
+ DocstringGenerator.get_OBBject_description(
+ standard_model, providers, "website"
+ )
+ )
+ # Add data for the endpoints without a standard model (data processing endpoints)
+ elif route_method == {"POST"}:
+ # POST method router `description` attribute is unreliable as it may or
+ # may not contain the "Parameters" and "Returns" sections. Hence, the
+ # endpoint function docstring is used instead.
+ description = route_func.__doc__.split("Parameters")[0].strip() # type: ignore
+ # Remove extra spaces in between the string
+ reference[path]["description"] = re.sub(" +", " ", description)
+ # Add endpoint parameters fields for POST methods
+ reference[path]["parameters"][
+ "standard"
+ ] = ReferenceGenerator.get_post_method_parameters_info(
+ route_func.__doc__ # type: ignore
+ )
+ # Add endpoint returns data
+ # Currently only OBBject object is returned
+ reference[path]["returns"][
+ "OBBject"
+ ] = cls.get_post_method_returns_info(
+ route_func.__doc__ # type: ignore
+ )
+
+ return reference
diff --git a/openbb_platform/core/tests/app/static/test_package_builder.py b/openbb_platform/core/tests/app/static/test_package_builder.py
index eb1fe3a1d2e..aa7c0029475 100644
--- a/openbb_platform/core/tests/app/static/test_package_builder.py
+++ b/openbb_platform/core/tests/app/static/test_package_builder.py
@@ -46,11 +46,6 @@ def test_package_builder_build(package_builder):
package_builder.build()
-def test_save_module_map(package_builder):
- """Test save module map."""
- package_builder._save_module_map()
-
-
def test_save_modules(package_builder):
"""Test save module."""
package_builder._save_modules()
diff --git a/openbb_platform/extensions/econometrics/openbb_econometrics/econometrics_router.py b/openbb_platform/extensions/econometrics/openbb_econometrics/econometrics_router.py
index bb9f2bd9a43..f4c8f48386e 100644
--- a/openbb_platform/extensions/econometrics/openbb_econometrics/econometrics_router.py
+++ b/openbb_platform/extensions/econometrics/openbb_econometrics/econometrics_router.py
@@ -58,7 +58,7 @@ def correlation_matrix(data: List[Data]) -> OBBject[List[Data]]:
Returns
-------
- OBBject[List[Data]]:
+ OBBject[List[Data]]
Correlation matrix.
"""
df = basemodel_to_df(data)
@@ -120,7 +120,7 @@ def ols_regression(
Returns
-------
- OBBject[Dict]:
+ OBBject[Dict]
OBBject with the results being model and results objects.
"""
X = sm.add_constant(get_target_columns(basemodel_to_df(data), x_columns))
@@ -169,7 +169,7 @@ def ols_regression_summary(
Returns
-------
- OBBject[Data]:
+ OBBject[Data]
OBBject with the results being summary object.
"""
X = sm.add_constant(get_target_columns(basemodel_to_df(data), x_columns))
@@ -260,7 +260,7 @@ def autocorrelation(
Returns
-------
- OBBject[Dict]:
+ OBBject[Dict]
OBBject with the results being the score from the test.
"""
X = sm.add_constant(get_target_columns(basemodel_to_df(data), x_columns))
@@ -317,7 +317,7 @@ def residual_autocorrelation(
Returns
-------
- OBBject[Data]:
+ OBBject[Data]
OBBject with the results being the score from the test.
"""
X = sm.add_constant(get_target_columns(basemodel_to_df(data), x_columns))
@@ -374,7 +374,7 @@ def cointegration(
Returns
-------
- OBBject[Data]:
+ OBBject[Data]
OBBject with the results being the score from the test.
"""
pairs = list(combinations(columns, 2))
@@ -450,7 +450,7 @@ def causality(
Returns
-------
- OBBject[Data]:
+ OBBject[Data]
OBBject with the results being the score from the test.
"""
X = get_target_column(basemodel_to_df(data), x_column)
@@ -518,7 +518,7 @@ def unit_root(
Returns
-------
- OBBject[Data]:
+ OBBject[Data]
OBBject with the results being the score from the test.
"""
dataset = get_target_column(basemodel_to_df(data), column)
@@ -568,7 +568,7 @@ def panel_random_effects(
Returns
-------
- OBBject[Dict]:
+ OBBject[Dict]
OBBject with the fit model returned
"""
X = get_target_columns(basemodel_to_df(data), x_columns)
@@ -615,7 +615,7 @@ def panel_between(
Returns
-------
- OBBject[Dict]:
+ OBBject[Dict]
OBBject with the fit model returned
"""
X = get_target_columns(basemodel_to_df(data), x_columns)
@@ -661,7 +661,7 @@ def panel_pooled(
Returns
-------
- OBBject[Dict]:
+ OBBject[Dict]
OBBject with the fit model returned
"""
X = get_target_columns(basemodel_to_df(data), x_columns)
@@ -706,7 +706,7 @@ def panel_fixed(
Returns
-------
- OBBject[Dict]:
+ OBBject[Dict]
OBBject with the fit model returned
"""
X = get_target_columns(basemodel_to_df(data), x_columns)
@@ -751,7 +751,7 @