summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDidierRLopes <dro.lopes@campus.fct.unl.pt>2021-09-23 22:04:38 +0100
committerGitHub <noreply@github.com>2021-09-23 22:04:38 +0100
commit942219bd6e0dbf7a49b5d6f288f14fce86d343d1 (patch)
treeba1429071b8c88a1c0c667e8aa3c6e3e87cf93bd
parenta73375a3862a19902db8f1455828d88229d41ce3 (diff)
Fix tsne command #744 (#759)
* Fix #755 ba/popularsi bug * Fix tnse command - #744 * Inner join and re-add normalization
-rw-r--r--gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py b/gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py
index e98f5032e3f..6e4e7b318e4 100644
--- a/gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py
+++ b/gamestonk_terminal/stocks/comparison_analysis/yahoo_finance_model.py
@@ -99,22 +99,23 @@ def get_sp500_comps_tsne(
# Adding the type makes pylint stop yelling
close_vals: pd.DataFrame = get_1y_sp500()
if ticker not in close_vals.columns:
- close_vals[ticker] = yf.download(
- ticker, start=close_vals.index[0], progress=False
- )["Adj Close"]
- rets = (
- close_vals.fillna(method="ffill").fillna(method="bfill").pct_change().dropna().T
- )
- normalized_movements = normalize(rets)
- companies = rets.index
+ df_ticker = yf.download(ticker, start=close_vals.index[0], progress=False)[
+ "Adj Close"
+ ].to_frame()
+ df_ticker.columns = [ticker]
+ close_vals = close_vals.join(df_ticker, how="inner")
+
+ close_vals = close_vals.dropna(how="all").fillna(method="bfill")
+ rets = close_vals.pct_change()[1:].T
+
model = TSNE(learning_rate=lr)
- tsne_features = model.fit_transform(normalized_movements)
+ tsne_features = model.fit_transform(normalize(rets))
xs = tsne_features[:, 0]
ys = tsne_features[:, 1]
if not no_plot:
fig, ax = plt.subplots(figsize=plot_autoscale(), dpi=PLOT_DPI)
ax.scatter(xs, ys, alpha=0.5)
- for x, y, company in zip(xs, ys, companies):
+ for x, y, company in zip(xs, ys, rets.index):
if company != ticker:
ax.annotate(company, (x, y), fontsize=9, alpha=0.75)
else: