diff options
author | martinb-bb <105685594+martinb-bb@users.noreply.github.com> | 2022-12-16 12:00:29 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-16 12:00:29 -0500 |
commit | 4c3c87d11b89893213927e27c3b7edc7214d2a0e (patch) | |
tree | c1ef8bc17ae0381d558b581b3730f3dfe56920b6 | |
parent | bc48e8051a7b0515a96c406a5f121e8b9dc3ee9e (diff) |
hot fix: True Range Feature (#3793)
* fix atr feature eng
* fix
Co-authored-by: James Maslek <jmaslek11@gmail.com>
-rw-r--r-- | openbb_terminal/forecast/forecast_controller.py | 6 | ||||
-rw-r--r-- | openbb_terminal/forecast/forecast_model.py | 6 |
2 files changed, 5 insertions, 7 deletions
diff --git a/openbb_terminal/forecast/forecast_controller.py b/openbb_terminal/forecast/forecast_controller.py index ff252507bfe..879ac2ca8e0 100644 --- a/openbb_terminal/forecast/forecast_controller.py +++ b/openbb_terminal/forecast/forecast_controller.py @@ -1556,14 +1556,14 @@ class ForecastController(BaseController): if not helpers.check_parser_input(ns_parser, self.datasets): return - check = False - self.datasets[ns_parser.target_dataset], check = forecast_model.add_atr( + self.datasets[ns_parser.target_dataset] = forecast_model.add_atr( self.datasets[ns_parser.target_dataset], close_column=ns_parser.close_col, high_column=ns_parser.high_col, low_column=ns_parser.low_col, ) - if check: + # check if true range was added + if "true_range" in self.datasets[ns_parser.target_dataset].columns: console.print( f"Successfully added 'Average True Range' to '{ns_parser.target_dataset}' dataset" ) diff --git a/openbb_terminal/forecast/forecast_model.py b/openbb_terminal/forecast/forecast_model.py index b87e4d8c31e..06f2f687a41 100644 --- a/openbb_terminal/forecast/forecast_model.py +++ b/openbb_terminal/forecast/forecast_model.py @@ -362,7 +362,7 @@ def add_atr( Calculate the Average True Range of a variable based on a a specific stock ticker. """ - if "high" in dataset and "low" in dataset and "close" in dataset: + if close_column in dataset and high_column in dataset and low_column in dataset: dataset["ATR1"] = abs(dataset[high_column] - dataset[low_column]) dataset["ATR2"] = abs(dataset[high_column] - dataset[close_column].shift()) dataset["ATR3"] = abs(dataset[low_column] - dataset[close_column].shift()) @@ -371,9 +371,7 @@ def add_atr( # drop ATR1, ATR2, ATR3 dataset = dataset.drop(["ATR1", "ATR2", "ATR3"], axis=1) - return dataset, True - - return dataset, False + return dataset @log_start_end(log=logger) |