summaryrefslogtreecommitdiffstats
path: root/openbb_platform/providers/econdb/openbb_econdb/models/yield_curve.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/providers/econdb/openbb_econdb/models/yield_curve.py')
-rw-r--r--openbb_platform/providers/econdb/openbb_econdb/models/yield_curve.py195
1 files changed, 195 insertions, 0 deletions
diff --git a/openbb_platform/providers/econdb/openbb_econdb/models/yield_curve.py b/openbb_platform/providers/econdb/openbb_econdb/models/yield_curve.py
new file mode 100644
index 00000000000..3fa44526705
--- /dev/null
+++ b/openbb_platform/providers/econdb/openbb_econdb/models/yield_curve.py
@@ -0,0 +1,195 @@
+"""EconDB Yield Curve."""
+
+# pylint: disable=unused-argument
+
+from datetime import datetime
+from typing import Any, Dict, List, Optional, Union
+
+from openbb_core.provider.abstract.annotated_result import AnnotatedResult
+from openbb_core.provider.abstract.fetcher import Fetcher
+from openbb_core.provider.standard_models.yield_curve import (
+ YieldCurveData,
+ YieldCurveQueryParams,
+)
+from openbb_core.provider.utils.descriptions import QUERY_DESCRIPTIONS
+from openbb_core.provider.utils.errors import EmptyDataError
+from openbb_econdb.utils import helpers
+from openbb_econdb.utils.yield_curves import COUNTRIES, COUNTRIES_DICT, COUNTRIES_LIST
+from pandas import Categorical, DataFrame, DatetimeIndex
+from pydantic import Field, field_validator
+
+
+class EconDbYieldCurveQueryParams(YieldCurveQueryParams):
+ """EconDB Economic Indicators Query."""
+
+ __json_schema_extra__ = {
+ "date": ["multiple_items_allowed"],
+ }
+ country: COUNTRIES = Field(
+ default="united_states",
+ description=QUERY_DESCRIPTIONS.get("country", "")
+ + " New Zealand, Mexico, Singapore, and Thailand have only monthly data."
+ + " The nearest date to the requested one will be used.",
+ json_schema_extra={"choices": COUNTRIES_LIST},
+ )
+ use_cache: bool = Field(
+ default=True,
+ description="If true, cache the request for four hours.",
+ )
+
+ @field_validator("country", mode="before", check_fields=False)
+ @classmethod
+ def validate_country(cls, v) -> str:
+ """Validate the country."""
+ if v is None:
+ return "united_states"
+ if v not in COUNTRIES_DICT:
+ raise ValueError(f"Country must be one of {COUNTRIES_DICT}")
+ return v
+
+
+class EconDbYieldCurveData(YieldCurveData):
+ """EconDB Yield Curve Data."""
+
+
+class EconDbYieldCurveFetcher(
+ Fetcher[EconDbYieldCurveQueryParams, List[EconDbYieldCurveData]]
+):
+ """EconDB Yield Curve Fetcher."""
+
+ require_credentials = False
+
+ @staticmethod
+ def transform_query(params: Dict[str, Any]) -> EconDbYieldCurveQueryParams:
+ """Transform the query parameters."""
+ if not params.get("date"):
+ params["date"] = datetime.now().strftime("%Y-%m-%d")
+ return EconDbYieldCurveQueryParams(**params)
+
+ @staticmethod
+ async def aextract_data( # pylint: disable=R0914.R0912,R0915
+ query: EconDbYieldCurveQueryParams,
+ credentials: Optional[Dict[str, str]],
+ **kwargs: Any,
+ ) -> List[Dict]:
+ """Extract the data."""
+ token = credentials.get("econdb_api_key") if credentials else "" # type: ignore
+ # Attempt to create a temporary token if one is not supplied.
+ if not token:
+ token = await helpers.create_token(use_cache=query.use_cache)
+ credentials.update({"econdb_api_key": token}) # type: ignore
+ base_url = "https://www.econdb.com/api/series/?ticker="
+ symbols = list(COUNTRIES_DICT[query.country].keys())
+ url = (
+ base_url
+ + f"%5B{','.join(symbols)}%5D&page_size=50&format=json&token={token}"
+ )
+ data: List[Dict] = []
+ response: Union[Dict, List[Dict]] = {}
+ if query.use_cache is True:
+ cache_dir = f"{helpers.get_user_cache_directory()}/http/econdb_yield_curve"
+ async with helpers.CachedSession(
+ cache=helpers.SQLiteBackend(
+ cache_dir, expire_after=3600 * 4, ignored_params=["token"]
+ )
+ ) as session:
+ try:
+ response = await helpers.amake_request( # type: ignore
+ url,
+ session=session,
+ timeout=20,
+ **kwargs,
+ )
+ finally:
+ await session.close()
+ else:
+ response = await helpers.amake_request( # type: ignore
+ url, timeout=20, **kwargs
+ )
+ if not response:
+ raise RuntimeError("Error: No data was returned.")
+ data = response.get("results") # type: ignore
+ if not data:
+ raise EmptyDataError("The response was returned empty.")
+ return data
+
+ @staticmethod
+ def transform_data(
+ query: EconDbYieldCurveQueryParams,
+ data: List[Dict],
+ **kwargs: Any,
+ ) -> AnnotatedResult[List[EconDbYieldCurveData]]:
+ """Transform the data."""
+ maturity_order = list(COUNTRIES_DICT[query.country].values())
+ dates = query.date.split(",") # type: ignore
+ dates_list = DatetimeIndex(dates)
+ new_data: Dict = {}
+ metadata: Dict = {}
+ # Unpack the data for each maturity.
+ for item in data:
+ ticker = item.get("ticker")
+ maturity = COUNTRIES_DICT[query.country].get(ticker)
+ dataset = item.get("dataset")
+ additional_metadata = item.get("additional_metadata")
+ units = "percent"
+ description = item.get("description")
+ dates = item["data"]["dates"]
+ rates = item["data"]["values"]
+ meta = {
+ "symbol": ticker,
+ "description": description,
+ "dataset": dataset,
+ "units": units,
+ "additional_metadata": additional_metadata,
+ }
+ metadata[maturity] = meta
+ for d, r in zip(dates, rates):
+ if maturity not in new_data:
+ new_data[maturity] = {}
+ new_data[maturity][d] = r
+
+ # Create a DataFrame from the data
+ df = DataFrame(new_data)
+
+ # Convert the index to a DatetimeIndex
+ df.index = DatetimeIndex(df.index)
+
+ # Find the nearest date in the DataFrame to each date in dates_list
+ def get_nearest_date(df, target_date):
+ differences = (df.index - target_date).to_series().abs()
+ nearest_date_index = differences.argmin()
+ nearest_date = df.index[nearest_date_index]
+ return nearest_date
+
+ nearest_dates = [get_nearest_date(df, date) for date in dates_list]
+
+ # Filter for only the nearest dates
+ df = df[df.index.isin(nearest_dates)]
+
+ # Flatten the DataFrame
+ flattened_data = df.reset_index().melt(
+ id_vars="index", var_name="maturity", value_name="rate"
+ )
+ flattened_data = flattened_data.rename(columns={"index": "date"}).sort_values(
+ "date"
+ )
+ flattened_data["maturity"] = Categorical(
+ flattened_data["maturity"], categories=maturity_order, ordered=True
+ )
+ flattened_data["rate"] = flattened_data["rate"].astype(float) / 100
+ flattened_data = flattened_data.sort_values(
+ by=["date", "maturity"]
+ ).reset_index(drop=True)
+ flattened_data.loc[:, "date"] = flattened_data["date"].dt.strftime("%Y-%m-%d")
+ records = flattened_data.copy()
+ records["rate"] = records["rate"].fillna("N/A").replace("N/A", None)
+ records = records.to_dict("records")
+ results: List[EconDbYieldCurveData] = [
+ EconDbYieldCurveData.model_validate(r) for r in records if r.get("rate")
+ ]
+ if not results:
+ raise EmptyDataError(
+ f"No data was found for the country, {query.country},"
+ f" and dates, {query.date}"
+ )
+ return AnnotatedResult(result=results, metadata=metadata)