diff options
author | DidierRLopes <dro.lopes@campus.fct.unl.pt> | 2023-08-25 10:28:34 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-25 15:28:34 +0000 |
commit | 1036308f358babb1cabf9ca059df488701518af9 (patch) | |
tree | d72185426c19a4c713dc0c3009ccf96bd0d04f2d | |
parent | 141dd68e68b6cb76abdd1a9ef5c16bf545d020a8 (diff) |
1st integration with TimeGPT-1 (Beta), from Nixtla (#5292)
* integrate time gpt from Nixtla
* add to reqs
* add residuals to timegpt
* fix some bugs
* remove twitter functions
* remove twitter functions and all references to it
* start work to integrate date features
* add date_features to timegpt
* remove twitter functions
* remove twitter functions and all references to it
* Update deps for timegpt + askobb
* Update forecast load to set an index to date
* lint
* spell
* mypy
* mypy2
* blac k
* pylint
---------
Co-authored-by: James Maslek <jmaslek11@gmail.com>
-rw-r--r-- | openbb_terminal/forecast/forecast_controller.py | 130 | ||||
-rw-r--r-- | openbb_terminal/forecast/helpers.py | 13 | ||||
-rw-r--r-- | openbb_terminal/forecast/timegpt_model.py | 91 | ||||
-rw-r--r-- | openbb_terminal/forecast/timegpt_view.py | 244 | ||||
-rw-r--r-- | openbb_terminal/helper_funcs.py | 26 | ||||
-rw-r--r-- | openbb_terminal/keys_controller.py | 30 | ||||
-rw-r--r-- | openbb_terminal/keys_model.py | 70 | ||||
-rw-r--r-- | openbb_terminal/miscellaneous/i18n/en.yml | 1 | ||||
-rw-r--r-- | openbb_terminal/miscellaneous/models/all_api_keys.json | 6 | ||||
-rw-r--r-- | openbb_terminal/miscellaneous/models/hub_credentials.json | 7 | ||||
-rw-r--r-- | openbb_terminal/parent_classes.py | 7 | ||||
-rw-r--r-- | openbb_terminal/stocks/stocks_controller.py | 5 | ||||
-rw-r--r-- | poetry.lock | 612 | ||||
-rw-r--r-- | pyproject.toml | 8 | ||||
-rw-r--r-- | requirements-full.txt | 1 | ||||
-rw-r--r-- | requirements.txt | 1 |
16 files changed, 740 insertions, 512 deletions
diff --git a/openbb_terminal/forecast/forecast_controller.py b/openbb_terminal/forecast/forecast_controller.py index 218fcd66d40..ad223eac852 100644 --- a/openbb_terminal/forecast/forecast_controller.py +++ b/openbb_terminal/forecast/forecast_controller.py @@ -72,6 +72,10 @@ from openbb_terminal.helper_funcs import ( log_and_raise, valid_date, parse_and_split_input, + check_non_negative, + check_positive_float_list, + check_list_values, + check_valid_date, ) from openbb_terminal.menu import session @@ -103,6 +107,7 @@ from openbb_terminal.forecast import ( theta_view, trans_view, whisper_model, + timegpt_view, ) logger = logging.getLogger(__name__) @@ -166,6 +171,7 @@ class ForecastController(BaseController): "nhits", "anom", "whisper", + "timegpt", ] pandas_plot_choices = [ "line", @@ -268,6 +274,7 @@ class ForecastController(BaseController): Overrides the parent class function to handle YouTube video URL conventions. See `BaseController.parse_input()` for details. """ + # Filtering out YouTube video parameters like "v=" and removing the domain name youtube_filter = r"(youtube\.com/watch\?v=)" @@ -304,6 +311,7 @@ class ForecastController(BaseController): def print_help(self): """Print help""" + self.update_runtime_choices() current_user = get_current_user() mt = MenuText("forecast/") mt.add_param("_disclaimer_", self.disclaimer) @@ -364,6 +372,7 @@ class ForecastController(BaseController): mt.add_cmd("anom", self.files) mt.add_raw("\n") mt.add_info("_misc_") + mt.add_cmd("timegpt", self.files) mt.add_cmd("whisper", WHISPER_AVAILABLE) console.print(text=mt.menu_text, menu="Forecast") @@ -709,7 +718,7 @@ class ForecastController(BaseController): """Loads news dataframes into memory""" # check if data has minimum number of rows - if len(data) < self.MINIMUM_DATA_LENGTH: + if ticker and len(data) < self.MINIMUM_DATA_LENGTH: console.print( f"[red]Dataset is smaller than recommended minimum {self.MINIMUM_DATA_LENGTH} data points. [/red]" ) @@ -732,6 +741,9 @@ class ForecastController(BaseController): # if we import a custom dataset, remove the old index "unnamed:_0" if "unnamed:_0" in data.columns: + # Some loaded datasets have the date as unnamed, which is not helpful + if check_valid_date(data["unnamed:_0"].iloc[0]): + data["date"] = data["unnamed:_0"].copy() data = data.drop(columns=["unnamed:_0"]) self.files.append(ticker) @@ -3441,3 +3453,119 @@ class ForecastController(BaseController): breaklines=ns_parser.breaklines, output_dir=ns_parser.save, ) + + # TimeGPT Model + @log_start_end(log=logger) + def call_timegpt(self, other_args: List[str]): + """Process expo command""" + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + add_help=False, + prog="timegpt", + description=""" + TODO: Update me + """, + ) + parser.add_argument( + "--horizon", + action="store", + dest="horizon", + type=check_positive, + default=12, + help="Forecasting horizon", + ) + parser.add_argument( + "--freq", + action="store", + dest="freq", + choices=["H", "D", "W", "M", "MS", "B"], + default=None, + help="Frequency of the data.", + ) + parser.add_argument( + "--finetune", + action="store", + dest="finetune", + type=check_non_negative, + default=0, + help="Number of steps used to finetune TimeGPT in the new data.", + ) + parser.add_argument( + "--ci", + action="store", + dest="confidence", + type=check_positive_float_list, + default=[80, 90], + help="Number of steps used to finetune TimeGPT in the new data.", + ) + parser.add_argument( + "--cleanex", + action="store_false", + help="Clean exogenous signal before making forecasts using TimeGPT.", + dest="cleanex", + default=True, + ) + parser.add_argument( + "--timecol", + action="store", + dest="timecol", + default="ds", + type=str, + help="Dataframe column that represents datetime", + ) + parser.add_argument( + "--targetcol", + action="store", + dest="targetcol", + default="y", + type=str, + help="Dataframe column that represents the target to forecast for", + ) + parser.add_argument( + "--sheet-name", + help="The name of the sheet to export to when type is XLSX.", + dest="sheet_name", + type=str, + default="", + ) + parser.add_argument( + "--datefeatures", + help="Specifies which date attributes have highest weight according to model.", + dest="date_features", + type=check_list_values(["auto", "year", "month", "week", "day", "weekday"]), + default=[], + ) + parser = self.add_standard_args( + parser, + target_dataset=True, + target_column=True, + start=True, + end=True, + residuals=True, + ) + if other_args and "-" not in other_args[0][0]: + other_args.insert(0, "--dataset") + + ns_parser = self.parse_known_args_and_warn( + parser, + other_args, + export_allowed=EXPORT_ONLY_FIGURES_ALLOWED, + ) + if ns_parser: + timegpt_view.display_timegpt_forecast( + data=self.datasets[ns_parser.target_dataset], + dataset_name=ns_parser.target_dataset, + time_col=ns_parser.timecol, + target_col=ns_parser.targetcol, + forecast_horizon=ns_parser.horizon, + freq=ns_parser.freq, + levels=ns_parser.confidence, + finetune_steps=ns_parser.finetune, + clean_ex_first=ns_parser.cleanex, + export=ns_parser.export, + sheet_name=ns_parser.sheet_name, + start_date=ns_parser.s_start_date, + end_date=ns_parser.s_end_date, + residuals=ns_parser.residuals, + date_features=ns_parser.date_features, + ) diff --git a/openbb_terminal/forecast/helpers.py b/openbb_terminal/forecast/helpers.py index da42d8309fd..cfefcb8191f 100644 --- a/openbb_terminal/forecast/helpers.py +++ b/openbb_terminal/forecast/helpers.py @@ -1253,9 +1253,15 @@ def filter_dates( console.print("[red]The start date must be before the end date.[/red]\n") return data if end_date: - data = data[data["date"] <= end_date] + if isinstance(data["date"].values[0], str): + data = data[pd.to_datetime(data["date"]) <= end_date] + else: + data = data[data["date"] <= end_date] if start_date: - data = data[data["date"] >= start_date] + if isinstance(data["date"].values[0], str): + data = data[pd.to_datetime(data["date"]) >= start_date] + else: + data = data[data["date"] >= start_date] return data @@ -1341,6 +1347,9 @@ def check_data( " Change the 'target_column' parameter.[/red]\n" ) return False + if data.empty: + console.print("[red]The data provided is empty.[/red]\n") + return False if past_covariates is not None: covariates = past_covariates.split(",") for covariate in covariates: diff --git a/openbb_terminal/forecast/timegpt_model.py b/openbb_terminal/forecast/timegpt_model.py new file mode 100644 index 00000000000..d973329e59f --- /dev/null +++ b/openbb_terminal/forecast/timegpt_model.py @@ -0,0 +1,91 @@ +# pylint: disable=too-many-arguments +"""Probabilistic Exponential Smoothing Model""" +__docformat__ = "numpy" + +import logging +from typing import List, Optional, Union + +import numpy +import pandas as pd +from nixtlats import TimeGPT + +from openbb_terminal.core.session.current_user import get_current_user +from openbb_terminal.decorators import check_api_key, log_start_end + +logger = logging.getLogger(__name__) + + +@log_start_end(log=logger) +@check_api_key(["API_KEY_NIXTLA"]) +def get_timegpt_model( + data: Union[pd.Series, pd.DataFrame], + time_col: str = "ds", + target_col: str = "y", + forecast_horizon: int = 12, + levels: Optional[List[float]] = None, + freq: Union[str, None] = None, + finetune_steps: int = 0, + clean_ex_first: bool = True, + residuals: bool = False, + date_features: Optional[List[str]] = None, +) -> pd.DataFrame: + """TimeGPT was trained on the largest collection of data in history - + over 100 billion rows of financial, weather, energy, and web data - + and democratizes the power of time-series analysis. + + Parameters + ---------- + data : Union[pd.Series, pd.DataFrame] + Input data. + time_col: str: + Column that identifies each timestep, its values can be timestamps or integers. Defaults to "ds". + target_column: str: + Target column to forecast. Defaults to "y". + forecast_horizon: int + Number of days to forecast. Defaults to 12. + levels: List[float] + Confidence levels between 0 and 100 for prediction intervals. + freq: Optional[str, None] + Frequency of the data. By default, the freq will be inferred automatically. + finetune_steps: int + Number of steps used to finetune TimeGPT in the new data. + clean_ex_first: bool + Clean exogenous signal before making forecasts using TimeGPT. + residuals: bool + Whether to show residuals for the model. Defaults to False. + date_features: Optional[List[str]] + Specifies which date attributes have highest weight according to model. + + Returns + ------- + pd.DataFrame + Forecasted values. + """ + timegpt = TimeGPT( + token=get_current_user().credentials.API_KEY_NIXTLA, + ) + + if levels is None: + levels = [80, 95] + + if isinstance(data[time_col].values[0], pd.Timestamp): + data[time_col] = data[time_col].dt.strftime("%Y-%m-%d") + elif isinstance(data[time_col].values[0], numpy.datetime64): + data[time_col] = pd.to_datetime(data[time_col]).dt.strftime("%Y-%m-%d") + + date_features_param = True if "auto" in date_features else date_features # type: ignore + + fcst_df = timegpt.forecast( + data, + time_col=time_col, + target_col=target_col, + h=forecast_horizon, + freq=freq, + level=levels, + finetune_steps=finetune_steps, + clean_ex_first=clean_ex_first, + add_history=residuals, + date_features=date_features_param if date_features else False, + ) + + return fcst_df, timegpt.weights_x diff --git a/openbb_terminal/forecast/timegpt_view.py b/openbb_terminal/forecast/timegpt_view.py new file mode 100644 index 00000000000..6b275fe851c --- /dev/null +++ b/openbb_terminal/forecast/timegpt_view.py @@ -0,0 +1,244 @@ +"""Probabilistic Exponential Smoothing View""" +__docformat__ = "numpy" + +import logging +import os +from datetime import datetime +from typing import List, Optional, Union + +import pandas as pd +import plotly.graph_objs as go + +from openbb_terminal import OpenBBFigure, theme +from openbb_terminal.decorators import check_api_key, log_start_end +from openbb_terminal.forecast import helpers, timegpt_model +from openbb_terminal.helper_funcs import export_data +from openbb_terminal.rich_config import console + +logger = logging.getLogger(__name__) +# pylint: disable=too-many-arguments + + +@log_start_end(log=logger) +@check_api_key(["API_KEY_NIXTLA"]) +def display_timegpt_forecast( + data: Union[pd.DataFrame, pd.Series], + dataset_name: str = "", + time_col: str = "ds", + target_col: str = "y", + forecast_horizon: int = 12, + levels: Optional[List[float]] = None, + freq: Union[str, None] = None, + finetune_steps: int = 0, + clean_ex_first: bool = True, + export: str = "", + sheet_name: Optional[str] = None, + external_axes: bool = False, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + residuals: bool = False, + date_features: Optional[List[str]] = None, +) -> Union[OpenBBFigure, None]: + """TimeGPT was trained on the largest collection of data in history - + over 100 billion rows of financial, weather, energy, and web data - + and democratizes the power of time-series analysis. + + Parameters + ---------- + data : Union[pd.Series, pd.DataFrame] + Input data. + dataset_name : str + Dataset name + time_col: str: + Column that identifies each timestep, its values can be timestamps or integers. Defaults to "ds". + target_column: str: + Target column to forecast. Defaults to "y". + forecast_horizon: int + Number of days to forecast. Defaults to 12. + levels: List[float] + Confidence levels between 0 and 100 for prediction intervals. + freq: Optional[str, None] + Frequency of the data. By default, the freq will be inferred automatically. + finetune_steps: int + Number of steps used to finetune TimeGPT in the new data. + clean_ex_first: bool + Clean exogenous signal before making forecasts using TimeGPT. + export: str + Format to export data + sheet_name: str + Optionally specify the name of the sheet the data is exported to. + external_axes: Optional[List[plt.axes]] + External axes to plot on + start_date: Optional[datetime] + The starting date to perform analysis, data before this is trimmed. Defaults to None. + end_date: Optional[datetime] + The ending date to perform analysis, data after this is trimmed. Defaults to None. + residuals: bool + Whether to show residuals for the model. Defaults to False. + date_features: Optional[List[str]] + Specifies which date attributes have highest weight according to model. + + Returns + ------- + pd.DataFrame + Forecasted values. + """ + if levels is None: + levels = [80, 95] + + if time_col not in data.columns: + if time_col == "ds": # means that the user has not set it yet, bc default + console.print( + f"[red]Set the time column for '{dataset_name}' between: {', '.join(data.columns)}[/red]" + ) + else: + console.print(f"[red]Time column '{time_col}' not found in data[/red]") + return None + + if target_col not in data.columns: + if target_col == "y": # means that the user has not set it yet, bc default + console.print( + f"[red]Set the target column for '{dataset_name}' between: {', '.join(data.columns)}[/red]" + ) + else: + console.print(f"[red]Target column '{target_col}' not found in data[/red]") + return None + + data = helpers.clean_data(data, start_date, end_date, target_col, None) + if not helpers.check_data(data, target_col, None): + return None + + df, datefeatures_df = timegpt_model.get_timegpt_model( + data=data, + time_col=time_col, + target_col=target_col, + forecast_horizon=forecast_horizon, + levels=levels, + freq=freq, + finetune_steps=finetune_steps, + clean_ex_first=clean_ex_first, + residuals=residuals, + date_features=date_features, + ) + + fig = OpenBBFigure(xaxis_title="Date") + + fig.set_title(f"TimeGPT-1 (Beta) on {target_col} with horizon {forecast_horizon}") + + if residuals: + xds = list(pd.to_datetime(df[time_col].values))[:-forecast_horizon] + xds_reverse = list(pd.to_datetime(df[time_col].values))[:-forecast_horizon] + xds_reverse.reverse() + xds += xds_reverse + + xds_forecast = list(pd.to_datetime(df[time_col].values))[-forecast_horizon:] + xds_forecast_reverse = list(pd.to_datetime(df[time_col].values))[-forecast_horizon:] + xds_forecast_reverse.reverse() + xds_forecast += xds_forecast_reverse + + # this is done so the confidence levels are displayed correctly + levels.sort() + for count, lvl in enumerate(levels): + lvl_name = ( + str(lvl) + if isinstance(lvl, int) + else str(int(lvl) if lvl.is_integer() else lvl) + ) + + if residuals: + ylo = list(df[f"TimeGPT-lo-{lvl_name}"].values)[:-forecast_horizon] + yhigh = list(df[f"TimeGPT-hi-{lvl_name}"].values)[:-forecast_horizon] + yhigh.reverse() + ylo += yhigh + + fig.add_traces( + [ + go.Scatter( + x=xds, + y=ylo, + mode="lines", + line_color=f"rgba(255,127,14,{.2+(len(levels)-count)*(.6/(len(levels)+1))})", + name=f"{lvl_name}% confidence interval historical", + fill="toself", + fillcolor=f"rgba(255,127,14,{.2+(len(levels)-count)*(.6/(len(levels)+1))})", + ) + ] + ) + + ylo_forecast = list(df[f"TimeGPT-lo-{lvl_name}"].values)[-forecast_horizon:] + yhigh_forecast = list(df[f"TimeGPT-hi-{lvl_name}"].values)[-forecast_horizon:] + yhigh_forecast.reverse() + ylo_forecast += yhigh_forecast + + fig.add_traces( + [ + go.Scatter( + x=xds_forecast, + y=ylo_forecast, + mode="lines", + line_color=f"rgba(0,172,255,{.2+(len(levels)-count)*(.6/(len(levels)+1))})", + name=f"{lvl_name}% confidence interval", + fill="toself", + fillcolor=f"rgba(0,172,255,{.2+(len(levels)-count)*(.6/(len(levels)+1))})", + ) + ] + ) + + if residuals: + # TimeGPT prediction - historical + fig.add_scatter( + x=list(pd.to_datetime(df[time_col].values))[:-forecast_horizon], + y=list(df["TimeGPT"].values)[:-forecast_horizon], + name="TimeGPT historical forecast", + mode="lines", + line=dict( + color="rgba(255,127,14,1)", + width=3, + ), + ) + + # TimeGPT prediction + fig.add_scatter( + x=list(pd.to_datetime(df[time_col].values))[-forecast_horizon:], + y=list(df["TimeGPT"].values)[-forecast_horizon:], + name="TimeGPT forecast", + mode="markers+lines", + line=dict( + color="rgba(0,172,255,1)", + width=3, + ), + ) + + # Current data + fig.add_scatter( + x=list(pd.to_datetime(data[time_col].values)), + y=list(data[target_col].values), + name="Actual", + line_color="gold", + mode="lines", + ) + fig.show(external=external_axes) + + if date_features: + fig2 = OpenBBFigure(xaxis_title="Weights") + + fig2.set_title("Date features weight") + + fig2.add_bar( + y=datefeatures_df["features"], + x=datefeatures_df["weights"], + marker_color=theme.up_color, + orientation="h", + ) + fig2.show(external=external_axes) + + export_data( + export, + os.path.dirname(os.path.abspath(__file__)), + f"timegpt_forecast_{target_col}", + df=df, + sheet_name=sheet_name, + figure=fig, + ) + + return fig diff --git a/openbb_terminal/helper_funcs.py b/openbb_terminal/helper_funcs.py index ab1b83457e3..700ef0b7316 100644 --- a/openbb_terminal/helper_funcs.py +++ b/openbb_terminal/helper_funcs.py @@ -540,6 +540,20 @@ def check_positive_list(value) -> List[int]: return list_of_pos +def check_positive_float_list(value) -> List[float]: + """Argparse type to return list of positive floats.""" + list_of_nums = value.split(",") + list_of_pos = [] + for a_value in list_of_nums: + new_value = float(a_value) + if new_value <= 0: + log_and_raise( + argparse.ArgumentTypeError(f"{value} is an invalid positive int value") + ) + list_of_pos.append(new_value) + return list_of_pos + + def check_positive(value) -> int: """Argparse type to check positive int.""" new_value = int(value) @@ -2281,3 +2295,15 @@ def query_LLM_remote(query_text: str): return None, None return ask_obbrequest_data["response"], ask_obbrequest_data["source_nodes"] + + +def check_valid_date(date_string) -> bool: + """ "Helper to see if we can parse the string to a date""" + try: + # Try to parse the string with strptime() + datetime.strptime( + date_string, "%Y-%m-%d" + ) # Use the format your dates are expected to be in + return True # If it can be parsed, then it is a valid date string + except ValueError: # strptime() throws a ValueError if the string can't be parsed + return False diff --git a/openbb_terminal/keys_controller.py b/openbb_terminal/keys_controller.py index ec8b430ee28..558faf4c014 100644 --- a/openbb_terminal/keys_controller.py +++ b/openbb_terminal/keys_controller.py @@ -1280,3 +1280,33 @@ class KeysController(BaseController): # pylint: disable=too-many-public-methods self.status_dict["dappradar"] = keys_model.set_dappradar_key( key=ns_parser.key, persist=True, show_output=True ) + + @log_start_end(log=logger) + def call_nixtla(self, other_args: List[str]): + """Process nixtla command""" + parser = argparse.ArgumentParser( + add_help=False, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + prog="nixtla", + description="Set Nixtla API key.", + ) + parser.add_argument( + "-k", + "--key", + type=str, + dest="key", + help="key", + ) + if not other_args: + console.print( + "For your API Key, visit: https://docs.nixtla.io/docs/getting-started" + ) + return + + if other_args and "-" not in other_args[0][0]: + other_args.insert(0, "-k") + ns_parser = self.parse_simple_args(parser, other_args) + if ns_parser: + self.status_dict["nixtla"] = keys_model.set_nixtla_key( + key=ns_parser.key, persist=True, show_output=True + ) diff --git a/openbb_terminal/keys_model.py b/openbb_terminal/keys_model.py index a3f40a9c846..7e62daa9b99 100644 --- a/openbb_terminal/keys_model.py +++ b/openbb_terminal/keys_model.py @@ -20,6 +20,7 @@ import quandl import stocksera from alpha_vantage.timeseries import TimeSeries from coinmarketcapapi import CoinMarketCapAPI +from nixtlats import TimeGPT from oandapyV20 import API as oanda_API from prawcore.exceptions import ResponseException from tokenterminal import TokenTerminal @@ -83,6 +84,7 @@ API_DICT: Dict = { "stocksera": "STOCKSERA", "dappradar": "DAPPRADAR", "openai": "OPENAI", + "nixtla": "NIXTLA", } # sorting api key section by name @@ -2829,3 +2831,71 @@ def check_openai_key(show_output: bool = False) -> str: console.print(status.colorize()) return str(status) + + +def set_nixtla_key(key: str, persist: bool = False, show_output: bool = False) -> str: + """Set Nixtla API key + + Parameters + ---------- + key: str + API key + persist: bool, optional + If False, |