diff options
author | DidierRLopes <dro.lopes@campus.fct.unl.pt> | 2021-09-23 22:04:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-23 22:04:38 +0100 |
commit | 942219bd6e0dbf7a49b5d6f288f14fce86d343d1 (patch) | |
tree | ba1429071b8c88a1c0c667e8aa3c6e3e87cf93bd | |
parent | a73375a3862a19902db8f1455828d88229d41ce3 (diff) |
Fix tsne command #744 (#759)
* Fix #755 ba/popularsi bug
* Fix tnse command - #744
* Inner join and re-add normalization
-rw-r--r-- | gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py b/gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py index e98f5032e3f..6e4e7b318e4 100644 --- a/gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py +++ b/gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py @@ -99,22 +99,23 @@ def get_sp500_comps_tsne( # Adding the type makes pylint stop yelling close_vals: pd.DataFrame = get_1y_sp500() if ticker not in close_vals.columns: - close_vals[ticker] = yf.download( - ticker, start=close_vals.index[0], progress=False - )["Adj Close"] - rets = ( - close_vals.fillna(method="ffill").fillna(method="bfill").pct_change().dropna().T - ) - normalized_movements = normalize(rets) - companies = rets.index + df_ticker = yf.download(ticker, start=close_vals.index[0], progress=False)[ + "Adj Close" + ].to_frame() + df_ticker.columns = [ticker] + close_vals = close_vals.join(df_ticker, how="inner") + + close_vals = close_vals.dropna(how="all").fillna(method="bfill") + rets = close_vals.pct_change()[1:].T + model = TSNE(learning_rate=lr) - tsne_features = model.fit_transform(normalized_movements) + tsne_features = model.fit_transform(normalize(rets)) xs = tsne_features[:, 0] ys = tsne_features[:, 1] if not no_plot: fig, ax = plt.subplots(figsize=plot_autoscale(), dpi=PLOT_DPI) ax.scatter(xs, ys, alpha=0.5) - for x, y, company in zip(xs, ys, companies): + for x, y, company in zip(xs, ys, rets.index): if company != ticker: ax.annotate(company, (x, y), fontsize=9, alpha=0.75) else: |