summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDidierRLopes <dro.lopes@campus.fct.unl.pt>2023-08-25 10:28:34 -0500
committerGitHub <noreply@github.com>2023-08-25 15:28:34 +0000
commit1036308f358babb1cabf9ca059df488701518af9 (patch)
treed72185426c19a4c713dc0c3009ccf96bd0d04f2d
parent141dd68e68b6cb76abdd1a9ef5c16bf545d020a8 (diff)
1st integration with TimeGPT-1 (Beta), from Nixtla (#5292)
* integrate time gpt from Nixtla * add to reqs * add residuals to timegpt * fix some bugs * remove twitter functions * remove twitter functions and all references to it * start work to integrate date features * add date_features to timegpt * remove twitter functions * remove twitter functions and all references to it * Update deps for timegpt + askobb * Update forecast load to set an index to date * lint * spell * mypy * mypy2 * blac k * pylint --------- Co-authored-by: James Maslek <jmaslek11@gmail.com>
-rw-r--r--openbb_terminal/forecast/forecast_controller.py130
-rw-r--r--openbb_terminal/forecast/helpers.py13
-rw-r--r--openbb_terminal/forecast/timegpt_model.py91
-rw-r--r--openbb_terminal/forecast/timegpt_view.py244
-rw-r--r--openbb_terminal/helper_funcs.py26
-rw-r--r--openbb_terminal/keys_controller.py30
-rw-r--r--openbb_terminal/keys_model.py70
-rw-r--r--openbb_terminal/miscellaneous/i18n/en.yml1
-rw-r--r--openbb_terminal/miscellaneous/models/all_api_keys.json6
-rw-r--r--openbb_terminal/miscellaneous/models/hub_credentials.json7
-rw-r--r--openbb_terminal/parent_classes.py7
-rw-r--r--openbb_terminal/stocks/stocks_controller.py5
-rw-r--r--poetry.lock612
-rw-r--r--pyproject.toml8
-rw-r--r--requirements-full.txt1
-rw-r--r--requirements.txt1
16 files changed, 740 insertions, 512 deletions
diff --git a/openbb_terminal/forecast/forecast_controller.py b/openbb_terminal/forecast/forecast_controller.py
index 218fcd66d40..ad223eac852 100644
--- a/openbb_terminal/forecast/forecast_controller.py
+++ b/openbb_terminal/forecast/forecast_controller.py
@@ -72,6 +72,10 @@ from openbb_terminal.helper_funcs import (
log_and_raise,
valid_date,
parse_and_split_input,
+ check_non_negative,
+ check_positive_float_list,
+ check_list_values,
+ check_valid_date,
)
from openbb_terminal.menu import session
@@ -103,6 +107,7 @@ from openbb_terminal.forecast import (
theta_view,
trans_view,
whisper_model,
+ timegpt_view,
)
logger = logging.getLogger(__name__)
@@ -166,6 +171,7 @@ class ForecastController(BaseController):
"nhits",
"anom",
"whisper",
+ "timegpt",
]
pandas_plot_choices = [
"line",
@@ -268,6 +274,7 @@ class ForecastController(BaseController):
Overrides the parent class function to handle YouTube video URL conventions.
See `BaseController.parse_input()` for details.
"""
+
# Filtering out YouTube video parameters like "v=" and removing the domain name
youtube_filter = r"(youtube\.com/watch\?v=)"
@@ -304,6 +311,7 @@ class ForecastController(BaseController):
def print_help(self):
"""Print help"""
+ self.update_runtime_choices()
current_user = get_current_user()
mt = MenuText("forecast/")
mt.add_param("_disclaimer_", self.disclaimer)
@@ -364,6 +372,7 @@ class ForecastController(BaseController):
mt.add_cmd("anom", self.files)
mt.add_raw("\n")
mt.add_info("_misc_")
+ mt.add_cmd("timegpt", self.files)
mt.add_cmd("whisper", WHISPER_AVAILABLE)
console.print(text=mt.menu_text, menu="Forecast")
@@ -709,7 +718,7 @@ class ForecastController(BaseController):
"""Loads news dataframes into memory"""
# check if data has minimum number of rows
- if len(data) < self.MINIMUM_DATA_LENGTH:
+ if ticker and len(data) < self.MINIMUM_DATA_LENGTH:
console.print(
f"[red]Dataset is smaller than recommended minimum {self.MINIMUM_DATA_LENGTH} data points. [/red]"
)
@@ -732,6 +741,9 @@ class ForecastController(BaseController):
# if we import a custom dataset, remove the old index "unnamed:_0"
if "unnamed:_0" in data.columns:
+ # Some loaded datasets have the date as unnamed, which is not helpful
+ if check_valid_date(data["unnamed:_0"].iloc[0]):
+ data["date"] = data["unnamed:_0"].copy()
data = data.drop(columns=["unnamed:_0"])
self.files.append(ticker)
@@ -3441,3 +3453,119 @@ class ForecastController(BaseController):
breaklines=ns_parser.breaklines,
output_dir=ns_parser.save,
)
+
+ # TimeGPT Model
+ @log_start_end(log=logger)
+ def call_timegpt(self, other_args: List[str]):
+ """Process expo command"""
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ add_help=False,
+ prog="timegpt",
+ description="""
+ TODO: Update me
+ """,
+ )
+ parser.add_argument(
+ "--horizon",
+ action="store",
+ dest="horizon",
+ type=check_positive,
+ default=12,
+ help="Forecasting horizon",
+ )
+ parser.add_argument(
+ "--freq",
+ action="store",
+ dest="freq",
+ choices=["H", "D", "W", "M", "MS", "B"],
+ default=None,
+ help="Frequency of the data.",
+ )
+ parser.add_argument(
+ "--finetune",
+ action="store",
+ dest="finetune",
+ type=check_non_negative,
+ default=0,
+ help="Number of steps used to finetune TimeGPT in the new data.",
+ )
+ parser.add_argument(
+ "--ci",
+ action="store",
+ dest="confidence",
+ type=check_positive_float_list,
+ default=[80, 90],
+ help="Number of steps used to finetune TimeGPT in the new data.",
+ )
+ parser.add_argument(
+ "--cleanex",
+ action="store_false",
+ help="Clean exogenous signal before making forecasts using TimeGPT.",
+ dest="cleanex",
+ default=True,
+ )
+ parser.add_argument(
+ "--timecol",
+ action="store",
+ dest="timecol",
+ default="ds",
+ type=str,
+ help="Dataframe column that represents datetime",
+ )
+ parser.add_argument(
+ "--targetcol",
+ action="store",
+ dest="targetcol",
+ default="y",
+ type=str,
+ help="Dataframe column that represents the target to forecast for",
+ )
+ parser.add_argument(
+ "--sheet-name",
+ help="The name of the sheet to export to when type is XLSX.",
+ dest="sheet_name",
+ type=str,
+ default="",
+ )
+ parser.add_argument(
+ "--datefeatures",
+ help="Specifies which date attributes have highest weight according to model.",
+ dest="date_features",
+ type=check_list_values(["auto", "year", "month", "week", "day", "weekday"]),
+ default=[],
+ )
+ parser = self.add_standard_args(
+ parser,
+ target_dataset=True,
+ target_column=True,
+ start=True,
+ end=True,
+ residuals=True,
+ )
+ if other_args and "-" not in other_args[0][0]:
+ other_args.insert(0, "--dataset")
+
+ ns_parser = self.parse_known_args_and_warn(
+ parser,
+ other_args,
+ export_allowed=EXPORT_ONLY_FIGURES_ALLOWED,
+ )
+ if ns_parser:
+ timegpt_view.display_timegpt_forecast(
+ data=self.datasets[ns_parser.target_dataset],
+ dataset_name=ns_parser.target_dataset,
+ time_col=ns_parser.timecol,
+ target_col=ns_parser.targetcol,
+ forecast_horizon=ns_parser.horizon,
+ freq=ns_parser.freq,
+ levels=ns_parser.confidence,
+ finetune_steps=ns_parser.finetune,
+ clean_ex_first=ns_parser.cleanex,
+ export=ns_parser.export,
+ sheet_name=ns_parser.sheet_name,
+ start_date=ns_parser.s_start_date,
+ end_date=ns_parser.s_end_date,
+ residuals=ns_parser.residuals,
+ date_features=ns_parser.date_features,
+ )
diff --git a/openbb_terminal/forecast/helpers.py b/openbb_terminal/forecast/helpers.py
index da42d8309fd..cfefcb8191f 100644
--- a/openbb_terminal/forecast/helpers.py
+++ b/openbb_terminal/forecast/helpers.py
@@ -1253,9 +1253,15 @@ def filter_dates(
console.print("[red]The start date must be before the end date.[/red]\n")
return data
if end_date:
- data = data[data["date"] <= end_date]
+ if isinstance(data["date"].values[0], str):
+ data = data[pd.to_datetime(data["date"]) <= end_date]
+ else:
+ data = data[data["date"] <= end_date]
if start_date:
- data = data[data["date"] >= start_date]
+ if isinstance(data["date"].values[0], str):
+ data = data[pd.to_datetime(data["date"]) >= start_date]
+ else:
+ data = data[data["date"] >= start_date]
return data
@@ -1341,6 +1347,9 @@ def check_data(
" Change the 'target_column' parameter.[/red]\n"
)
return False
+ if data.empty:
+ console.print("[red]The data provided is empty.[/red]\n")
+ return False
if past_covariates is not None:
covariates = past_covariates.split(",")
for covariate in covariates:
diff --git a/openbb_terminal/forecast/timegpt_model.py b/openbb_terminal/forecast/timegpt_model.py
new file mode 100644
index 00000000000..d973329e59f
--- /dev/null
+++ b/openbb_terminal/forecast/timegpt_model.py
@@ -0,0 +1,91 @@
+# pylint: disable=too-many-arguments
+"""Probabilistic Exponential Smoothing Model"""
+__docformat__ = "numpy"
+
+import logging
+from typing import List, Optional, Union
+
+import numpy
+import pandas as pd
+from nixtlats import TimeGPT
+
+from openbb_terminal.core.session.current_user import get_current_user
+from openbb_terminal.decorators import check_api_key, log_start_end
+
+logger = logging.getLogger(__name__)
+
+
+@log_start_end(log=logger)
+@check_api_key(["API_KEY_NIXTLA"])
+def get_timegpt_model(
+ data: Union[pd.Series, pd.DataFrame],
+ time_col: str = "ds",
+ target_col: str = "y",
+ forecast_horizon: int = 12,
+ levels: Optional[List[float]] = None,
+ freq: Union[str, None] = None,
+ finetune_steps: int = 0,
+ clean_ex_first: bool = True,
+ residuals: bool = False,
+ date_features: Optional[List[str]] = None,
+) -> pd.DataFrame:
+ """TimeGPT was trained on the largest collection of data in history -
+ over 100 billion rows of financial, weather, energy, and web data -
+ and democratizes the power of time-series analysis.
+
+ Parameters
+ ----------
+ data : Union[pd.Series, pd.DataFrame]
+ Input data.
+ time_col: str:
+ Column that identifies each timestep, its values can be timestamps or integers. Defaults to "ds".
+ target_column: str:
+ Target column to forecast. Defaults to "y".
+ forecast_horizon: int
+ Number of days to forecast. Defaults to 12.
+ levels: List[float]
+ Confidence levels between 0 and 100 for prediction intervals.
+ freq: Optional[str, None]
+ Frequency of the data. By default, the freq will be inferred automatically.
+ finetune_steps: int
+ Number of steps used to finetune TimeGPT in the new data.
+ clean_ex_first: bool
+ Clean exogenous signal before making forecasts using TimeGPT.
+ residuals: bool
+ Whether to show residuals for the model. Defaults to False.
+ date_features: Optional[List[str]]
+ Specifies which date attributes have highest weight according to model.
+
+ Returns
+ -------
+ pd.DataFrame
+ Forecasted values.
+ """
+ timegpt = TimeGPT(
+ token=get_current_user().credentials.API_KEY_NIXTLA,
+ )
+
+ if levels is None:
+ levels = [80, 95]
+
+ if isinstance(data[time_col].values[0], pd.Timestamp):
+ data[time_col] = data[time_col].dt.strftime("%Y-%m-%d")
+ elif isinstance(data[time_col].values[0], numpy.datetime64):
+ data[time_col] = pd.to_datetime(data[time_col]).dt.strftime("%Y-%m-%d")
+
+ date_features_param = True if "auto" in date_features else date_features # type: ignore
+
+ fcst_df = timegpt.forecast(
+ data,
+ time_col=time_col,
+ target_col=target_col,
+ h=forecast_horizon,
+ freq=freq,
+ level=levels,
+ finetune_steps=finetune_steps,
+ clean_ex_first=clean_ex_first,
+ add_history=residuals,
+ date_features=date_features_param if date_features else False,
+ )
+
+ return fcst_df, timegpt.weights_x
diff --git a/openbb_terminal/forecast/timegpt_view.py b/openbb_terminal/forecast/timegpt_view.py
new file mode 100644
index 00000000000..6b275fe851c
--- /dev/null
+++ b/openbb_terminal/forecast/timegpt_view.py
@@ -0,0 +1,244 @@
+"""Probabilistic Exponential Smoothing View"""
+__docformat__ = "numpy"
+
+import logging
+import os
+from datetime import datetime
+from typing import List, Optional, Union
+
+import pandas as pd
+import plotly.graph_objs as go
+
+from openbb_terminal import OpenBBFigure, theme
+from openbb_terminal.decorators import check_api_key, log_start_end
+from openbb_terminal.forecast import helpers, timegpt_model
+from openbb_terminal.helper_funcs import export_data
+from openbb_terminal.rich_config import console
+
+logger = logging.getLogger(__name__)
+# pylint: disable=too-many-arguments
+
+
+@log_start_end(log=logger)
+@check_api_key(["API_KEY_NIXTLA"])
+def display_timegpt_forecast(
+ data: Union[pd.DataFrame, pd.Series],
+ dataset_name: str = "",
+ time_col: str = "ds",
+ target_col: str = "y",
+ forecast_horizon: int = 12,
+ levels: Optional[List[float]] = None,
+ freq: Union[str, None] = None,
+ finetune_steps: int = 0,
+ clean_ex_first: bool = True,
+ export: str = "",
+ sheet_name: Optional[str] = None,
+ external_axes: bool = False,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ residuals: bool = False,
+ date_features: Optional[List[str]] = None,
+) -> Union[OpenBBFigure, None]:
+ """TimeGPT was trained on the largest collection of data in history -
+ over 100 billion rows of financial, weather, energy, and web data -
+ and democratizes the power of time-series analysis.
+
+ Parameters
+ ----------
+ data : Union[pd.Series, pd.DataFrame]
+ Input data.
+ dataset_name : str
+ Dataset name
+ time_col: str:
+ Column that identifies each timestep, its values can be timestamps or integers. Defaults to "ds".
+ target_column: str:
+ Target column to forecast. Defaults to "y".
+ forecast_horizon: int
+ Number of days to forecast. Defaults to 12.
+ levels: List[float]
+ Confidence levels between 0 and 100 for prediction intervals.
+ freq: Optional[str, None]
+ Frequency of the data. By default, the freq will be inferred automatically.
+ finetune_steps: int
+ Number of steps used to finetune TimeGPT in the new data.
+ clean_ex_first: bool
+ Clean exogenous signal before making forecasts using TimeGPT.
+ export: str
+ Format to export data
+ sheet_name: str
+ Optionally specify the name of the sheet the data is exported to.
+ external_axes: Optional[List[plt.axes]]
+ External axes to plot on
+ start_date: Optional[datetime]
+ The starting date to perform analysis, data before this is trimmed. Defaults to None.
+ end_date: Optional[datetime]
+ The ending date to perform analysis, data after this is trimmed. Defaults to None.
+ residuals: bool
+ Whether to show residuals for the model. Defaults to False.
+ date_features: Optional[List[str]]
+ Specifies which date attributes have highest weight according to model.
+
+ Returns
+ -------
+ pd.DataFrame
+ Forecasted values.
+ """
+ if levels is None:
+ levels = [80, 95]
+
+ if time_col not in data.columns:
+ if time_col == "ds": # means that the user has not set it yet, bc default
+ console.print(
+ f"[red]Set the time column for '{dataset_name}' between: {', '.join(data.columns)}[/red]"
+ )
+ else:
+ console.print(f"[red]Time column '{time_col}' not found in data[/red]")
+ return None
+
+ if target_col not in data.columns:
+ if target_col == "y": # means that the user has not set it yet, bc default
+ console.print(
+ f"[red]Set the target column for '{dataset_name}' between: {', '.join(data.columns)}[/red]"
+ )
+ else:
+ console.print(f"[red]Target column '{target_col}' not found in data[/red]")
+ return None
+
+ data = helpers.clean_data(data, start_date, end_date, target_col, None)
+ if not helpers.check_data(data, target_col, None):
+ return None
+
+ df, datefeatures_df = timegpt_model.get_timegpt_model(
+ data=data,
+ time_col=time_col,
+ target_col=target_col,
+ forecast_horizon=forecast_horizon,
+ levels=levels,
+ freq=freq,
+ finetune_steps=finetune_steps,
+ clean_ex_first=clean_ex_first,
+ residuals=residuals,
+ date_features=date_features,
+ )
+
+ fig = OpenBBFigure(xaxis_title="Date")
+
+ fig.set_title(f"TimeGPT-1 (Beta) on {target_col} with horizon {forecast_horizon}")
+
+ if residuals:
+ xds = list(pd.to_datetime(df[time_col].values))[:-forecast_horizon]
+ xds_reverse = list(pd.to_datetime(df[time_col].values))[:-forecast_horizon]
+ xds_reverse.reverse()
+ xds += xds_reverse
+
+ xds_forecast = list(pd.to_datetime(df[time_col].values))[-forecast_horizon:]
+ xds_forecast_reverse = list(pd.to_datetime(df[time_col].values))[-forecast_horizon:]
+ xds_forecast_reverse.reverse()
+ xds_forecast += xds_forecast_reverse
+
+ # this is done so the confidence levels are displayed correctly
+ levels.sort()
+ for count, lvl in enumerate(levels):
+ lvl_name = (
+ str(lvl)
+ if isinstance(lvl, int)
+ else str(int(lvl) if lvl.is_integer() else lvl)
+ )
+
+ if residuals:
+ ylo = list(df[f"TimeGPT-lo-{lvl_name}"].values)[:-forecast_horizon]
+ yhigh = list(df[f"TimeGPT-hi-{lvl_name}"].values)[:-forecast_horizon]
+ yhigh.reverse()
+ ylo += yhigh
+
+ fig.add_traces(
+ [
+ go.Scatter(
+ x=xds,
+ y=ylo,
+ mode="lines",
+ line_color=f"rgba(255,127,14,{.2+(len(levels)-count)*(.6/(len(levels)+1))})",
+ name=f"{lvl_name}% confidence interval historical",
+ fill="toself",
+ fillcolor=f"rgba(255,127,14,{.2+(len(levels)-count)*(.6/(len(levels)+1))})",
+ )
+ ]
+ )
+
+ ylo_forecast = list(df[f"TimeGPT-lo-{lvl_name}"].values)[-forecast_horizon:]
+ yhigh_forecast = list(df[f"TimeGPT-hi-{lvl_name}"].values)[-forecast_horizon:]
+ yhigh_forecast.reverse()
+ ylo_forecast += yhigh_forecast
+
+ fig.add_traces(
+ [
+ go.Scatter(
+ x=xds_forecast,
+ y=ylo_forecast,
+ mode="lines",
+ line_color=f"rgba(0,172,255,{.2+(len(levels)-count)*(.6/(len(levels)+1))})",
+ name=f"{lvl_name}% confidence interval",
+ fill="toself",
+ fillcolor=f"rgba(0,172,255,{.2+(len(levels)-count)*(.6/(len(levels)+1))})",
+ )
+ ]
+ )
+
+ if residuals:
+ # TimeGPT prediction - historical
+ fig.add_scatter(
+ x=list(pd.to_datetime(df[time_col].values))[:-forecast_horizon],
+ y=list(df["TimeGPT"].values)[:-forecast_horizon],
+ name="TimeGPT historical forecast",
+ mode="lines",
+ line=dict(
+ color="rgba(255,127,14,1)",
+ width=3,
+ ),
+ )
+
+ # TimeGPT prediction
+ fig.add_scatter(
+ x=list(pd.to_datetime(df[time_col].values))[-forecast_horizon:],
+ y=list(df["TimeGPT"].values)[-forecast_horizon:],
+ name="TimeGPT forecast",
+ mode="markers+lines",
+ line=dict(
+ color="rgba(0,172,255,1)",
+ width=3,
+ ),
+ )
+
+ # Current data
+ fig.add_scatter(
+ x=list(pd.to_datetime(data[time_col].values)),
+ y=list(data[target_col].values),
+ name="Actual",
+ line_color="gold",
+ mode="lines",
+ )
+ fig.show(external=external_axes)
+
+ if date_features:
+ fig2 = OpenBBFigure(xaxis_title="Weights")
+
+ fig2.set_title("Date features weight")
+
+ fig2.add_bar(
+ y=datefeatures_df["features"],
+ x=datefeatures_df["weights"],
+ marker_color=theme.up_color,
+ orientation="h",
+ )
+ fig2.show(external=external_axes)
+
+ export_data(
+ export,
+ os.path.dirname(os.path.abspath(__file__)),
+ f"timegpt_forecast_{target_col}",
+ df=df,
+ sheet_name=sheet_name,
+ figure=fig,
+ )
+
+ return fig
diff --git a/openbb_terminal/helper_funcs.py b/openbb_terminal/helper_funcs.py
index ab1b83457e3..700ef0b7316 100644
--- a/openbb_terminal/helper_funcs.py
+++ b/openbb_terminal/helper_funcs.py
@@ -540,6 +540,20 @@ def check_positive_list(value) -> List[int]:
return list_of_pos
+def check_positive_float_list(value) -> List[float]:
+ """Argparse type to return list of positive floats."""
+ list_of_nums = value.split(",")
+ list_of_pos = []
+ for a_value in list_of_nums:
+ new_value = float(a_value)
+ if new_value <= 0:
+ log_and_raise(
+ argparse.ArgumentTypeError(f"{value} is an invalid positive int value")
+ )
+ list_of_pos.append(new_value)
+ return list_of_pos
+
+
def check_positive(value) -> int:
"""Argparse type to check positive int."""
new_value = int(value)
@@ -2281,3 +2295,15 @@ def query_LLM_remote(query_text: str):
return None, None
return ask_obbrequest_data["response"], ask_obbrequest_data["source_nodes"]
+
+
+def check_valid_date(date_string) -> bool:
+ """ "Helper to see if we can parse the string to a date"""
+ try:
+ # Try to parse the string with strptime()
+ datetime.strptime(
+ date_string, "%Y-%m-%d"
+ ) # Use the format your dates are expected to be in
+ return True # If it can be parsed, then it is a valid date string
+ except ValueError: # strptime() throws a ValueError if the string can't be parsed
+ return False
diff --git a/openbb_terminal/keys_controller.py b/openbb_terminal/keys_controller.py
index ec8b430ee28..558faf4c014 100644
--- a/openbb_terminal/keys_controller.py
+++ b/openbb_terminal/keys_controller.py
@@ -1280,3 +1280,33 @@ class KeysController(BaseController): # pylint: disable=too-many-public-methods
self.status_dict["dappradar"] = keys_model.set_dappradar_key(
key=ns_parser.key, persist=True, show_output=True
)
+
+ @log_start_end(log=logger)
+ def call_nixtla(self, other_args: List[str]):
+ """Process nixtla command"""
+ parser = argparse.ArgumentParser(
+ add_help=False,
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ prog="nixtla",
+ description="Set Nixtla API key.",
+ )
+ parser.add_argument(
+ "-k",
+ "--key",
+ type=str,
+ dest="key",
+ help="key",
+ )
+ if not other_args:
+ console.print(
+ "For your API Key, visit: https://docs.nixtla.io/docs/getting-started"
+ )
+ return
+
+ if other_args and "-" not in other_args[0][0]:
+ other_args.insert(0, "-k")
+ ns_parser = self.parse_simple_args(parser, other_args)
+ if ns_parser:
+ self.status_dict["nixtla"] = keys_model.set_nixtla_key(
+ key=ns_parser.key, persist=True, show_output=True
+ )
diff --git a/openbb_terminal/keys_model.py b/openbb_terminal/keys_model.py
index a3f40a9c846..7e62daa9b99 100644
--- a/openbb_terminal/keys_model.py
+++ b/openbb_terminal/keys_model.py
@@ -20,6 +20,7 @@ import quandl
import stocksera
from alpha_vantage.timeseries import TimeSeries
from coinmarketcapapi import CoinMarketCapAPI
+from nixtlats import TimeGPT
from oandapyV20 import API as oanda_API
from prawcore.exceptions import ResponseException
from tokenterminal import TokenTerminal
@@ -83,6 +84,7 @@ API_DICT: Dict = {
"stocksera": "STOCKSERA",
"dappradar": "DAPPRADAR",
"openai": "OPENAI",
+ "nixtla": "NIXTLA",
}
# sorting api key section by name
@@ -2829,3 +2831,71 @@ def check_openai_key(show_output: bool = False) -> str:
console.print(status.colorize())
return str(status)
+
+
+def set_nixtla_key(key: str, persist: bool = False, show_output: bool = False) -> str:
+ """Set Nixtla API key
+
+ Parameters
+ ----------
+ key: str
+ API key
+ persist: bool, optional
+ If False,