summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormartinb-bb <105685594+martinb-bb@users.noreply.github.com>2022-12-16 12:00:29 -0500
committerGitHub <noreply@github.com>2022-12-16 12:00:29 -0500
commit4c3c87d11b89893213927e27c3b7edc7214d2a0e (patch)
treec1ef8bc17ae0381d558b581b3730f3dfe56920b6
parentbc48e8051a7b0515a96c406a5f121e8b9dc3ee9e (diff)
hot fix: True Range Feature (#3793)
* fix atr feature eng * fix Co-authored-by: James Maslek <jmaslek11@gmail.com>
-rw-r--r--openbb_terminal/forecast/forecast_controller.py6
-rw-r--r--openbb_terminal/forecast/forecast_model.py6
2 files changed, 5 insertions, 7 deletions
diff --git a/openbb_terminal/forecast/forecast_controller.py b/openbb_terminal/forecast/forecast_controller.py
index ff252507bfe..879ac2ca8e0 100644
--- a/openbb_terminal/forecast/forecast_controller.py
+++ b/openbb_terminal/forecast/forecast_controller.py
@@ -1556,14 +1556,14 @@ class ForecastController(BaseController):
if not helpers.check_parser_input(ns_parser, self.datasets):
return
- check = False
- self.datasets[ns_parser.target_dataset], check = forecast_model.add_atr(
+ self.datasets[ns_parser.target_dataset] = forecast_model.add_atr(
self.datasets[ns_parser.target_dataset],
close_column=ns_parser.close_col,
high_column=ns_parser.high_col,
low_column=ns_parser.low_col,
)
- if check:
+ # check if true range was added
+ if "true_range" in self.datasets[ns_parser.target_dataset].columns:
console.print(
f"Successfully added 'Average True Range' to '{ns_parser.target_dataset}' dataset"
)
diff --git a/openbb_terminal/forecast/forecast_model.py b/openbb_terminal/forecast/forecast_model.py
index b87e4d8c31e..06f2f687a41 100644
--- a/openbb_terminal/forecast/forecast_model.py
+++ b/openbb_terminal/forecast/forecast_model.py
@@ -362,7 +362,7 @@ def add_atr(
Calculate the Average True Range of a variable based on a a specific stock ticker.
"""
- if "high" in dataset and "low" in dataset and "close" in dataset:
+ if close_column in dataset and high_column in dataset and low_column in dataset:
dataset["ATR1"] = abs(dataset[high_column] - dataset[low_column])
dataset["ATR2"] = abs(dataset[high_column] - dataset[close_column].shift())
dataset["ATR3"] = abs(dataset[low_column] - dataset[close_column].shift())
@@ -371,9 +371,7 @@ def add_atr(
# drop ATR1, ATR2, ATR3
dataset = dataset.drop(["ATR1", "ATR2", "ATR3"], axis=1)
- return dataset, True
-
- return dataset, False
+ return dataset
@log_start_end(log=logger)