diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/provider/standard_models/cpi.py')
-rw-r--r-- | openbb_platform/core/openbb_core/provider/standard_models/cpi.py | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py index f43524063eb..24a94721864 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py @@ -1,6 +1,7 @@ """CPI Standard Model.""" + from datetime import date as dateType -from typing import List, Literal, Optional +from typing import Literal, Optional from dateutil import parser from pydantic import Field, field_validator @@ -11,8 +12,9 @@ from openbb_core.provider.utils.descriptions import ( DATA_DESCRIPTIONS, QUERY_DESCRIPTIONS, ) +from openbb_core.provider.utils.helpers import check_item -CPI_COUNTRIES = Literal[ +CPI_COUNTRIES = [ "australia", "austria", "belgium", @@ -72,9 +74,7 @@ CPI_FREQUENCY = Literal["monthly", "quarter", "annual"] class ConsumerPriceIndexQueryParams(QueryParams): """CPI Query.""" - countries: List[CPI_COUNTRIES] = Field( - description=QUERY_DESCRIPTIONS.get("countries") - ) + country: str = Field(description=QUERY_DESCRIPTIONS.get("country")) units: CPI_UNITS = Field( default="growth_same", description=QUERY_DESCRIPTIONS.get("units", "") @@ -102,6 +102,16 @@ class ConsumerPriceIndexQueryParams(QueryParams): default=None, description=QUERY_DESCRIPTIONS.get("end_date") ) + @field_validator("country", mode="before", check_fields=False) + def validate_country(cls, c: str): # pylint: disable=E0213 + """Validate country.""" + result = [] + values = c.split(",") + for v in values: + check_item(v.lower(), CPI_COUNTRIES) + result.append(v.lower()) + return ",".join(result) + class ConsumerPriceIndexData(Data): """CPI data.""" |