diff options
Diffstat (limited to 'openbb_terminal/forecast/helpers.py')
-rw-r--r-- | openbb_terminal/forecast/helpers.py | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/openbb_terminal/forecast/helpers.py b/openbb_terminal/forecast/helpers.py index 7c5d4193963..0e6dbcb8d6c 100644 --- a/openbb_terminal/forecast/helpers.py +++ b/openbb_terminal/forecast/helpers.py @@ -782,7 +782,7 @@ def plotly_shap_scatter_plot( fig.set_xaxis_title("SHAP value (impact on model output)") for pos, i in enumerate(feature_order): - pos += 2 + new_pos = pos + 2 shaps = shap_values[:, i] values = None if features is None else features[:, i] inds = np.arange(len(shaps)) @@ -838,7 +838,7 @@ def plotly_shap_scatter_plot( nan_mask = np.isnan(values) fig.add_scattergl( x=shaps[nan_mask], - y=pos + ys[nan_mask], + y=new_pos + ys[nan_mask], mode="markers", marker=dict( color="#777777", @@ -860,12 +860,12 @@ def plotly_shap_scatter_plot( fig.add_scattergl( x=shaps[np.invert(nan_mask)], - y=pos + ys[np.invert(nan_mask)], + y=new_pos + ys[np.invert(nan_mask)], mode="markers", marker=dict( color=cvals, colorscale="Bluered", - showscale=bool(pos == 2), + showscale=bool(new_pos == 2), colorbar=dict( x=-0.05, thickness=10, @@ -1119,13 +1119,12 @@ def get_prediction( past_covariates=past_covariate_whole, n=n_predict, ) + elif probabilistic: + prediction = best_model.predict( + series=ticker_series, n=n_predict, num_samples=500 + ) else: - if probabilistic: - prediction = best_model.predict( - series=ticker_series, n=n_predict, num_samples=500 - ) - else: - prediction = best_model.predict(series=ticker_series, n=n_predict) + prediction = best_model.predict(series=ticker_series, n=n_predict) # calculate precision based on metric (rmse, mse, mape) if metric == "rmse": |