diff options
Diffstat (limited to 'cli/openbb_cli/controllers/utils.py')
-rw-r--r-- | cli/openbb_cli/controllers/utils.py | 1056 |
1 files changed, 1056 insertions, 0 deletions
diff --git a/cli/openbb_cli/controllers/utils.py b/cli/openbb_cli/controllers/utils.py new file mode 100644 index 00000000000..a85751a58fa --- /dev/null +++ b/cli/openbb_cli/controllers/utils.py @@ -0,0 +1,1056 @@ +"""Utils.""" + +import argparse +import os +import random +import re +import shutil +import sys +from contextlib import contextmanager +from datetime import ( + datetime, +) +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import requests +from openbb import obb +from openbb_charting.core.backend import create_backend, get_backend +from openbb_cli.config.constants import AVAILABLE_FLAIRS, ENV_FILE_SETTINGS +from openbb_cli.session import Session +from openbb_core.app.model.charts.charting_settings import ChartingSettings +from packaging import version +from pytz import all_timezones, timezone +from rich.table import Table + +if TYPE_CHECKING: + from openbb_charting.core.openbb_figure import OpenBBFigure + +# pylint: disable=R1702,R0912 + + +# pylint: disable=too-many-statements,no-member,too-many-branches,C0302 + + +def remove_file(path: Path) -> bool: + """Remove path. + + Parameters + ---------- + path : Path + The file path. + + Returns + ------- + bool + The status of the removal. + """ + # TODO: Check why module level import leads to circular import. + try: + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + return True + except Exception: + Session().console.print( + f"\n[bold red]Failed to remove {path}" + "\nPlease delete this manually![/bold red]" + ) + return False + + +def print_goodbye(): + """Print a goodbye message when quitting the terminal.""" + # LEGACY GOODBYE MESSAGES - You'll live in our hearts forever. + # "An informed ape, is a strong ape." + # "Remember that stonks only go up." + # "Diamond hands." + # "Apes together strong." + # "This is our way." + # "Keep the spacesuit ape, we haven't reached the moon yet." + # "I am not a cat. I'm an ape." + # "We like the terminal." + # "...when offered a flight to the moon, nobody asks about what seat." + + text = """ +[param]Thank you for using the OpenBB Platform CLI and being part of this journey.[/param] + +We hope you'll find the new CLI as valuable as this. To stay tuned, sign up for our newsletter: [cmds]https://openbb.co/newsletter.[/] + +In the meantime, check out our other products: + +[bold]OpenBB Terminal Pro[/]: [cmds]https://openbb.co/products/pro[/cmds] +[bold]OpenBB Platform:[/] [cmds]https://openbb.co/products/platform[/cmds] +[bold]OpenBB Bot[/]: [cmds]https://openbb.co/products/bot[/cmds] + """ + Session().console.print(text) + + +def hide_splashscreen(): + """Hide the splashscreen on Windows bundles. + + `pyi_splash` is a PyInstaller "fake-package" that's used to communicate + with the splashscreen on Windows. + Sending the `close` signal to the splash screen is required. + The splash screen remains open until this function is called or the Python + program is terminated. + """ + try: + import pyi_splash # type: ignore # pylint: disable=import-outside-toplevel + + pyi_splash.update_text("CLI Loaded!") + pyi_splash.close() + except Exception as e: + Session().console.print(f"Error: Unable to hide splashscreen: {e}") + + +def print_guest_block_msg(): + """Block guest users from using the cli.""" + if Session().is_local(): + Session().console.print( + "[info]You are currently logged as a guest.[/info]\n" + "[info]Login to use this feature.[/info]\n\n" + "[info]If you don't have an account, you can create one here: [/info]" + f"[cmds]{Session().settings.HUB_URL + '/register'}\n[/cmds]" + ) + + +def is_installer() -> bool: + """Check whether or not it is a packaged version (Windows or Mac installer.""" + return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS") + + +def bootup(): + """Bootup the cli.""" + if sys.platform == "win32": + # Enable VT100 Escape Sequence for WINDOWS 10 Ver. 1607 + os.system("") # nosec # noqa: S605,S607 + # Hide splashscreen loader of the packaged app + if is_installer(): + hide_splashscreen() + + try: + if os.name == "nt": + # pylint: disable=E1101 + sys.stdin.reconfigure(encoding="utf-8") # type: ignore + # pylint: disable=E1101 + sys.stdout.reconfigure(encoding="utf-8") # type: ignore + except Exception as e: + Session().console.print(e, "\n") + + +def check_for_updates() -> None: + """Check if the latest version is running. + + Checks github for the latest release version and compares it to cfg.VERSION. + """ + # The commit has was commented out because the terminal was crashing due to git import for multiple users + # ({str(git.Repo('.').head.commit)[:7]}) + try: + r = request( + "https://api.github.com/repos/openbb-finance/openbbterminal/releases/latest" + ) + except Exception: + r = None + + if r and r.status_code == 200: + latest_tag_name = r.json()["tag_name"] + latest_version = version.parse(latest_tag_name) + current_version = version.parse(Session().settings.VERSION) + + if check_valid_versions(latest_version, current_version): + if current_version == latest_version: + Session().console.print( + "[green]You are using the latest stable version[/green]" + ) + else: + Session().console.print( + "[yellow]You are not using the latest stable version[/yellow]" + ) + if current_version < latest_version: + Session().console.print( + "[yellow]Check for updates at https://my.openbb.co/app/terminal/download[/yellow]" + ) + + else: + Session().console.print( + "[yellow]You are using an unreleased version[/yellow]" + ) + + else: + Session().console.print("[red]You are using an unrecognized version.[/red]") + else: + Session().console.print( + "[yellow]Unable to check for updates... " + + "Check your internet connection and try again...[/yellow]" + ) + Session().console.print("\n") + + +def check_valid_versions( + latest_version: version.Version, + current_version: version.Version, +) -> bool: + """Check if the versions are valid.""" + if ( + not latest_version + or not current_version + or not isinstance(latest_version, version.Version) + or not isinstance(current_version, version.Version) + ): + return False + return True + + +def welcome_message(): + """Print the welcome message. + + Prints first welcome message, help and a notification if updates are available. + """ + Session().console.print( + f"\nWelcome to OpenBB Platform CLI v{Session().settings.VERSION}" + ) + + +def reset(queue: Optional[List[str]] = None): + """Reset the CLI. + + Allows for checking code without quitting. + """ + Session().console.print("resetting...") + Session().reset() + debug = Session().settings.DEBUG_MODE + dev = Session().settings.DEV_BACKEND + + try: + # remove the hub routines + if not Session().is_local(): + remove_file( + Path(Session().user.preferences.export_directory, "routines", "hub") + ) + + # if not get_current_user().profile.remember: + # Local.remove(HIST_FILE_PROMPT) + + # we clear all openbb_cli modules from sys.modules + for module in list(sys.modules.keys()): + parts = module.split(".") + if parts[0] == "openbb_cli": + del sys.modules[module] + + queue_list = ["/".join(queue) if len(queue) > 0 else ""] # type: ignore + # pylint: disable=import-outside-toplevel + # we run the cli again + if Session().is_local(): + from openbb_cli.controllers.cli_controller import main + + main(debug, dev, queue_list, module="") # type: ignore + else: + from openbb_cli.controllers.cli_controller import launch + + launch(queue=queue_list) + + except Exception as e: + Session().console.print(f"Unfortunately, resetting wasn't possible: {e}\n") + print_goodbye() + + +@contextmanager +def suppress_stdout(): + """Suppress the stdout.""" + with open(os.devnull, "w") as devnull: + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = devnull + sys.stderr = devnull + try: + yield + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + + +def first_time_user() -> bool: + """Check whether a user is a first time user. + + A first time user is someone with an empty .env file. + If this is true, it also adds an env variable to make sure this does not run again. + + Returns + ------- + bool + Whether or not the user is a first time user + """ + if ENV_FILE_SETTINGS.stat().st_size == 0: + Session().settings.set_item("PREVIOUS_USE", True) + return True + return False + + +def parse_and_split_input(an_input: str, custom_filters: List) -> List[str]: + """Filter and split the input queue. + + Uses regex to filters command arguments that have forward slashes so that it doesn't + break the execution of the command queue. + Currently handles unix paths and sorting settings for screener menus. + + Parameters + ---------- + an_input : str + User input as string + custom_filters : List + Additional regular expressions to match + + Returns + ------- + List[str] + Command queue as list + """ + # Make sure that the user can go back to the root when doing "/" + if an_input and an_input == "/": + an_input = "home" + + # everything from ` -f ` to the next known extension + file_flag = r"(\ -f |\ --file )" + up_to = r".*?" + known_extensions = r"(\.(xlsx|csv|xls|tsv|json|yaml|ini|openbb|ipynb))" + unix_path_arg_exp = f"({file_flag}{up_to}{known_extensions})" + + # Add custom expressions to handle edge cases of individual controllers + custom_filter = "" + for exp in custom_filters: + if exp is not None: + custom_filter += f"|{exp}" + del exp + + slash_filter_exp = f"({unix_path_arg_exp}){custom_filter}" + + filter_input = True + placeholders: Dict[str, str] = {} + while filter_input: + match = re.search(pattern=slash_filter_exp, string=an_input) + if match is not None: + placeholder = f"{{placeholder{len(placeholders)+1}}}" + placeholders[placeholder] = an_input[ + match.span()[0] : match.span()[1] # noqa:E203 + ] + an_input = ( + an_input[: match.span()[0]] + + placeholder + + an_input[match.span()[1] :] # noqa:E203 + ) + else: + filter_input = False + + commands = an_input.split("/") if "-t" not in an_input else [an_input] + + for command_num, command in enumerate(commands): + if command == commands[command_num] == commands[-1] == "": + return list(filter(None, commands)) + matching_placeholders = [tag for tag in placeholders if tag in command] + if len(matching_placeholders) > 0: + for tag in matching_placeholders: + commands[command_num] = command.replace(tag, placeholders[tag]) + return commands + + +def return_colored_value(value: str): + """Return the string value based on condition. + + Return it with green, yellow, red or white color based on + whether the number is positive, negative, zero or other, respectively. + + Parameters + ---------- + value: str + string to be checked + + Returns + ------- + value: str + string with color based on value of number if it exists + """ + values = re.findall(r"[-+]?(?:\d*\.\d+|\d+)", value) + + # Finds exactly 1 number in the string + if len(values) == 1: + if float(values[0]) > 0: + return f"[green]{value}[/green]" + + if float(values[0]) < 0: + return f"[red]{value}[/red]" + + if float(values[0]) == 0: + return f"[yellow]{value}[/yellow]" + + return f"{value}" + + +def _get_backend(): + """Get the Platform charting backend.""" + try: + return get_backend() + except ValueError: + # backend might not be created yet + charting_settings = ChartingSettings( + system_settings=obb.system, user_settings=obb.user # type: ignore + ) + create_backend(charting_settings) + get_backend().start(debug=charting_settings.debug_mode) + return get_backend() + + +# pylint: disable=too-many-arguments +def print_rich_table( # noqa: PLR0912 + df: pd.DataFrame, + show_index: bool = False, + title: str = "", + index_name: str = "", + headers: Optional[Union[List[str], pd.Index]] = None, + floatfmt: Union[str, List[str]] = ".2f", + show_header: bool = True, + automatic_coloring: bool = False, + columns_to_auto_color: Optional[List[str]] = None, + rows_to_auto_color: Optional[List[str]] = None, + export: bool = False, + limit: Optional[int] = 1000, + columns_keep_types: Optional[List[str]] = None, + use_tabulate_df: bool = True, +): + """Prepare a table from df in rich. + + Parameters + ---------- + df: pd.DataFrame + Dataframe to turn into table + show_index: bool + Whether to include index + title: str + Title for table + index_name : str + Title for index column + headers: List[str] + Titles for columns + floatfmt: Union[str, List[str]] + Float number formatting specs as string or list of strings. Defaults to ".2f" + show_header: bool + Whether to show the header row. + automatic_coloring: bool + Automatically color a table based on positive and negative values + columns_to_auto_color: List[str] + Columns to automatically color + rows_to_auto_color: List[str] + Rows to automatically color + export: bool + Whether we are exporting the table to a file. If so, we don't want to print it. + limit: Optional[int] + Limit the number of rows to show. + columns_keep_types: Optional[List[str]] + Columns to keep their types, i.e. not convert to numeric + """ + if export: + return + + # Make a copy of the dataframe to avoid SettingWithCopyWarning + df = df.copy() + + show_index = not isinstance(df.index, pd.RangeIndex) and show_index + # convert non-str that are not timestamp or int into str + # eg) praw.models.reddit.subreddit.Subreddit + for col in df.columns: + if columns_keep_types is not None and col in columns_keep_types: + continue + try: + if not any( + isinstance(df[col].iloc[x], pd.Timestamp) + for x in range(min(10, len(df))) + ): + df[col] = pd.to_numeric(df[col], errors="ignore") + except (ValueError, TypeError): + df[col] = df[col].astype(str) + + def _get_headers(_headers: Union[List[str], pd.Index]) -> List[str]: + """Check if headers are valid and return them.""" + output = _headers + if isinstance(_headers, pd.Index): + output = list(_headers) + if len(output) != len(df.columns): + raise ValueError("Length of headers does not match length of DataFrame.") + return output + + if Session().settings.USE_INTERACTIVE_DF: + df_outgoing = df.copy() + # If headers are provided, use them + if headers is not None: + # We check if headers are valid + df_outgoing.columns = _get_headers(headers) + + if show_index and index_name not in df_outgoing.columns: + # If index name is provided, we use it + df_outgoing.index.name = index_name or "Index" + df_outgoing = df_outgoing.reset_index() + + for col in df_outgoing.columns: + if col == "": + df_outgoing = df_outgoing.rename(columns={col: " "}) + + # ensure everything on the dataframe is a string + df_outgoing = df_outgoing.applymap(str) + + _get_backend().send_table( + df_table=df_outgoing, + title=title, + theme=Session().user.preferences.table_style, + ) + return + + df = df.copy() if not limit else df.copy().iloc[:limit] + if automatic_coloring: + if columns_to_auto_color: + for col in columns_to_auto_color: + # checks whether column exists + if col in df.columns: + df[col] = df[col].apply(lambda x: return_colored_value(str(x))) + if rows_to_auto_color: + for row in rows_to_auto_color: + # checks whether row exists + if row in df.index: + df.loc[row] = df.loc[row].apply( + lambda x: return_colored_value(str(x)) + ) + + if columns_to_auto_color is None and rows_to_auto_color is None: + df = df.applymap(lambda x: return_colored_value(str(x))) + + exceeds_allowed_columns = ( + len(df.columns) > Session().settings.ALLOWED_NUMBER_OF_COLUMNS + ) + exceeds_allowed_rows = len(df) > Session().settings.ALLOWED_NUMBER_OF_ROWS + + if exceeds_allowed_columns: + original_columns = df.columns.tolist() + trimmed_columns = df.columns.tolist()[ + : Session().settings.ALLOWED_NUMBER_OF_COLUMNS + ] + df = df[trimmed_columns] + trimmed_columns = [ + col for col in original_columns if col not in trimmed_columns + ] + + if exceeds_allowed_rows: + n_rows = len(df.index) + trimmed_rows = df.index.tolist()[: Session().settings.ALLOWED_NUMBER_OF_ROWS] + df = df.loc[trimmed_rows] + trimmed_rows_count = n_rows - Session().settings.ALLOWED_NUMBER_OF_ROWS + + if use_tabulate_df: + table = Table(title=title, show_lines=True, show_header=show_header) + + if show_index: + table.add_column(index_name) + + if headers is not None: + headers = _get_headers(headers) + for header in headers: + table.add_column(str(header)) + else: + for column in df.columns: + table.add_column(str(column)) + + if isinstance(floatfmt, list) and len(floatfmt) != len(df.columns): + raise ( + ValueError( + "Length of floatfmt list does not match length of DataFrame columns." + ) + ) + if isinstance(floatfmt, str): + floatfmt = [floatfmt for _ in range(len(df.columns))] + + for idx, values in zip(df.index.tolist(), df.values.tolist()): + # remove hour/min/sec from timestamp index - Format: YYYY-MM-DD # make better + row_idx = [str(idx)] if show_index else [] + row_idx += [ + ( + str(x) + if not isinstance(x, float) and not isinstance(x, np.float64) + else ( + f"{x:{floatfmt[idx]}}" + if isinstance(floatfmt, list) + else ( + f"{x:.2e}" + if 0 < abs(float(x)) <= 0.0001 + else f"{x:floatfmt}" + ) + ) + ) + for idx, x in enumerate(values) + ] + table.add_row(*row_idx) + Session().console.print(table) + else: + Session().console.print(df.to_string(col_space=0)) + + if exceeds_allowed_columns: + Session().console.print( + f"[yellow]\nAllowed number of columns exceeded ({Session().settings.ALLOWED_NUMBER_OF_COLUMNS}).\n" + f"The following columns were removed from the output: {', '.join(trimmed_columns)}.\n[/yellow]" + ) + + if exceeds_allowed_rows: + Session().console.print( + f"[yellow]\nAllowed number of rows exceeded ({Session().settings.ALLOWED_NUMBER_OF_ROWS}).\n" + f"{trimmed_rows_count} rows were removed from the output.\n[/yellow]" + ) + + if exceeds_allowed_columns or exceeds_allowed_rows: + Session().console.print( + "Use the `--export` flag to analyse the full output on a file." + ) + + +def check_non_negative(value) -> int: + """Argparse type to check non negative int.""" + new_value = int(value) + if new_value < 0: + raise argparse.ArgumentTypeError(f"{value} is negative") + return new_value + + +def check_positive(value) -> int: + """Argparse type to check positive int.""" + new_value = int(value) + if new_value <= 0: + raise argparse.ArgumentTypeError(f"{value} is an invalid positive int value") + return new_value + + +def get_user_agent() -> str: + """Get a not very random user agent.""" + user_agent_strings = [ + "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.10; rv:86.1) Gecko/20100101 Firefox/86.1", + "Mozilla/5.0 (Windows NT 6.1; WOW64; rv:86.1) Gecko/20100101 Firefox/86.1", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.10; rv:82.1) Gecko/20100101 Firefox/82.1", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:86.0) Gecko/20100101 Firefox/86.0", + "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:86.0) Gecko/20100101 Firefox/86.0", + "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.10; rv:83.0) Gecko/20100101 Firefox/83.0", + "Mozilla/5.0 (Windows NT 6.1; WOW64; rv:84.0) Gecko/20100101 Firefox/84.0", + ] + + return random.choice(user_agent_strings) # nosec # noqa: S311 + + +def get_flair() -> str: + """Get a flair icon.""" + current_flair = str(Session().settings.FLAIR) + flair = AVAILABLE_FLAIRS.get(current_flair, current_flair) + return flair + + +def get_dtime() -> str: + """Get a datetime string.""" + dtime = "" + if Session().settings.USE_DATETIME and get_user_timezone_or_invalid() != "INVALID": + dtime = datetime.now(timezone(get_user_timezone())).strftime("%Y %b %d, %H:%M") + return dtime + + +def get_flair_and_username() -> str: + """Get a flair icon and username.""" + flair = get_flair() + dtime = get_dtime() + + if dtime: + dtime = f"{dtime} " + + username = getattr(Session().user.profile.hub_session, "username", "") + if username: + username = f"[{username}] " + + return f"{dtime}{username}{flair}" + + +def is_timezone_valid(user_tz: str) -> bool: + """Check whether user timezone is valid. + + Parameters + ---------- + user_tz: str + Timezone to check for validity + + Returns + ------- + bool + True if timezone provided is valid + """ + return user_tz in all_timezones + + +def get_user_timezone() -> str: + """Get user timezone if it is a valid one. + + Returns + ------- + str + user timezone based on .env file + """ + return Session().settings.TIMEZONE + + +def get_user_timezone_or_invalid() -> str: + """Get user timezone if it is a valid one. + + Returns + ------- + str + user timezone based on timezone.openbb file or INVALID + """ + user_tz = get_user_timezone() + if is_timezone_valid(user_tz): + return f"{user_tz}" + return "INVALID" + + +def check_file_type_saved(valid_types: Optional[List[str]] = None): + """Provide valid types for the user to be able to select. + + Parameters + ---------- + valid_types: List[str] + List of valid types to export data + + Returns + ------- + check_filenames: Optional[List[str]] + Function that returns list of filenames to export data + """ + + def check_filenames(filenames: str = "") -> str: + """Check if filenames are valid. + + Parameters + ---------- + filenames: str + filenames to be saved separated with comma + + Returns + ---------- + str + valid filenames separated with comma + """ + if not filenames or not valid_types: + return "" + valid_filenames = list() + for filename in filenames.split(","): + if filename.endswith(tuple(valid_types)): + valid_filenames.append(filename) + else: + Session().console.print( + f"[red]Filename '{filename}' provided is not valid!\nPlease use one of the following file types:" + f"{','.join(valid_types)}[/red]\n" + ) + return ",".join(valid_filenames) + + return check_filenames + + +def remove_timezone_from_dataframe(df: pd.DataFrame) -> pd.DataFrame: + """Remove timezone information from a dataframe. + + Parameters + ---------- + df : pd.DataFrame + The dataframe to remove timezone information from + + Returns + ------- + pd.DataFrame + The dataframe with timezone information removed + """ + + date_cols = [] + index_is_date = False + + # Find columns and index containing date data + if ( + df.index.dtype.kind == "M" + and hasattr(df.index.dtype, "tz") + and df.index.dtype.tz is not None + ): + index_is_date = True + + for col, dtype in df.dtypes.items(): + if dtype.kind == "M" and hasattr(df.index.dtype, "tz") and dtype.tz is not None: + date_cols.append(col) + + # Remove the timezone information + for col in date_cols: + df[col] = df[col].dt.date + + if index_is_date: + index_name = df.index.name + df.index = df.index.date + df.index.name = index_name + + return df + + +def compose_export_path(func_name: str, dir_path: str) -> Path: + """Compose export path for data from the terminal. + + Creates a path to a folder and a filename based on conditions. + + Parameters + ---------- + func_name : str + Name of the command that invokes this function + dir_path : str + Path of directory from where this function is called + + Returns + ------- + Path + Path variable containing the path of the exported file + """ + now = datetime.now() + # Resolving all symlinks and also normalizing path. + resolve_path = Path(dir_path).resolve() + # Getting the directory names from the path. Instead of using split/replace (Windows doesn't like that) + # check if this is done in a main context to avoid saving with openbb_cli + if resolve_path.parts[-2] == "openbb_cli": + path_cmd = f"{resolve_path.parts[-1]}" + else: + path_cmd = f"{resolve_path.parts[-2]}_{resolve_path.parts[-1]}" + + default_filename = f"{now.strftime('%Y%m%d_%H%M%S')}_{path_cmd}_{func_name}" + + full_path = Path(Session().user.preferences.export_directory) / default_filename + + return full_path + + +def ask_file_overwrite(file_path: Path) -> Tuple[bool, bool]: + """Provide a prompt for overwriting existing files. + + Returns two values, the first is a boolean indicating if the file exists and the + second is a boolean indicating if the user wants to overwrite the file. + """ + if Session().settings.FILE_OVERWRITE: + return False, True + if Session().settings.TEST_MODE: |