summaryrefslogtreecommitdiffstats
path: root/cli/openbb_cli/controllers/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/openbb_cli/controllers/utils.py')
-rw-r--r--cli/openbb_cli/controllers/utils.py1056
1 files changed, 1056 insertions, 0 deletions
diff --git a/cli/openbb_cli/controllers/utils.py b/cli/openbb_cli/controllers/utils.py
new file mode 100644
index 00000000000..a85751a58fa
--- /dev/null
+++ b/cli/openbb_cli/controllers/utils.py
@@ -0,0 +1,1056 @@
+"""Utils."""
+
+import argparse
+import os
+import random
+import re
+import shutil
+import sys
+from contextlib import contextmanager
+from datetime import (
+ datetime,
+)
+from pathlib import Path
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import pandas as pd
+import requests
+from openbb import obb
+from openbb_charting.core.backend import create_backend, get_backend
+from openbb_cli.config.constants import AVAILABLE_FLAIRS, ENV_FILE_SETTINGS
+from openbb_cli.session import Session
+from openbb_core.app.model.charts.charting_settings import ChartingSettings
+from packaging import version
+from pytz import all_timezones, timezone
+from rich.table import Table
+
+if TYPE_CHECKING:
+ from openbb_charting.core.openbb_figure import OpenBBFigure
+
+# pylint: disable=R1702,R0912
+
+
+# pylint: disable=too-many-statements,no-member,too-many-branches,C0302
+
+
+def remove_file(path: Path) -> bool:
+ """Remove path.
+
+ Parameters
+ ----------
+ path : Path
+ The file path.
+
+ Returns
+ -------
+ bool
+ The status of the removal.
+ """
+ # TODO: Check why module level import leads to circular import.
+ try:
+ if os.path.isfile(path):
+ os.remove(path)
+ elif os.path.isdir(path):
+ shutil.rmtree(path)
+ return True
+ except Exception:
+ Session().console.print(
+ f"\n[bold red]Failed to remove {path}"
+ "\nPlease delete this manually![/bold red]"
+ )
+ return False
+
+
+def print_goodbye():
+ """Print a goodbye message when quitting the terminal."""
+ # LEGACY GOODBYE MESSAGES - You'll live in our hearts forever.
+ # "An informed ape, is a strong ape."
+ # "Remember that stonks only go up."
+ # "Diamond hands."
+ # "Apes together strong."
+ # "This is our way."
+ # "Keep the spacesuit ape, we haven't reached the moon yet."
+ # "I am not a cat. I'm an ape."
+ # "We like the terminal."
+ # "...when offered a flight to the moon, nobody asks about what seat."
+
+ text = """
+[param]Thank you for using the OpenBB Platform CLI and being part of this journey.[/param]
+
+We hope you'll find the new CLI as valuable as this. To stay tuned, sign up for our newsletter: [cmds]https://openbb.co/newsletter.[/]
+
+In the meantime, check out our other products:
+
+[bold]OpenBB Terminal Pro[/]: [cmds]https://openbb.co/products/pro[/cmds]
+[bold]OpenBB Platform:[/] [cmds]https://openbb.co/products/platform[/cmds]
+[bold]OpenBB Bot[/]: [cmds]https://openbb.co/products/bot[/cmds]
+ """
+ Session().console.print(text)
+
+
+def hide_splashscreen():
+ """Hide the splashscreen on Windows bundles.
+
+ `pyi_splash` is a PyInstaller "fake-package" that's used to communicate
+ with the splashscreen on Windows.
+ Sending the `close` signal to the splash screen is required.
+ The splash screen remains open until this function is called or the Python
+ program is terminated.
+ """
+ try:
+ import pyi_splash # type: ignore # pylint: disable=import-outside-toplevel
+
+ pyi_splash.update_text("CLI Loaded!")
+ pyi_splash.close()
+ except Exception as e:
+ Session().console.print(f"Error: Unable to hide splashscreen: {e}")
+
+
+def print_guest_block_msg():
+ """Block guest users from using the cli."""
+ if Session().is_local():
+ Session().console.print(
+ "[info]You are currently logged as a guest.[/info]\n"
+ "[info]Login to use this feature.[/info]\n\n"
+ "[info]If you don't have an account, you can create one here: [/info]"
+ f"[cmds]{Session().settings.HUB_URL + '/register'}\n[/cmds]"
+ )
+
+
+def is_installer() -> bool:
+ """Check whether or not it is a packaged version (Windows or Mac installer."""
+ return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
+
+
+def bootup():
+ """Bootup the cli."""
+ if sys.platform == "win32":
+ # Enable VT100 Escape Sequence for WINDOWS 10 Ver. 1607
+ os.system("") # nosec # noqa: S605,S607
+ # Hide splashscreen loader of the packaged app
+ if is_installer():
+ hide_splashscreen()
+
+ try:
+ if os.name == "nt":
+ # pylint: disable=E1101
+ sys.stdin.reconfigure(encoding="utf-8") # type: ignore
+ # pylint: disable=E1101
+ sys.stdout.reconfigure(encoding="utf-8") # type: ignore
+ except Exception as e:
+ Session().console.print(e, "\n")
+
+
+def check_for_updates() -> None:
+ """Check if the latest version is running.
+
+ Checks github for the latest release version and compares it to cfg.VERSION.
+ """
+ # The commit has was commented out because the terminal was crashing due to git import for multiple users
+ # ({str(git.Repo('.').head.commit)[:7]})
+ try:
+ r = request(
+ "https://api.github.com/repos/openbb-finance/openbbterminal/releases/latest"
+ )
+ except Exception:
+ r = None
+
+ if r and r.status_code == 200:
+ latest_tag_name = r.json()["tag_name"]
+ latest_version = version.parse(latest_tag_name)
+ current_version = version.parse(Session().settings.VERSION)
+
+ if check_valid_versions(latest_version, current_version):
+ if current_version == latest_version:
+ Session().console.print(
+ "[green]You are using the latest stable version[/green]"
+ )
+ else:
+ Session().console.print(
+ "[yellow]You are not using the latest stable version[/yellow]"
+ )
+ if current_version < latest_version:
+ Session().console.print(
+ "[yellow]Check for updates at https://my.openbb.co/app/terminal/download[/yellow]"
+ )
+
+ else:
+ Session().console.print(
+ "[yellow]You are using an unreleased version[/yellow]"
+ )
+
+ else:
+ Session().console.print("[red]You are using an unrecognized version.[/red]")
+ else:
+ Session().console.print(
+ "[yellow]Unable to check for updates... "
+ + "Check your internet connection and try again...[/yellow]"
+ )
+ Session().console.print("\n")
+
+
+def check_valid_versions(
+ latest_version: version.Version,
+ current_version: version.Version,
+) -> bool:
+ """Check if the versions are valid."""
+ if (
+ not latest_version
+ or not current_version
+ or not isinstance(latest_version, version.Version)
+ or not isinstance(current_version, version.Version)
+ ):
+ return False
+ return True
+
+
+def welcome_message():
+ """Print the welcome message.
+
+ Prints first welcome message, help and a notification if updates are available.
+ """
+ Session().console.print(
+ f"\nWelcome to OpenBB Platform CLI v{Session().settings.VERSION}"
+ )
+
+
+def reset(queue: Optional[List[str]] = None):
+ """Reset the CLI.
+
+ Allows for checking code without quitting.
+ """
+ Session().console.print("resetting...")
+ Session().reset()
+ debug = Session().settings.DEBUG_MODE
+ dev = Session().settings.DEV_BACKEND
+
+ try:
+ # remove the hub routines
+ if not Session().is_local():
+ remove_file(
+ Path(Session().user.preferences.export_directory, "routines", "hub")
+ )
+
+ # if not get_current_user().profile.remember:
+ # Local.remove(HIST_FILE_PROMPT)
+
+ # we clear all openbb_cli modules from sys.modules
+ for module in list(sys.modules.keys()):
+ parts = module.split(".")
+ if parts[0] == "openbb_cli":
+ del sys.modules[module]
+
+ queue_list = ["/".join(queue) if len(queue) > 0 else ""] # type: ignore
+ # pylint: disable=import-outside-toplevel
+ # we run the cli again
+ if Session().is_local():
+ from openbb_cli.controllers.cli_controller import main
+
+ main(debug, dev, queue_list, module="") # type: ignore
+ else:
+ from openbb_cli.controllers.cli_controller import launch
+
+ launch(queue=queue_list)
+
+ except Exception as e:
+ Session().console.print(f"Unfortunately, resetting wasn't possible: {e}\n")
+ print_goodbye()
+
+
+@contextmanager
+def suppress_stdout():
+ """Suppress the stdout."""
+ with open(os.devnull, "w") as devnull:
+ old_stdout = sys.stdout
+ old_stderr = sys.stderr
+ sys.stdout = devnull
+ sys.stderr = devnull
+ try:
+ yield
+ finally:
+ sys.stdout = old_stdout
+ sys.stderr = old_stderr
+
+
+def first_time_user() -> bool:
+ """Check whether a user is a first time user.
+
+ A first time user is someone with an empty .env file.
+ If this is true, it also adds an env variable to make sure this does not run again.
+
+ Returns
+ -------
+ bool
+ Whether or not the user is a first time user
+ """
+ if ENV_FILE_SETTINGS.stat().st_size == 0:
+ Session().settings.set_item("PREVIOUS_USE", True)
+ return True
+ return False
+
+
+def parse_and_split_input(an_input: str, custom_filters: List) -> List[str]:
+ """Filter and split the input queue.
+
+ Uses regex to filters command arguments that have forward slashes so that it doesn't
+ break the execution of the command queue.
+ Currently handles unix paths and sorting settings for screener menus.
+
+ Parameters
+ ----------
+ an_input : str
+ User input as string
+ custom_filters : List
+ Additional regular expressions to match
+
+ Returns
+ -------
+ List[str]
+ Command queue as list
+ """
+ # Make sure that the user can go back to the root when doing "/"
+ if an_input and an_input == "/":
+ an_input = "home"
+
+ # everything from ` -f ` to the next known extension
+ file_flag = r"(\ -f |\ --file )"
+ up_to = r".*?"
+ known_extensions = r"(\.(xlsx|csv|xls|tsv|json|yaml|ini|openbb|ipynb))"
+ unix_path_arg_exp = f"({file_flag}{up_to}{known_extensions})"
+
+ # Add custom expressions to handle edge cases of individual controllers
+ custom_filter = ""
+ for exp in custom_filters:
+ if exp is not None:
+ custom_filter += f"|{exp}"
+ del exp
+
+ slash_filter_exp = f"({unix_path_arg_exp}){custom_filter}"
+
+ filter_input = True
+ placeholders: Dict[str, str] = {}
+ while filter_input:
+ match = re.search(pattern=slash_filter_exp, string=an_input)
+ if match is not None:
+ placeholder = f"{{placeholder{len(placeholders)+1}}}"
+ placeholders[placeholder] = an_input[
+ match.span()[0] : match.span()[1] # noqa:E203
+ ]
+ an_input = (
+ an_input[: match.span()[0]]
+ + placeholder
+ + an_input[match.span()[1] :] # noqa:E203
+ )
+ else:
+ filter_input = False
+
+ commands = an_input.split("/") if "-t" not in an_input else [an_input]
+
+ for command_num, command in enumerate(commands):
+ if command == commands[command_num] == commands[-1] == "":
+ return list(filter(None, commands))
+ matching_placeholders = [tag for tag in placeholders if tag in command]
+ if len(matching_placeholders) > 0:
+ for tag in matching_placeholders:
+ commands[command_num] = command.replace(tag, placeholders[tag])
+ return commands
+
+
+def return_colored_value(value: str):
+ """Return the string value based on condition.
+
+ Return it with green, yellow, red or white color based on
+ whether the number is positive, negative, zero or other, respectively.
+
+ Parameters
+ ----------
+ value: str
+ string to be checked
+
+ Returns
+ -------
+ value: str
+ string with color based on value of number if it exists
+ """
+ values = re.findall(r"[-+]?(?:\d*\.\d+|\d+)", value)
+
+ # Finds exactly 1 number in the string
+ if len(values) == 1:
+ if float(values[0]) > 0:
+ return f"[green]{value}[/green]"
+
+ if float(values[0]) < 0:
+ return f"[red]{value}[/red]"
+
+ if float(values[0]) == 0:
+ return f"[yellow]{value}[/yellow]"
+
+ return f"{value}"
+
+
+def _get_backend():
+ """Get the Platform charting backend."""
+ try:
+ return get_backend()
+ except ValueError:
+ # backend might not be created yet
+ charting_settings = ChartingSettings(
+ system_settings=obb.system, user_settings=obb.user # type: ignore
+ )
+ create_backend(charting_settings)
+ get_backend().start(debug=charting_settings.debug_mode)
+ return get_backend()
+
+
+# pylint: disable=too-many-arguments
+def print_rich_table( # noqa: PLR0912
+ df: pd.DataFrame,
+ show_index: bool = False,
+ title: str = "",
+ index_name: str = "",
+ headers: Optional[Union[List[str], pd.Index]] = None,
+ floatfmt: Union[str, List[str]] = ".2f",
+ show_header: bool = True,
+ automatic_coloring: bool = False,
+ columns_to_auto_color: Optional[List[str]] = None,
+ rows_to_auto_color: Optional[List[str]] = None,
+ export: bool = False,
+ limit: Optional[int] = 1000,
+ columns_keep_types: Optional[List[str]] = None,
+ use_tabulate_df: bool = True,
+):
+ """Prepare a table from df in rich.
+
+ Parameters
+ ----------
+ df: pd.DataFrame
+ Dataframe to turn into table
+ show_index: bool
+ Whether to include index
+ title: str
+ Title for table
+ index_name : str
+ Title for index column
+ headers: List[str]
+ Titles for columns
+ floatfmt: Union[str, List[str]]
+ Float number formatting specs as string or list of strings. Defaults to ".2f"
+ show_header: bool
+ Whether to show the header row.
+ automatic_coloring: bool
+ Automatically color a table based on positive and negative values
+ columns_to_auto_color: List[str]
+ Columns to automatically color
+ rows_to_auto_color: List[str]
+ Rows to automatically color
+ export: bool
+ Whether we are exporting the table to a file. If so, we don't want to print it.
+ limit: Optional[int]
+ Limit the number of rows to show.
+ columns_keep_types: Optional[List[str]]
+ Columns to keep their types, i.e. not convert to numeric
+ """
+ if export:
+ return
+
+ # Make a copy of the dataframe to avoid SettingWithCopyWarning
+ df = df.copy()
+
+ show_index = not isinstance(df.index, pd.RangeIndex) and show_index
+ # convert non-str that are not timestamp or int into str
+ # eg) praw.models.reddit.subreddit.Subreddit
+ for col in df.columns:
+ if columns_keep_types is not None and col in columns_keep_types:
+ continue
+ try:
+ if not any(
+ isinstance(df[col].iloc[x], pd.Timestamp)
+ for x in range(min(10, len(df)))
+ ):
+ df[col] = pd.to_numeric(df[col], errors="ignore")
+ except (ValueError, TypeError):
+ df[col] = df[col].astype(str)
+
+ def _get_headers(_headers: Union[List[str], pd.Index]) -> List[str]:
+ """Check if headers are valid and return them."""
+ output = _headers
+ if isinstance(_headers, pd.Index):
+ output = list(_headers)
+ if len(output) != len(df.columns):
+ raise ValueError("Length of headers does not match length of DataFrame.")
+ return output
+
+ if Session().settings.USE_INTERACTIVE_DF:
+ df_outgoing = df.copy()
+ # If headers are provided, use them
+ if headers is not None:
+ # We check if headers are valid
+ df_outgoing.columns = _get_headers(headers)
+
+ if show_index and index_name not in df_outgoing.columns:
+ # If index name is provided, we use it
+ df_outgoing.index.name = index_name or "Index"
+ df_outgoing = df_outgoing.reset_index()
+
+ for col in df_outgoing.columns:
+ if col == "":
+ df_outgoing = df_outgoing.rename(columns={col: " "})
+
+ # ensure everything on the dataframe is a string
+ df_outgoing = df_outgoing.applymap(str)
+
+ _get_backend().send_table(
+ df_table=df_outgoing,
+ title=title,
+ theme=Session().user.preferences.table_style,
+ )
+ return
+
+ df = df.copy() if not limit else df.copy().iloc[:limit]
+ if automatic_coloring:
+ if columns_to_auto_color:
+ for col in columns_to_auto_color:
+ # checks whether column exists
+ if col in df.columns:
+ df[col] = df[col].apply(lambda x: return_colored_value(str(x)))
+ if rows_to_auto_color:
+ for row in rows_to_auto_color:
+ # checks whether row exists
+ if row in df.index:
+ df.loc[row] = df.loc[row].apply(
+ lambda x: return_colored_value(str(x))
+ )
+
+ if columns_to_auto_color is None and rows_to_auto_color is None:
+ df = df.applymap(lambda x: return_colored_value(str(x)))
+
+ exceeds_allowed_columns = (
+ len(df.columns) > Session().settings.ALLOWED_NUMBER_OF_COLUMNS
+ )
+ exceeds_allowed_rows = len(df) > Session().settings.ALLOWED_NUMBER_OF_ROWS
+
+ if exceeds_allowed_columns:
+ original_columns = df.columns.tolist()
+ trimmed_columns = df.columns.tolist()[
+ : Session().settings.ALLOWED_NUMBER_OF_COLUMNS
+ ]
+ df = df[trimmed_columns]
+ trimmed_columns = [
+ col for col in original_columns if col not in trimmed_columns
+ ]
+
+ if exceeds_allowed_rows:
+ n_rows = len(df.index)
+ trimmed_rows = df.index.tolist()[: Session().settings.ALLOWED_NUMBER_OF_ROWS]
+ df = df.loc[trimmed_rows]
+ trimmed_rows_count = n_rows - Session().settings.ALLOWED_NUMBER_OF_ROWS
+
+ if use_tabulate_df:
+ table = Table(title=title, show_lines=True, show_header=show_header)
+
+ if show_index:
+ table.add_column(index_name)
+
+ if headers is not None:
+ headers = _get_headers(headers)
+ for header in headers:
+ table.add_column(str(header))
+ else:
+ for column in df.columns:
+ table.add_column(str(column))
+
+ if isinstance(floatfmt, list) and len(floatfmt) != len(df.columns):
+ raise (
+ ValueError(
+ "Length of floatfmt list does not match length of DataFrame columns."
+ )
+ )
+ if isinstance(floatfmt, str):
+ floatfmt = [floatfmt for _ in range(len(df.columns))]
+
+ for idx, values in zip(df.index.tolist(), df.values.tolist()):
+ # remove hour/min/sec from timestamp index - Format: YYYY-MM-DD # make better
+ row_idx = [str(idx)] if show_index else []
+ row_idx += [
+ (
+ str(x)
+ if not isinstance(x, float) and not isinstance(x, np.float64)
+ else (
+ f"{x:{floatfmt[idx]}}"
+ if isinstance(floatfmt, list)
+ else (
+ f"{x:.2e}"
+ if 0 < abs(float(x)) <= 0.0001
+ else f"{x:floatfmt}"
+ )
+ )
+ )
+ for idx, x in enumerate(values)
+ ]
+ table.add_row(*row_idx)
+ Session().console.print(table)
+ else:
+ Session().console.print(df.to_string(col_space=0))
+
+ if exceeds_allowed_columns:
+ Session().console.print(
+ f"[yellow]\nAllowed number of columns exceeded ({Session().settings.ALLOWED_NUMBER_OF_COLUMNS}).\n"
+ f"The following columns were removed from the output: {', '.join(trimmed_columns)}.\n[/yellow]"
+ )
+
+ if exceeds_allowed_rows:
+ Session().console.print(
+ f"[yellow]\nAllowed number of rows exceeded ({Session().settings.ALLOWED_NUMBER_OF_ROWS}).\n"
+ f"{trimmed_rows_count} rows were removed from the output.\n[/yellow]"
+ )
+
+ if exceeds_allowed_columns or exceeds_allowed_rows:
+ Session().console.print(
+ "Use the `--export` flag to analyse the full output on a file."
+ )
+
+
+def check_non_negative(value) -> int:
+ """Argparse type to check non negative int."""
+ new_value = int(value)
+ if new_value < 0:
+ raise argparse.ArgumentTypeError(f"{value} is negative")
+ return new_value
+
+
+def check_positive(value) -> int:
+ """Argparse type to check positive int."""
+ new_value = int(value)
+ if new_value <= 0:
+ raise argparse.ArgumentTypeError(f"{value} is an invalid positive int value")
+ return new_value
+
+
+def get_user_agent() -> str:
+ """Get a not very random user agent."""
+ user_agent_strings = [
+ "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.10; rv:86.1) Gecko/20100101 Firefox/86.1",
+ "Mozilla/5.0 (Windows NT 6.1; WOW64; rv:86.1) Gecko/20100101 Firefox/86.1",
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.10; rv:82.1) Gecko/20100101 Firefox/82.1",
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:86.0) Gecko/20100101 Firefox/86.0",
+ "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:86.0) Gecko/20100101 Firefox/86.0",
+ "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.10; rv:83.0) Gecko/20100101 Firefox/83.0",
+ "Mozilla/5.0 (Windows NT 6.1; WOW64; rv:84.0) Gecko/20100101 Firefox/84.0",
+ ]
+
+ return random.choice(user_agent_strings) # nosec # noqa: S311
+
+
+def get_flair() -> str:
+ """Get a flair icon."""
+ current_flair = str(Session().settings.FLAIR)
+ flair = AVAILABLE_FLAIRS.get(current_flair, current_flair)
+ return flair
+
+
+def get_dtime() -> str:
+ """Get a datetime string."""
+ dtime = ""
+ if Session().settings.USE_DATETIME and get_user_timezone_or_invalid() != "INVALID":
+ dtime = datetime.now(timezone(get_user_timezone())).strftime("%Y %b %d, %H:%M")
+ return dtime
+
+
+def get_flair_and_username() -> str:
+ """Get a flair icon and username."""
+ flair = get_flair()
+ dtime = get_dtime()
+
+ if dtime:
+ dtime = f"{dtime} "
+
+ username = getattr(Session().user.profile.hub_session, "username", "")
+ if username:
+ username = f"[{username}] "
+
+ return f"{dtime}{username}{flair}"
+
+
+def is_timezone_valid(user_tz: str) -> bool:
+ """Check whether user timezone is valid.
+
+ Parameters
+ ----------
+ user_tz: str
+ Timezone to check for validity
+
+ Returns
+ -------
+ bool
+ True if timezone provided is valid
+ """
+ return user_tz in all_timezones
+
+
+def get_user_timezone() -> str:
+ """Get user timezone if it is a valid one.
+
+ Returns
+ -------
+ str
+ user timezone based on .env file
+ """
+ return Session().settings.TIMEZONE
+
+
+def get_user_timezone_or_invalid() -> str:
+ """Get user timezone if it is a valid one.
+
+ Returns
+ -------
+ str
+ user timezone based on timezone.openbb file or INVALID
+ """
+ user_tz = get_user_timezone()
+ if is_timezone_valid(user_tz):
+ return f"{user_tz}"
+ return "INVALID"
+
+
+def check_file_type_saved(valid_types: Optional[List[str]] = None):
+ """Provide valid types for the user to be able to select.
+
+ Parameters
+ ----------
+ valid_types: List[str]
+ List of valid types to export data
+
+ Returns
+ -------
+ check_filenames: Optional[List[str]]
+ Function that returns list of filenames to export data
+ """
+
+ def check_filenames(filenames: str = "") -> str:
+ """Check if filenames are valid.
+
+ Parameters
+ ----------
+ filenames: str
+ filenames to be saved separated with comma
+
+ Returns
+ ----------
+ str
+ valid filenames separated with comma
+ """
+ if not filenames or not valid_types:
+ return ""
+ valid_filenames = list()
+ for filename in filenames.split(","):
+ if filename.endswith(tuple(valid_types)):
+ valid_filenames.append(filename)
+ else:
+ Session().console.print(
+ f"[red]Filename '{filename}' provided is not valid!\nPlease use one of the following file types:"
+ f"{','.join(valid_types)}[/red]\n"
+ )
+ return ",".join(valid_filenames)
+
+ return check_filenames
+
+
+def remove_timezone_from_dataframe(df: pd.DataFrame) -> pd.DataFrame:
+ """Remove timezone information from a dataframe.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ The dataframe to remove timezone information from
+
+ Returns
+ -------
+ pd.DataFrame
+ The dataframe with timezone information removed
+ """
+
+ date_cols = []
+ index_is_date = False
+
+ # Find columns and index containing date data
+ if (
+ df.index.dtype.kind == "M"
+ and hasattr(df.index.dtype, "tz")
+ and df.index.dtype.tz is not None
+ ):
+ index_is_date = True
+
+ for col, dtype in df.dtypes.items():
+ if dtype.kind == "M" and hasattr(df.index.dtype, "tz") and dtype.tz is not None:
+ date_cols.append(col)
+
+ # Remove the timezone information
+ for col in date_cols:
+ df[col] = df[col].dt.date
+
+ if index_is_date:
+ index_name = df.index.name
+ df.index = df.index.date
+ df.index.name = index_name
+
+ return df
+
+
+def compose_export_path(func_name: str, dir_path: str) -> Path:
+ """Compose export path for data from the terminal.
+
+ Creates a path to a folder and a filename based on conditions.
+
+ Parameters
+ ----------
+ func_name : str
+ Name of the command that invokes this function
+ dir_path : str
+ Path of directory from where this function is called
+
+ Returns
+ -------
+ Path
+ Path variable containing the path of the exported file
+ """
+ now = datetime.now()
+ # Resolving all symlinks and also normalizing path.
+ resolve_path = Path(dir_path).resolve()
+ # Getting the directory names from the path. Instead of using split/replace (Windows doesn't like that)
+ # check if this is done in a main context to avoid saving with openbb_cli
+ if resolve_path.parts[-2] == "openbb_cli":
+ path_cmd = f"{resolve_path.parts[-1]}"
+ else:
+ path_cmd = f"{resolve_path.parts[-2]}_{resolve_path.parts[-1]}"
+
+ default_filename = f"{now.strftime('%Y%m%d_%H%M%S')}_{path_cmd}_{func_name}"
+
+ full_path = Path(Session().user.preferences.export_directory) / default_filename
+
+ return full_path
+
+
+def ask_file_overwrite(file_path: Path) -> Tuple[bool, bool]:
+ """Provide a prompt for overwriting existing files.
+
+ Returns two values, the first is a boolean indicating if the file exists and the
+ second is a boolean indicating if the user wants to overwrite the file.
+ """
+ if Session().settings.FILE_OVERWRITE:
+ return False, True
+ if Session().settings.TEST_MODE:
+ return False, True
+ if file_path.exists():
+ overwrite = input("\nFile already exists. Overwrite? [y/n]: ").lower()
+ if overwrite == "y":
+ file_path.unlink(missing_ok=True)
+ # File exists and user wants to overwrite
+ return True, True
+ # File exists and user does not want to overwrite
+ return True, False
+ # File does not exist
+ return False, True
+
+
+# This is a false positive on pylint and being tracked in pylint #3060
+# pylint: disable=abstract-class-instantiated
+def save_to_excel(df, saved_path, sheet_name, start_row=0, index=True, header=True):
+ """Save a Pandas DataFrame to an Excel file.
+
+ Args:
+ df: A Pandas DataFrame.
+ saved_path: The path to the Excel file to save to.
+ sheet_name: The name of the sheet to save the DataFrame to.
+ start_row: The row number to start writing the DataFrame at.
+ index: Whether to write the DataFrame index to the Excel file.
+ header: Whether to write the DataFrame header to the Excel file.
+ """
+ overwrite_options = {
+ "o": "replace",
+ "a": "overlay",
+ "n": "new",
+ }
+
+ if not saved_path.exists():
+ with pd.ExcelWriter(saved_path, engine="openpyxl") as writer:
+ df.to_excel(writer, sheet_name=sheet_name, index=index, header=header)
+
+ else:
+ with pd.ExcelFile(saved_path) as reader:
+ overwrite_option = "n"
+ if sheet_name in reader.sheet_names:
+ overwrite_option = input(
+ "\nSheet already exists. Overwrite/Append/New? [o/a/n]: "
+ ).lower()
+ start_row = 0
+ if overwrite_option == "a":
+ existing_df = pd.read_excel(saved_path, sheet_name=sheet_name)
+ start_row = existing_df.shape[0] + 1
+
+ with pd.ExcelWriter(
+ saved_path,
+ mode="a",
+ if_sheet_exists=overwrite_options[overwrite_option],
+ engine="openpyxl",
+ ) as writer:
+ df.to_excel(
+ writer,
+ sheet_name=sheet_name,
+ startrow=start_row,
+ index=index,
+ header=False if overwrite_option == "a" else header,
+ )
+
+
+# This is a false positive on pylint and being tracked in pylint #3060
+# pylint: disable=abstract-class-instantiated
+def export_data(
+ export_type: str,
+ dir_path: str,
+ func_name: str,
+ df: pd.DataFrame = pd.DataFrame(),
+ sheet_name: Optional[str] = None,
+ figure: Optional["OpenBBFigure"] = None,
+ margin: bool = True,
+) -> None:
+ """Export data to a file.
+
+ Parameters
+ ----------
+ export_type : str
+ Type of export between: csv,json,xlsx,xls
+ dir_path : str
+ Path of directory from where this function is called
+ func_name : str
+ Name of the command that invokes this function
+ df : pd.Dataframe
+ Dataframe of data to save
+ sheet_name : str
+ If provided. The name of the sheet to save in excel file
+ figure : Optional[OpenBBFigure]
+ Figure object to save as image file
+ margin : bool
+ Automatically adjust subplot parameters to give specified padding.
+ """
+
+ if export_type:
+ saved_path = compose_export_path(func_name, dir_path).resolve()
+ saved_path.parent.mkdir(parents=True, exist_ok=True)
+ for exp_type in export_type.split(","):
+ # In this scenario the path was provided, e.g. -