diff options
author | James Maslek <jmaslek11@gmail.com> | 2024-03-12 16:06:05 -0400 |
---|---|---|
committer | James Maslek <jmaslek11@gmail.com> | 2024-03-12 16:06:05 -0400 |
commit | c1d518ab2daf9381a48213885ce0cd08e6a61986 (patch) | |
tree | acdc63b5b5b7492c4e5cd21dddd4358e6b43b3ef | |
parent | 58d48dbe4e2c06fb8cc0f75f1adda4169f6f9b23 (diff) |
Clean up the CPI oecd model with the choices and allow expensitures to be "all"
-rw-r--r-- | openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py | 45 |
1 files changed, 30 insertions, 15 deletions
diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py index dfb1a1810e3..b039d4d2acc 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py @@ -17,9 +17,9 @@ from openbb_oecd.utils.constants import ( from pydantic import Field, field_validator countries = tuple(CODE_TO_COUNTRY_CPI.values()) + ("all",) -CountriesLiteral = Literal[countries] # type: ignore +CountriesList = list(countries) # type: ignore -expendature_dict = { +expenditure_dict_rev = { "_T": "total", "CP01": "food_non_alcoholic_beverages", "CP02": "alcoholic_beverages_tobacco_narcotics", @@ -35,7 +35,7 @@ expendature_dict = { "CP12": "miscellaneous_goods_services", "CP045_0722": "energy", "GD": "goods", - "CP014T043": "housing", + "CP041T043": "housing", "CP041T043X042": "housing_excluding_rentals" "", "_TXCP01_NRG": "all_non_food_non_energy", "SERVXCP041_042_0432": "services_less_housing", @@ -50,8 +50,9 @@ expendature_dict = { "CP044": "water_supply_other_services", "CP045": "electricity_gas_other_fuels", } -expendature_dict = {v: k for k, v in expendature_dict.items()} -ExpendatureLiteral = Literal[tuple(expendature_dict.values())] # type: ignore +expenditure_dict = {v: k for k, v in expenditure_dict_rev.items()} +expenditures = tuple(expenditure_dict.keys()) + ("all",) +ExpenditureLiteral = Literal[expenditures] # type: ignore class OECDCPIQueryParams(ConsumerPriceIndexQueryParams): @@ -60,8 +61,10 @@ class OECDCPIQueryParams(ConsumerPriceIndexQueryParams): Source: https://data-explorer.oecd.org/?lc=en """ - country: CountriesLiteral = Field( - description="Country to get CPI for.", default="united_states" + country: str = Field( + description="Country to get CPI for. This is the list of OECD supported countries", + default="united_states", + choices=CountriesList, ) seasonal_adjustment: bool = Field( @@ -74,8 +77,8 @@ class OECDCPIQueryParams(ConsumerPriceIndexQueryParams): default="yoy", ) - expendature: ExpendatureLiteral = Field( - description="Expendature component of CPI.", + expenditure: ExpenditureLiteral = Field( + description="Expenditure component of CPI.", default="total", ) @@ -83,6 +86,8 @@ class OECDCPIQueryParams(ConsumerPriceIndexQueryParams): class OECDCPIData(ConsumerPriceIndexData): """OECD CPI Data.""" + expenditure: str = Field(description="Expenditure component of CPI.") + @field_validator("date", mode="before") @classmethod def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213 @@ -148,7 +153,7 @@ class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]): "yoy": "PA", "mom": "PC", }[query.units] - expendature = expendature_dict[query.expendature] + expenditure = "" if query.expenditure == "all" else expenditure_dict[query.expenditure] seasonal_adjustment = "Y" if query.seasonal_adjustment else "N" country = "" if query.country == "all" else COUNTRY_TO_CODE_CPI[query.country] # For caching, include this in the key @@ -160,26 +165,36 @@ class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]): url = ( f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_PRICES@DF_PRICES_ALL,1.0/" - f"{country}.{frequency}.{methodology}.CPI.{units}.{expendature}.{seasonal_adjustment}." + f"{country}.{frequency}.{methodology}.CPI.{units}.{expenditure}.{seasonal_adjustment}." ) data = helpers.get_possibly_cached_data( url, function="economy_cpi", query_dict=query_dict ) url_query = ( f"METHODOLOGY=='{methodology}' & UNIT_MEASURE=='{units}' & FREQ=='{frequency}' & " - f"ADJUSTMENT=='{seasonal_adjustment}' & EXPENDITURE=='{expendature}'" + f"ADJUSTMENT=='{seasonal_adjustment}' " ) url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query - + url_query = ( + url_query + f" & EXPENDITURE=='{expenditure}'" + if query.expenditure != "all" + else url_query + ) # Filter down data = ( data.query(url_query) - .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]] + .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE", "EXPENDITURE"]] .rename( - columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"} + columns={ + "REF_AREA": "country", + "TIME_PERIOD": "date", + "VALUE": "value", + "EXPENDITURE": "expenditure", + } ) ) data["country"] = data["country"].map(CODE_TO_COUNTRY_CPI) + data["expenditure"] = data["expenditure"].map(expenditure_dict_rev) data["date"] = data["date"].apply(helpers.oecd_date_to_python_date) data = data[ |