summaryrefslogtreecommitdiffstats
path: root/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py
blob: f06dd92910170216e8d69c8d7f762164d4403aa4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""OECD Short Term Interest Rate Rate Data."""

# pylint: disable=unused-argument

import re
from datetime import date, timedelta
from typing import Any, Dict, List, Literal, Optional, Union

from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.short_term_interest_rate import (
    STIRData,
    STIRQueryParams,
)
from openbb_oecd.utils import helpers
from pydantic import Field, field_validator

stir_mapping = {
    "BEL": "belgium",
    "IRL": "ireland",
    "MEX": "mexico",
    "IDN": "indonesia",
    "NZL": "new_zealand",
    "JPN": "japan",
    "GBR": "united_kingdom",
    "FRA": "france",
    "CHL": "chile",
    "CAN": "canada",
    "NLD": "netherlands",
    "USA": "united_states",
    "KOR": "south_korea",
    "NOR": "norway",
    "AUT": "austria",
    "ZAF": "south_africa",
    "DNK": "denmark",
    "CHE": "switzerland",
    "HUN": "hungary",
    "LUX": "luxembourg",
    "AUS": "australia",
    "DEU": "germany",
    "SWE": "sweden",
    "ISL": "iceland",
    "TUR": "turkey",
    "GRC": "greece",
    "ISR": "israel",
    "CZE": "czech_republic",
    "LVA": "latvia",
    "SVN": "slovenia",
    "POL": "poland",
    "EST": "estonia",
    "LTU": "lithuania",
    "PRT": "portugal",
    "CRI": "costa_rica",
    "SVK": "slovakia",
    "FIN": "finland",
    "ESP": "spain",
    "RUS": "russia",
    "EA19": "euro_area19",
    "COL": "colombia",
    "ITA": "italy",
    "IND": "india",
    "CHN": "china",
    "HRV": "croatia",
}

countries = tuple(stir_mapping.values()) + ("all",)
CountriesLiteral = Literal[countries]  # type: ignore
country_to_code = {v: k for k, v in stir_mapping.items()}


class OECDSTIRQueryParams(STIRQueryParams):
    """OECD Short Term Interest Rate Query."""

    country: CountriesLiteral = Field(
        description="Country to get GDP for.", default="united_states"
    )

    frequency: Literal["monthly", "quarterly", "annual"] = Field(
        description="Frequency to get interest rate for for.", default="monthly"
    )


class OECDSTIRData(STIRData):
    """OECD Short Term Interest Rate Data."""

    @field_validator("date", mode="before")
    @classmethod
    def date_validate(cls, in_date: Union[date, str]):  # pylint: disable=E0213
        """Validate value."""
        if isinstance(in_date, str):
            # i.e 2022-Q1
            if re.match(r"\d{4}-Q[1-4]$", in_date):
                year, quarter = in_date.split("-")
                _year = int(year)
                if quarter == "Q1":
                    return date(_year, 3, 31)
                if quarter == "Q2":
                    return date(_year, 6, 30)
                if quarter == "Q3":
                    return date(_year, 9, 30)
                if quarter == "Q4":
                    return date(_year, 12, 31)
            # Now match if it is monthly, i.e 2022-01
            elif re.match(r"\d{4}-\d{2}$", in_date):
                year, month = map(int, in_date.split("-"))
                if month == 12:
                    return date(year, month, 31)
                next_month = date(year, month + 1, 1)
                return date(next_month.year, next_month.month, 1) - timedelta(days=1)
            # Now match if it is yearly, i.e 2022
            elif re.match(r"\d{4}$", in_date):
                return date(int(in_date), 12, 31)
        # If the input date is a year
        if isinstance(in_date, int):
            return date(in_date, 12, 31)

        return in_date


class OECDSTIRFetcher(Fetcher[OECDSTIRQueryParams, List[OECDSTIRData]]):
    """Transform the query, extract and transform the data from the OECD endpoints."""

    @staticmethod
    def transform_query(params: Dict[str, Any]) -> OECDSTIRQueryParams:
        """Transform the query."""
        transformed_params = params.copy()
        if transformed_params["start_date"] is None:
            transformed_params["start_date"] = date(1950, 1, 1)
        if transformed_params["end_date"] is None:
            transformed_params["end_date"] = date(date.today().year, 12, 31)

        return OECDSTIRQueryParams(**transformed_params)

    @staticmethod
    def extract_data(
        query: OECDSTIRQueryParams,  # pylint: disable=W0613
        credentials: Optional[Dict[str, str]],
        **kwargs: Any,
    ) -> Dict:
        """Return the raw data from the OECD endpoint."""
        frequency = query.frequency[0].upper()
        country = "" if query.country == "all" else country_to_code[query.country]
        url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IR3TIB...."
        data = helpers.get_possibly_cached_data(
            url, function="economy_short_term_interest_rate"
        )
        query = f"FREQ=='{frequency}'"
        query = query + f" & REF_AREA=='{country}'" if country else query
        # Filter down
        data = (
            data.query(query)
            .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]]
            .rename(
                columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
            )
        )
        data["country"] = data["country"].map(stir_mapping)
        data = data.fillna("N/A").replace("N/A", None)
        return data.to_dict(orient="records")

    @staticmethod
    def transform_data(
        query: OECDSTIRQueryParams, data: Dict, **kwargs: Any
    ) -> List[OECDSTIRData]:
        """Transform the data from the OECD endpoint."""
        return [OECDSTIRData.model_validate(d) for d in data]