summaryrefslogtreecommitdiffstats
path: root/openbb_terminal/forecast/helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_terminal/forecast/helpers.py')
-rw-r--r--openbb_terminal/forecast/helpers.py19
1 files changed, 9 insertions, 10 deletions
diff --git a/openbb_terminal/forecast/helpers.py b/openbb_terminal/forecast/helpers.py
index 7c5d4193963..0e6dbcb8d6c 100644
--- a/openbb_terminal/forecast/helpers.py
+++ b/openbb_terminal/forecast/helpers.py
@@ -782,7 +782,7 @@ def plotly_shap_scatter_plot(
fig.set_xaxis_title("SHAP value (impact on model output)")
for pos, i in enumerate(feature_order):
- pos += 2
+ new_pos = pos + 2
shaps = shap_values[:, i]
values = None if features is None else features[:, i]
inds = np.arange(len(shaps))
@@ -838,7 +838,7 @@ def plotly_shap_scatter_plot(
nan_mask = np.isnan(values)
fig.add_scattergl(
x=shaps[nan_mask],
- y=pos + ys[nan_mask],
+ y=new_pos + ys[nan_mask],
mode="markers",
marker=dict(
color="#777777",
@@ -860,12 +860,12 @@ def plotly_shap_scatter_plot(
fig.add_scattergl(
x=shaps[np.invert(nan_mask)],
- y=pos + ys[np.invert(nan_mask)],
+ y=new_pos + ys[np.invert(nan_mask)],
mode="markers",
marker=dict(
color=cvals,
colorscale="Bluered",
- showscale=bool(pos == 2),
+ showscale=bool(new_pos == 2),
colorbar=dict(
x=-0.05,
thickness=10,
@@ -1119,13 +1119,12 @@ def get_prediction(
past_covariates=past_covariate_whole,
n=n_predict,
)
+ elif probabilistic:
+ prediction = best_model.predict(
+ series=ticker_series, n=n_predict, num_samples=500
+ )
else:
- if probabilistic:
- prediction = best_model.predict(
- series=ticker_series, n=n_predict, num_samples=500
- )
- else:
- prediction = best_model.predict(series=ticker_series, n=n_predict)
+ prediction = best_model.predict(series=ticker_series, n=n_predict)
# calculate precision based on metric (rmse, mse, mape)
if metric == "rmse":