summaryrefslogtreecommitdiffstats
path: root/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py')
-rw-r--r--openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py181
1 files changed, 103 insertions, 78 deletions
diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py b/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py
index b85098c69e0..2f91d735b5c 100644
--- a/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py
+++ b/openbb_platform/providers/oecd/openbb_oecd/models/unemployment.py
@@ -1,23 +1,46 @@
"""OECD Unemployment Data."""
-import re
-from datetime import date, timedelta
-from typing import Any, Dict, List, Literal, Optional, Union
+# pylint: disable=unused-argument
+
+from datetime import date
+from io import StringIO
+from typing import Any, Dict, List, Literal, Optional
+from warnings import warn
from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.unemployment import (
UnemploymentData,
UnemploymentQueryParams,
)
+from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS
+from openbb_core.provider.utils.errors import EmptyDataError
+from openbb_core.provider.utils.helpers import check_item, make_request
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import (
CODE_TO_COUNTRY_UNEMPLOYMENT,
COUNTRY_TO_CODE_UNEMPLOYMENT,
)
+from pandas import read_csv
from pydantic import Field, field_validator
countries = tuple(CODE_TO_COUNTRY_UNEMPLOYMENT.values()) + ("all",)
-CountriesLiteral = Literal[countries] # type: ignore
+CountriesList = sorted(list(countries)) # type: ignore
+AGES = [
+ "total",
+ "15-24",
+ "25-54",
+ "55-64",
+ "15-64",
+ "15-74",
+]
+AgesLiteral = Literal[
+ "total",
+ "15-24",
+ "25-54",
+ "55-64",
+ "15-64",
+ "15-74",
+]
class OECDUnemploymentQueryParams(UnemploymentQueryParams):
@@ -26,61 +49,55 @@ class OECDUnemploymentQueryParams(UnemploymentQueryParams):
Source: https://data-explorer.oecd.org/?lc=en
"""
- country: CountriesLiteral = Field(
- description="Country to get GDP for.", default="united_states"
+ __json_schema_extra__ = {"country": ["multiple_items_allowed"]}
+
+ country: str = Field(
+ description=QUERY_DESCRIPTIONS.get("country", ""),
+ default="united_states",
+ choices=CountriesList,
)
sex: Literal["total", "male", "female"] = Field(
- description="Sex to get unemployment for.", default="total"
- )
- frequency: Literal["monthly", "quarterly", "annual"] = Field(
- description="Frequency to get unemployment for.", default="monthly"
+ description="Sex to get unemployment for.",
+ default="total",
+ json_schema_extra={"choices": ["total", "male", "female"]},
)
- age: Literal["total", "15-24", "15-64", "25-54", "55-64"] = Field(
+ age: Literal[AgesLiteral] = Field(
description="Age group to get unemployment for. Total indicates 15 years or over",
default="total",
+ json_schema_extra={"choices": AGES},
)
seasonal_adjustment: bool = Field(
description="Whether to get seasonally adjusted unemployment. Defaults to False.",
default=False,
)
+ @field_validator("country", mode="before", check_fields=False)
+ @classmethod
+ def validate_country(cls, c):
+ """Validate country."""
+ result: List = []
+ values = c.replace(" ", "_").split(",")
+ for v in values:
+ if v.upper() in CODE_TO_COUNTRY_UNEMPLOYMENT:
+ result.append(CODE_TO_COUNTRY_UNEMPLOYMENT.get(v.upper()))
+ continue
+ try:
+ check_item(v.lower(), CountriesList)
+ except Exception as e:
+ if len(values) == 1:
+ raise e from e
+ else:
+ warn(f"Invalid country: {v}. Skipping...")
+ continue
+ result.append(v.lower())
+ if result:
+ return ",".join(result)
+ raise ValueError(f"No valid country found. -> {values}")
+
class OECDUnemploymentData(UnemploymentData):
"""OECD Unemployment Data."""
- @field_validator("date", mode="before")
- @classmethod
- def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
- """Validate value."""
- if isinstance(in_date, str):
- # i.e 2022-Q1
- if re.match(r"\d{4}-Q[1-4]$", in_date):
- year, quarter = in_date.split("-")
- _year = int(year)
- if quarter == "Q1":
- return date(_year, 3, 31)
- if quarter == "Q2":
- return date(_year, 6, 30)
- if quarter == "Q3":
- return date(_year, 9, 30)
- if quarter == "Q4":
- return date(_year, 12, 31)
- # Now match if it is monthly, i.e 2022-01
- elif re.match(r"\d{4}-\d{2}$", in_date):
- year, month = map(int, in_date.split("-")) # type: ignore
- if month == 12:
- return date(year, month, 31) # type: ignore
- next_month = date(year, month + 1, 1) # type: ignore
- return date(next_month.year, next_month.month, 1) - timedelta(days=1)
- # Now match if it is yearly, i.e 2022
- elif re.match(r"\d{4}$", in_date):
- return date(int(in_date), 12, 31)
- # If the input date is a year
- if isinstance(in_date, int):
- return date(in_date, 12, 31)
-
- return in_date
-
class OECDUnemploymentFetcher(
Fetcher[OECDUnemploymentQueryParams, List[OECDUnemploymentData]]
@@ -92,13 +109,16 @@ class OECDUnemploymentFetcher(
"""Transform the query."""
transformed_params = params.copy()
if transformed_params["start_date"] is None:
- transformed_params["start_date"] = date(1950, 1, 1)
+ transformed_params["start_date"] = (
+ date(2010, 1, 1)
+ if transformed_params.get("country") == "all"
+ else date(1950, 1, 1)
+ )
if transformed_params["end_date"] is None:
transformed_params["end_date"] = date(date.today().year, 12, 31)
return OECDUnemploymentQueryParams(**transformed_params)
- # pylint: disable=unused-argument
@staticmethod
def extract_data(
query: OECDUnemploymentQueryParams,
@@ -112,49 +132,54 @@ class OECDUnemploymentFetcher(
"total": "Y_GE15",
"15-24": "Y15T24",
"15-64": "Y15T64",
+ "15-74": "Y15T74",
"25-54": "Y25T54",
"55-64": "Y55T64",
}[query.age]
seasonal_adjustment = "Y" if query.seasonal_adjustment else "N"
- country = (
- ""
- if query.country == "all"
- else COUNTRY_TO_CODE_UNEMPLOYMENT[query.country]
- )
- # For caching, include this in the key
- query_dict = {
- k: v
- for k, v in query.__dict__.items()
- if k not in ["start_date", "end_date"]
- }
+ def country_string(input_str: str):
+ if input_str == "all":
+ return ""
+ countries = input_str.split(",")
+ return "+".join(
+ [COUNTRY_TO_CODE_UNEMPLOYMENT[country] for country in countries]
+ )
+
+ country = country_string(query.country)
+ start_date = query.start_date.strftime("%Y-%m") if query.start_date else ""
+ end_date = query.end_date.strftime("%Y-%m") if query.end_date else ""
url = (
- f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_LFS@DF_IALFS_INDIC,"
- f"1.0/{country}.UNE_LF...{seasonal_adjustment}.{sex}.{age}..."
+ "https://sdmx.oecd.org/public/rest/data/OECD.SDD.TPS,DSD_LFS@DF_IALFS_UNE_M,1.0/"
+ + f"{country}..._Z.{seasonal_adjustment}.{sex}.{age}..{frequency}"
+ + f"?startPeriod={start_date}&endPeriod={end_date}"
+ + "&dimensionAtObservation=TIME_PERIOD&detail=dataonly"
)
- data = helpers.get_possibly_cached_data(
- url, function="economy_unemployment", query_dict=query_dict
+ headers = {"Accept": "application/vnd.sdmx.data+csv; charset=utf-8"}
+ response = make_request(url, headers=headers, timeout=20)
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code}")
+ df = read_csv(StringIO(response.text)).get(
+ ["REF_AREA", "TIME_PERIOD", "OBS_VALUE"]
)
- url_query = f"AGE=='{age}' & SEX=='{sex}' & FREQ=='{frequency}' & ADJUSTMENT=='{seasonal_adjustment}'"
- url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
- # Filter down
- data = (
- data.query(url_query)
- .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]]
- .rename(
- columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
- )
+ if df.empty:
+ raise EmptyDataError()
+ df = df.rename(
+ columns={"REF_AREA": "country", "TIME_PERIOD": "date", "OBS_VALUE": "value"}
)
- data["country"] = data["country"].map(CODE_TO_COUNTRY_UNEMPLOYMENT)
-
- data["date"] = data["date"].apply(helpers.oecd_date_to_python_date)
- data = data[
- (data["date"] <= query.end_date) & (data["date"] >= query.start_date)
- ]
+ df["value"] = df["value"].astype(float) / 100
+ df["country"] = df["country"].map(CODE_TO_COUNTRY_UNEMPLOYMENT)
+ df["date"] = df["date"].apply(helpers.oecd_date_to_python_date)
+ df = (
+ df.query("value.notnull()")
+ .set_index(["date", "country"])
+ .sort_index()
+ .reset_index()
+ )
+ df = df[(df["date"] <= query.end_date) & (df["date"] >= query.start_date)]
- return data.to_dict(orient="records")
+ return df.to_dict(orient="records")
- # pylint: disable=unused-argument
@staticmethod
def transform_data(
query: OECDUnemploymentQueryParams, data: List[Dict], **kwargs: Any