diff options
Diffstat (limited to 'openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py')
-rw-r--r-- | openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py | 181 |
1 files changed, 103 insertions, 78 deletions
diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py b/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py index b85098c69e0..2f91d735b5c 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py @@ -1,23 +1,46 @@ """OECD Unemployment Data.""" -import re -from datetime import date, timedelta -from typing import Any, Dict, List, Literal, Optional, Union +# pylint: disable=unused-argument + +from datetime import date +from io import StringIO +from typing import Any, Dict, List, Literal, Optional +from warnings import warn from openbb_core.provider.abstract.fetcher import Fetcher from openbb_core.provider.standard_models.unemployment import ( UnemploymentData, UnemploymentQueryParams, ) +from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS +from openbb_core.provider.utils.errors import EmptyDataError +from openbb_core.provider.utils.helpers import check_item, make_request from openbb_oecd.utils import helpers from openbb_oecd.utils.constants import ( CODE_TO_COUNTRY_UNEMPLOYMENT, COUNTRY_TO_CODE_UNEMPLOYMENT, ) +from pandas import read_csv from pydantic import Field, field_validator countries = tuple(CODE_TO_COUNTRY_UNEMPLOYMENT.values()) + ("all",) -CountriesLiteral = Literal[countries] # type: ignore +CountriesList = sorted(list(countries)) # type: ignore +AGES = [ + "total", + "15-24", + "25-54", + "55-64", + "15-64", + "15-74", +] +AgesLiteral = Literal[ + "total", + "15-24", + "25-54", + "55-64", + "15-64", + "15-74", +] class OECDUnemploymentQueryParams(UnemploymentQueryParams): @@ -26,61 +49,55 @@ class OECDUnemploymentQueryParams(UnemploymentQueryParams): Source: https://data-explorer.oecd.org/?lc=en """ - country: CountriesLiteral = Field( - description="Country to get GDP for.", default="united_states" + __json_schema_extra__ = {"country": ["multiple_items_allowed"]} + + country: str = Field( + description=QUERY_DESCRIPTIONS.get("country", ""), + default="united_states", + choices=CountriesList, ) sex: Literal["total", "male", "female"] = Field( - description="Sex to get unemployment for.", default="total" - ) - frequency: Literal["monthly", "quarterly", "annual"] = Field( - description="Frequency to get unemployment for.", default="monthly" + description="Sex to get unemployment for.", + default="total", + json_schema_extra={"choices": ["total", "male", "female"]}, ) - age: Literal["total", "15-24", "15-64", "25-54", "55-64"] = Field( + age: Literal[AgesLiteral] = Field( description="Age group to get unemployment for. Total indicates 15 years or over", default="total", + json_schema_extra={"choices": AGES}, ) seasonal_adjustment: bool = Field( description="Whether to get seasonally adjusted unemployment. Defaults to False.", default=False, ) + @field_validator("country", mode="before", check_fields=False) + @classmethod + def validate_country(cls, c): + """Validate country.""" + result: List = [] + values = c.replace(" ", "_").split(",") + for v in values: + if v.upper() in CODE_TO_COUNTRY_UNEMPLOYMENT: + result.append(CODE_TO_COUNTRY_UNEMPLOYMENT.get(v.upper())) + continue + try: + check_item(v.lower(), CountriesList) + except Exception as e: + if len(values) == 1: + raise e from e + else: + warn(f"Invalid country: {v}. Skipping...") + continue + result.append(v.lower()) + if result: + return ",".join(result) + raise ValueError(f"No valid country found. -> {values}") + class OECDUnemploymentData(UnemploymentData): """OECD Unemployment Data.""" - @field_validator("date", mode="before") - @classmethod - def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213 - """Validate value.""" - if isinstance(in_date, str): - # i.e 2022-Q1 - if re.match(r"\d{4}-Q[1-4]$", in_date): - year, quarter = in_date.split("-") - _year = int(year) - if quarter == "Q1": - return date(_year, 3, 31) - if quarter == "Q2": - return date(_year, 6, 30) - if quarter == "Q3": - return date(_year, 9, 30) - if quarter == "Q4": - return date(_year, 12, 31) - # Now match if it is monthly, i.e 2022-01 - elif re.match(r"\d{4}-\d{2}$", in_date): - year, month = map(int, in_date.split("-")) # type: ignore - if month == 12: - return date(year, month, 31) # type: ignore - next_month = date(year, month + 1, 1) # type: ignore - return date(next_month.year, next_month.month, 1) - timedelta(days=1) - # Now match if it is yearly, i.e 2022 - elif re.match(r"\d{4}$", in_date): - return date(int(in_date), 12, 31) - # If the input date is a year - if isinstance(in_date, int): - return date(in_date, 12, 31) - - return in_date - class OECDUnemploymentFetcher( Fetcher[OECDUnemploymentQueryParams, List[OECDUnemploymentData]] @@ -92,13 +109,16 @@ class OECDUnemploymentFetcher( """Transform the query.""" transformed_params = params.copy() if transformed_params["start_date"] is None: - transformed_params["start_date"] = date(1950, 1, 1) + transformed_params["start_date"] = ( + date(2010, 1, 1) + if transformed_params.get("country") == "all" + else date(1950, 1, 1) + ) if transformed_params["end_date"] is None: transformed_params["end_date"] = date(date.today().year, 12, 31) return OECDUnemploymentQueryParams(**transformed_params) - # pylint: disable=unused-argument @staticmethod def extract_data( query: OECDUnemploymentQueryParams, @@ -112,49 +132,54 @@ class OECDUnemploymentFetcher( "total": "Y_GE15", "15-24": "Y15T24", "15-64": "Y15T64", + "15-74": "Y15T74", "25-54": "Y25T54", "55-64": "Y55T64", }[query.age] seasonal_adjustment = "Y" if query.seasonal_adjustment else "N" - country = ( - "" - if query.country == "all" - else COUNTRY_TO_CODE_UNEMPLOYMENT[query.country] - ) - # For caching, include this in the key - query_dict = { - k: v - for k, v in query.__dict__.items() - if k not in ["start_date", "end_date"] - } + def country_string(input_str: str): + if input_str == "all": + return "" + countries = input_str.split(",") + return "+".join( + [COUNTRY_TO_CODE_UNEMPLOYMENT[country] for country in countries] + ) + + country = country_string(query.country) + start_date = query.start_date.strftime("%Y-%m") if query.start_date else "" + end_date = query.end_date.strftime("%Y-%m") if query.end_date else "" url = ( - f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_LFS@DF_IALFS_INDIC," - f"1.0/{country}.UNE_LF...{seasonal_adjustment}.{sex}.{age}..." + "https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_LFS@DF_IALFS_UNE_M,1.0/" + + f"{country}..._Z.{seasonal_adjustment}.{sex}.{age}..{frequency}" + + f"?startPeriod={start_date}&endPeriod={end_date}" + + "&dimensionAtObservation=TIME_PERIOD&detail=dataonly" ) - data = helpers.get_possibly_cached_data( - url, function="economy_unemployment", query_dict=query_dict + headers = {"Accept": "application/vnd.sdmx.data+csv; charset=utf-8"} + response = make_request(url, headers=headers, timeout=20) + if response.status_code != 200: + raise Exception(f"Error: {response.status_code}") + df = read_csv(StringIO(response.text)).get( + ["REF_AREA", "TIME_PERIOD", "OBS_VALUE"] ) - url_query = f"AGE=='{age}' & SEX=='{sex}' & FREQ=='{frequency}' & ADJUSTMENT=='{seasonal_adjustment}'" - url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query - # Filter down - data = ( - data.query(url_query) - .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]] - .rename( - columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"} - ) + if df.empty: + raise EmptyDataError() + df = df.rename( + columns={"REF_AREA": "country", "TIME_PERIOD": "date", "OBS_VALUE": "value"} ) - data["country"] = data["country"].map(CODE_TO_COUNTRY_UNEMPLOYMENT) - - data["date"] = data["date"].apply(helpers.oecd_date_to_python_date) - data = data[ - (data["date"] <= query.end_date) & (data["date"] >= query.start_date) - ] + df["value"] = df["value"].astype(float) / 100 + df["country"] = df["country"].map(CODE_TO_COUNTRY_UNEMPLOYMENT) + df["date"] = df["date"].apply(helpers.oecd_date_to_python_date) + df = ( + df.query("value.notnull()") + .set_index(["date", "country"]) + .sort_index() + .reset_index() + ) + df = df[(df["date"] <= query.end_date) & (df["date"] >= query.start_date)] - return data.to_dict(orient="records") + return df.to_dict(orient="records") - # pylint: disable=unused-argument @staticmethod def transform_data( query: OECDUnemploymentQueryParams, data: List[Dict], **kwargs: Any |