summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormartinb-bb <105685594+martinb-bb@users.noreply.github.com>2022-11-28 09:27:11 -0500
committerGitHub <noreply@github.com>2022-11-28 09:27:11 -0500
commitb1f3dc1bf143c1280e40ec36ce9c877977d0c454 (patch)
tree9a44491df79af198e143c55b96fd3c02da7e8f54
parentcaf2f853c3b9b4163fc2d7c4901753f17c4a6d8a (diff)
fixing parameters (#3602)
Co-authored-by: James Maslek <jmaslek11@gmail.com>
-rw-r--r--openbb_terminal/forecast/brnn_view.py2
-rw-r--r--openbb_terminal/forecast/forecast_controller.py6
-rw-r--r--openbb_terminal/forecast/nhits_model.py2
-rw-r--r--openbb_terminal/forecast/nhits_view.py2
-rw-r--r--openbb_terminal/forecast/tcn_model.py7
-rw-r--r--openbb_terminal/forecast/tcn_view.py6
-rw-r--r--openbb_terminal/forecast/trans_model.py3
-rw-r--r--openbb_terminal/forecast/trans_view.py6
-rw-r--r--website/content/terminal/reference/forecast/nhits.md6
9 files changed, 19 insertions, 21 deletions
diff --git a/openbb_terminal/forecast/brnn_view.py b/openbb_terminal/forecast/brnn_view.py
index 21daa9a80ae..cff4d7fd05a 100644
--- a/openbb_terminal/forecast/brnn_view.py
+++ b/openbb_terminal/forecast/brnn_view.py
@@ -33,7 +33,7 @@ def display_brnn_forecast(
batch_size: int = 32,
n_epochs: int = 100,
learning_rate: float = 1e-3,
- model_save_name: str = "rnn_model",
+ model_save_name: str = "brnn_model",
force_reset: bool = True,
save_checkpoints: bool = True,
export: str = "",
diff --git a/openbb_terminal/forecast/forecast_controller.py b/openbb_terminal/forecast/forecast_controller.py
index 448ebec0f06..cf183430c60 100644
--- a/openbb_terminal/forecast/forecast_controller.py
+++ b/openbb_terminal/forecast/forecast_controller.py
@@ -3032,7 +3032,7 @@ class ForecastController(BaseController):
"--layer_widths",
dest="layer_widths",
type=check_positive,
- default=3,
+ default=512,
help="The number of neurons in each layer",
)
parser.add_argument(
@@ -3056,7 +3056,7 @@ class ForecastController(BaseController):
"--max_pool_1d",
action="store_true",
dest="maxpool1d",
- default=False,
+ default=True,
help="Whether to use max_pool_1d or AvgPool1d",
)
if other_args and "-" not in other_args[0][0]:
@@ -3068,7 +3068,7 @@ class ForecastController(BaseController):
target_dataset=True,
n_days=True,
force_reset=True,
- model_save_name="tft_model",
+ model_save_name="nhits_model",
train_split=True,
dropout=0.1,
input_chunk_length=True,
diff --git a/openbb_terminal/forecast/nhits_model.py b/openbb_terminal/forecast/nhits_model.py
index 84786465c55..8b3fa948832 100644
--- a/openbb_terminal/forecast/nhits_model.py
+++ b/openbb_terminal/forecast/nhits_model.py
@@ -40,7 +40,7 @@ def get_nhits_data(
batch_size: int = 32,
n_epochs: int = 100,
learning_rate: float = 1e-3,
- model_save_name: str = "brnn_model",
+ model_save_name: str = "nhits_model",
force_reset: bool = True,
save_checkpoints: bool = True,
) -> Tuple[
diff --git a/openbb_terminal/forecast/nhits_view.py b/openbb_terminal/forecast/nhits_view.py
index d0c82250e6f..caf2f07bd98 100644
--- a/openbb_terminal/forecast/nhits_view.py
+++ b/openbb_terminal/forecast/nhits_view.py
@@ -39,7 +39,7 @@ def display_nhits_forecast(
batch_size: int = 32,
n_epochs: int = 100,
learning_rate: float = 1e-3,
- model_save_name: str = "rnn_model",
+ model_save_name: str = "nhits_model",
force_reset: bool = True,
save_checkpoints: bool = True,
export: str = "",
diff --git a/openbb_terminal/forecast/tcn_model.py b/openbb_terminal/forecast/tcn_model.py
index ef1534aea49..9c0de7d2dbe 100644
--- a/openbb_terminal/forecast/tcn_model.py
+++ b/openbb_terminal/forecast/tcn_model.py
@@ -27,12 +27,12 @@ def get_tcn_data(
input_chunk_length: int = 14,
output_chunk_length: int = 5,
dropout: float = 0.1,
- num_filters: int = 6,
+ num_filters: int = 3,
weight_norm: bool = True,
dilation_base: int = 2,
- n_epochs: int = 100,
+ n_epochs: int = 300,
learning_rate: float = 1e-3,
- batch_size: int = 800,
+ batch_size: int = 32,
model_save_name: str = "tcn_model",
force_reset: bool = True,
save_checkpoints: bool = True,
@@ -95,7 +95,6 @@ def get_tcn_data(
Mean average precision error,
Best TCN Model.
"""
-
# TODO Check if torch GPU AVAILABLE
use_scalers = True
diff --git a/openbb_terminal/forecast/tcn_view.py b/openbb_terminal/forecast/tcn_view.py
index 8b93028df33..793f35246ce 100644
--- a/openbb_terminal/forecast/tcn_view.py
+++ b/openbb_terminal/forecast/tcn_view.py
@@ -28,12 +28,12 @@ def display_tcn_forecast(
input_chunk_length: int = 14,
output_chunk_length: int = 5,
dropout: float = 0.1,
- num_filters: int = 6,
+ num_filters: int = 3,
weight_norm: bool = True,
dilation_base: int = 2,
- n_epochs: int = 100,
+ n_epochs: int = 300,
learning_rate: float = 1e-3,
- batch_size: int = 800,
+ batch_size: int = 32,
model_save_name: str = "tcn_model",
force_reset: bool = True,
save_checkpoints: bool = True,
diff --git a/openbb_terminal/forecast/trans_model.py b/openbb_terminal/forecast/trans_model.py
index ebdd621ae3a..dc94111571e 100644
--- a/openbb_terminal/forecast/trans_model.py
+++ b/openbb_terminal/forecast/trans_model.py
@@ -34,7 +34,7 @@ def get_trans_data(
activation: str = "relu",
dropout: float = 0.0,
batch_size: int = 32,
- n_epochs: int = 100,
+ n_epochs: int = 300,
learning_rate: float = 1e-3,
model_save_name: str = "trans_model",
force_reset: bool = True,
@@ -103,7 +103,6 @@ def get_trans_data(
Mean average precision error,
Best transformer Model.
"""
-
# TODO Check if torch GPU AVAILABLE
use_scalers = True
diff --git a/openbb_terminal/forecast/trans_view.py b/openbb_terminal/forecast/trans_view.py
index f59c615f8c7..5fde027f746 100644
--- a/openbb_terminal/forecast/trans_view.py
+++ b/openbb_terminal/forecast/trans_view.py
@@ -33,9 +33,9 @@ def display_trans_forecast(
num_decoder_layers: int = 3,
dim_feedforward: int = 512,
activation: str = "relu",
- dropout: float = 0.1,
- batch_size: int = 16,
- n_epochs: int = 100,
+ dropout: float = 0.0,
+ batch_size: int = 32,
+ n_epochs: int = 300,
learning_rate: float = 1e-3,
model_save_name: str = "trans_model",
force_reset: bool = True,
diff --git a/website/content/terminal/reference/forecast/nhits.md b/website/content/terminal/reference/forecast/nhits.md
index adc4c134654..53b062ddbe4 100644
--- a/website/content/terminal/reference/forecast/nhits.md
+++ b/website/content/terminal/reference/forecast/nhits.md
@@ -22,9 +22,9 @@ nhits [--num-stacks NUM_STACKS] [--num-blocks NUM_BLOCKS] [--num-layers NUM_LAYE
| num_stacks | The number of stacks that make up the model | 3 | True | None |
| num_blocks | The number of blocks making up every stack | 1 | True | None |
| num_layers | The number of fully connected layers | 2 | True | None |
-| layer_widths | The number of neurons in each layer | 3 | True | None |
+| layer_widths | The number of neurons in each layer | 512 | True | None |
| activation | The desired activation | ReLU | True | ReLU, RReLU, PReLU, Softplus, Tanh, SELU, LeakyReLU, Sigmoid |
-| maxpool1d | Whether to use max_pool_1d or AvgPool1d | False | True | None |
+| maxpool1d | Whether to use max_pool_1d or AvgPool1d | True | True | None |
| past_covariates | Past covariates(columns/features) in same dataset. Comma separated. | None | True | None |
| all_past_covariates | Adds all rows as past covariates except for date and the target column. | False | True | None |
| naive | Show the naive baseline for a model. | False | True | None |
@@ -36,7 +36,7 @@ nhits [--num-stacks NUM_STACKS] [--num-blocks NUM_BLOCKS] [--num-layers NUM_LAYE
| output_chunk_length | The length of the forecast of the model. | 5 | True | None |
| force_reset | If set to True, any previously-existing model with the same name will be reset (all checkpoints will be discarded). | True | True | None |
| save_checkpoints | Whether to automatically save the untrained model and checkpoints. | True | True | None |
-| model_save_name | Name of the model to save. | tft_model | True | None |
+| model_save_name | Name of the model to save. | nhits_model | True | None |
| n_epochs | Number of epochs over which to train the model. | 300 | True | None |
| dropout | Fraction of neurons affected by Dropout, from 0 to 1. | 0.1 | True | None |
| batch_size | Number of time series (input and output) used in each training pass | 32 | True | None |