summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorIgor Radovanovic <74266147+IgorWounds@users.noreply.github.com>2024-01-27 00:22:21 +0100
committerGitHub <noreply@github.com>2024-01-26 23:22:21 +0000
commitc6eefd26b92b44b32ac13151fce44d6e0bc0e80f (patch)
tree29708ff0a33fcfb81bdad24ed53676a3e0cd992b
parent34312d236a54fc7212063cbb8fcbeb2c69d50b58 (diff)
[Feature] - Support for custom examples in router commands (#5993)
* Disable auto_build on test run * Add support for custom router examples * Revert odd file change * Fix API pollution * Refactor examples with @montezdesousa * Check if model is inside the PI * Fix * feat: add & fix examples (#6001) * feat: add & fix examples * fix: ruff + comment * feat: read parameter pool from file * feat: typing + unit test * Disable auto_build on test run * Add field order to OBBject * Revert * lint and revert * fix test --------- Co-authored-by: montezdesousa <79287829+montezdesousa@users.noreply.github.com> Co-authored-by: Diogo Sousa <montezdesousa@gmail.com>
-rw-r--r--openbb_platform/core/openbb_core/app/assets/parameter_pool.json50
-rw-r--r--openbb_platform/core/openbb_core/app/constants.py1
-rw-r--r--openbb_platform/core/openbb_core/app/example_generator.py83
-rw-r--r--openbb_platform/core/openbb_core/app/router.py32
-rw-r--r--openbb_platform/core/openbb_core/app/static/package_builder.py146
-rw-r--r--openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py2
-rw-r--r--openbb_platform/core/tests/app/static/test_example_generator.py59
-rw-r--r--openbb_platform/core/tests/app/test_platform_router.py6
-rw-r--r--openbb_platform/openbb/package/crypto.py2
-rw-r--r--openbb_platform/openbb/package/economy.py8
-rw-r--r--openbb_platform/openbb/package/equity.py2
-rw-r--r--openbb_platform/openbb/package/equity_fundamental.py8
-rw-r--r--openbb_platform/openbb/package/equity_ownership.py50
-rw-r--r--openbb_platform/openbb/package/etf.py18
-rw-r--r--openbb_platform/openbb/package/fixedincome_government.py2
-rw-r--r--openbb_platform/openbb/package/index.py4
-rw-r--r--openbb_platform/openbb/package/regulators_sec.py8
17 files changed, 294 insertions, 187 deletions
diff --git a/openbb_platform/core/openbb_core/app/assets/parameter_pool.json b/openbb_platform/core/openbb_core/app/assets/parameter_pool.json
new file mode 100644
index 00000000000..ae6beee70e9
--- /dev/null
+++ b/openbb_platform/core/openbb_core/app/assets/parameter_pool.json
@@ -0,0 +1,50 @@
+{
+ "crypto": {
+ "symbol": "BTCUSD"
+ },
+ "currency": {
+ "symbol": "EURUSD"
+ },
+ "derivatives": {
+ "symbol": "AAPL"
+ },
+ "economy": {
+ "country": "portugal",
+ "countries": ["portugal", "spain"]
+ },
+ "economy.fred_series": {
+ "symbol": "GFDGDPA188S"
+ },
+ "equity": {
+ "symbol": "AAPL",
+ "symbols": "AAPL,MSFT",
+ "query": "AAPL"
+ },
+ "equity.fundamental.historical_attributes": {
+ "tag": "ebitda"
+ },
+ "equity.fundamental.latest_attributes": {
+ "tag": "ceo"
+ },
+ "equity.fundamental.transcript": {
+ "year": 2020
+ },
+ "etf": {
+ "symbol": "SPY",
+ "query": "Vanguard"
+ },
+ "futures": {
+ "symbol": "ES"
+ },
+ "index": {
+ "symbol": "SPX",
+ "index": "^IBEX"
+ },
+ "news": {
+ "symbols": "AAPL,MSFT"
+ },
+ "regulators": {
+ "symbol": "AAPL",
+ "query": "AAPL"
+ }
+}
diff --git a/openbb_platform/core/openbb_core/app/constants.py b/openbb_platform/core/openbb_core/app/constants.py
index 91996980b23..74ad9bcf6de 100644
--- a/openbb_platform/core/openbb_core/app/constants.py
+++ b/openbb_platform/core/openbb_core/app/constants.py
@@ -2,6 +2,7 @@
from pathlib import Path
+ASSETS_DIRECTORY = Path(__file__).parent / "assets"
HOME_DIRECTORY = Path.home()
OPENBB_DIRECTORY = Path(HOME_DIRECTORY, ".openbb_platform")
USER_SETTINGS_PATH = Path(OPENBB_DIRECTORY, "user_settings.json")
diff --git a/openbb_platform/core/openbb_core/app/example_generator.py b/openbb_platform/core/openbb_core/app/example_generator.py
new file mode 100644
index 00000000000..c661936621d
--- /dev/null
+++ b/openbb_platform/core/openbb_core/app/example_generator.py
@@ -0,0 +1,83 @@
+"""OpenBB Platform example generator."""
+
+import json
+from pathlib import Path
+from typing import (
+ Any,
+ Dict,
+)
+
+from pydantic.fields import FieldInfo
+from pydantic_core import PydanticUndefined
+
+from openbb_core.app.constants import ASSETS_DIRECTORY
+from openbb_core.app.provider_interface import ProviderInterface
+
+try:
+ with Path(ASSETS_DIRECTORY, "parameter_pool.json").open() as f:
+ PARAMETER_POOL = json.load(f)
+except Exception:
+ PARAMETER_POOL = {}
+
+
+class ExampleGenerator:
+ """Generate examples for the API."""
+
+ @staticmethod
+ def _get_value_from_pool(pool: dict, route: str, param: str) -> str:
+ """Get the value from the pool.
+
+ The example parameters can be defined for:
+ - route: "crypto.historical.price": {"symbol": "CRYPTO_HISTORICAL_PRICE_SYMBOL"}
+ - sub-router: "crypto.historical": {"symbol": "CRYPTO_HISTORICAL_SYMBOL"}
+ - router: "crypto": {"symbol": "CRYPTO_SYMBOL"}
+
+ The search for the 'key' is done in the following order:
+ - route
+ - sub-router
+ - router
+ """
+ parts = route.split(".")
+ for i in range(len(parts), 0, -1):
+ partial_route = ".".join(parts[:i])
+ if partial_route in pool and param in pool[partial_route]:
+ return pool[partial_route][param]
+ return "VALUE_NOT_FOUND"
+
+ @classmethod
+ def generate(
+ cls,
+ route: str,
+ model: str,
+ ) -> str:
+ """Generate the example for the command."""
+ if not route or not model:
+ return ""
+
+ standard_params: Dict[str, FieldInfo] = (
+ ProviderInterface()
+ .map.get(model, {})
+ .get("openbb", {})
+ .get("QueryParams", {})
+ .get("fields", {})
+ )
+
+ eg_params: Dict[str, Any] = {}
+ for p, v in standard_params.items():
+ if v.default is not None:
+ if v.default is not PydanticUndefined and v.default != "":
+ eg_params[p] = v.default
+ else:
+ eg_params[p] = cls._get_value_from_pool(PARAMETER_POOL, route, p)
+
+ example = f"obb.{route}("
+ for n, v in eg_params.items():
+ if isinstance(v, str):
+ v = f'"{v}"' # noqa: PLW2901
+ example += f"{n}={v}, "
+ if eg_params:
+ example = example[:-2] + ")"
+ else:
+ example += ")"
+
+ return example
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py
index 210001d2bf3..2426f2f6da9 100644
--- a/openbb_platform/core/openbb_core/app/router.py
+++ b/openbb_platform/core/openbb_core/app/router.py
@@ -24,6 +24,7 @@ from pydantic import BaseModel
from pydantic.v1.validators import find_validators
from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias
+from openbb_core.app.example_generator import ExampleGenerator
from openbb_core.app.extension_loader import ExtensionLoader
from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
@@ -232,16 +233,23 @@ class Router:
model = kwargs.pop("model", "")
deprecation_message = kwargs.pop("deprecation_message", None)
+ examples = kwargs.pop("examples", [])
+ exclude_auto_examples = kwargs.pop("exclude_auto_examples", False)
+
+ if func := SignatureInspector.complete(func, model):
+ if not exclude_auto_examples:
+ examples.insert(
+ 0,
+ ExampleGenerator.generate(
+ route=SignatureInspector.get_operation_id(func, sep="."),
+ model=model,
+ ),
+ )
- if model:
kwargs["response_model_exclude_unset"] = True
- kwargs["openapi_extra"] = {"model": model}
-
- func = SignatureInspector.complete_signature(func, model)
-
- if func:
- CommandValidator.check(func=func, model=model)
-
+ kwargs["openapi_extra"] = kwargs.get("openapi_extra", {})
+ kwargs["openapi_extra"]["model"] = model
+ kwargs["openapi_extra"]["examples"] = examples
kwargs["operation_id"] = kwargs.get(
"operation_id", SignatureInspector.get_operation_id(func)
)
@@ -300,7 +308,7 @@ class SignatureInspector:
"""Inspect function signature."""
@classmethod
- def complete_signature(
+ def complete(
cls, func: Callable[P, OBBject], model: str
) -> Optional[Callable[P, OBBject]]:
"""Complete function signature."""
@@ -321,7 +329,6 @@ class SignatureInspector:
category=OpenBBWarning,
)
return None
-
cls.validate_signature(
func,
{
@@ -445,19 +452,20 @@ class SignatureInspector:
if doc:
description = doc.split(" Parameters\n ----------")[0]
description = description.split(" Returns\n -------")[0]
+ description = description.split(" Example\n -------")[0]
description = "\n".join([line.strip() for line in description.split("\n")])
return description
return ""
@staticmethod
- def get_operation_id(func: Callable) -> str:
+ def get_operation_id(func: Callable, sep: str = "_") -> str:
"""Get operation id."""
operation_id = [
t.replace("_router", "").replace("openbb_", "")
for t in func.__module__.split(".") + [func.__name__]
]
- cleaned_id = "_".join({c: "" for c in operation_id if c}.keys())
+ cleaned_id = sep.join({c: "" for c in operation_id if c}.keys())
return cleaned_id
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py
index 01b8946015b..7648ed89511 100644
--- a/openbb_platform/core/openbb_core/app/static/package_builder.py
+++ b/openbb_platform/core/openbb_core/app/static/package_builder.py
@@ -9,11 +9,9 @@ from inspect import Parameter, _empty, isclass, signature
from json import dumps, load
from pathlib import Path
from typing import (
- Any,
Callable,
Dict,
List,
- Literal,
Optional,
OrderedDict,
Set,
@@ -22,7 +20,6 @@ from typing import (
TypeVar,
Union,
get_args,
- get_origin,
get_type_hints,
)
@@ -38,7 +35,7 @@ from openbb_core.app.charting_service import ChartingService
from openbb_core.app.extension_loader import ExtensionLoader, OpenBBGroups
from openbb_core.app.model.custom_parameter import OpenBBCustomParameter
from openbb_core.app.provider_interface import ProviderInterface
-from openbb_core.app.router import CommandMap, RouterLoader
+from openbb_core.app.router import RouterLoader
from openbb_core.app.static.utils.console import Console
from openbb_core.app.static.utils.linters import Linters
from openbb_core.env import Env
@@ -389,6 +386,7 @@ class ClassDefinition:
if route.openapi_extra
else None
),
+ examples=route.openapi_extra.get("examples", None),
) # type: ignore
else:
doc += " /" if path else " /"
@@ -664,12 +662,16 @@ class MethodDefinition:
func: Callable,
formatted_params: OrderedDict[str, Parameter],
model_name: Optional[str] = None,
+ examples: Optional[List[str]] = None,
):
"""Build the command method docstring."""
doc = func.__doc__
if model_name:
doc = DocstringGenerator.generate(
- func=func, formatted_params=formatted_params, model_name=model_name
+ func=func,
+ formatted_params=formatted_params,
+ model_name=model_name,
+ examples=examples,
)
code = f' """{doc} """ # noqa: E501\n\n' if doc else ""
@@ -728,7 +730,11 @@ class MethodDefinition:
@classmethod
def build_command_method(
- cls, path: str, func: Callable, model_name: Optional[str] = None
+ cls,
+ path: str,
+ func: Callable,
+ model_name: Optional[str] = None,
+ examples: Optional[List[str]] = None,
) -> str:
"""Build the command method."""
func_name = func.__name__
@@ -745,7 +751,10 @@ class MethodDefinition:
model_name=model_name,
)
code += cls.build_command_method_doc(
- func=func, formatted_params=formatted_params, model_name=model_name
+ func=func,
+ formatted_params=formatted_params,
+ model_name=model_name,
+ examples=examples,
)
code += cls.build_command_method_body(path=path, func=func)
@@ -780,114 +789,6 @@ class DocstringGenerator:
return obbject_description
- @staticmethod
- def get_model_standard_params(param_fields: Dict[str, FieldInfo]) -> Dict[str, Any]:
- """Get the test params for the fetcher based on the required standard params."""
- test_params: Dict[str, Any] = {}
- for field_name, field in param_fields.items():
- if field.default and field.default is not PydanticUndefined:
- test_params[field_name] = field.default
- elif field.default and field.default is PydanticUndefined:
- example_dict = {
- "symbol": "AAPL",
- "symbols": "AAPL,MSFT",
- "start_date": "2023-01-01",
- "end_date": "2023-06-06",
- "country": "Portugal",
- "date": "2023-01-01",
- "countries": ["portugal", "spain"],
- }
- if field_name in example_dict:
- test_params[field_name] = example_dict[field_name]
- elif field.annotation == str:
- test_params[field_name] = "TEST_STRING"
- elif field.annotation == int:
- test_params[field_name] = 1
- elif field.annotation == float:
- test_params[field_name] = 1.0
- elif field.annotation == bool:
- test_params[field_name] = True
- elif get_origin(field.annotation) is Literal: # type: ignore
- option = field.annotation.__args__[0] # type: ignore
- if isinstance(option, str):
- test_params[field_name] = f'"{option}"'
- else:
- test_params[field_name] = option
-
- return test_params
-
- @staticmethod
- def get_full_command_name(route: str) -> str:
- """Get the full command name."""
- cmd_parts = route.split("/")
- del cmd_parts[0]
-
- menu = cmd_parts[0]
- command = cmd_parts[-1]
- sub_menus = cmd_parts[1:-1]
-
- sub_menu_str_cmd = f".{'.'.join(sub_menus)}" if sub_menus else ""
-
- full_command = f"{menu}{sub_menu_str_cmd}.{command}"
-
- return full_command
-
- @classmethod
- def generate_example(
- cls,
- model_name: str,
- standard_params: Dict[str, FieldInfo],
- ) -> str:
- """Generate the example for the command."""
- # find the model router here
- cm = CommandMap()
- commands_model = cm.commands_model
- route = [k for k, v in commands_model.items() if v == model_name]
-
- if not route:
- return ""
-
- full_command_name = cls.get_full_command_name(route=route[0])
- example_params = cls.get_model_standard_params(param_fields=standard_params)
-
- # Edge cases (might find more)
- if "crypto" in route[0] and "symbol" in example_params:
- example_params["symbol"] = "BTCUSD"
- elif "currency" in route[0] and "symbol" in example_params:
- example_params["symbol"] = "EURUSD"
- elif (
- "index" in route[0]
- and "european" not in route[0]
- and "symbol" in example_params
- ):
- example_params["symbol"] = "SPX"
- elif (
- "index" in route[0]
- and "european" in route[0]
- and "symbol" in example_params
- ):
- example_params["symbol"] = "BUKBUS"
- elif (
- "futures" in route[0] and "curve" in route[0] and "symbol" in example_params
- ):
- example_params["symbol"] = "VX"
- elif "futures" in route[0] and "symbol" in example_params:
- example_params["symbol"] = "ES"
-
- example = "\n Example\n -------\n"
- example += " >>> from openbb import obb\n"
- example += f" >>> obb.{full_command_name}("
- for param_name, param_value in example_params.items():
- if isinstance(param_value, str):
- param_value = f'"{param_value}"' # noqa: PLW2901
- example += f"{param_name}={param_value}, "
- if example_params:
- example = example[:-2] + ")\n"
- else:
- example += ")\n"
-
- return example
-
@classmethod
def generate_model_docstring(
cls,
@@ -897,6 +798,7 @@ class DocstringGenerator:
params: dict,
returns: Dict[str, FieldInfo],
results_type: str,
+ examples: Optional[List[str]] = None,
) -> str:
"""Create the docstring for model."""
@@ -922,9 +824,11 @@ class DocstringGenerator:
"openbb"
]["QueryParams"]["fields"]
- example_docstring = cls.generate_example(
- model_name=model_name, standard_params=obb_query_fields
- )
+ if examples:
+ example_docstring = "\n Example\n -------\n"
+ example_docstring += " >>> from openbb import obb\n"
+ for example in examples:
+ example_docstring += f" >>> {example}\n"
docstring = summary
docstring += "\n\n"
@@ -1014,7 +918,9 @@ class DocstringGenerator:
docstring += f" {field.alias or name} : {field_type}\n"
docstring += f" {format_description(description)}\n"
- docstring += example_docstring
+ if examples:
+ docstring += example_docstring
+
return docstring
@classmethod
@@ -1023,6 +929,7 @@ class DocstringGenerator:
func: Callable,
formatted_params: OrderedDict[str, Parameter],
model_name: Optional[str] = None,
+ examples: Optional[List[str]] = None,
) -> Optional[str]:
"""Generate the docstring for the function."""
doc = func.__doc__
@@ -1045,6 +952,7 @@ class DocstringGenerator:
params=params,
returns=returns,
results_type=results_type,
+ examples=examples,
)
return doc
return doc
diff --git a/openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py b/openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py
index 943ecb27907..b1334471d8b 100644
--- a/openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py
+++ b/openbb_platform/core/openbb_core/provider/standard_models/crypto_search.py
@@ -12,7 +12,7 @@ from openbb_core.provider.utils.descriptions import DATA_DESCRIPTIONS
class CryptoSearchQueryParams(QueryParams):
"""Crypto Search Query."""
- query: Optional[str] = Field(description="Search query.", default="")
+ query: Optional[str] = Field(description="Search query.", default=None)
class CryptoSearchData(Data):
diff --git a/openbb_platform/core/tests/app/static/test_example_generator.py b/openbb_platform/core/tests/app/static/test_example_generator.py
new file mode 100644
index 00000000000..86029fabe13
--- /dev/null
+++ b/openbb_platform/core/tests/app/static/test_example_generator.py
@@ -0,0 +1,59 @@
+"""Test the example_generator.py file."""
+
+# pylint: disable=redefined-outer-name, protected-access
+
+
+import pytest
+from openbb_core.app.example_generator import ExampleGenerator
+
+
+@pytest.fixture(scope="module")
+def example_generator():
+ """Return example generator."""
+ return ExampleGenerator()
+
+
+def test_docstring_generator_init(example_generator):
+ """Test example generator init."""
+ assert example_generator
+
+
+TEST_POOL = {
+ "crypto": {"symbol": "CRYPTO_SYMBOL"},
+ "crypto.search": {"symbol": "CRYPTO_SEARCH_SYMBOL"},
+ "crypto.price.historical": {"symbol": "CRYPTO_HISTORICAL_PRICE_SYMBOL"},
+}
+
+
+@pytest.mark.parametrize(
+ "route, param, expected",
+ [
+ ("", "", "VALUE_NOT_FOUND"),
+ ("random_route", "", "VALUE_NOT_FOUND"),
+ ("crypto", "symbol", "CRYPTO_SYMBOL"),
+ ("crypto.search", "symbol", "CRYPTO_SEARCH_SYMBOL"),
+ ("crypto.price.historical", "symbol", "CRYPTO_HISTORICAL_PRICE_SYMBOL"),
+ ("crypto.price.historical", "random_param", "VALUE_NOT_FOUND"),
+ ],
+)
+def test_get_value_from_pool(example_generator, route, param, expected):
+ """Test get value from pool."""
+ assert example_generator._get_value_from_pool(TEST_POOL, route, param) == expected
+
+
+@pytest.mark.parametrize(
+ "route, model, expected",
+ [
+ ("", "", ""),
+ ("random", "test", "obb.random()"),
+ ("crypto.search", "CryptoSearch", "obb.crypto.search()"),
+ (
+ "crypto.price.historical",
+ "CryptoHistorical",
+ 'obb.crypto.price.historical(symbol="BTCUSD")',
+ ),
+ ],
+)
+def test_generate(example_generator, route, model, expected):
+ """Test generate example."""
+ assert example_generator.generate(route, model) == expected
diff --git a/openbb_platform/core/tests/app/test_platform_router.py b/openbb_platform/core/tests/app/test_platform_router.py
index 14238ff0c11..590a2709aa6 100644
--- a/openbb_platform/core/tests/app/test_platform_router.py
+++ b/openbb_platform/core/tests/app/test_platform_router.py
@@ -197,7 +197,7 @@ def test_complete_signature(signature_inspector):
model = "EquityHistorical"
- assert signature_inspector.complete_signature(sample_function, model)
+ assert signature_inspector.complete(sample_function, model)
def test_complete_signature_error(signature_inspector):
@@ -206,9 +206,7 @@ def test_complete_signature_error(signature_inspector):
async def valid_function() -> OBBject[Optional[List[int]]]:
return OBBject(results=[1, 2, 3])
- assert (
- signature_inspector.complete_signature(valid_function, "invalid_model") is None
- )
+ assert signature_inspector.complete(valid_function, "invalid_model") is None
def test_validate_signature(signature_inspector):
diff --git a/openbb_platform/openbb/package/crypto.py b/openbb_platform/openbb/package/crypto.py
index 3839ceb2608..85a700d2531 100644
--- a/openbb_platform/openbb/package/crypto.py
+++ b/openbb_platform/openbb/package/crypto.py
@@ -31,7 +31,7 @@ class ROUTER_crypto(Container):
self,
query: Annotated[
Optional[str], OpenBBCustomParameter(description="Search query.")
- ] = "",
+ ] = None,
provider: Optional[Literal["fmp"]] = None,
**kwargs
) -> OBBject:
diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py
index e90862e7041..de1285f8335 100644
--- a/openbb_platform/openbb/package/economy.py
+++ b/openbb_platform/openbb/package/economy.py
@@ -277,7 +277,7 @@ class ROUTER_economy(Container):
Example
-------
>>> from openbb import obb
- >>> obb.economy.cpi(countries=['portugal', 'spain'], units="growth_same", frequency="monthly")
+ >>> obb.economy.cpi(countries=['portugal', 'spain'], units="growth_same", frequency="monthly", harmonized=False)
""" # noqa: E501
return self._run(
@@ -323,7 +323,7 @@ class ROUTER_economy(Container):
no default.
is_release : Optional[bool]
Is release? If True, other search filter variables are ignored. If no query text or release_id is supplied, this defaults to True. (provider: fred)
- release_id : Optional[Union[str, int]]
+ release_id : Optional[Union[int, str]]
A specific release ID to target. (provider: fred)
limit : Optional[int]
The number of data entries to return. (1-1000) (provider: fred)
@@ -354,7 +354,7 @@ class ROUTER_economy(Container):
FredSearch
----------
- release_id : Optional[Union[str, int]]
+ release_id : Optional[Union[int, str]]
The release ID for queries.
series_id : Optional[str]
The series ID for the item in the release.
@@ -521,7 +521,7 @@ class ROUTER_economy(Container):
Example
-------
>>> fr