summaryrefslogtreecommitdiffstats
path: root/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py')
-rw-r--r--openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py238
1 files changed, 238 insertions, 0 deletions
diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
new file mode 100644
index 00000000000..51ebaecc78a
--- /dev/null
+++ b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
@@ -0,0 +1,238 @@
+"""OECD CPI Data."""
+
+# pylint: disable=unused-argument
+
+from datetime import date
+from typing import Any, Dict, List, Literal, Optional
+
+from openbb_core.provider.abstract.fetcher import Fetcher
+from openbb_core.provider.standard_models.consumer_price_index import (
+ ConsumerPriceIndexData,
+ ConsumerPriceIndexQueryParams,
+)
+from openbb_core.provider.utils.helpers import check_item
+from openbb_oecd.utils import helpers
+from openbb_oecd.utils.constants import (
+ CODE_TO_COUNTRY_CPI,
+ COUNTRY_TO_CODE_CPI,
+)
+from pydantic import Field, field_validator
+from requests.exceptions import HTTPError
+
+countries = tuple(CODE_TO_COUNTRY_CPI.values()) + ("all",)
+CountriesList = list(countries) # type: ignore
+
+expenditure_dict_rev = {
+ "_T": "total",
+ "CP01": "food_non_alcoholic_beverages",
+ "CP02": "alcoholic_beverages_tobacco_narcotics",
+ "CP03": "clothing_footwear",
+ "CP04": "housing_water_electricity_gas",
+ "CP05": "furniture_household_equipment",
+ "CP06": "health",
+ "CP07": "transport",
+ "CP08": "communication",
+ "CP09": "recreation_culture",
+ "CP10": "education",
+ "CP11": "restaurants_hotels",
+ "CP12": "miscellaneous_goods_services",
+ "CP045_0722": "energy",
+ "GD": "goods",
+ "CP041T043": "housing",
+ "CP041T043X042": "housing_excluding_rentals" "",
+ "_TXCP01_NRG": "all_non_food_non_energy",
+ "SERVXCP041_042_0432": "services_less_housing",
+ "SERVXCP041_0432": "services_less_house_excl_rentals",
+ "SERV": "services",
+ "_TXNRG_01_02": "overall_excl_energy_food_alcohol_tobacco",
+ "CPRES": "residuals",
+ "CP0722": "fuels_lubricants_personal",
+ "CP041": "actual_rentals",
+ "CP042": "imputed_rentals",
+ "CP043": "maintenance_repair_dwelling",
+ "CP044": "water_supply_other_services",
+ "CP045": "electricity_gas_other_fuels",
+}
+expenditure_dict = {v: k for k, v in expenditure_dict_rev.items()}
+expenditures = tuple(expenditure_dict.keys()) + ("all",)
+ExpenditureChoices = Literal[
+ "total",
+ "all",
+ "actual_rentals",
+ "alcoholic_beverages_tobacco_narcotics",
+ "all_non_food_non_energy",
+ "clothing_footwear",
+ "communication",
+ "education",
+ "electricity_gas_other_fuels",
+ "energy",
+ "overall_excl_energy_food_alcohol_tobacco",
+ "food_non_alcoholic_beverages",
+ "fuels_lubricants_personal",
+ "furniture_household_equipment",
+ "goods",
+ "housing",
+ "housing_excluding_rentals",
+ "housing_water_electricity_gas",
+ "health",
+ "imputed_rentals",
+ "maintenance_repair_dwelling",
+ "miscellaneous_goods_services",
+ "recreation_culture",
+ "residuals",
+ "restaurants_hotels",
+ "services_less_housing",
+ "services_less_house_excl_rentals",
+ "services",
+ "transport",
+ "water_supply_other_services",
+]
+
+
+class OECDCPIQueryParams(ConsumerPriceIndexQueryParams):
+ """OECD CPI Query.
+
+ Source: https://data-explorer.oecd.org/?lc=en
+ """
+
+ __json_schema_extra__ = {"country": ["multiple_items_allowed"]}
+
+ country: str = Field(
+ description="Country to get CPI for. This is the list of OECD supported countries",
+ default="united_states",
+ choices=CountriesList,
+ )
+ expenditure: ExpenditureChoices = Field(
+ description="Expenditure component of CPI.",
+ default="total",
+ json_schema_extra={"choices": list(expenditures)},
+ )
+
+ @field_validator("country", mode="before", check_fields=False)
+ def validate_country(cls, c: str): # pylint: disable=E0213
+ """Validate country."""
+ result: List = []
+ values = c.replace(" ", "_").split(",")
+ for v in values:
+ check_item(v.lower(), CountriesList)
+ result.append(v.lower())
+ return ",".join(result)
+
+
+class OECDCPIData(ConsumerPriceIndexData):
+ """OECD CPI Data."""
+
+ expenditure: str = Field(description="Expenditure component of CPI.")
+
+
+class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]):
+ """OECD CPI Fetcher."""
+
+ @staticmethod
+ def transform_query(params: Dict[str, Any]) -> OECDCPIQueryParams:
+ """Transform the query."""
+ transformed_params = params.copy()
+ if transformed_params.get("start_date") is None:
+ transformed_params["start_date"] = date(1950, 1, 1)
+ if transformed_params.get("end_date") is None:
+ transformed_params["end_date"] = date(date.today().year, 12, 31)
+ if transformed_params.get("country") is None:
+ transformed_params["country"] = "united_states"
+
+ return OECDCPIQueryParams(**transformed_params)
+
+ @staticmethod
+ def extract_data(
+ query: OECDCPIQueryParams,
+ credentials: Optional[Dict[str, str]],
+ **kwargs: Any,
+ ) -> List[Dict]:
+ """Return the raw data from the OECD endpoint."""
+ methodology = "HICP" if query.harmonized is True else "N"
+ query.units = "mom" if query.transform == "period" else query.transform
+ query.frequency = (
+ "monthly"
+ if query.harmonized is True and query.frequency == "quarter"
+ else query.frequency
+ )
+ frequency = query.frequency[0].upper()
+ units = {
+ "index": "IX",
+ "yoy": "PA",
+ "mom": "PC",
+ }[query.units]
+ expenditure = (
+ "" if query.expenditure == "all" else expenditure_dict[query.expenditure]
+ )
+
+ def country_string(input_str: str):
+ if input_str == "all":
+ return ""
+ countries = input_str.split(",")
+ return "+".join([COUNTRY_TO_CODE_CPI[country] for country in countries])
+
+ country = country_string(query.country)
+ # For caching, include this in the key
+ query_dict = {
+ k: v
+ for k, v in query.__dict__.items()
+ if k not in ["start_date", "end_date"]
+ }
+
+ url = (
+ f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_PRICES@DF_PRICES_ALL,1.0/"
+ f"{country}.{frequency}.{methodology}.CPI.{units}.{expenditure}.N."
+ )
+ try:
+ data = helpers.get_possibly_cached_data(
+ url, function="economy_cpi", query_dict=query_dict
+ )
+ except HTTPError:
+ raise ValueError("No data found for the given query.")
+ url_query = f"METHODOLOGY=='{methodology}' & UNIT_MEASURE=='{units}' & FREQ=='{frequency}'"
+
+ if country != "all":
+ if "+" in country:
+ countries = country.split("+")
+ country_conditions = " or ".join(
+ [f"REF_AREA=='{c}'" for c in countries]
+ )
+ url_query += f" & ({country_conditions})"
+ else:
+ url_query = url_query + f" & REF_AREA=='{country}'"
+ url_query = (
+ url_query + f" & EXPENDITURE=='{expenditure}'"
+ if query.expenditure != "all"
+ else url_query
+ )
+ # Filter down
+ data = (
+ data.query(url_query)
+ .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE", "EXPENDITURE"]]
+ .rename(
+ columns={
+ "REF_AREA": "country",
+ "TIME_PERIOD": "date",
+ "VALUE": "value",
+ "EXPENDITURE": "expenditure",
+ }
+ )
+ )
+ data["country"] = data["country"].map(CODE_TO_COUNTRY_CPI)
+ data["expenditure"] = data["expenditure"].map(expenditure_dict_rev)
+ data["date"] = data["date"].apply(helpers.oecd_date_to_python_date)
+ data = data[
+ (data["date"] <= query.end_date) & (data["date"] >= query.start_date)
+ ]
+ # Normalize the percent value.
+ if query.transform in ("yoy", "period"):
+ data["value"] = data["value"].astype(float) / 100
+
+ return data.fillna("N/A").replace("N/A", None).to_dict(orient="records")
+
+ @staticmethod
+ def transform_data(
+ query: OECDCPIQueryParams, data: List[Dict], **kwargs: Any
+ ) -> List[OECDCPIData]:
+ """Transform the data from the OECD endpoint."""
+ return [OECDCPIData.model_validate(d) for d in data]