diff options
author | Henrique Joaquim <h.joaquim@campus.fct.unl.pt> | 2024-03-07 07:26:48 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-07 07:26:48 +0000 |
commit | 48c508b292a49cdfdf29cde547fe0cb4b32e0466 (patch) | |
tree | cc804b40debf7fe4b9357a8bb91c705c8674f283 | |
parent | a8122e9c037a1ded402adc2fb80cf160fa45688a (diff) |
[Feature] Custom choices (#6169)
* custom choices
* remove unnecessary choices from the extra info
* static assets
* simpler way to pop choices on extra
* Getting json_schema_extra without changing the original dict
---------
Co-authored-by: montezdesousa <79287829+montezdesousa@users.noreply.github.com>
4 files changed, 102 insertions, 18 deletions
diff --git a/openbb_platform/core/openbb_core/app/model/custom_parameter.py b/openbb_platform/core/openbb_core/app/model/custom_parameter.py index 0df8298e59e..06b947e0c55 100644 --- a/openbb_platform/core/openbb_core/app/model/custom_parameter.py +++ b/openbb_platform/core/openbb_core/app/model/custom_parameter.py @@ -2,6 +2,8 @@ import sys from dataclasses import dataclass from typing import Dict, Optional +from typing_extensions import LiteralString + # `slots` is available on Python >= 3.10 if sys.version_info >= (3, 10): slots_true = {"slots": True} @@ -24,3 +26,10 @@ class OpenBBCustomParameter(BaseMetadata): """Custom parameter for OpenBB.""" description: Optional[str] = None + + +@dataclass(frozen=True, **slots_true) +class OpenBBCustomChoices(BaseMetadata): + """Custom choices for OpenBB.""" + + choices: Optional[LiteralString] = None diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index 924baa69347..a64c0fd1fd2 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -36,7 +36,10 @@ from starlette.routing import BaseRoute from typing_extensions import Annotated, _AnnotatedAlias from openbb_core.app.extension_loader import ExtensionLoader, OpenBBGroups -from openbb_core.app.model.custom_parameter import OpenBBCustomParameter +from openbb_core.app.model.custom_parameter import ( + OpenBBCustomChoices, + OpenBBCustomParameter, +) from openbb_core.app.provider_interface import ProviderInterface from openbb_core.app.router import CommandMap, RouterLoader from openbb_core.app.static.utils.console import Console @@ -330,9 +333,7 @@ class ImportDefinition: hint_type_list = cls.get_path_hint_type_list(path=path) code = "from openbb_core.app.static.container import Container" code += "\nfrom openbb_core.app.model.obbject import OBBject" - code += ( - "\nfrom openbb_core.app.model.custom_parameter import OpenBBCustomParameter" - ) + code += "\nfrom openbb_core.app.model.custom_parameter import OpenBBCustomParameter, OpenBBCustomChoices" # These imports were not detected before build, so we add them manually and # ruff --fix the resulting code to remove unused imports. @@ -500,7 +501,10 @@ class MethodDefinition: """Get json schema extra.""" field_default = getattr(field, "default", None) if field_default: - return getattr(field_default, "json_schema_extra", {}) + # Getting json_schema_extra without changing the original dict + json_schema_extra = getattr(field_default, "json_schema_extra", {}).copy() + json_schema_extra.pop("choices", None) + return json_schema_extra return {} @staticmethod @@ -604,10 +608,10 @@ class MethodDefinition: return MethodDefinition.reorder_params(params=formatted) @staticmethod - def add_field_descriptions( + def add_field_custom_annotations( od: OrderedDict[str, Parameter], model_name: Optional[str] = None ): - """Add the field description to the param signature.""" + """Add the field custom description and choices to the param signature as annotations.""" if model_name: available_fields: Dict[str, Field] = ( ProviderInterface().params[model_name]["standard"].__dataclass_fields__ @@ -617,16 +621,28 @@ class MethodDefinition: if param not in available_fields: continue - field = available_fields[param] + field_default = available_fields[param].default - new_value = value.replace( - annotation=Annotated[ - value.annotation, - OpenBBCustomParameter( - description=getattr(field.default, "description", "") - ), - ], + choices = getattr(field_default, "json_schema_extra", {}).get( + "choices", [] ) + description = getattr(field_default, "description", "") + + if choices: + new_value = value.replace( + annotation=Annotated[ + value.annotation, + OpenBBCustomParameter(description=description), + OpenBBCustomChoices(choices=choices), + ], + ) + else: + new_value = value.replace( + annotation=Annotated[ + value.annotation, + OpenBBCustomParameter(description=description), + ], + ) od[param] = new_value @@ -667,7 +683,7 @@ class MethodDefinition: model_name: Optional[str] = None, ) -> str: """Build the command method signature.""" - MethodDefinition.add_field_descriptions( + MethodDefinition.add_field_custom_annotations( od=formatted_params, model_name=model_name ) # this modified `od` in place func_params = MethodDefinition.build_func_params(formatted_params) diff --git a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py index cb4d7afbdde..a1e37f81ba6 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py @@ -74,7 +74,10 @@ CPI_FREQUENCY = Literal["monthly", "quarter", "annual"] class ConsumerPriceIndexQueryParams(QueryParams): """CPI Query.""" - country: str = Field(description=QUERY_DESCRIPTIONS.get("country")) + country: str = Field( + description=QUERY_DESCRIPTIONS.get("country"), + choices=CPI_COUNTRIES, # type: ignore + ) units: CPI_UNITS = Field( default="growth_same", description=QUERY_DESCRIPTIONS.get("units", "") diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py index 9f4556f72fb..4b0149ad315 100644 --- a/openbb_platform/openbb/package/economy.py +++ b/openbb_platform/openbb/package/economy.py @@ -3,7 +3,10 @@ import datetime from typing import List, Literal, Optional, Union -from openbb_core.app.model.custom_parameter import OpenBBCustomParameter +from openbb_core.app.model.custom_parameter import ( + OpenBBCustomChoices, + OpenBBCustomParameter, +) from openbb_core.app.model.obbject import OBBject from openbb_core.app.static.container import Container from openbb_core.app.static.utils.decorators import exception_handler, validate @@ -240,6 +243,59 @@ class ROUTER_economy(Container): OpenBBCustomParameter( description="The country to get data. Multiple items allowed for provider(s): fred." ), + OpenBBCustomChoices( + choices=[ + "australia", + "austria", + "belgium", + "brazil", + "bulgaria", + "canada", + "chile", + "china", + "croatia", + "cyprus", + "czech_republic", + "denmark", + "estonia", + "euro_area", + "finland", + "france", + "germany", + "greece", + "hungary", + "iceland", + "india", + "indonesia", + "ireland", + "israel", + "italy", + "japan", + "korea", + "latvia", + "lithuania", + "luxembourg", + "malta", + "mexico", + "netherlands", + "new_zealand", + "norway", + "poland", + "portugal", + "romania", + "russian_federation", + "slovak_republic", + "slovakia", + "slovenia", + "south_africa", + "spain", + "sweden", + "switzerland", + "turkey", + "united_kingdom", + "united_states", + ] + ), ], units: Annotated[ Literal["growth_previous", "growth_same", "index_2015"], |