diff options
author | James Maslek <jmaslek11@gmail.com> | 2024-03-12 17:02:24 -0400 |
---|---|---|
committer | James Maslek <jmaslek11@gmail.com> | 2024-03-12 17:02:24 -0400 |
commit | 2319164af3d5c5ead27867492253bd1145df7cb5 (patch) | |
tree | 79089fd3075fc1634584ab66614fa75941b81a7a | |
parent | 895b20a19582826785a7c10f0d0c1f0ebf56355e (diff) |
Correctly handle country "choices" and let oecd do multiple countries
3 files changed, 135 insertions, 42 deletions
diff --git a/openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py b/openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py index 539ad2b4f07..0fa3675d700 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py @@ -14,56 +14,48 @@ from openbb_core.provider.utils.descriptions import ( ) from openbb_core.provider.utils.helpers import check_item -CPI_COUNTRIES = [ - "australia", - "austria", - "belgium", - "brazil", - "bulgaria", - "canada", +CPI_STANDARD_COUNTRIES = [ + "israel", + "portugal", "chile", - "china", - "croatia", - "cyprus", + "finland", + "japan", "czech_republic", "denmark", - "estonia", - "euro_area", - "finland", - "france", - "germany", - "greece", - "hungary", - "iceland", - "india", + "poland", "indonesia", - "ireland", - "israel", "italy", - "japan", + "spain", "korea", + "iceland", + "slovak_republic", "latvia", + "turkey", + "hungary", + "united_kingdom", + "india", + "norway", + "australia", + "estonia", + "netherlands", + "germany", + "greece", + "china", "lithuania", + "united_states", "luxembourg", - "malta", + "france", + "sweden", + "switzerland", + "slovenia", "mexico", - "netherlands", "new_zealand", - "norway", - "poland", - "portugal", - "romania", - "russian_federation", - "slovak_republic", - "slovakia", - "slovenia", + "canada", + "austria", + "belgium", + "ireland", + "brazil", "south_africa", - "spain", - "sweden", - "switzerland", - "turkey", - "united_kingdom", - "united_states", ] @@ -76,7 +68,7 @@ class ConsumerPriceIndexQueryParams(QueryParams): country: str = Field( description=QUERY_DESCRIPTIONS.get("country"), default="united_states", - choices=CPI_COUNTRIES, # type: ignore + choices=CPI_STANDARD_COUNTRIES, # type: ignore ) units: Literal["index", "yoy", "mom"] = Field( description="Units to get CPI for. Either index, month over month or year over year. Defaults to year over year.", @@ -104,7 +96,7 @@ class ConsumerPriceIndexQueryParams(QueryParams): result = [] values = c.replace(" ", "_").split(",") for v in values: - check_item(v.lower(), CPI_COUNTRIES) + check_item(v.lower(), CPI_STANDARD_COUNTRIES) result.append(v.lower()) return ",".join(result) diff --git a/openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py b/openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py index cdb95018cce..64712ac15ac 100644 --- a/openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py +++ b/openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py @@ -7,15 +7,87 @@ from openbb_core.provider.standard_models.consumer_price_index import ( ConsumerPriceIndexData, ConsumerPriceIndexQueryParams, ) +from openbb_core.provider.utils.descriptions import ( + QUERY_DESCRIPTIONS, +) +from openbb_core.provider.utils.helpers import check_item from openbb_fred.utils.fred_base import Fred from openbb_fred.utils.fred_helpers import all_cpi_options +from pydantic import Field, field_validator + +CPI_COUNTRIES = [ + "australia", + "austria", + "belgium", + "brazil", + "bulgaria", + "canada", + "chile", + "china", + "croatia", + "cyprus", + "czech_republic", + "denmark", + "estonia", + "euro_area", + "finland", + "france", + "germany", + "greece", + "hungary", + "iceland", + "india", + "indonesia", + "ireland", + "israel", + "italy", + "japan", + "korea", + "latvia", + "lithuania", + "luxembourg", + "malta", + "mexico", + "netherlands", + "new_zealand", + "norway", + "poland", + "portugal", + "romania", + "russian_federation", + "slovak_republic", + "slovakia", + "slovenia", + "south_africa", + "spain", + "sweden", + "switzerland", + "turkey", + "united_kingdom", + "united_states", +] class FREDConsumerPriceIndexQueryParams(ConsumerPriceIndexQueryParams): """FRED Consumer Price Index Query.""" + country: str = Field( + description=QUERY_DESCRIPTIONS.get("country"), + default="united_states", + choices=CPI_COUNTRIES, # type: ignore + ) __json_schema_extra__ = {"country": ["multiple_items_allowed"]} + @field_validator("country", mode="before", check_fields=False) + def validate_country(cls, c: str): # pylint: disable=E0213 + """Validate country.""" + result = [] + values = c.replace(" ", "_").split(",") + for v in values: + check_item(v.lower(), CPI_COUNTRIES) + result.append(v.lower()) + return ",".join(result) + class FREDConsumerPriceIndexData(ConsumerPriceIndexData): """FRED Consumer Price Index Data.""" diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py index 185c270e450..408c065ae89 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py @@ -9,6 +9,7 @@ from openbb_core.provider.standard_models.consumer_price_index import ( ConsumerPriceIndexData, ConsumerPriceIndexQueryParams, ) +from openbb_core.provider.utils.helpers import check_item from openbb_oecd.utils import helpers from openbb_oecd.utils.constants import ( CODE_TO_COUNTRY_CPI, @@ -82,6 +83,18 @@ class OECDCPIQueryParams(ConsumerPriceIndexQueryParams): default="total", ) + @field_validator("country", mode="before", check_fields=False) + def validate_country(cls, c: str): # pylint: disable=E0213 + """Validate country.""" + result = [] + values = c.replace(" ", "_").split(",") + for v in values: + check_item(v.lower(), CountriesList) + result.append(v.lower()) + return ",".join(result) + + __json_schema_extra__ = {"country": ["multiple_items_allowed"]} + class OECDCPIData(ConsumerPriceIndexData): """OECD CPI Data.""" @@ -157,7 +170,14 @@ class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]): "" if query.expenditure == "all" else expenditure_dict[query.expenditure] ) seasonal_adjustment = "Y" if query.seasonal_adjustment else "N" - country = "" if query.country == "all" else COUNTRY_TO_CODE_CPI[query.country] + + def country_string(input_str: str): + if input_str == "all": + return "" + countries = input_str.split(",") + return "+".join([COUNTRY_TO_CODE_CPI[country] for country in countries]) + + country = country_string(query.country) # For caching, include this in the key query_dict = { k: v @@ -176,7 +196,16 @@ class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]): f"METHODOLOGY=='{methodology}' & UNIT_MEASURE=='{units}' & FREQ=='{frequency}' & " f"ADJUSTMENT=='{seasonal_adjustment}' " ) - url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query + + if country != "all": + if "+" in country: + countries = country.split("+") + country_conditions = " or ".join( + [f"REF_AREA=='{c}'" for c in countries] + ) + url_query += f" & ({country_conditions})" + else: + url_query = url_query + f" & REF_AREA=='{country}'" url_query = ( url_query + f" & EXPENDITURE=='{expenditure}'" if query.expenditure != "all" |