summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJames Maslek <jmaslek11@gmail.com>2024-03-12 17:02:24 -0400
committerJames Maslek <jmaslek11@gmail.com>2024-03-12 17:02:24 -0400
commit2319164af3d5c5ead27867492253bd1145df7cb5 (patch)
tree79089fd3075fc1634584ab66614fa75941b81a7a
parent895b20a19582826785a7c10f0d0c1f0ebf56355e (diff)
Correctly handle country "choices" and let oecd do multiple countries
-rw-r--r--openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py72
-rw-r--r--openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py72
-rw-r--r--openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py33
3 files changed, 135 insertions, 42 deletions
diff --git a/openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py b/openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py
index 539ad2b4f07..0fa3675d700 100644
--- a/openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py
+++ b/openbb_platform/core/openbb_core/provider/standard_models/consumer_price_index.py
@@ -14,56 +14,48 @@ from openbb_core.provider.utils.descriptions import (
)
from openbb_core.provider.utils.helpers import check_item
-CPI_COUNTRIES = [
- "australia",
- "austria",
- "belgium",
- "brazil",
- "bulgaria",
- "canada",
+CPI_STANDARD_COUNTRIES = [
+ "israel",
+ "portugal",
"chile",
- "china",
- "croatia",
- "cyprus",
+ "finland",
+ "japan",
"czech_republic",
"denmark",
- "estonia",
- "euro_area",
- "finland",
- "france",
- "germany",
- "greece",
- "hungary",
- "iceland",
- "india",
+ "poland",
"indonesia",
- "ireland",
- "israel",
"italy",
- "japan",
+ "spain",
"korea",
+ "iceland",
+ "slovak_republic",
"latvia",
+ "turkey",
+ "hungary",
+ "united_kingdom",
+ "india",
+ "norway",
+ "australia",
+ "estonia",
+ "netherlands",
+ "germany",
+ "greece",
+ "china",
"lithuania",
+ "united_states",
"luxembourg",
- "malta",
+ "france",
+ "sweden",
+ "switzerland",
+ "slovenia",
"mexico",
- "netherlands",
"new_zealand",
- "norway",
- "poland",
- "portugal",
- "romania",
- "russian_federation",
- "slovak_republic",
- "slovakia",
- "slovenia",
+ "canada",
+ "austria",
+ "belgium",
+ "ireland",
+ "brazil",
"south_africa",
- "spain",
- "sweden",
- "switzerland",
- "turkey",
- "united_kingdom",
- "united_states",
]
@@ -76,7 +68,7 @@ class ConsumerPriceIndexQueryParams(QueryParams):
country: str = Field(
description=QUERY_DESCRIPTIONS.get("country"),
default="united_states",
- choices=CPI_COUNTRIES, # type: ignore
+ choices=CPI_STANDARD_COUNTRIES, # type: ignore
)
units: Literal["index", "yoy", "mom"] = Field(
description="Units to get CPI for. Either index, month over month or year over year. Defaults to year over year.",
@@ -104,7 +96,7 @@ class ConsumerPriceIndexQueryParams(QueryParams):
result = []
values = c.replace(" ", "_").split(",")
for v in values:
- check_item(v.lower(), CPI_COUNTRIES)
+ check_item(v.lower(), CPI_STANDARD_COUNTRIES)
result.append(v.lower())
return ",".join(result)
diff --git a/openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py b/openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py
index cdb95018cce..64712ac15ac 100644
--- a/openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py
+++ b/openbb_platform/providers/fred/openbb_fred/models/consumer_price_index.py
@@ -7,15 +7,87 @@ from openbb_core.provider.standard_models.consumer_price_index import (
ConsumerPriceIndexData,
ConsumerPriceIndexQueryParams,
)
+from openbb_core.provider.utils.descriptions import (
+ QUERY_DESCRIPTIONS,
+)
+from openbb_core.provider.utils.helpers import check_item
from openbb_fred.utils.fred_base import Fred
from openbb_fred.utils.fred_helpers import all_cpi_options
+from pydantic import Field, field_validator
+
+CPI_COUNTRIES = [
+ "australia",
+ "austria",
+ "belgium",
+ "brazil",
+ "bulgaria",
+ "canada",
+ "chile",
+ "china",
+ "croatia",
+ "cyprus",
+ "czech_republic",
+ "denmark",
+ "estonia",
+ "euro_area",
+ "finland",
+ "france",
+ "germany",
+ "greece",
+ "hungary",
+ "iceland",
+ "india",
+ "indonesia",
+ "ireland",
+ "israel",
+ "italy",
+ "japan",
+ "korea",
+ "latvia",
+ "lithuania",
+ "luxembourg",
+ "malta",
+ "mexico",
+ "netherlands",
+ "new_zealand",
+ "norway",
+ "poland",
+ "portugal",
+ "romania",
+ "russian_federation",
+ "slovak_republic",
+ "slovakia",
+ "slovenia",
+ "south_africa",
+ "spain",
+ "sweden",
+ "switzerland",
+ "turkey",
+ "united_kingdom",
+ "united_states",
+]
class FREDConsumerPriceIndexQueryParams(ConsumerPriceIndexQueryParams):
"""FRED Consumer Price Index Query."""
+ country: str = Field(
+ description=QUERY_DESCRIPTIONS.get("country"),
+ default="united_states",
+ choices=CPI_COUNTRIES, # type: ignore
+ )
__json_schema_extra__ = {"country": ["multiple_items_allowed"]}
+ @field_validator("country", mode="before", check_fields=False)
+ def validate_country(cls, c: str): # pylint: disable=E0213
+ """Validate country."""
+ result = []
+ values = c.replace(" ", "_").split(",")
+ for v in values:
+ check_item(v.lower(), CPI_COUNTRIES)
+ result.append(v.lower())
+ return ",".join(result)
+
class FREDConsumerPriceIndexData(ConsumerPriceIndexData):
"""FRED Consumer Price Index Data."""
diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
index 185c270e450..408c065ae89 100644
--- a/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
+++ b/openbb_platform/providers/oecd/openbb_oecd/models/consumer_price_index.py
@@ -9,6 +9,7 @@ from openbb_core.provider.standard_models.consumer_price_index import (
ConsumerPriceIndexData,
ConsumerPriceIndexQueryParams,
)
+from openbb_core.provider.utils.helpers import check_item
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import (
CODE_TO_COUNTRY_CPI,
@@ -82,6 +83,18 @@ class OECDCPIQueryParams(ConsumerPriceIndexQueryParams):
default="total",
)
+ @field_validator("country", mode="before", check_fields=False)
+ def validate_country(cls, c: str): # pylint: disable=E0213
+ """Validate country."""
+ result = []
+ values = c.replace(" ", "_").split(",")
+ for v in values:
+ check_item(v.lower(), CountriesList)
+ result.append(v.lower())
+ return ",".join(result)
+
+ __json_schema_extra__ = {"country": ["multiple_items_allowed"]}
+
class OECDCPIData(ConsumerPriceIndexData):
"""OECD CPI Data."""
@@ -157,7 +170,14 @@ class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]):
"" if query.expenditure == "all" else expenditure_dict[query.expenditure]
)
seasonal_adjustment = "Y" if query.seasonal_adjustment else "N"
- country = "" if query.country == "all" else COUNTRY_TO_CODE_CPI[query.country]
+
+ def country_string(input_str: str):
+ if input_str == "all":
+ return ""
+ countries = input_str.split(",")
+ return "+".join([COUNTRY_TO_CODE_CPI[country] for country in countries])
+
+ country = country_string(query.country)
# For caching, include this in the key
query_dict = {
k: v
@@ -176,7 +196,16 @@ class OECDCPIFetcher(Fetcher[OECDCPIQueryParams, List[OECDCPIData]]):
f"METHODOLOGY=='{methodology}' & UNIT_MEASURE=='{units}' & FREQ=='{frequency}' & "
f"ADJUSTMENT=='{seasonal_adjustment}' "
)
- url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
+
+ if country != "all":
+ if "+" in country:
+ countries = country.split("+")
+ country_conditions = " or ".join(
+ [f"REF_AREA=='{c}'" for c in countries]
+ )
+ url_query += f" & ({country_conditions})"
+ else:
+ url_query = url_query + f" & REF_AREA=='{country}'"
url_query = (
url_query + f" & EXPENDITURE=='{expenditure}'"
if query.expenditure != "all"