summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJames Maslek <jmaslek11@gmail.com>2024-03-12 16:06:05 -0400
committerJames Maslek <jmaslek11@gmail.com>2024-03-12 16:06:05 -0400
commitc1d518ab2daf9381a48213885ce0cd08e6a61986 (patch)
treeacdc63b5b5b7492c4e5cd21dddd4358e6b43b3ef
parent58d48dbe4e2c06fb8cc0f75f1adda4169f6f9b23 (diff)
Clean up the CPI oecd model with the choices and allow expensitures to be "all"
-rw-r--r--openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py45
1 files changed, 30 insertions, 15 deletions
diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
index dfb1a1810e3..b039d4d2acc 100644
--- a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
+++ b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
@@ -17,9 +17,9 @@ from openbb_oecd.utils.constants import (
from pydantic import Field, field_validator
countries = tuple(CODE_TO_COUNTRY_CPI.values()) + ("all",)
-CountriesLiteral = Literal[countries] # type: ignore
+CountriesList = list(countries) # type: ignore
-expendature_dict = {
+expenditure_dict_rev = {
"_T": "total",
"CP01": "food_non_alcoholic_beverages",
"CP02": "alcoholic_beverages_tobacco_narcotics",
@@ -35,7 +35,7 @@ expendature_dict = {
"CP12": "miscellaneous_goods_services",
"CP045_0722": "energy",
"GD": "goods",
- "CP014T043": "housing",
+ "CP041T043": "housing",
"CP041T043X042": "housing_excluding_rentals" "",
"_TXCP01_NRG": "all_non_food_non_energy",
"SERVXCP041_042_0432": "services_less_housing",
@@ -50,8 +50,9 @@ expendature_dict = {
"CP044": "water_supply_other_services",
"CP045": "electricity_gas_other_fuels",
}
-expendature_dict = {v: k for k, v in expendature_dict.items()}
-ExpendatureLiteral = Literal[tuple(expendature_dict.values())] # type: ignore
+expenditure_dict = {v: k for k, v in expenditure_dict_rev.items()}
+expenditures = tuple(expenditure_dict.keys()) + ("all",)
+ExpenditureLiteral = Literal[expenditures] # type: ignore
class OECDCPIQueryParams(ConsumerPriceIndexQueryParams):
@@ -60,8 +61,10 @@ class OECDCPIQueryParams(ConsumerPriceIndexQueryParams):
Source: https://data-explorer.oecd.org/?lc=en
"""
- country: CountriesLiteral = Field(
- description="Country to get CPI for.", default="united_states"
+ country: str = Field(
+ description="Country to get CPI for. This is the list of OECD supported countries",
+ default="united_states",
+ choices=CountriesList,
)
seasonal_adjustment: bool = Field(
@@ -74,8 +77,8 @@ class OECDCPIQueryParams(ConsumerPriceIndexQueryParams):
default="yoy",
)
- expendature: ExpendatureLiteral = Field(
- description="Expendature component of CPI.",
+ expenditure: ExpenditureLiteral = Field(
+ description="Expenditure component of CPI.",
default="total",
)
@@ -83,6 +86,8 @@ class OECDCPIQueryParams(ConsumerPriceIndexQueryParams):
class OECDCPIData(ConsumerPriceIndexData):
"""OECD CPI Data."""
+ expenditure: str = Field(description="Expenditure component of CPI.")
+
@field_validator("date", mode="before")
@classmethod
def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
@@ -148,7 +153,7 @@ class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]):
"yoy": "PA",
"mom": "PC",
}[query.units]
- expendature = expendature_dict[query.expendature]
+ expenditure = "" if query.expenditure == "all" else expenditure_dict[query.expenditure]
seasonal_adjustment = "Y" if query.seasonal_adjustment else "N"
country = "" if query.country == "all" else COUNTRY_TO_CODE_CPI[query.country]
# For caching, include this in the key
@@ -160,26 +165,36 @@ class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]):
url = (
f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_PRICES@DF_PRICES_ALL,1.0/"
- f"{country}.{frequency}.{methodology}.CPI.{units}.{expendature}.{seasonal_adjustment}."
+ f"{country}.{frequency}.{methodology}.CPI.{units}.{expenditure}.{seasonal_adjustment}."
)
data = helpers.get_possibly_cached_data(
url, function="economy_cpi", query_dict=query_dict
)
url_query = (
f"METHODOLOGY=='{methodology}' & UNIT_MEASURE=='{units}' & FREQ=='{frequency}' & "
- f"ADJUSTMENT=='{seasonal_adjustment}' & EXPENDITURE=='{expendature}'"
+ f"ADJUSTMENT=='{seasonal_adjustment}' "
)
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
-
+ url_query = (
+ url_query + f" & EXPENDITURE=='{expenditure}'"
+ if query.expenditure != "all"
+ else url_query
+ )
# Filter down
data = (
data.query(url_query)
- .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]]
+ .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE", "EXPENDITURE"]]
.rename(
- columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
+ columns={
+ "REF_AREA": "country",
+ "TIME_PERIOD": "date",
+ "VALUE": "value",
+ "EXPENDITURE": "expenditure",
+ }
)
)
data["country"] = data["country"].map(CODE_TO_COUNTRY_CPI)
+ data["expenditure"] = data["expenditure"].map(expenditure_dict_rev)
data["date"] = data["date"].apply(helpers.oecd_date_to_python_date)
data = data[