diff options
author | Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com> | 2024-01-27 00:22:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-26 23:22:21 +0000 |
commit | c6eefd26b92b44b32ac13151fce44d6e0bc0e80f (patch) | |
tree | 29708ff0a33fcfb81bdad24ed53676a3e0cd992b | |
parent | 34312d236a54fc7212063cbb8fcbeb2c69d50b58 (diff) |
[Feature] - Support for custom examples in router commands (#5993)
* Disable auto_build on test run
* Add support for custom router examples
* Revert odd file change
* Fix API pollution
* Refactor examples with @montezdesousa
* Check if model is inside the PI
* Fix
* feat: add & fix examples (#6001)
* feat: add & fix examples
* fix: ruff + comment
* feat: read parameter pool from file
* feat: typing + unit test
* Disable auto_build on test run
* Add field order to OBBject
* Revert
* lint and revert
* fix test
---------
Co-authored-by: montezdesousa <79287829+montezdesousa@users.noreply.github.com>
Co-authored-by: Diogo Sousa <montezdesousa@gmail.com>
17 files changed, 294 insertions, 187 deletions
diff --git a/openbb_platform/core/openbb_core/app/assets/parameter_pool.json b/openbb_platform/core/openbb_core/app/assets/parameter_pool.json new file mode 100644 index 00000000000..ae6beee70e9 --- /dev/null +++ b/openbb_platform/core/openbb_core/app/assets/parameter_pool.json @@ -0,0 +1,50 @@ +{ + "crypto": { + "symbol": "BTCUSD" + }, + "currency": { + "symbol": "EURUSD" + }, + "derivatives": { + "symbol": "AAPL" + }, + "economy": { + "country": "portugal", + "countries": ["portugal", "spain"] + }, + "economy.fred_series": { + "symbol": "GFDGDPA188S" + }, + "equity": { + "symbol": "AAPL", + "symbols": "AAPL,MSFT", + "query": "AAPL" + }, + "equity.fundamental.historical_attributes": { + "tag": "ebitda" + }, + "equity.fundamental.latest_attributes": { + "tag": "ceo" + }, + "equity.fundamental.transcript": { + "year": 2020 + }, + "etf": { + "symbol": "SPY", + "query": "Vanguard" + }, + "futures": { + "symbol": "ES" + }, + "index": { + "symbol": "SPX", + "index": "^IBEX" + }, + "news": { + "symbols": "AAPL,MSFT" + }, + "regulators": { + "symbol": "AAPL", + "query": "AAPL" + } +} diff --git a/openbb_platform/core/openbb_core/app/constants.py b/openbb_platform/core/openbb_core/app/constants.py index 91996980b23..74ad9bcf6de 100644 --- a/openbb_platform/core/openbb_core/app/constants.py +++ b/openbb_platform/core/openbb_core/app/constants.py @@ -2,6 +2,7 @@ from pathlib import Path +ASSETS_DIRECTORY = Path(__file__).parent / "assets" HOME_DIRECTORY = Path.home() OPENBB_DIRECTORY = Path(HOME_DIRECTORY, ".openbb_platform") USER_SETTINGS_PATH = Path(OPENBB_DIRECTORY, "user_settings.json") diff --git a/openbb_platform/core/openbb_core/app/example_generator.py b/openbb_platform/core/openbb_core/app/example_generator.py new file mode 100644 index 00000000000..c661936621d --- /dev/null +++ b/openbb_platform/core/openbb_core/app/example_generator.py @@ -0,0 +1,83 @@ +"""OpenBB Platform example generator.""" + +import json +from pathlib import Path +from typing import ( + Any, + Dict, +) + +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined + +from openbb_core.app.constants import ASSETS_DIRECTORY +from openbb_core.app.provider_interface import ProviderInterface + +try: + with Path(ASSETS_DIRECTORY, "parameter_pool.json").open() as f: + PARAMETER_POOL = json.load(f) +except Exception: + PARAMETER_POOL = {} + + +class ExampleGenerator: + """Generate examples for the API.""" + + @staticmethod + def _get_value_from_pool(pool: dict, route: str, param: str) -> str: + """Get the value from the pool. + + The example parameters can be defined for: + - route: "crypto.historical.price": {"symbol": "CRYPTO_HISTORICAL_PRICE_SYMBOL"} + - sub-router: "crypto.historical": {"symbol": "CRYPTO_HISTORICAL_SYMBOL"} + - router: "crypto": {"symbol": "CRYPTO_SYMBOL"} + + The search for the 'key' is done in the following order: + - route + - sub-router + - router + """ + parts = route.split(".") + for i in range(len(parts), 0, -1): + partial_route = ".".join(parts[:i]) + if partial_route in pool and param in pool[partial_route]: + return pool[partial_route][param] + return "VALUE_NOT_FOUND" + + @classmethod + def generate( + cls, + route: str, + model: str, + ) -> str: + """Generate the example for the command.""" + if not route or not model: + return "" + + standard_params: Dict[str, FieldInfo] = ( + ProviderInterface() + .map.get(model, {}) + .get("openbb", {}) + .get("QueryParams", {}) + .get("fields", {}) + ) + + eg_params: Dict[str, Any] = {} + for p, v in standard_params.items(): + if v.default is not None: + if v.default is not PydanticUndefined and v.default != "": + eg_params[p] = v.default + else: + eg_params[p] = cls._get_value_from_pool(PARAMETER_POOL, route, p) + + example = f"obb.{route}(" + for n, v in eg_params.items(): + if isinstance(v, str): + v = f'"{v}"' # noqa: PLW2901 + example += f"{n}={v}, " + if eg_params: + example = example[:-2] + ")" + else: + example += ")" + + return example diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index 210001d2bf3..2426f2f6da9 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -24,6 +24,7 @@ from pydantic import BaseModel from pydantic.v1.validators import find_validators from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias +from openbb_core.app.example_generator import ExampleGenerator from openbb_core.app.extension_loader import ExtensionLoader from openbb_core.app.model.abstract.warning import OpenBBWarning from openbb_core.app.model.command_context import CommandContext @@ -232,16 +233,23 @@ class Router: model = kwargs.pop("model", "") deprecation_message = kwargs.pop("deprecation_message", None) + examples = kwargs.pop("examples", []) + exclude_auto_examples = kwargs.pop("exclude_auto_examples", False) + + if func := SignatureInspector.complete(func, model): + if not exclude_auto_examples: + examples.insert( + 0, + ExampleGenerator.generate( + route=SignatureInspector.get_operation_id(func, sep="."), + model=model, + ), + ) - if model: kwargs["response_model_exclude_unset"] = True - kwargs["openapi_extra"] = {"model": model} - - func = SignatureInspector.complete_signature(func, model) - - if func: - CommandValidator.check(func=func, model=model) - + kwargs["openapi_extra"] = kwargs.get("openapi_extra", {}) + kwargs["openapi_extra"]["model"] = model + kwargs["openapi_extra"]["examples"] = examples kwargs["operation_id"] = kwargs.get( "operation_id", SignatureInspector.get_operation_id(func) ) @@ -300,7 +308,7 @@ class SignatureInspector: """Inspect function signature.""" @classmethod - def complete_signature( + def complete( cls, func: Callable[P, OBBject], model: str ) -> Optional[Callable[P, OBBject]]: """Complete function signature.""" @@ -321,7 +329,6 @@ class SignatureInspector: category=OpenBBWarning, ) return None - cls.validate_signature( func, { @@ -445,19 +452,20 @@ class SignatureInspector: if doc: description = doc.split(" Parameters\n ----------")[0] description = description.split(" Returns\n -------")[0] + description = description.split(" Example\n -------")[0] description = "\n".join([line.strip() for line in description.split("\n")]) return description return "" @staticmethod - def get_operation_id(func: Callable) -> str: + def get_operation_id(func: Callable, sep: str = "_") -> str: """Get operation id.""" operation_id = [ t.replace("_router", "").replace("openbb_", "") for t in func.__module__.split(".") + [func.__name__] ] - cleaned_id = "_".join({c: "" for c in operation_id if c}.keys()) + cleaned_id = sep.join({c: "" for c in operation_id if c}.keys()) return cleaned_id diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index 01b8946015b..7648ed89511 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -9,11 +9,9 @@ from inspect import Parameter, _empty, isclass, signature from json import dumps, load from pathlib import Path from typing import ( - Any, Callable, Dict, List, - Literal, Optional, OrderedDict, Set, @@ -22,7 +20,6 @@ from typing import ( TypeVar, Union, get_args, - get_origin, get_type_hints, ) @@ -38,7 +35,7 @@ from openbb_core.app.charting_service import ChartingService from openbb_core.app.extension_loader import ExtensionLoader, OpenBBGroups from openbb_core.app.model.custom_parameter import OpenBBCustomParameter from openbb_core.app.provider_interface import ProviderInterface -from openbb_core.app.router import CommandMap, RouterLoader +from openbb_core.app.router import RouterLoader from openbb_core.app.static.utils.console import Console from openbb_core.app.static.utils.linters import Linters from openbb_core.env import Env @@ -389,6 +386,7 @@ class ClassDefinition: if route.openapi_extra else None ), + examples=route.openapi_extra.get("examples", None), ) # type: ignore else: doc += " /" if path else " /" @@ -664,12 +662,16 @@ class MethodDefinition: func: Callable, formatted_params: OrderedDict[str, Parameter], model_name: Optional[str] = None, + examples: Optional[List[str]] = None, ): """Build the command method docstring.""" doc = func.__doc__ if model_name: doc = DocstringGenerator.generate( - func=func, formatted_params=formatted_params, model_name=model_name + func=func, + formatted_params=formatted_params, + model_name=model_name, + examples=examples, ) code = f' """{doc} """ # noqa: E501\n\n' if doc else "" @@ -728,7 +730,11 @@ class MethodDefinition: @classmethod def build_command_method( - cls, path: str, func: Callable, model_name: Optional[str] = None + cls, + path: str, + func: Callable, + model_name: Optional[str] = None, + examples: Optional[List[str]] = None, ) -> str: """Build the command method.""" func_name = func.__name__ @@ -745,7 +751,10 @@ class MethodDefinition: model_name=model_name, ) code += cls.build_command_method_doc( - func=func, formatted_params=formatted_params, model_name=model_name + func=func, + formatted_params=formatted_params, + model_name=model_name, + examples=examples, ) code += cls.build_command_method_body(path=path, func=func) @@ -780,114 +789,6 @@ class DocstringGenerator: return obbject_description - @staticmethod - def get_model_standard_params(param_fields: Dict[str, FieldInfo]) -> Dict[str, Any]: - """Get the test params for the fetcher based on the required standard params.""" - test_params: Dict[str, Any] = {} - for field_name, field in param_fields.items(): - if field.default and field.default is not PydanticUndefined: - test_params[field_name] = field.default - elif field.default and field.default is PydanticUndefined: - example_dict = { - "symbol": "AAPL", - "symbols": "AAPL,MSFT", - "start_date": "2023-01-01", - "end_date": "2023-06-06", - "country": "Portugal", - "date": "2023-01-01", - "countries": ["portugal", "spain"], - } - if field_name in example_dict: - test_params[field_name] = example_dict[field_name] - elif field.annotation == str: - test_params[field_name] = "TEST_STRING" - elif field.annotation == int: - test_params[field_name] = 1 - elif field.annotation == float: - test_params[field_name] = 1.0 - elif field.annotation == bool: - test_params[field_name] = True - elif get_origin(field.annotation) is Literal: # type: ignore - option = field.annotation.__args__[0] # type: ignore - if isinstance(option, str): - test_params[field_name] = f'"{option}"' - else: - test_params[field_name] = option - - return test_params - - @staticmethod - def get_full_command_name(route: str) -> str: - """Get the full command name.""" - cmd_parts = route.split("/") - del cmd_parts[0] - - menu = cmd_parts[0] - command = cmd_parts[-1] - sub_menus = cmd_parts[1:-1] - - sub_menu_str_cmd = f".{'.'.join(sub_menus)}" if sub_menus else "" - - full_command = f"{menu}{sub_menu_str_cmd}.{command}" - - return full_command - - @classmethod - def generate_example( - cls, - model_name: str, - standard_params: Dict[str, FieldInfo], - ) -> str: - """Generate the example for the command.""" - # find the model router here - cm = CommandMap() - commands_model = cm.commands_model - route = [k for k, v in commands_model.items() if v == model_name] - - if not route: - return "" - - full_command_name = cls.get_full_command_name(route=route[0]) - example_params = cls.get_model_standard_params(param_fields=standard_params) - - # Edge cases (might find more) - if "crypto" in route[0] and "symbol" in example_params: - example_params["symbol"] = "BTCUSD" - elif "currency" in route[0] and "symbol" in example_params: - example_params["symbol"] = "EURUSD" - elif ( - "index" in route[0] - and "european" not in route[0] - and "symbol" in example_params - ): - example_params["symbol"] = "SPX" - elif ( - "index" in route[0] - and "european" in route[0] - and "symbol" in example_params - ): - example_params["symbol"] = "BUKBUS" - elif ( - "futures" in route[0] and "curve" in route[0] and "symbol" in example_params - ): - example_params["symbol"] = "VX" - elif "futures" in route[0] and "symbol" in example_params: - example_params["symbol"] = "ES" - - example = "\n Example\n -------\n" - example += " >>> from openbb import obb\n" - example += f" >>> obb.{full_command_name}(" - for param_name, param_value in example_params.items(): - if isinstance(param_value, str): - param_value = f'"{param_value}"' # noqa: PLW2901 - example += f"{param_name}={param_value}, " - if example_params: - example = example[:-2] + ")\n" - else: - example += ")\n" - - return example - @classmethod def generate_model_docstring( cls, @@ -897,6 +798,7 @@ class DocstringGenerator: params: dict, returns: Dict[str, FieldInfo], results_type: str, + examples: Optional[List[str]] = None, ) -> str: """Create the docstring for model.""" @@ -922,9 +824,11 @@ class DocstringGenerator: "openbb" ]["QueryParams"]["fields"] - example_docstring = cls.generate_example( - model_name=model_name, standard_params=obb_query_fields - ) + if examples: + example_docstring = "\n Example\n -------\n" + example_docstring += " >>> from openbb import obb\n" + for example in examples: + example_docstring += f" >>> {example}\n" docstring = summary docstring += "\n\n" @@ -1014,7 +918,9 @@ class DocstringGenerator: docstring += f" {field.alias or name} : {field_type}\n" docstring += f" {format_description(description)}\n" - docstring += example_docstring + if examples: + docstring += example_docstring + return docstring @classmethod @@ -1023,6 +929,7 @@ class DocstringGenerator: func: Callable, formatted_params: OrderedDict[str, Parameter], model_name: Optional[str] = None, + examples: Optional[List[str]] = None, ) -> Optional[str]: """Generate the docstring for the function.""" doc = func.__doc__ @@ -1045,6 +952,7 @@ class DocstringGenerator: params=params, returns=returns, results_type=results_type, + examples=examples, ) return doc return doc diff --git a/openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py b/openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py index 943ecb27907..b1334471d8b 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py @@ -12,7 +12,7 @@ from openbb_core.provider.utils.descriptions import DATA_DESCRIPTIONS class CryptoSearchQueryParams(QueryParams): """Crypto Search Query.""" - query: Optional[str] = Field(description="Search query.", default="") + query: Optional[str] = Field(description="Search query.", default=None) class CryptoSearchData(Data): diff --git a/openbb_platform/core/tests/app/static/test_example_generator.py b/openbb_platform/core/tests/app/static/test_example_generator.py new file mode 100644 index 00000000000..86029fabe13 --- /dev/null +++ b/openbb_platform/core/tests/app/static/test_example_generator.py @@ -0,0 +1,59 @@ +"""Test the example_generator.py file.""" + +# pylint: disable=redefined-outer-name, protected-access + + +import pytest +from openbb_core.app.example_generator import ExampleGenerator + + +@pytest.fixture(scope="module") +def example_generator(): + """Return example generator.""" + return ExampleGenerator() + + +def test_docstring_generator_init(example_generator): + """Test example generator init.""" + assert example_generator + + +TEST_POOL = { + "crypto": {"symbol": "CRYPTO_SYMBOL"}, + "crypto.search": {"symbol": "CRYPTO_SEARCH_SYMBOL"}, + "crypto.price.historical": {"symbol": "CRYPTO_HISTORICAL_PRICE_SYMBOL"}, +} + + +@pytest.mark.parametrize( + "route, param, expected", + [ + ("", "", "VALUE_NOT_FOUND"), + ("random_route", "", "VALUE_NOT_FOUND"), + ("crypto", "symbol", "CRYPTO_SYMBOL"), + ("crypto.search", "symbol", "CRYPTO_SEARCH_SYMBOL"), + ("crypto.price.historical", "symbol", "CRYPTO_HISTORICAL_PRICE_SYMBOL"), + ("crypto.price.historical", "random_param", "VALUE_NOT_FOUND"), + ], +) +def test_get_value_from_pool(example_generator, route, param, expected): + """Test get value from pool.""" + assert example_generator._get_value_from_pool(TEST_POOL, route, param) == expected + + +@pytest.mark.parametrize( + "route, model, expected", + [ + ("", "", ""), + ("random", "test", "obb.random()"), + ("crypto.search", "CryptoSearch", "obb.crypto.search()"), + ( + "crypto.price.historical", + "CryptoHistorical", + 'obb.crypto.price.historical(symbol="BTCUSD")', + ), + ], +) +def test_generate(example_generator, route, model, expected): + """Test generate example.""" + assert example_generator.generate(route, model) == expected diff --git a/openbb_platform/core/tests/app/test_platform_router.py b/openbb_platform/core/tests/app/test_platform_router.py index 14238ff0c11..590a2709aa6 100644 --- a/openbb_platform/core/tests/app/test_platform_router.py +++ b/openbb_platform/core/tests/app/test_platform_router.py @@ -197,7 +197,7 @@ def test_complete_signature(signature_inspector): model = "EquityHistorical" - assert signature_inspector.complete_signature(sample_function, model) + assert signature_inspector.complete(sample_function, model) def test_complete_signature_error(signature_inspector): @@ -206,9 +206,7 @@ def test_complete_signature_error(signature_inspector): async def valid_function() -> OBBject[Optional[List[int]]]: return OBBject(results=[1, 2, 3]) - assert ( - signature_inspector.complete_signature(valid_function, "invalid_model") is None - ) + assert signature_inspector.complete(valid_function, "invalid_model") is None def test_validate_signature(signature_inspector): diff --git a/openbb_platform/openbb/package/crypto.py b/openbb_platform/openbb/package/crypto.py index 3839ceb2608..85a700d2531 100644 --- a/openbb_platform/openbb/package/crypto.py +++ b/openbb_platform/openbb/package/crypto.py @@ -31,7 +31,7 @@ class ROUTER_crypto(Container): self, query: Annotated[ Optional[str], OpenBBCustomParameter(description="Search query.") - ] = "", + ] = None, provider: Optional[Literal["fmp"]] = None, **kwargs ) -> OBBject: diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py index e90862e7041..de1285f8335 100644 --- a/openbb_platform/openbb/package/economy.py +++ b/openbb_platform/openbb/package/economy.py @@ -277,7 +277,7 @@ class ROUTER_economy(Container): Example ------- >>> from openbb import obb - >>> obb.economy.cpi(countries=['portugal', 'spain'], units="growth_same", frequency="monthly") + >>> obb.economy.cpi(countries=['portugal', 'spain'], units="growth_same", frequency="monthly", harmonized=False) """ # noqa: E501 return self._run( @@ -323,7 +323,7 @@ class ROUTER_economy(Container): no default. is_release : Optional[bool] Is release? If True, other search filter variables are ignored. If no query text or release_id is supplied, this defaults to True. (provider: fred) - release_id : Optional[Union[str, int]] + release_id : Optional[Union[int, str]] A specific release ID to target. (provider: fred) limit : Optional[int] The number of data entries to return. (1-1000) (provider: fred) @@ -354,7 +354,7 @@ class ROUTER_economy(Container): FredSearch ---------- - release_id : Optional[Union[str, int]] + release_id : Optional[Union[int, str]] The release ID for queries. series_id : Optional[str] The series ID for the item in the release. @@ -521,7 +521,7 @@ class ROUTER_economy(Container): Example ------- >>> fr |