summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenrique Joaquim <h.joaquim@campus.fct.unl.pt>2024-03-07 07:26:48 +0000
committerGitHub <noreply@github.com>2024-03-07 07:26:48 +0000
commit48c508b292a49cdfdf29cde547fe0cb4b32e0466 (patch)
treecc804b40debf7fe4b9357a8bb91c705c8674f283
parenta8122e9c037a1ded402adc2fb80cf160fa45688a (diff)
[Feature] Custom choices (#6169)
* custom choices * remove unnecessary choices from the extra info * static assets * simpler way to pop choices on extra * Getting json_schema_extra without changing the original dict --------- Co-authored-by: montezdesousa <79287829+montezdesousa@users.noreply.github.com>
-rw-r--r--openbb_platform/core/openbb_core/app/model/custom_parameter.py9
-rw-r--r--openbb_platform/core/openbb_core/app/static/package_builder.py48
-rw-r--r--openbb_platform/core/openbb_core/provider/standard_models/cpi.py5
-rw-r--r--openbb_platform/openbb/package/economy.py58
4 files changed, 102 insertions, 18 deletions
diff --git a/openbb_platform/core/openbb_core/app/model/custom_parameter.py b/openbb_platform/core/openbb_core/app/model/custom_parameter.py
index 0df8298e59e..06b947e0c55 100644
--- a/openbb_platform/core/openbb_core/app/model/custom_parameter.py
+++ b/openbb_platform/core/openbb_core/app/model/custom_parameter.py
@@ -2,6 +2,8 @@ import sys
from dataclasses import dataclass
from typing import Dict, Optional
+from typing_extensions import LiteralString
+
# `slots` is available on Python >= 3.10
if sys.version_info >= (3, 10):
slots_true = {"slots": True}
@@ -24,3 +26,10 @@ class OpenBBCustomParameter(BaseMetadata):
"""Custom parameter for OpenBB."""
description: Optional[str] = None
+
+
+@dataclass(frozen=True, **slots_true)
+class OpenBBCustomChoices(BaseMetadata):
+ """Custom choices for OpenBB."""
+
+ choices: Optional[LiteralString] = None
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py
index 924baa69347..a64c0fd1fd2 100644
--- a/openbb_platform/core/openbb_core/app/static/package_builder.py
+++ b/openbb_platform/core/openbb_core/app/static/package_builder.py
@@ -36,7 +36,10 @@ from starlette.routing import BaseRoute
from typing_extensions import Annotated, _AnnotatedAlias
from openbb_core.app.extension_loader import ExtensionLoader, OpenBBGroups
-from openbb_core.app.model.custom_parameter import OpenBBCustomParameter
+from openbb_core.app.model.custom_parameter import (
+ OpenBBCustomChoices,
+ OpenBBCustomParameter,
+)
from openbb_core.app.provider_interface import ProviderInterface
from openbb_core.app.router import CommandMap, RouterLoader
from openbb_core.app.static.utils.console import Console
@@ -330,9 +333,7 @@ class ImportDefinition:
hint_type_list = cls.get_path_hint_type_list(path=path)
code = "from openbb_core.app.static.container import Container"
code += "\nfrom openbb_core.app.model.obbject import OBBject"
- code += (
- "\nfrom openbb_core.app.model.custom_parameter import OpenBBCustomParameter"
- )
+ code += "\nfrom openbb_core.app.model.custom_parameter import OpenBBCustomParameter, OpenBBCustomChoices"
# These imports were not detected before build, so we add them manually and
# ruff --fix the resulting code to remove unused imports.
@@ -500,7 +501,10 @@ class MethodDefinition:
"""Get json schema extra."""
field_default = getattr(field, "default", None)
if field_default:
- return getattr(field_default, "json_schema_extra", {})
+ # Getting json_schema_extra without changing the original dict
+ json_schema_extra = getattr(field_default, "json_schema_extra", {}).copy()
+ json_schema_extra.pop("choices", None)
+ return json_schema_extra
return {}
@staticmethod
@@ -604,10 +608,10 @@ class MethodDefinition:
return MethodDefinition.reorder_params(params=formatted)
@staticmethod
- def add_field_descriptions(
+ def add_field_custom_annotations(
od: OrderedDict[str, Parameter], model_name: Optional[str] = None
):
- """Add the field description to the param signature."""
+ """Add the field custom description and choices to the param signature as annotations."""
if model_name:
available_fields: Dict[str, Field] = (
ProviderInterface().params[model_name]["standard"].__dataclass_fields__
@@ -617,16 +621,28 @@ class MethodDefinition:
if param not in available_fields:
continue
- field = available_fields[param]
+ field_default = available_fields[param].default
- new_value = value.replace(
- annotation=Annotated[
- value.annotation,
- OpenBBCustomParameter(
- description=getattr(field.default, "description", "")
- ),
- ],
+ choices = getattr(field_default, "json_schema_extra", {}).get(
+ "choices", []
)
+ description = getattr(field_default, "description", "")
+
+ if choices:
+ new_value = value.replace(
+ annotation=Annotated[
+ value.annotation,
+ OpenBBCustomParameter(description=description),
+ OpenBBCustomChoices(choices=choices),
+ ],
+ )
+ else:
+ new_value = value.replace(
+ annotation=Annotated[
+ value.annotation,
+ OpenBBCustomParameter(description=description),
+ ],
+ )
od[param] = new_value
@@ -667,7 +683,7 @@ class MethodDefinition:
model_name: Optional[str] = None,
) -> str:
"""Build the command method signature."""
- MethodDefinition.add_field_descriptions(
+ MethodDefinition.add_field_custom_annotations(
od=formatted_params, model_name=model_name
) # this modified `od` in place
func_params = MethodDefinition.build_func_params(formatted_params)
diff --git a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py
index cb4d7afbdde..a1e37f81ba6 100644
--- a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py
+++ b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py
@@ -74,7 +74,10 @@ CPI_FREQUENCY = Literal["monthly", "quarter", "annual"]
class ConsumerPriceIndexQueryParams(QueryParams):
"""CPI Query."""
- country: str = Field(description=QUERY_DESCRIPTIONS.get("country"))
+ country: str = Field(
+ description=QUERY_DESCRIPTIONS.get("country"),
+ choices=CPI_COUNTRIES, # type: ignore
+ )
units: CPI_UNITS = Field(
default="growth_same",
description=QUERY_DESCRIPTIONS.get("units", "")
diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py
index 9f4556f72fb..4b0149ad315 100644
--- a/openbb_platform/openbb/package/economy.py
+++ b/openbb_platform/openbb/package/economy.py
@@ -3,7 +3,10 @@
import datetime
from typing import List, Literal, Optional, Union
-from openbb_core.app.model.custom_parameter import OpenBBCustomParameter
+from openbb_core.app.model.custom_parameter import (
+ OpenBBCustomChoices,
+ OpenBBCustomParameter,
+)
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.static.container import Container
from openbb_core.app.static.utils.decorators import exception_handler, validate
@@ -240,6 +243,59 @@ class ROUTER_economy(Container):
OpenBBCustomParameter(
description="The country to get data. Multiple items allowed for provider(s): fred."
),
+ OpenBBCustomChoices(
+ choices=[
+ "australia",
+ "austria",
+ "belgium",
+ "brazil",
+ "bulgaria",
+ "canada",
+ "chile",
+ "china",
+ "croatia",
+ "cyprus",
+ "czech_republic",
+ "denmark",
+ "estonia",
+ "euro_area",
+ "finland",
+ "france",
+ "germany",
+ "greece",
+ "hungary",
+ "iceland",
+ "india",
+ "indonesia",
+ "ireland",
+ "israel",
+ "italy",
+ "japan",
+ "korea",
+ "latvia",
+ "lithuania",
+ "luxembourg",
+ "malta",
+ "mexico",
+ "netherlands",
+ "new_zealand",
+ "norway",
+ "poland",
+ "portugal",
+ "romania",
+ "russian_federation",
+ "slovak_republic",
+ "slovakia",
+ "slovenia",
+ "south_africa",
+ "spain",
+ "sweden",
+ "switzerland",
+ "turkey",
+ "united_kingdom",
+ "united_states",
+ ]
+ ),
],
units: Annotated[
Literal["growth_previous", "growth_same", "index_2015"],