diff options
author | Theodore Aptekarev <aptekarev@gmail.com> | 2023-03-14 18:54:15 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-14 17:54:15 +0000 |
commit | a276dc4dfe61b9f8018e7675c72a5387c171c920 (patch) | |
tree | 818f95a5b76f014f3d1d5e5c650b26224f9a05c1 | |
parent | 7ec63f827b2fa0fa48ffeb36f0e4f318d2439308 (diff) |
Update reddit connection client (#4478)
* Change deprecated psaw dependency to the maintained pmaw
* Fix docstring linting errors
* Remove deprecated watchlist command because it's taking too long to run
* Remove deprecated spac command because it's taking too long to run
* Update SDK and trailmap
* Fix legacy seaborn chart popping up when calling redditsent
* Clean up spacc code
* Refactor reddit ba commands to reflect the dependency change
* Change regex that searches for tickers
* Fix popular command error when no stock tickers found
* Update tests
* Update requirements.txt files
* Sync dependencies with upstream
---------
Co-authored-by: James Maslek <jmaslek11@gmail.com>
-rw-r--r-- | openbb_terminal/common/behavioural_analysis/reddit_helpers.py | 6 | ||||
-rw-r--r-- | openbb_terminal/common/behavioural_analysis/reddit_model.py | 398 | ||||
-rw-r--r-- | openbb_terminal/common/behavioural_analysis/reddit_view.py | 147 | ||||
-rw-r--r-- | openbb_terminal/core/sdk/controllers/stocks_sdk_controller.py | 6 | ||||
-rw-r--r-- | openbb_terminal/core/sdk/models/stocks_sdk_model.py | 9 | ||||
-rw-r--r-- | openbb_terminal/core/sdk/trail_map.csv | 14 | ||||
-rw-r--r-- | openbb_terminal/miscellaneous/data_sources_default.json | 2 | ||||
-rw-r--r-- | openbb_terminal/stocks/behavioural_analysis/ba_controller.py | 131 | ||||
-rw-r--r-- | poetry.lock | 20489 | ||||
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | requirements-full.txt | 822 | ||||
-rw-r--r-- | requirements.txt | 587 | ||||
-rw-r--r-- | tests/openbb_terminal/stocks/behavioural_analysis/test_ba_controller.py | 23 | ||||
-rw-r--r-- | tests/openbb_terminal/stocks/behavioural_analysis/txt/test_ba_controller/test_print_help.txt | 2 |
14 files changed, 11071 insertions, 11567 deletions
diff --git a/openbb_terminal/common/behavioural_analysis/reddit_helpers.py b/openbb_terminal/common/behavioural_analysis/reddit_helpers.py index 40b9bcaebd1..7c087b91ca8 100644 --- a/openbb_terminal/common/behavioural_analysis/reddit_helpers.py +++ b/openbb_terminal/common/behavioural_analysis/reddit_helpers.py @@ -1,4 +1,4 @@ -"""Reddit Helpers""" +"""Reddit Helpers.""" __docformat__ = "numpy" import re @@ -8,7 +8,7 @@ import praw def find_tickers(submission: praw.models.reddit.submission.Submission) -> List[str]: - """Extracts potential tickers from reddit submission. + """Extract potential tickers from reddit submission. Parameters ---------- @@ -27,7 +27,7 @@ def find_tickers(submission: praw.models.reddit.submission.Submission) -> List[s l_tickers_found = [] for s_text in ls_text: - for s_ticker in set(re.findall(r"([A-Z]{3,5} )", s_text)): + for s_ticker in set(re.findall(r"([A-Z]{3,5})", s_text)): l_tickers_found.append(s_ticker.strip()) return l_tickers_found diff --git a/openbb_terminal/common/behavioural_analysis/reddit_model.py b/openbb_terminal/common/behavioural_analysis/reddit_model.py index df1e6defb23..ef983301047 100644 --- a/openbb_terminal/common/behavioural_analysis/reddit_model.py +++ b/openbb_terminal/common/behavioural_analysis/reddit_model.py @@ -1,4 +1,4 @@ -"""Reddit Model""" +"""Reddit Model.""" __docformat__ = "numpy" # pylint:disable=C0302 @@ -10,8 +10,8 @@ from typing import List, Tuple import finviz import pandas as pd import praw +from pmaw import PushshiftAPI from prawcore.exceptions import ResponseException -from psaw import PushshiftAPI from requests import HTTPError from sklearn.feature_extraction import _stop_words from tqdm import tqdm @@ -46,114 +46,6 @@ l_sub_reddits = [ "API_REDDIT_PASSWORD", ] ) -def get_watchlists( - limit: int = 5, -) -> Tuple[List[praw.models.reddit.submission.Submission], dict, int]: - """Get reddit users watchlists [Source: reddit]. - - Parameters - ---------- - limit : int - Number of posts to look through - - Returns - ------- - Tuple[List[praw.models.reddit.submission.Submission], dict, int] - List of reddit submissions, - Dictionary of tickers and their count, - Count of how many posts were analyzed. - """ - - current_user = get_current_user() - - d_watchlist_tickers: dict = {} - l_watchlist_author = [] - subs = [] - - praw_api = praw.Reddit( - client_id=current_user.credentials.API_REDDIT_CLIENT_ID, - client_secret=current_user.credentials.API_REDDIT_CLIENT_SECRET, - username=current_user.credentials.API_REDDIT_USERNAME, - user_agent=current_user.credentials.API_REDDIT_USER_AGENT, - password=current_user.credentials.API_REDDIT_PASSWORD, - check_for_updates=False, - comment_kind="t1", - message_kind="t4", - redditor_kind="t2", - submission_kind="t3", - subreddit_kind="t5", - trophy_kind="t6", - oauth_url="https://oauth.reddit.com", - reddit_url="https://www.reddit.com", - short_url="https://redd.it", - ratelimit_seconds=5, - timeout=16, - ) - try: - praw_api.user.me() - except (Exception, ResponseException): - console.print("[red]Wrong Reddit API keys[/red]\n") - return [], {}, 0 - - psaw_api = PushshiftAPI() - submissions = psaw_api.search_submissions( - subreddit=l_sub_reddits, - q="WATCHLIST|Watchlist|watchlist", - filter=["id"], - ) - n_flair_posts_found = 0 - - try: - for sub in submissions: - submission = praw_api.submission(id=sub.id) - if ( - not submission.removed_by_category - and submission.selftext - and submission.link_flair_text not in ["Yolo", "Meme"] - and submission.author.name not in l_watchlist_author - ): - l_tickers_found = find_tickers(submission) - - if l_tickers_found: - # Add another author's name to the parsed watchlists - l_watchlist_author.append(submission.author.name) - - # Lookup stock tickers within a watchlist - for key in l_tickers_found: - if key in d_watchlist_tickers: - # Increment stock ticker found - d_watchlist_tickers[key] += 1 - else: - # Initialize stock ticker found - d_watchlist_tickers[key] = 1 - - # Increment count of valid posts found - n_flair_posts_found += 1 - subs.append(submission) - if n_flair_posts_found > limit - 1: - break - - except ResponseException as e: - logger.exception("Invalid response: %s", str(e)) - - if "received 401 HTTP response" in str(e): - console.print("[red]Invalid API Key[/red]\n") - else: - console.print(f"[red]Invalid response: {str(e)}[/red]\n") - - return subs, d_watchlist_tickers, n_flair_posts_found - - -@log_start_end(log=logger) -@check_api_key( - [ - "API_REDDIT_CLIENT_ID", - "API_REDDIT_CLIENT_SECRET", - "API_REDDIT_USERNAME", - "API_REDDIT_USER_AGENT", - "API_REDDIT_PASSWORD", - ] -) def get_popular_tickers( limit: int = 10, post_limit: int = 50, subreddits: str = "" ) -> pd.DataFrame: @@ -207,43 +99,57 @@ def get_popular_tickers( console.print("[red]Wrong Reddit API keys[/red]\n") return pd.DataFrame() - psaw_api = PushshiftAPI() + pmaw_api = PushshiftAPI() for s_sub_reddit in sub_reddit_list: console.print( - f"Searching for latest tickers for {post_limit} '{s_sub_reddit}' posts" + f"Searching for tickers in latest {post_limit} '{s_sub_reddit}' posts" ) warnings.filterwarnings( "ignore", message=".*Not all PushShift shards are active.*" ) - submissions = psaw_api.search_submissions( + submissions = pmaw_api.search_submissions( subreddit=s_sub_reddit, limit=post_limit, filter=["id"], ) n_tickers = 0 - for submission in submissions: + for submission in submissions.responses: try: # Get more information about post using PRAW api - submission = praw_api.submission(id=submission.id) + submission_ = praw_api.submission(id=submission["id"]) # Ensure that the post hasn't been removed by moderator in the meanwhile, # that there is a description and it's not just an image, that the flair is # meaningful, and that we aren't re-considering same author's content + def has_author(submission_) -> bool: + """Check if submission has author.""" + return ( + hasattr(submission_, "author") + and hasattr(submission_.author, "name") + and submission_.author.name not in l_watchlist_author + ) + + def has_content(submission_) -> bool: + """Check if submission has text or title.""" + return hasattr(submission_, "selftext") or hasattr( + submission_, "title" + ) + if ( - submission is not None - and not submission.removed_by_category - and (submission.selftext or submission.title) - and submission.author.name not in l_watchlist_author + submission_ is not None + and not submission_.removed_by_category + and has_content(submission_) + and has_author(submission_) ): - l_tickers_found = find_tickers(submission) + l_tickers_found = find_tickers(submission_) if l_tickers_found: n_tickers += len(l_tickers_found) # Add another author's name to the parsed watchlists - l_watchlist_author.append(submission.author.name) + l_watchlist_author.append(submission_.author.name) # Lookup stock tickers within a watchlist for key in l_tickers_found: @@ -315,6 +221,8 @@ def get_popular_tickers( "URL", ], ) + else: + popular_tickers_df = pd.DataFrame() return popular_tickers_df @@ -346,7 +254,6 @@ def get_spac_community( Dataframe of reddit submission, Dictionary of tickers and number of mentions. """ - current_user = get_current_user() praw_api = praw.Reddit( @@ -470,154 +377,6 @@ def get_spac_community( "API_REDDIT_PASSWORD", ] ) -def get_spac( - limit: int = 5, -) -> Tuple[pd.DataFrame, dict, int]: - """Get posts containing SPAC from top subreddits [Source: reddit]. - - Parameters - ---------- - limit : int, optional - Number of posts to get for each subreddit, by default 5 - - Returns - ------- - Tuple[pd.DataFrame, dict, int] - Dataframe of reddit submission, - Dictionary of tickers and counts, - Number of posts found. - """ - - current_user = get_current_user() - - praw_api = praw.Reddit( - client_id=current_user.credentials.API_REDDIT_CLIENT_ID, - client_secret=current_user.credentials.API_REDDIT_CLIENT_SECRET, - username=current_user.credentials.API_REDDIT_USERNAME, - user_agent=current_user.credentials.API_REDDIT_USER_AGENT, - password=current_user.credentials.API_REDDIT_PASSWORD, - check_for_updates=False, - comment_kind="t1", - message_kind="t4", - redditor_kind="t2", - submission_kind="t3", - subreddit_kind="t5", - trophy_kind="t6", - oauth_url="https://oauth.reddit.com", - reddit_url="https://www.reddit.com", - short_url="https://redd.it", - ratelimit_seconds=5, - timeout=16, - ) - try: - praw_api.user.me() - except (Exception, ResponseException): - console.print("[red]Wrong Reddit API keys[/red]\n") - return pd.DataFrame(), {}, 0 - - d_watchlist_tickers: dict = {} - l_watchlist_author = [] - columns = [ - "Date", - "Subreddit", - "Flair", - "Title", - "Score", - "# Comments", - "Upvote %", - "Awards", - "Link", - ] - subs = pd.DataFrame(columns=columns) - psaw_api = PushshiftAPI() - submissions = psaw_api.search_submissions( - subreddit=l_sub_reddits, - q="SPAC|Spac|spac|Spacs|spacs", - filter=["id"], - ) - n_flair_posts_found = 0 - - try: - for submission in submissions: - # Get more information about post using PRAW api - submission = praw_api.submission(id=submission.id) - - # Ensure that the post hasn't been removed by moderator in the meanwhile, - # that there is a description and it's not just an image, that the flair is - # meaningful, and that we aren't re-considering same author's watchlist - if ( - not submission.removed_by_category - and submission.selftext - and submission.link_flair_text not in ["Yolo", "Meme"] - and submission.author.name not in l_watchlist_author - ): - l_tickers_found = find_tickers(submission) - - s_datetime = datetime.utcfromtimestamp(submission.created_utc).strftime( - "%Y-%m-%d %H:%M:%S" - ) - s_link = f"https://old.reddit.com{submission.permalink}" - s_all_awards = "".join( - f"{award['count']} {award['name']}\n" - for award in submission.all_awardings - ) - - s_all_awards = s_all_awards[:-2] - - data = [ - s_datetime, - submission.subreddit, - submission.link_flair_text, - submission.title, - submission.score, - submission.num_comments, - f"{round(100 * submission.upvote_ratio)}%", - s_all_awards, - s_link, - ] - subs.loc[len(subs)] = data - - if l_tickers_found: - # Add another author's name to the parsed watchlists - l_watchlist_author.append(submission.author.name) - - # Lookup stock tickers within a watchlist - for key in l_tickers_found: - if key in d_watchlist_tickers: - # Increment stock ticker found - d_watchlist_tickers[key] += 1 - else: - # Initialize stock ticker found - d_watchlist_tickers[key] = 1 - - # Increment count of valid posts found - n_flair_posts_found += 1 - - # Check if number of wanted posts found has been reached - if n_flair_posts_found > limit - 1: - break - - except ResponseException as e: - logger.exception("Invalid response: %s", str(e)) - - if "received 401 HTTP response" in str(e): - console.print("[red]Invalid API Key[/red]\n") - else: - console.print(f"[red]Invalid response: {str(e)}[/red]\n") - - return subs, d_watchlist_tickers, n_flair_posts_found - - -@log_start_end(log=logger) -@check_api_key( - [ - "API_REDDIT_CLIENT_ID", - "API_REDDIT_CLIENT_SECRET", - "API_REDDIT_USERNAME", - "API_REDDIT_USER_AGENT", - "API_REDDIT_PASSWORD", - ] -) def get_wsb_community(limit: int = 10, new: bool = False) -> pd.DataFrame: """Get wsb posts [Source: reddit]. @@ -633,7 +392,6 @@ def get_wsb_community(limit: int = 10, new: bool = False) -> pd.DataFrame: pd.DataFrame Dataframe of reddit submissions """ - current_user = get_current_user() # See https://github.com/praw-dev/praw/issues/1016 regarding praw arguments @@ -682,30 +440,30 @@ def get_wsb_community(limit: int = 10, new: bool = False) -> pd.DataFrame: try: for submission in submissions: - submission = praw_api.submission(id=submission.id) + submission_ = praw_api.submission(submission.id) # Ensure that the post hasn't been removed by moderator in the meanwhile, # that there is a description and it's not just an image, that the flair is # meaningful, and that we aren't re-considering same author's watchlist - if not submission.removed_by_category: - s_datetime = datetime.utcfromtimestamp(submission.created_utc).strftime( - "%Y-%m-%d %H:%M:%S" - ) - s_link = f"https://old.reddit.com{submission.permalink}" + if not submission_.removed_by_category: + s_datetime = datetime.utcfromtimestamp( + submission_.created_utc + ).strftime("%Y-%m-%d %H:%M:%S") + s_link = f"https://old.reddit.com{submission_.permalink}" s_all_awards = "".join( f"{award['count']} {award['name']}\n" - for award in submission.all_awardings + for award in submission_.all_awardings ) s_all_awards = s_all_awards[:-2] data = [ s_datetime, - submission.subreddit, - submission.link_flair_text, - submission.title, - submission.score, - submission.num_comments, - f"{round(100 * submission.upvote_ratio)}%", + submission_.subreddit, + submission_.link_flair_text, + submission_.title, + submission_.score, + submission_.num_comments, + f"{round(100 * submission_.upvote_ratio)}%", s_all_awards, s_link, ] @@ -731,14 +489,12 @@ def get_wsb_community(limit: int = 10, new: bool = False) -> pd.DataFrame: ] ) def get_due_dilligence( - symbol: str, limit: int = 5, n_days: int = 3, show_all_flairs: bool = False + limit: int = 5, n_days: int = 3, show_all_flairs: bool = False ) -> pd.DataFrame: - """Gets due diligence posts from list of subreddits [Source: reddit]. + """Get due diligence posts from list of subreddits [Source: reddit]. Parameters ---------- - symbol: str - Stock ticker limit: int Number of posts to get n_days: int @@ -751,7 +507,6 @@ def get_due_dilligence( pd.DataFrame Dataframe of submissions """ - current_user = get_current_user() praw_api = praw.Reddit( @@ -779,7 +534,7 @@ def get_due_dilligence( console.print("[red]Wrong Reddit API keys[/red]\n") return pd.DataFrame() - psaw_api = PushshiftAPI() + pmaw_api = PushshiftAPI() n_ts_after = int((datetime.today() - timedelta(days=n_days)).timestamp()) l_flair_text = [ @@ -806,8 +561,10 @@ def get_due_dilligence( "Forexstrategy", ] - submissions = psaw_api.search_submissions( - after=int(n_ts_after), subreddit=l_sub_reddits_dd, q=symbol, filter=["id"] + submissions = pmaw_api.search_submissions( + after=int(n_ts_after), + subreddit=l_sub_reddits_dd, + filter=["id"], ) n_flair_posts_found = 0 columns = [ @@ -824,36 +581,36 @@ def get_due_dilligence( subs = pd.DataFrame(columns=columns) try: - for submission in submissions: + for submission in submissions.responses: # Get more information about post using PRAW api - submission = praw_api.submission(id=submission.id) + submission_ = praw_api.submission(id=submission["id"]) # Ensure that the post hasn't been removed in the meanwhile # Either just filter out Yolo, and Meme flairs, or focus on DD, based on b_DD flag if ( - not submission.removed_by_category - and submission.link_flair_text in l_flair_text, - submission.link_flair_text not in ["Yolo", "Meme"], + not submission_.removed_by_category + and submission_.link_flair_text in l_flair_text, + submission_.link_flair_text not in ["Yolo", "Meme"], )[show_all_flairs]: - s_datetime = datetime.utcfromtimestamp(submission.created_utc).strftime( - "%Y-%m-%d %H:%M:%S" - ) - s_link = f"https://old.reddit.com{submission.permalink}" + s_datetime = datetime.utcfromtimestamp( + submission_.created_utc + ).strftime("%Y-%m-%d %H:%M:%S") + s_link = f"https://old.reddit.com{submission_.permalink}" s_all_awards = "".join( f"{award['count']} {award['name']}\n" - for award in submission.all_awardings + for award in submission_.all_awardings ) s_all_awards = s_all_awards[:-2] data = [ s_datetime, - submission.subreddit, - submission.link_flair_text, - submission.title, - submission.score, - submission.num_comments, - f"{round(100 * submission.upvote_ratio)}%", + submission_.subreddit, + submission_.link_flair_text, + submission_.title, + submission_.score, + submission_.num_comments, + f"{round(100 * submission_.upvote_ratio)}%", s_all_awards, s_link, ] @@ -892,7 +649,7 @@ def get_posts_about( full_search: bool = True, subreddits: str = "all", ) -> Tuple[pd.DataFrame, list, float]: - """Finds posts related to a specific search term in Reddit. + """Find posts related to a specific search term in Reddit. Parameters ---------- @@ -916,7 +673,6 @@ def get_posts_about( List of polarity scores, Average polarity score. """ - current_user = get_current_user() praw_api = praw.Reddit( @@ -976,8 +732,8 @@ def get_posts_about( for p in tqdm(posts): texts = [p.title, p.selftext] if full_search: - tlcs = get_comments(p) - texts.extend(tlcs) + top_level_comments = get_comments(p) + texts.extend(top_level_comments) preprocessed_text = clean_reddit_text(texts) sentiment = get_sentiment(preprocessed_text) polarity_scores.append(sentiment) @@ -1018,14 +774,14 @@ def get_comments( """ def get_more_comments(comments): - sub_tlcs = [] + sub_top_level_comments = [] for comment in comments: if isinstance(comment, praw.models.reddit.comment.Comment): - sub_tlcs.append(comment.body) + sub_top_level_comments.append(comment.body) else: sub_comments = get_more_comments(comment.comments()) - sub_tlcs.extend(sub_comments) - return sub_tlcs + sub_top_level_comments.extend(sub_comments) + return sub_top_level_comments if post.comments: return get_more_comments(post.comments) @@ -1034,7 +790,7 @@ def get_comments( @log_start_end(log=logger) def clean_reddit_text(docs: List[str]) -> List[str]: - """Tokenizes and cleans a list of documents for sentiment analysis. + """Tokenize and clean a list of documents for sentiment analysis. Parameters ---------- @@ -1046,7 +802,7 @@ def clean_reddit_text(docs: List[str]) -> List[str]: list[str] List of cleaned and prepared docs """ - stopwords = _stop_words.ENGLISH_STOP_WORDS + stop_words = _stop_words.ENGLISH_STOP_WORDS clean_docs = [] docs = [doc.lower().strip() for doc in docs] @@ -1055,9 +811,9 @@ def clean_reddit_text(docs: List[str]) -> List[str]: tokens = doc.split() for tok in tokens: clean_tok = [c for c in tok if c.isalpha()] - tok = "".join(clean_tok) - if tok not in stopwords: - clean_doc.append(tok) + tok_ = "".join(clean_tok) + if tok_ not in stop_words: + clean_doc.append(tok_) clean_docs.append(" ".join(clean_doc)) return clean_docs diff --git a/openbb_terminal/common/behavioural_analysis/reddit_view.py b/openbb_terminal/common/behavioural_analysis/reddit_view.py index 02b183e7d76..0019e36119e 100644 --- a/openbb_terminal/common/behavioural_analysis/reddit_view.py +++ b/openbb_terminal/common/behavioural_analysis/reddit_view.py @@ -1,24 +1,19 @@ -"""Reddit View""" +"""Reddit View.""" __docformat__ = "numpy" import logging import os -import warnings from datetime import datetime from typing import Dict, Optional, Union import finviz -import matplotlib.pyplot as plt import pandas as pd import praw -import seaborn as sns from openbb_terminal import OpenBBFigure from openbb_terminal.common.behavioural_analysis import reddit_model -from openbb_terminal.config_terminal import theme -from openbb_terminal.core.session.current_user import get_current_user from openbb_terminal.decorators import check_api_key, log_start_end -from openbb_terminal.helper_funcs import export_data, plot_autoscale, print_rich_table +from openbb_terminal.helper_funcs import export_data, print_rich_table from openbb_terminal.rich_config import console # pylint: disable=R0913,C0302 @@ -41,7 +36,7 @@ logger = logging.getLogger(__name__) def print_and_record_reddit_post( submissions_dict: Dict, submission: praw.models.reddit.submission.Submission ): - """Prints reddit submission. + """Print reddit submission. Parameters ---------- @@ -101,14 +96,13 @@ def print_and_record_reddit_post( ] ) def print_reddit_post(sub: tuple): - """Prints reddit submission. + """Print reddit submission. Parameters ---------- sub : tuple Row from submissions dataframe """ - sub_list = list(sub[1]) date = sub_list[0] title = sub_list[3] @@ -142,54 +136,6 @@ def print_reddit_post(sub: tuple): "API_REDDIT_PASSWORD", ] ) -def display_watchlist(limit: int = 5): - """Prints other users watchlist. [Source: Reddit]. - - Parameters - ---------- - limit: int - Maximum number of submissions to look at - """ - subs, d_watchlist_tickers, n_flair_posts_found = reddit_model.get_watchlists(limit) - if subs: - for sub in subs: - print_and_record_reddit_post({}, sub) - console.print("") - - if n_flair_posts_found > 0: - lt_watchlist_sorted = sorted( - d_watchlist_tickers.items(), key=lambda item: item[1], reverse=True - ) - s_watchlist_tickers = "" - n_tickers = 0 - for t_ticker in lt_watchlist_sorted: - try: - # If try doesn't trigger exception, it means that this stock exists on finviz - # thus we can print it. - finviz.get_stock(t_ticker[0]) - if int(t_ticker[1]) > 1: - s_watchlist_tickers += f"{t_ticker[1]} {t_ticker[0]}, " - n_tickers += 1 - except Exception: - # console.print(e, "\n") |