summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDanglewood <85772166+deeleeramone@users.noreply.github.com>2024-06-27 21:59:08 -0700
committerDanglewood <85772166+deeleeramone@users.noreply.github.com>2024-06-27 21:59:08 -0700
commit880822ea00e64c46b5537f45c09d3d9df1d78e6c (patch)
treee2f210f8199dcfe1b60fbffc7b8fbbe2e55f18d3
parentff6dc848d25d86068c48fd1310ef42d601114a03 (diff)
multiple dates allowed for yfinance
-rw-r--r--openbb_platform/core/openbb_core/provider/standard_models/futures_curve.py15
-rw-r--r--openbb_platform/providers/yfinance/openbb_yfinance/models/futures_curve.py49
-rw-r--r--openbb_platform/providers/yfinance/openbb_yfinance/utils/helpers.py343
3 files changed, 310 insertions, 97 deletions
diff --git a/openbb_platform/core/openbb_core/provider/standard_models/futures_curve.py b/openbb_platform/core/openbb_core/provider/standard_models/futures_curve.py
index 266de2e79e2..da8a29fbd64 100644
--- a/openbb_platform/core/openbb_core/provider/standard_models/futures_curve.py
+++ b/openbb_platform/core/openbb_core/provider/standard_models/futures_curve.py
@@ -1,7 +1,7 @@
"""Futures Curve Standard Model."""
from datetime import date as dateType
-from typing import Optional
+from typing import Optional, Union
from pydantic import Field, field_validator
@@ -17,14 +17,14 @@ class FuturesCurveQueryParams(QueryParams):
"""Futures Curve Query."""
symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
- date: Optional[dateType] = Field(
+ date: Optional[Union[dateType, str]] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("date", ""),
)
@field_validator("symbol", mode="before", check_fields=False)
@classmethod
- def to_upper(cls, v: str) -> str:
+ def to_upper(cls, v):
"""Convert field to uppercase."""
return v.upper()
@@ -32,7 +32,12 @@ class FuturesCurveQueryParams(QueryParams):
class FuturesCurveData(Data):
"""Futures Curve Data."""
+ date: Optional[dateType] = Field(
+ default=None, description=DATA_DESCRIPTIONS.get("date", "")
+ )
expiration: str = Field(description="Futures expiration month.")
- price: Optional[float] = Field(
- default=None, description=DATA_DESCRIPTIONS.get("close", "")
+ price: float = Field(
+ default=None,
+ description="The priec of the futures contract.",
+ json_schema_extra={"x-unit_measurement": "currency"},
)
diff --git a/openbb_platform/providers/yfinance/openbb_yfinance/models/futures_curve.py b/openbb_platform/providers/yfinance/openbb_yfinance/models/futures_curve.py
index 99307b0570a..ac35121ea2b 100644
--- a/openbb_platform/providers/yfinance/openbb_yfinance/models/futures_curve.py
+++ b/openbb_platform/providers/yfinance/openbb_yfinance/models/futures_curve.py
@@ -1,8 +1,7 @@
"""Yahoo Finance Futures Curve Model."""
-# ruff: noqa: SIM105
+# pylint: disable=unused-argument
-from datetime import datetime
from typing import Any, Dict, List, Optional
from openbb_core.provider.abstract.fetcher import Fetcher
@@ -11,21 +10,22 @@ from openbb_core.provider.standard_models.futures_curve import (
FuturesCurveQueryParams,
)
from openbb_core.provider.utils.errors import EmptyDataError
-from openbb_yfinance.utils.helpers import get_futures_curve
class YFinanceFuturesCurveQueryParams(FuturesCurveQueryParams):
"""Yahoo Finance Futures Curve Query.
- Source: https://finance.yahoo.com/crypto/
+ Source: https://finance.yahoo.com/
"""
+ __json_schema_extra__ = {
+ "date": {"multiple_items_allowed": True},
+ }
+
class YFinanceFuturesCurveData(FuturesCurveData):
"""Yahoo Finance Futures Curve Data."""
- __alias_dict__ = {"price": "Last Price"}
-
class YFinanceFuturesCurveFetcher(
Fetcher[
@@ -33,27 +33,26 @@ class YFinanceFuturesCurveFetcher(
List[YFinanceFuturesCurveData],
]
):
- """Transform the query, extract and transform the data from the Yahoo Finance endpoints."""
+ """YFiannce Futures Curve Fetcher."""
@staticmethod
def transform_query(params: Dict[str, Any]) -> YFinanceFuturesCurveQueryParams:
"""Transform the query."""
- transformed_params = params
-
- now = datetime.now().date()
- if params.get("date") is None:
- transformed_params["date"] = now
-
- return YFinanceFuturesCurveQueryParams(**transformed_params)
+ return YFinanceFuturesCurveQueryParams(**params)
@staticmethod
- def extract_data(
- query: YFinanceFuturesCurveQueryParams, # pylint: disable=unused-argument
+ async def aextract_data(
+ query: YFinanceFuturesCurveQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
- ) -> List[dict]:
- """Return the raw data from the Yahoo Finance endpoint."""
- data = get_futures_curve(query.symbol, query.date).to_dict(orient="records")
+ ) -> List[Dict]:
+ """Extract the data from Yahoo."""
+ # pylint: disable=import-outside-toplevel
+ from openbb_yfinance.utils.helpers import get_futures_curve
+
+ # TODO: Find a better way to do this.
+ data = await get_futures_curve(query.symbol, query.date) # type: ignore
+ data = data.to_dict(orient="records")
if not data:
raise EmptyDataError()
@@ -62,15 +61,9 @@ class YFinanceFuturesCurveFetcher(
@staticmethod
def transform_data(
- query: YFinanceFuturesCurveQueryParams, # pylint: disable=unused-argument
- data: dict,
+ query: YFinanceFuturesCurveQueryParams,
+ data: List[Dict],
**kwargs: Any,
) -> List[YFinanceFuturesCurveData]:
"""Transform the data to the standard format."""
- return [
- YFinanceFuturesCurveData(
- expiration=curve["expiration"],
- price=curve["Last Price"],
- )
- for curve in data
- ]
+ return [YFinanceFuturesCurveData.model_validate(d) for d in data]
diff --git a/openbb_platform/providers/yfinance/openbb_yfinance/utils/helpers.py b/openbb_platform/providers/yfinance/openbb_yfinance/utils/helpers.py
index 0965aefbd54..e36af38ca93 100644
--- a/openbb_platform/providers/yfinance/openbb_yfinance/utils/helpers.py
+++ b/openbb_platform/providers/yfinance/openbb_yfinance/utils/helpers.py
@@ -1,26 +1,130 @@
"""Yahoo Finance helpers module."""
-# pylint: disable=unused-argument
-from datetime import (
- date as dateType,
- datetime,
-)
-from pathlib import Path
-from typing import Any, Literal, Optional, Union
-
-import pandas as pd
-import yfinance as yf
-from dateutil.relativedelta import relativedelta
+# pylint: disable=unused-argument,too-many-arguments,too-many-branches,too-many-locals,too-many-statements
+
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
+
from openbb_core.provider.utils.errors import EmptyDataError
from openbb_yfinance.utils.references import INTERVALS, MONTHS, PERIODS
+if TYPE_CHECKING:
+ from datetime import date # noqa
+ from pandas import DataFrame
+
+
+MONTH_MAP = {
+ "F": "01",
+ "G": "02",
+ "H": "03",
+ "J": "04",
+ "K": "05",
+ "M": "06",
+ "N": "07",
+ "Q": "08",
+ "U": "09",
+ "V": "10",
+ "X": "11",
+ "Z": "12",
+}
+
+
+def get_expiration_month(symbol: str) -> str:
+ """Get the expiration month for a given symbol."""
+ month = symbol.split(".")[0][-3]
+ year = "20" + symbol.split(".")[0][-2:]
+ return f"{year}-{MONTH_MAP[month]}"
-def get_futures_data() -> pd.DataFrame:
+
+def get_futures_data() -> "DataFrame":
"""Return the dataframe of the futures csv file."""
- return pd.read_csv(Path(__file__).resolve().parent / "futures.csv")
+ # pylint: disable=import-outside-toplevel
+ from pathlib import Path # noqa
+ from pandas import read_csv # noqa
+
+ return read_csv(Path(__file__).resolve().parent / "futures.csv")
+
+
+def get_futures_symbols(symbol: str) -> List:
+ """Get the list of futures symbols from the continuation symbol."""
+ # pylint: disable=import-outside-toplevel
+ from yfinance.data import YfData
+
+ _symbol = symbol.upper() + "%3DF"
+ URL = f"https://query2.finance.yahoo.com/v10/finance/quoteSummary/{_symbol}"
+ params = {"modules": "futuresChain"}
+ response: Dict = YfData(session=None).get_raw_json(url=URL, params=params)
+ futures_symbols: List = []
+ if "quoteSummary" in response:
+ result = response["quoteSummary"].get("result", [])
+ if not result:
+ raise ValueError(f"No futures chain found for, {symbol}")
+ futures = result[0].get("futuresChain", {})
+ if futures:
+ futures_symbols = futures.get("futures", [])
+ return futures_symbols
+
+
+async def get_futures_quotes(symbols: List) -> "DataFrame":
+ """Get the current futures quotes for a list of symbols."""
+ # pylint: disable=import-outside-toplevel
+ import os # noqa
+ from contextlib import (
+ contextmanager,
+ redirect_stderr,
+ redirect_stdout,
+ suppress,
+ ) # noqa
+ from aiohttp import ClientError # noqa
+ from openbb_yfinance.models.equity_quote import YFinanceEquityQuoteFetcher # noqa
+ from pandas import DataFrame # noqa
+
+ @contextmanager
+ def suppress_all_output():
+ with open(os.devnull, "w") as devnull, redirect_stdout(
+ devnull
+ ), redirect_stderr(devnull):
+ yield
+
+ with suppress_all_output(), suppress(ClientError):
+ fetcher = YFinanceEquityQuoteFetcher()
+ data = await fetcher.fetch_data(
+ params={"symbol": ",".join(symbols)}, credentials={}
+ )
+
+ df = DataFrame([d.model_dump() for d in data]) # type: ignore
+ prices = df[["symbol", "bid", "ask", "prev_close"]].copy()
+ prices.loc[:, "price"] = round((prices.ask + prices.bid) / 2, 2)
+ prices.price = prices.price.fillna(prices.prev_close)
+ prices["expiration"] = [get_expiration_month(symbol) for symbol in prices.symbol]
+
+ return prices[["expiration", "price"]]
+
+
+async def get_historical_futures_prices(
+ symbols: List, start_date: "date", end_date: "date"
+):
+ """Get historical futures prices for the list of symbols."""
+ # pylint: disable=import-outside-toplevel
+ from openbb_yfinance.models.equity_historical import ( # noqa
+ YFinanceEquityHistoricalFetcher,
+ )
+
+ fetcher = YFinanceEquityHistoricalFetcher()
+
+ return await fetcher.fetch_data(
+ params={
+ "symbol": ",".join(symbols),
+ "start_date": start_date,
+ "end_date": end_date,
+ },
+ credentials={},
+ )
-def get_futures_curve(symbol: str, date: Optional[dateType]) -> pd.DataFrame:
+async def get_futures_curve( # pylint: disable=too-many-return-statements
+ symbol: str, date: Optional[Union[str, List]] = None
+) -> "DataFrame":
"""Get the futures curve for a given symbol.
Parameters
@@ -32,60 +136,167 @@ def get_futures_curve(symbol: str, date: Optional[dateType]) -> pd.DataFrame:
Returns
-------
- pd.DataFrame
+ DataFrame
DataFrame with futures curve
"""
- futures_data = get_futures_data()
- try:
- exchange = futures_data[futures_data["Ticker"] == symbol]["Exchange"].values[0]
- except IndexError:
- return pd.DataFrame({"Last Price": [], "expiration": []})
-
- today = datetime.today()
- futures_index = []
- futures_curve = []
- historical_curve = []
- i = 0
- empty_count = 0
- # Loop through until we find 12 consecutive empty months
- while empty_count < 12:
- future = today + relativedelta(months=i)
- future_symbol = (
- f"{symbol}{MONTHS[future.month]}{str(future.year)[-2:]}.{exchange}"
+ # pylint: disable=import-outside-toplevel
+ from datetime import date as dateType # noqa
+ from dateutil.relativedelta import relativedelta # noqa
+ from pandas import Categorical, DataFrame, DatetimeIndex, to_datetime # noqa
+
+ futures_symbols = get_futures_symbols(symbol)
+ today = datetime.today().date()
+ dates: List = []
+ if date:
+ if isinstance(date, dateType):
+ date = date.strftime("%Y-%m-%d")
+ if isinstance(date, list) and isinstance(date[0], dateType):
+ date = [d.strftime("%Y-%m-%d") for d in date]
+ dates = date.split(",") if isinstance(date, str) else date
+ dates = sorted([to_datetime(d).date() for d in dates])
+
+ if futures_symbols and (not date or len(dates) == 1 and dates[0] >= today):
+ futures_quotes = await get_futures_quotes(futures_symbols)
+ return futures_quotes
+
+ if dates and futures_symbols:
+ historical_futures_prices = await get_historical_futures_prices(
+ futures_symbols, dates[0], dates[-1]
+ )
+ df = DataFrame([d.model_dump() for d in historical_futures_prices]) # type: ignore
+ df = df.set_index("date").sort_index()
+ df.index = df.index.astype(str)
+ df.index = DatetimeIndex(df.index)
+ dates_list = DatetimeIndex(dates)
+ symbols = df.symbol.unique().tolist()
+ expiration_dict = {symbol: get_expiration_month(symbol) for symbol in symbols}
+ df = (
+ df.reset_index()
+ .pivot(columns="symbol", values="close", index="date") # type: ignore
+ .copy()
)
- data = yf.download(future_symbol, progress=False, ignore_tz=True, threads=False)
+ df = df.rename(columns=expiration_dict)
+ df.columns.name = "expiration"
- if data.empty:
- empty_count += 1
+ # Find the nearest date in the DataFrame to each date in dates_list
+ nearest_dates = [df.index.asof(date) for date in dates_list]
- else:
- empty_count = 0
- futures_index.append(future.strftime("%b-%Y"))
- futures_curve.append(data["Adj Close"].values[-1])
- if date is not None:
- historical_curve.append(
- data["Adj Close"].get(date.strftime("%Y-%m-%d"), None)
+ # Filter for only the nearest dates
+ df = df[df.index.isin(nearest_dates)]
+
+ df = df.fillna("N/A").replace("N/A", None)
+
+ # Flatten the DataFrame
+ flattened_data = df.reset_index().melt(
+ id_vars="date", var_name="expiration", value_name="price"
+ )
+ flattened_data = flattened_data.sort_values("date")
+ flattened_data["expiration"] = Categorical(
+ flattened_data["expiration"],
+ categories=sorted(list(expiration_dict.values())),
+ ordered=True,
+ )
+ flattened_data = flattened_data.sort_values(
+ by=["date", "expiration"]
+ ).reset_index(drop=True)
+ flattened_data.loc[:, "date"] = flattened_data["date"].dt.strftime("%Y-%m-%d")
+
+ return flattened_data
+
+ if not futures_symbols:
+ # pylint: disable=import-outside-toplevel
+ import os # noqa
+ from contextlib import contextmanager, redirect_stderr, redirect_stdout # noqa
+
+ futures_data = get_futures_data()
+ try:
+ exchange = futures_data[futures_data["Ticker"] == symbol][
+ "Exchange"
+ ].values[0]
+ except IndexError as exc:
+ raise ValueError(f"Symbol {symbol} was not found.") from exc
+
+ futures_index: List = []
+ futures_curve: List = []
+ futures_date: List = []
+ historical_curve: List = []
+ if dates:
+ dates = [d.strftime("%Y-%m-%d") for d in dates]
+ dates_list = DatetimeIndex(dates)
+
+ i = 0
+ empty_count = 0
+
+ @contextmanager
+ def suppress_all_output():
+ with open(os.devnull, "w") as devnull, redirect_stdout(
+ devnull
+ ), redirect_stderr(devnull):
+ yield
+
+ with suppress_all_output():
+ while empty_count < 12:
+ future = today + relativedelta(months=i)
+ future_symbol = (
+ f"{symbol}{MONTHS[future.month]}{str(future.year)[-2:]}.{exchange}"
)
+ data = yf_download(future_symbol)
+ if data.empty:
+ empty_count += 1
+ else:
+ empty_count = 0
+ if dates:
+ data = data.set_index("date").sort_index()
+ data.index = DatetimeIndex(data.index)
+ nearest_dates = [data.index.asof(date) for date in dates_list]
+ data = data[data.index.isin(nearest_dates)]
+ data.index = data.index.strftime("%Y-%m-%d")
+ for date in dates:
+ try:
+ historical_curve.append(data.loc[date, "close"])
+ futures_date.append(date)
+ futures_index.append(future.strftime("%Y-%m"))
+ except KeyError:
+ historical_curve.append(None)
+ else:
+ futures_index.append(future.strftime("%Y-%m"))
+ futures_curve.append(
+ data.query("close.notnull()")["close"].values[-1]
+ )
- i += 1
+ i += 1
- if not futures_index:
- return pd.DataFrame({"date": [], "Last Price": []})
+ if not futures_index:
+ raise EmptyDataError()
- if historical_curve:
- return pd.DataFrame(
- {"Last Price": historical_curve, "expiration": futures_index}
- )
- return pd.DataFrame({"Last Price": futures_curve, "expiration": futures_index})
+ if historical_curve:
+ df = DataFrame(
+ {
+ "date": futures_date,
+ "price": historical_curve,
+ "expiration": futures_index,
+ }
+ )
+ df["expiration"] = Categorical(
+ df["expiration"],
+ categories=sorted(list(set(futures_index))),
+ ordered=True,
+ )
+ df = df.sort_values(by=["date", "expiration"]).reset_index(drop=True)
+ if len(df.date.unique()) == 1:
+ df = df.drop(columns=["date"])
+
+ return df
+
+ return DataFrame({"price": futures_curve, "expiration": futures_index})
-# pylint: disable=too-many-arguments,unused-argument
def yf_download(
symbol: str,
- start_date: Optional[Union[str, dateType]] = None,
- end_date: Optional[Union[str, dateType]] = None,
+ start_date: Optional[Union[str, "date"]] = None,
+ end_date: Optional[Union[str, "date"]] = None,
interval: INTERVALS = "1d",
- period: PERIODS = "max",
+ period: Optional[PERIODS] = None,
prepost: bool = False,
actions: bool = False,
progress: bool = False,
@@ -96,8 +307,13 @@ def yf_download(
group_by: Literal["ticker", "column"] = "ticker",
adjusted: bool = False,
**kwargs: Any,
-) -> pd.DataFrame:
+) -> "DataFrame":
"""Get yFinance OHLC data for any ticker and interval available."""
+ # pylint: disable=import-outside-toplevel
+ from dateutil.relativedelta import relativedelta # noqa
+ import yfinance as yf
+ from pandas import DataFrame, concat, to_datetime
+
symbol = symbol.upper()
_start_date = start_date
intraday = False
@@ -118,7 +334,7 @@ def yf_download(
intraday = True
if adjusted is False:
- kwargs = dict(auto_adjust=False, back_adjust=False)
+ kwargs = dict(auto_adjust=False, back_adjust=False, period=period)
try:
data = yf.download(
@@ -126,7 +342,6 @@ def yf_download(
start=_start_date,
end=None,
interval=interval,
- period=period,
prepost=prepost,
actions=actions,
progress=progress,
@@ -143,7 +358,7 @@ def yf_download(
tickers = symbol.split(",")
if len(tickers) > 1:
- _data = pd.DataFrame()
+ _data = DataFrame()
for ticker in tickers:
temp = data[ticker].copy().dropna(how="all")
if len(temp) > 0:
@@ -151,7 +366,7 @@ def yf_download(
temp = temp.reset_index().rename(
columns={"Date": "date", "Datetime": "date", "index": "date"}
)
- _data = pd.concat([_data, temp])
+ _data = concat([_data, temp])
if not _data.empty:
index_keys = ["date", "symbol"] if "symbol" in _data.columns else "date"
_data = _data.set_index(index_keys).sort_index()
@@ -159,19 +374,19 @@ def yf_download(
if not data.empty:
data = data.reset_index()
data = data.rename(columns={"Date": "date", "Datetime": "date"})
- data["date"] = data["date"].apply(pd.to_datetime)
+ data["date"] = data["date"].apply(to_datetime)
data = data[data["Open"] > 0]
if start_date is not None:
- data = data[data["date"] >= pd.to_datetime(start_date)]
+ data = data[data["date"] >= to_datetime(start_date)] # type: ignore
if (
end_date is not None
and start_date is not None
- and pd.to_datetime(end_date) > pd.to_datetime(start_date)
+ and to_datetime(end_date) > to_datetime(start_date) # type: ignore
):
data = data[
data["date"]
<= (
- pd.to_datetime(end_date)
+ to_datetime(end_date) # type: ignore
+ relativedelta(minutes=719 if intraday is True else 0)
)
]
@@ -185,7 +400,7 @@ def yf_download(
return data
-def df_transform_numbers(data: pd.DataFrame, columns: list) -> pd.DataFrame:
+def df_transform_numbers(data: "DataFrame", columns: list) -> "DataFrame":
"""Replace abbreviations of numbers with actual numbers."""
multipliers = {"M": 1e6, "B": 1e9, "T": 1e12}