diff options
author | montezdesousa <79287829+montezdesousa@users.noreply.github.com> | 2024-02-08 16:47:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-08 16:47:55 +0000 |
commit | 6088c98ceb290ed8c9f843455050a2bee5560331 (patch) | |
tree | 59dc51faae3a25788aa339f4285a12c6bbe91568 | |
parent | 3ca266d9b4491f1d86392b39beb8020ee0442acc (diff) |
[Feature] - Standardise multiple symbols input (#6056)
* feat: remove list of str
* fix: update cpi
* package builder + cpi
* build
* fix: update unittests
* fix: check_item
* fix: unittests
26 files changed, 180 insertions, 249 deletions
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index 8b6f14d3fdf..966726d6a8f 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -520,7 +520,6 @@ class MethodDefinition: # Be careful, if the type is not coercible by pydantic to the original type, you # will need to add some conversion code in the input filter. TYPE_EXPANSION = { - "symbol": List[str], "data": DataProcessingSupportedTypes, "start_date": str, "end_date": str, @@ -734,9 +733,6 @@ class MethodDefinition: value = {k: k for k in fields} code += f" {name}={{\n" for k, v in value.items(): - if k == "symbol": - code += f' "{k}": ",".join(symbol) if isinstance(symbol, list) else symbol, \n' - continue code += f' "{k}": {v},\n' code += " },\n" else: diff --git a/openbb_platform/core/openbb_core/provider/standard_models/bond_prices.py b/openbb_platform/core/openbb_core/provider/standard_models/bond_prices.py index dcb3c0d8248..7af10935681 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/bond_prices.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/bond_prices.py @@ -16,7 +16,7 @@ class BondPricesQueryParams(QueryParams): country: Optional[str] = Field( default=None, - description="Country of the bond issuer. Matches partial name.", + description="The country to get data. Matches partial name.", ) issuer_name: Optional[str] = Field( default=None, diff --git a/openbb_platform/core/openbb_core/provider/standard_models/bond_reference.py b/openbb_platform/core/openbb_core/provider/standard_models/bond_reference.py index 37c0ee4a622..3724a2b7882 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/bond_reference.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/bond_reference.py @@ -16,7 +16,7 @@ class BondReferenceQueryParams(QueryParams): country: Optional[str] = Field( default=None, - description="Country of the bond issuer. Matches partial name.", + description="The country to get data. Matches partial name.", ) issuer_name: Optional[str] = Field( default=None, diff --git a/openbb_platform/core/openbb_core/provider/standard_models/bond_trades.py b/openbb_platform/core/openbb_core/provider/standard_models/bond_trades.py index 29ac4487871..41ccd455211 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/bond_trades.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/bond_trades.py @@ -21,7 +21,7 @@ class BondTradesQueryParams(QueryParams): country: Optional[str] = Field( default=None, - description="Country of the bond issuer. Matches partial name.", + description="The country to get data. Matches partial name.", ) isin: Optional[str] = Field( default=None, diff --git a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py index fa4115d75aa..24a94721864 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py @@ -1,7 +1,7 @@ """CPI Standard Model.""" from datetime import date as dateType -from typing import List, Literal, Optional +from typing import Literal, Optional from dateutil import parser from pydantic import Field, field_validator @@ -12,8 +12,9 @@ from openbb_core.provider.utils.descriptions import ( DATA_DESCRIPTIONS, QUERY_DESCRIPTIONS, ) +from openbb_core.provider.utils.helpers import check_item -CPI_COUNTRIES = Literal[ +CPI_COUNTRIES = [ "australia", "austria", "belgium", @@ -73,9 +74,7 @@ CPI_FREQUENCY = Literal["monthly", "quarter", "annual"] class ConsumerPriceIndexQueryParams(QueryParams): """CPI Query.""" - countries: List[CPI_COUNTRIES] = Field( - description=QUERY_DESCRIPTIONS.get("countries") - ) + country: str = Field(description=QUERY_DESCRIPTIONS.get("country")) units: CPI_UNITS = Field( default="growth_same", description=QUERY_DESCRIPTIONS.get("units", "") @@ -103,6 +102,16 @@ class ConsumerPriceIndexQueryParams(QueryParams): default=None, description=QUERY_DESCRIPTIONS.get("end_date") ) + @field_validator("country", mode="before", check_fields=False) + def validate_country(cls, c: str): # pylint: disable=E0213 + """Validate country.""" + result = [] + values = c.split(",") + for v in values: + check_item(v.lower(), CPI_COUNTRIES) + result.append(v.lower()) + return ",".join(result) + class ConsumerPriceIndexData(Data): """CPI data.""" diff --git a/openbb_platform/core/openbb_core/provider/utils/descriptions.py b/openbb_platform/core/openbb_core/provider/utils/descriptions.py index 2ad3aeca59e..70758ef111f 100644 --- a/openbb_platform/core/openbb_core/provider/utils/descriptions.py +++ b/openbb_platform/core/openbb_core/provider/utils/descriptions.py @@ -8,6 +8,7 @@ QUERY_DESCRIPTIONS = { "period": "Time period of the data to return.", "date": "A specific date to get data for.", "limit": "The number of data entries to return.", + "country": "The country to get data.", "countries": "The country or countries to get data.", "units": "The unit of measurement for the data.", "frequency": "The frequency of the data.", diff --git a/openbb_platform/core/openbb_core/provider/utils/helpers.py b/openbb_platform/core/openbb_core/provider/utils/helpers.py index e1977ea8ac7..99064cd8396 100644 --- a/openbb_platform/core/openbb_core/provider/utils/helpers.py +++ b/openbb_platform/core/openbb_core/provider/utils/helpers.py @@ -3,6 +3,7 @@ import asyncio import re from datetime import datetime +from difflib import SequenceMatcher from functools import partial from inspect import iscoroutinefunction from typing import Awaitable, Callable, List, Literal, Optional, TypeVar, Union, cast @@ -22,6 +23,33 @@ T = TypeVar("T") P = ParamSpec("P") +def check_item(item: str, allowed: List[str], threshold: float = 0.75) -> None: + """Check if an item is in a list of allowed items and raise an error if not. + + Parameters + ---------- + item : str + The item to check. + allowed : List[str] + The list of allowed items. + threshold : float, optional + The similarity threshold for the error message, by default 0.75 + + Raises + ------ + ValueError + If the item is not in the allowed list. + """ + if item not in allowed: + similarities = map( + lambda c: (c, SequenceMatcher(None, item, c).ratio()), allowed + ) + similar, score = max(similarities, key=lambda x: x[1]) + if score > threshold: + raise ValueError(f"'{item}' is not available. Did you mean '{similar}'?") + raise ValueError(f"'{item}' is not available.") + + def get_querystring(items: dict, exclude: List[str]) -> str: """Turn a dictionary into a querystring, excluding the keys in the exclude list. diff --git a/openbb_platform/extensions/economy/integration/test_economy_api.py b/openbb_platform/extensions/economy/integration/test_economy_api.py index 0ba2eb5ca7b..dcdcb21ef1b 100644 --- a/openbb_platform/extensions/economy/integration/test_economy_api.py +++ b/openbb_platform/extensions/economy/integration/test_economy_api.py @@ -63,7 +63,7 @@ def test_economy_calendar(params, headers): [ ( { - "countries": "spain", + "country": "spain", "units": "growth_same", "frequency": "monthly", "harmonized": True, @@ -74,7 +74,7 @@ def test_economy_calendar(params, headers): ), ( { - "countries": ["portugal", "spain"], + "country": "portugal,spain", "units": "growth_same", "frequency": "monthly", "harmonized": True, diff --git a/openbb_platform/extensions/economy/integration/test_economy_python.py b/openbb_platform/extensions/economy/integration/test_economy_python.py index d7e5dbf64fd..95c73732055 100644 --- a/openbb_platform/extensions/economy/integration/test_economy_python.py +++ b/openbb_platform/extensions/economy/integration/test_economy_python.py @@ -57,7 +57,7 @@ def test_economy_calendar(params, obb): [ ( { - "countries": ["portugal", "spain"], + "country": "portugal,spain", "units": "growth_same", "frequency": "monthly", "harmonized": True, diff --git a/openbb_platform/openbb/package/crypto_price.py b/openbb_platform/openbb/package/crypto_price.py index 549d3b2f60c..28673b562c5 100644 --- a/openbb_platform/openbb/package/crypto_price.py +++ b/openbb_platform/openbb/package/crypto_price.py @@ -1,7 +1,7 @@ ### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ### import datetime -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from openbb_core.app.model.custom_parameter import OpenBBCustomParameter from openbb_core.app.model.obbject import OBBject @@ -23,7 +23,7 @@ class ROUTER_crypto_price(Container): def historical( self, symbol: Annotated[ - Union[str, List[str]], + str, OpenBBCustomParameter( description="Symbol to get data for. Can use CURR1-CURR2 or CURR1CURR2 format." ), @@ -134,7 +134,7 @@ class ROUTER_crypto_price(Container): "provider": provider, }, standard_params={ - "symbol": ",".join(symbol) if isinstance(symbol, list) else symbol, + "symbol": symbol, "start_date": start_date, "end_date": end_date, }, diff --git a/openbb_platform/openbb/package/currency_price.py b/openbb_platform/openbb/package/currency_price.py index c3a0562390a..1e757613380 100644 --- a/openbb_platform/openbb/package/currency_price.py +++ b/openbb_platform/openbb/package/currency_price.py @@ -1,7 +1,7 @@ ### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ### import datetime -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from openbb_core.app.model.custom_parameter import OpenBBCustomParameter from openbb_core.app.model.obbject import OBBject @@ -23,7 +23,7 @@ class ROUTER_currency_price(Container): def historical( self, symbol: Annotated[ - Union[str, List[str]], + str, OpenBBCustomParameter( description="Symbol to get data for. Can use CURR1-CURR2 or CURR1CURR2 format." ), @@ -136,7 +136,7 @@ class ROUTER_currency_price(Container): "provider": provider, }, standard_params={ - "symbol": ",".join(symbol) if isinstance(symbol, list) else symbol, + "symbol": symbol, "start_date": start_date, "end_date": end_date, }, diff --git a/openbb_platform/openbb/package/derivatives_options.py b/openbb_platform/openbb/package/derivatives_options.py index 43077b374bb..84e5eb1d26a 100644 --- a/openbb_platform/openbb/package/derivatives_options.py +++ b/openbb_platform/openbb/package/derivatives_options.py @@ -1,6 +1,6 @@ ### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ### -from typing import List, Literal, Optional, Union +from typing import Literal, Optional from openbb_core.app.model.custom_parameter import OpenBBCustomParameter from openbb_core.app.model.obbject import OBBject @@ -23,8 +23,7 @@ class ROUTER_derivatives_options(Container): def chains( self, symbol: Annotated[ - Union[str, List[str]], - OpenBBCustomParameter(description="Symbol to get data for."), + str, OpenBBCustomParameter(description="Symbol to get data for.") ], provider: Optional[Literal["intrinio"]] = None, **kwargs @@ -160,7 +159,7 @@ class ROUTER_derivatives_options(Container): "provider": provider, }, standard_params={ - "symbol": ",".join(symbol) if isinstance(symbol, list) else symbol, + "symbol": symbol, }, extra_params=kwargs, ) @@ -170,7 +169,7 @@ class ROUTER_derivatives_options(Container): def unusual( self, symbol: Annotated[ - Union[str, None, List[str]], + Optional[str], OpenBBCustomParameter( description="Symbol to get data for. (the underlying symbol)" ), @@ -243,7 +242,7 @@ class ROUTER_derivatives_options(Container): "provider": provider, }, standard_params={ - "symbol": ",".join(symbol) if isinstance(symbol, list) else symbol, + "symbol": symbol, }, extra_params=kwargs, ) diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py index 95cb79a8422..df8b70aae7a 100644 --- a/openbb_platform/openbb/package/economy.py +++ b/openbb_platform/openbb/package/economy.py @@ -1,7 +1,7 @@ ### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ### import datetime -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from openbb_core.app.model.custom_parameter import OpenBBCustomParameter from openbb_core.app.model.obbject import OBBject @@ -221,61 +221,8 @@ class ROUTER_economy(Container): @validate def cpi( self, - countries: Annotated[ - List[ - Literal[ - "australia", - "austria", - "belgium", - "brazil", - "bulgaria", - "canada", - "chile", - "china", - "croatia", - "cyprus", - "czech_republic", - "denmark", - "estonia", - "euro_area", - "finland", - "france", - "germany", - "greece", - "hungary", - "iceland", - "india", - "indonesia", - "ireland", - "israel", - "italy", - "japan", - "korea", - "latvia", - "lithuania", - "luxembourg", - "malta", - "mexico", - "netherlands", - "new_zealand", - "norway", - "poland", - "portugal", - "romania", - "russian_federation", - "slovak_republic", - "slovakia", - "slovenia", - "south_africa", - "spain", - "sweden", - "switzerland", - "turkey", - "united_kingdom", - "united_states", - ] - ], - OpenBBCustomParameter(description="The country or countries to get data."), + country: Annotated[ + str, OpenBBCustomParameter(description="The country to get data.") ], units: Annotated[ Literal["growth_previous", "growth_same", "index_2015"], @@ -314,8 +261,8 @@ class ROUTER_economy(Container): Parameters ---------- - countries : List[Literal['australia', 'austria', 'belgium', 'brazil', 'bulgar... - The country or countries to get data. + country : str + The country to get data. units : Literal['growth_previous', 'growth_same', 'index_2015'] The unit of measurement for the data. Options: @@ -360,7 +307,7 @@ class ROUTER_economy(Container): Example ------- >>> from openbb import obb - >>> obb.economy.cpi(countries=['portugal', 'spain'], units="growth_same", frequency="monthly", harmonized=False) + >>> obb.economy.cpi(country="portugal", units="growth_same", frequency="monthly", harmonized=False) """ # noqa: E501 return self._run( @@ -370,7 +317,7 @@ class ROUTER_economy(Container): "provider": provider, }, standard_params={ - "countries": countries, + "country": country, "units": units, "frequency": frequency, "harmonized": harmonized, @@ -497,8 +444,7 @@ class ROUTER_economy(Container): def fred_series( self, symbol: Annotated[ - Union[str, List[str]], - OpenBBCustomParameter(description="Symbol to get data for."), + str, OpenBBCustomParameter(description="Symbol to get data for.") ], start_date: Annotated[ Union[datetime.date, None, str], @@ -614,7 +560,7 @@ class ROUTER_economy(Container): "provider": provider, }, standard_params={ - "symbol": ",".join(symbol) if isinstance(symbol, list) else symbol, + "symbol": symbol, "start_date": start_date, "end_date": end_date, "limit": limit, diff --git a/openbb_platform/openbb/package/equity.py b/openbb_platform/openbb/package/equity.py index 4b9d5c6e744..185c13df4fd 100644 --- a/openbb_platform/openbb/package/equity.py +++ b/openbb_platform/openbb/package/equity.py @@ -1,6 +1,6 @@ ### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ### -from typing import List, Literal, Optional, Union +from typing import Literal, Optional from openbb_core.app.model.custom_parameter import OpenBBCustomParameter from openbb_core.app.model.obbject import OBBject @@ -221,8 +221,7 @@ class ROUTER_equity(Container): def profile( self, symbol: Annotated[ - Union[str, List[str]], - OpenBBCustomParameter(description="Symbol to get data for."), + str, OpenBBCustomParameter(description="Symbol to get data for.") ], provider: Optional[Literal["fmp", "intrinio", "yfinance"]] = None, **kwargs @@ -389,7 +388,7 @@ class ROUTER_equity(Container): "provider": provider, }, standard_params={ - "symbol": ",".join(symbol) if isinstance(symbol, list) else symbol, + "symbol": symbol, }, extra_params=kwargs, ) diff --git a/openbb_platform/openbb/package/equity_calendar.py b/openbb_platform/openbb/package/equity_calendar.py index 1985dbcd6d4..6d8c5cbd085 100644 --- a/openbb_platform/openbb/package/equity_calendar.py +++ b/openbb_platform/openbb/package/equity_calendar.py @@ -1,7 +1,7 @@ ### THIS FILE IS AUTO-GENERATED. DO NOT EDIT. ### import datetime -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from openbb_core.app.model.custom_parameter import OpenBBCustomParameter from openbb_core.app.model.obbject import OBBject @@ -202,8 +202,7 @@ class ROUTER_equity_calendar(Container): def ipo( self, |