summaryrefslogtreecommitdiffstats
path: root/openbb_terminal/openbb_terminal/controllers/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_terminal/openbb_terminal/controllers/utils.py')
-rw-r--r--openbb_terminal/openbb_terminal/controllers/utils.py1059
1 files changed, 0 insertions, 1059 deletions
diff --git a/openbb_terminal/openbb_terminal/controllers/utils.py b/openbb_terminal/openbb_terminal/controllers/utils.py
deleted file mode 100644
index 1ef3fb4e4e4..00000000000
--- a/openbb_terminal/openbb_terminal/controllers/utils.py
+++ /dev/null
@@ -1,1059 +0,0 @@
-"""Terminal utils."""
-
-import argparse
-import os
-import random
-import re
-import shutil
-import sys
-from contextlib import contextmanager
-from datetime import (
- datetime,
-)
-from pathlib import Path
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
-
-import numpy as np
-import pandas as pd
-import requests
-from openbb import obb
-from openbb_charting.core.backend import create_backend, get_backend
-from openbb_core.app.model.charts.charting_settings import ChartingSettings
-from packaging import version
-from pytz import all_timezones, timezone
-from rich.table import Table
-
-from openbb_terminal.config.constants import AVAILABLE_FLAIRS, ENV_FILE_SETTINGS
-from openbb_terminal.session import Session
-
-if TYPE_CHECKING:
- from openbb_charting.core.openbb_figure import OpenBBFigure
-
-# pylint: disable=R1702,R0912
-
-
-# pylint: disable=too-many-statements,no-member,too-many-branches,C0302
-
-
-def remove_file(path: Path) -> bool:
- """Remove path.
-
- Parameters
- ----------
- path : Path
- The file path.
-
- Returns
- -------
- bool
- The status of the removal.
- """
- # TODO: Check why module level import leads to circular import.
- try:
- if os.path.isfile(path):
- os.remove(path)
- elif os.path.isdir(path):
- shutil.rmtree(path)
- return True
- except Exception:
- Session().console.print(
- f"\n[bold red]Failed to remove {path}"
- "\nPlease delete this manually![/bold red]"
- )
- return False
-
-
-def print_goodbye():
- """Print a goodbye message when quitting the terminal."""
- # LEGACY GOODBYE MESSAGES - You'll live in our hearts forever.
- # "An informed ape, is a strong ape."
- # "Remember that stonks only go up."
- # "Diamond hands."
- # "Apes together strong."
- # "This is our way."
- # "Keep the spacesuit ape, we haven't reached the moon yet."
- # "I am not a cat. I'm an ape."
- # "We like the terminal."
- # "...when offered a flight to the moon, nobody asks about what seat."
-
- text = """
-[param]Thank you for using the OpenBB Platform CLI and being part of this journey.[/param]
-
-We hope you'll find the new CLI as valuable as this. To stay tuned, sign up for our newsletter: [cmds]https://openbb.co/newsletter.[/]
-
-In the meantime, check out our other products:
-
-[bold]OpenBB Terminal Pro[/]: [cmds]https://openbb.co/products/pro[/cmds]
-[bold]OpenBB Platform:[/] [cmds]https://openbb.co/products/platform[/cmds]
-[bold]OpenBB Bot[/]: [cmds]https://openbb.co/products/bot[/cmds]
- """
- Session().console.print(text)
-
-
-def hide_splashscreen():
- """Hide the splashscreen on Windows bundles.
-
- `pyi_splash` is a PyInstaller "fake-package" that's used to communicate
- with the splashscreen on Windows.
- Sending the `close` signal to the splash screen is required.
- The splash screen remains open until this function is called or the Python
- program is terminated.
- """
- try:
- import pyi_splash # type: ignore # pylint: disable=import-outside-toplevel
-
- pyi_splash.update_text("Terminal Loaded!")
- pyi_splash.close()
- except Exception as e:
- Session().console.print(f"Error: Unable to hide splashscreen: {e}")
-
-
-def print_guest_block_msg():
- """Block guest users from using the terminal."""
- if Session().is_local():
- Session().console.print(
- "[info]You are currently logged as a guest.[/info]\n"
- "[info]Login to use this feature.[/info]\n\n"
- "[info]If you don't have an account, you can create one here: [/info]"
- f"[cmds]{Session().settings.HUB_URL + '/register'}\n[/cmds]"
- )
-
-
-def is_installer() -> bool:
- """Check whether or not it is a packaged version (Windows or Mac installer."""
- return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
-
-
-def bootup():
- """Bootup the terminal."""
- if sys.platform == "win32":
- # Enable VT100 Escape Sequence for WINDOWS 10 Ver. 1607
- os.system("") # nosec # noqa: S605,S607
- # Hide splashscreen loader of the packaged app
- if is_installer():
- hide_splashscreen()
-
- try:
- if os.name == "nt":
- # pylint: disable=E1101
- sys.stdin.reconfigure(encoding="utf-8") # type: ignore
- # pylint: disable=E1101
- sys.stdout.reconfigure(encoding="utf-8") # type: ignore
- except Exception as e:
- Session().console.print(e, "\n")
-
-
-def check_for_updates() -> None:
- """Check if the latest version is running.
-
- Checks github for the latest release version and compares it to cfg.VERSION.
- """
- # The commit has was commented out because the terminal was crashing due to git import for multiple users
- # ({str(git.Repo('.').head.commit)[:7]})
- try:
- r = request(
- "https://api.github.com/repos/openbb-finance/openbbterminal/releases/latest"
- )
- except Exception:
- r = None
-
- if r and r.status_code == 200:
- latest_tag_name = r.json()["tag_name"]
- latest_version = version.parse(latest_tag_name)
- current_version = version.parse(Session().settings.VERSION)
-
- if check_valid_versions(latest_version, current_version):
- if current_version == latest_version:
- Session().console.print(
- "[green]You are using the latest stable version[/green]"
- )
- else:
- Session().console.print(
- "[yellow]You are not using the latest stable version[/yellow]"
- )
- if current_version < latest_version:
- Session().console.print(
- "[yellow]Check for updates at https://my.openbb.co/app/terminal/download[/yellow]"
- )
-
- else:
- Session().console.print(
- "[yellow]You are using an unreleased version[/yellow]"
- )
-
- else:
- Session().console.print("[red]You are using an unrecognized version.[/red]")
- else:
- Session().console.print(
- "[yellow]Unable to check for updates... "
- + "Check your internet connection and try again...[/yellow]"
- )
- Session().console.print("\n")
-
-
-def check_valid_versions(
- latest_version: version.Version,
- current_version: version.Version,
-) -> bool:
- """Check if the versions are valid."""
- if (
- not latest_version
- or not current_version
- or not isinstance(latest_version, version.Version)
- or not isinstance(current_version, version.Version)
- ):
- return False
- return True
-
-
-def welcome_message():
- """Print the welcome message.
-
- Prints first welcome message, help and a notification if updates are available.
- """
- Session().console.print(
- f"\nWelcome to OpenBB Platform CLI v{Session().settings.VERSION}"
- )
-
-
-def reset(queue: Optional[List[str]] = None):
- """Reset the CLI.
-
- Allows for checking code without quitting.
- """
- Session().console.print("resetting...")
- Session().reset()
- debug = Session().settings.DEBUG_MODE
- dev = Session().settings.DEV_BACKEND
-
- try:
- # remove the hub routines
- if not Session().is_local():
- remove_file(
- Path(Session().user.preferences.export_directory, "routines", "hub")
- )
-
- # if not get_current_user().profile.remember:
- # Local.remove(HIST_FILE_PROMPT)
-
- # we clear all openbb_terminal modules from sys.modules
- for module in list(sys.modules.keys()):
- parts = module.split(".")
- if parts[0] == "openbb_terminal":
- del sys.modules[module]
-
- queue_list = ["/".join(queue) if len(queue) > 0 else ""] # type: ignore
- # pylint: disable=import-outside-toplevel
- # we run the terminal again
- if Session().is_local():
- from openbb_terminal.controllers.terminal_controller import main
-
- main(debug, dev, queue_list, module="") # type: ignore
- else:
- from openbb_terminal.controllers.terminal_controller import launch
-
- launch(queue=queue_list)
-
- except Exception as e:
- Session().console.print(f"Unfortunately, resetting wasn't possible: {e}\n")
- print_goodbye()
-
-
-@contextmanager
-def suppress_stdout():
- """Suppress the stdout."""
- with open(os.devnull, "w") as devnull:
- old_stdout = sys.stdout
- old_stderr = sys.stderr
- sys.stdout = devnull
- sys.stderr = devnull
- try:
- yield
- finally:
- sys.stdout = old_stdout
- sys.stderr = old_stderr
-
-
-def first_time_user() -> bool:
- """Check whether a user is a first time user.
-
- A first time user is someone with an empty .env file.
- If this is true, it also adds an env variable to make sure this does not run again.
-
- Returns
- -------
- bool
- Whether or not the user is a first time user
- """
- if ENV_FILE_SETTINGS.stat().st_size == 0:
- Session().settings.set_item("PREVIOUS_USE", True)
- return True
- return False
-
-
-def parse_and_split_input(an_input: str, custom_filters: List) -> List[str]:
- """Filter and split the input queue.
-
- Uses regex to filters command arguments that have forward slashes so that it doesn't
- break the execution of the command queue.
- Currently handles unix paths and sorting settings for screener menus.
-
- Parameters
- ----------
- an_input : str
- User input as string
- custom_filters : List
- Additional regular expressions to match
-
- Returns
- -------
- List[str]
- Command queue as list
- """
- # Make sure that the user can go back to the root when doing "/"
- if an_input and an_input == "/":
- an_input = "home"
-
- # everything from ` -f ` to the next known extension
- file_flag = r"(\ -f |\ --file )"
- up_to = r".*?"
- known_extensions = r"(\.(xlsx|csv|xls|tsv|json|yaml|ini|openbb|ipynb))"
- unix_path_arg_exp = f"({file_flag}{up_to}{known_extensions})"
-
- # Add custom expressions to handle edge cases of individual controllers
- custom_filter = ""
- for exp in custom_filters:
- if exp is not None:
- custom_filter += f"|{exp}"
- del exp
-
- slash_filter_exp = f"({unix_path_arg_exp}){custom_filter}"
-
- filter_input = True
- placeholders: Dict[str, str] = {}
- while filter_input:
- match = re.search(pattern=slash_filter_exp, string=an_input)
- if match is not None:
- placeholder = f"{{placeholder{len(placeholders)+1}}}"
- placeholders[placeholder] = an_input[
- match.span()[0] : match.span()[1] # noqa:E203
- ]
- an_input = (
- an_input[: match.span()[0]]
- + placeholder
- + an_input[match.span()[1] :] # noqa:E203
- )
- else:
- filter_input = False
-
- commands = an_input.split("/") if "-t" not in an_input else [an_input]
-
- for command_num, command in enumerate(commands):
- if command == commands[command_num] == commands[-1] == "":
- return list(filter(None, commands))
- matching_placeholders = [tag for tag in placeholders if tag in command]
- if len(matching_placeholders) > 0:
- for tag in matching_placeholders:
- commands[command_num] = command.replace(tag, placeholders[tag])
- return commands
-
-
-def return_colored_value(value: str):
- """Return the string value based on condition.
-
- Return it with green, yellow, red or white color based on
- whether the number is positive, negative, zero or other, respectively.
-
- Parameters
- ----------
- value: str
- string to be checked
-
- Returns
- -------
- value: str
- string with color based on value of number if it exists
- """
- values = re.findall(r"[-+]?(?:\d*\.\d+|\d+)", value)
-
- # Finds exactly 1 number in the string
- if len(values) == 1:
- if float(values[0]) > 0:
- return f"[green]{value}[/green]"
-
- if float(values[0]) < 0:
- return f"[red]{value}[/red]"
-
- if float(values[0]) == 0:
- return f"[yellow]{value}[/yellow]"
-
- return f"{value}"
-
-
-def _get_backend():
- """Get the Platform charting backend."""
- try:
- return get_backend()
- except ValueError:
- # backend might not be created yet
- charting_settings = ChartingSettings(
- system_settings=obb.system, user_settings=obb.user # type: ignore
- )
- create_backend(charting_settings)
- get_backend().start(debug=charting_settings.debug_mode)
- return get_backend()
-
-
-# pylint: disable=too-many-arguments
-def print_rich_table( # noqa: PLR0912
- df: pd.DataFrame,
- show_index: bool = False,
- title: str = "",
- index_name: str = "",
- headers: Optional[Union[List[str], pd.Index]] = None,
- floatfmt: Union[str, List[str]] = ".2f",
- show_header: bool = True,
- automatic_coloring: bool = False,
- columns_to_auto_color: Optional[List[str]] = None,
- rows_to_auto_color: Optional[List[str]] = None,
- export: bool = False,
- limit: Optional[int] = 1000,
- columns_keep_types: Optional[List[str]] = None,
- use_tabulate_df: bool = True,
-):
- """Prepare a table from df in rich.
-
- Parameters
- ----------
- df: pd.DataFrame
- Dataframe to turn into table
- show_index: bool
- Whether to include index
- title: str
- Title for table
- index_name : str
- Title for index column
- headers: List[str]
- Titles for columns
- floatfmt: Union[str, List[str]]
- Float number formatting specs as string or list of strings. Defaults to ".2f"
- show_header: bool
- Whether to show the header row.
- automatic_coloring: bool
- Automatically color a table based on positive and negative values
- columns_to_auto_color: List[str]
- Columns to automatically color
- rows_to_auto_color: List[str]
- Rows to automatically color
- export: bool
- Whether we are exporting the table to a file. If so, we don't want to print it.
- limit: Optional[int]
- Limit the number of rows to show.
- columns_keep_types: Optional[List[str]]
- Columns to keep their types, i.e. not convert to numeric
- """
- if export:
- return
-
- # Make a copy of the dataframe to avoid SettingWithCopyWarning
- df = df.copy()
-
- show_index = not isinstance(df.index, pd.RangeIndex) and show_index
- # convert non-str that are not timestamp or int into str
- # eg) praw.models.reddit.subreddit.Subreddit
- for col in df.columns:
- if columns_keep_types is not None and col in columns_keep_types:
- continue
- try:
- if not any(
- isinstance(df[col].iloc[x], pd.Timestamp)
- for x in range(min(10, len(df)))
- ):
- df[col] = pd.to_numeric(df[col], errors="ignore")
- except (ValueError, TypeError):
- df[col] = df[col].astype(str)
-
- def _get_headers(_headers: Union[List[str], pd.Index]) -> List[str]:
- """Check if headers are valid and return them."""
- output = _headers
- if isinstance(_headers, pd.Index):
- output = list(_headers)
- if len(output) != len(df.columns):
- raise ValueError("Length of headers does not match length of DataFrame.")
- return output
-
- if Session().settings.USE_INTERACTIVE_DF:
- df_outgoing = df.copy()
- # If headers are provided, use them
- if headers is not None:
- # We check if headers are valid
- df_outgoing.columns = _get_headers(headers)
-
- if show_index and index_name not in df_outgoing.columns:
- # If index name is provided, we use it
- df_outgoing.index.name = index_name or "Index"
- df_outgoing = df_outgoing.reset_index()
-
- for col in df_outgoing.columns:
- if col == "":
- df_outgoing = df_outgoing.rename(columns={col: " "})
-
- # ensure everything on the dataframe is a string
- df_outgoing = df_outgoing.applymap(str)
-
- _get_backend().send_table(
- df_table=df_outgoing,
- title=title,
- theme=Session().user.preferences.table_style,
- )
- return
-
- df = df.copy() if not limit else df.copy().iloc[:limit]
- if automatic_coloring:
- if columns_to_auto_color:
- for col in columns_to_auto_color:
- # checks whether column exists
- if col in df.columns:
- df[col] = df[col].apply(lambda x: return_colored_value(str(x)))
- if rows_to_auto_color:
- for row in rows_to_auto_color:
- # checks whether row exists
- if row in df.index:
- df.loc[row] = df.loc[row].apply(
- lambda x: return_colored_value(str(x))
- )
-
- if columns_to_auto_color is None and rows_to_auto_color is None:
- df = df.applymap(lambda x: return_colored_value(str(x)))
-
- exceeds_allowed_columns = (
- len(df.columns) > Session().settings.ALLOWED_NUMBER_OF_COLUMNS
- )
- exceeds_allowed_rows = len(df) > Session().settings.ALLOWED_NUMBER_OF_ROWS
-
- if exceeds_allowed_columns:
- original_columns = df.columns.tolist()
- trimmed_columns = df.columns.tolist()[
- : Session().settings.ALLOWED_NUMBER_OF_COLUMNS
- ]
- df = df[trimmed_columns]
- trimmed_columns = [
- col for col in original_columns if col not in trimmed_columns
- ]
-
- if exceeds_allowed_rows:
- n_rows = len(df.index)
- trimmed_rows = df.index.tolist()[: Session().settings.ALLOWED_NUMBER_OF_ROWS]
- df = df.loc[trimmed_rows]
- trimmed_rows_count = n_rows - Session().settings.ALLOWED_NUMBER_OF_ROWS
-
- if use_tabulate_df:
- table = Table(title=title, show_lines=True, show_header=show_header)
-
- if show_index:
- table.add_column(index_name)
-
- if headers is not None:
- headers = _get_headers(headers)
- for header in headers:
- table.add_column(str(header))
- else:
- for column in df.columns:
- table.add_column(str(column))
-
- if isinstance(floatfmt, list) and len(floatfmt) != len(df.columns):
- raise (
- ValueError(
- "Length of floatfmt list does not match length of DataFrame columns."
- )
- )
- if isinstance(floatfmt, str):
- floatfmt = [floatfmt for _ in range(len(df.columns))]
-
- for idx, values in zip(df.index.tolist(), df.values.tolist()):
- # remove hour/min/sec from timestamp index - Format: YYYY-MM-DD # make better
- row_idx = [str(idx)] if show_index else []
- row_idx += [
- (
- str(x)
- if not isinstance(x, float) and not isinstance(x, np.float64)
- else (
- f"{x:{floatfmt[idx]}}"
- if isinstance(floatfmt, list)
- else (
- f"{x:.2e}"
- if 0 < abs(float(x)) <= 0.0001
- else f"{x:floatfmt}"
- )
- )
- )
- for idx, x in enumerate(values)
- ]
- table.add_row(*row_idx)
- Session().console.print(table)
- else:
- Session().console.print(df.to_string(col_space=0))
-
- if exceeds_allowed_columns:
- Session().console.print(
- f"[yellow]\nAllowed number of columns exceeded ({Session().settings.ALLOWED_NUMBER_OF_COLUMNS}).\n"
- f"The following columns were removed from the output: {', '.join(trimmed_columns)}.\n[/yellow]"
- )
-
- if exceeds_allowed_rows:
- Session().console.print(
- f"[yellow]\nAllowed number of rows exceeded ({Session().settings.ALLOWED_NUMBER_OF_ROWS}).\n"
- f"{trimmed_rows_count} rows were removed from the output.\n[/yellow]"
- )
-
- if exceeds_allowed_columns or exceeds_allowed_rows:
- Session().console.print(
- "Use the `--export` flag to analyse the full output on a file."
- )
-
-
-def check_non_negative(value) -> int:
- """Argparse type to check non negative int."""
- new_value = int(value)
- if new_value < 0:
- raise argparse.ArgumentTypeError(f"{value} is negative")
- return new_value
-
-
-def check_positive(value) -> int:
- """Argparse type to check positive int."""
- new_value = int(value)
- if new_value <= 0:
- raise argparse.ArgumentTypeError(f"{value} is an invalid positive int value")
- return new_value
-
-
-def get_user_agent() -> str:
- """Get a not very random user agent."""
- user_agent_strings = [
- "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.10; rv:86.1) Gecko/20100101 Firefox/86.1",
- "Mozilla/5.0 (Windows NT 6.1; WOW64; rv:86.1) Gecko/20100101 Firefox/86.1",
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.10; rv:82.1) Gecko/20100101 Firefox/82.1",
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:86.0) Gecko/20100101 Firefox/86.0",
- "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:86.0) Gecko/20100101 Firefox/86.0",
- "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.10; rv:83.0) Gecko/20100101 Firefox/83.0",
- "Mozilla/5.0 (Windows NT 6.1; WOW64; rv:84.0) Gecko/20100101 Firefox/84.0",
- ]
-
- return random.choice(user_agent_strings) # nosec # noqa: S311
-
-
-def get_flair() -> str:
- """Get a flair icon."""
- current_flair = str(Session().settings.FLAIR)
- flair = AVAILABLE_FLAIRS.get(current_flair, current_flair)
- return flair
-
-
-def get_dtime() -> str:
- """Get a datetime string."""
- dtime = ""
- if Session().settings.USE_DATETIME and get_user_timezone_or_invalid() != "INVALID":
- dtime = datetime.now(timezone(get_user_timezone())).strftime("%Y %b %d, %H:%M")
- return dtime
-
-
-def get_flair_and_username() -> str:
- """Get a flair icon and username."""
- flair = get_flair()
- dtime = get_dtime()
-
- if dtime:
- dtime = f"{dtime} "
-
- username = getattr(Session().user.profile.hub_session, "username", "")
- if username:
- username = f"[{username}] "
-
- return f"{dtime}{username}{flair}"
-
-
-def is_timezone_valid(user_tz: str) -> bool:
- """Check whether user timezone is valid.
-
- Parameters
- ----------
- user_tz: str
- Timezone to check for validity
-
- Returns
- -------
- bool
- True if timezone provided is valid
- """
- return user_tz in all_timezones
-
-
-def get_user_timezone() -> str:
- """Get user timezone if it is a valid one.
-
- Returns
- -------
- str
- user timezone based on .env file
- """
- return Session().settings.TIMEZONE
-
-
-def get_user_timezone_or_invalid() -> str:
- """Get user timezone if it is a valid one.
-
- Returns
- -------
- str
- user timezone based on timezone.openbb file or INVALID
- """
- user_tz = get_user_timezone()
- if is_timezone_valid(user_tz):
- return f"{user_tz}"
- return "INVALID"
-
-
-def check_file_type_saved(valid_types: Optional[List[str]] = None):
- """Provide valid types for the user to be able to select.
-
- Parameters
- ----------
- valid_types: List[str]
- List of valid types to export data
-
- Returns
- -------
- check_filenames: Optional[List[str]]
- Function that returns list of filenames to export data
- """
-
- def check_filenames(filenames: str = "") -> str:
- """Check if filenames are valid.
-
- Parameters
- ----------
- filenames: str
- filenames to be saved separated with comma
-
- Returns
- ----------
- str
- valid filenames separated with comma
- """
- if not filenames or not valid_types:
- return ""
- valid_filenames = list()
- for filename in filenames.split(","):
- if filename.endswith(tuple(valid_types)):
- valid_filenames.append(filename)
- else:
- Session().console.print(
- f"[red]Filename '{filename}' provided is not valid!\nPlease use one of the following file types:"
- f"{','.join(valid_types)}[/red]\n"
- )
- return ",".join(valid_filenames)
-
- return check_filenames
-
-
-def remove_timezone_from_dataframe(df: pd.DataFrame) -> pd.DataFrame:
- """Remove timezone information from a dataframe.
-
- Parameters
- ----------
- df : pd.DataFrame
- The dataframe to remove timezone information from
-
- Returns
- -------
- pd.DataFrame
- The dataframe with timezone information removed
- """
-
- date_cols = []
- index_is_date = False
-
- # Find columns and index containing date data
- if (
- df.index.dtype.kind == "M"
- and hasattr(df.index.dtype, "tz")
- and df.index.dtype.tz is not None
- ):
- index_is_date = True
-
- for col, dtype in df.dtypes.items():
- if dtype.kind == "M" and hasattr(df.index.dtype, "tz") and dtype.tz is not None:
- date_cols.append(col)
-
- # Remove the timezone information
- for col in date_cols:
- df[col] = df[col].dt.date
-
- if index_is_date:
- index_name = df.index.name
- df.index = df.index.date
- df.index.name = index_name
-
- return df
-
-
-def compose_export_path(func_name: str, dir_path: str) -> Path:
- """Compose export path for data from the terminal.
-
- Creates a path to a folder and a filename based on conditions.
-
- Parameters
- ----------
- func_name : str
- Name of the command that invokes this function
- dir_path : str
- Path of directory from where this function is called
-
- Returns
- -------
- Path
- Path variable containing the path of the exported file
- """
- now = datetime.now()
- # Resolving all symlinks and also normalizing path.
- resolve_path = Path(dir_path).resolve()
- # Getting the directory names from the path. Instead of using split/replace (Windows doesn't like that)
- # check if this is done in a main context to avoid saving with openbb_terminal
- if resolve_path.parts[-2] == "openbb_terminal":
- path_cmd = f"{resolve_path.parts[-1]}"
- else:
- path_cmd = f"{resolve_path.parts[-2]}_{resolve_path.parts[-1]}"
-
- default_filename = f"{now.strftime('%Y%m%d_%H%M%S')}_{path_cmd}_{func_name}"
-
- full_path = Path(Session().user.preferences.export_directory) / default_filename
-
- return full_path
-
-
-def ask_file_overwrite(file_path: Path) -> Tuple[bool, bool]:
- """Provide a prompt for overwriting existing files.
-
- Returns two values, the first is a boolean indicating if the file exists and the
- second is a boolean indicating if the user wants to overwrite the file.
- """
- if Session().settings.FILE_OVERWRITE:
- return False, True
- if Session().settings.TEST_MODE:
- return False, True
- if file_path.exists():
- overwrite = input("\nFile already exists. Overwrite? [y/n]: ").lower()
- if overwrite == "y":
- file_path.unlink(missing_ok=True)
- # File exists and user wants to overwrite
- return True, True
- # File exists and user does not want to overwrite
- return True, False
- # File does not exist
- return False, True
-
-
-# This is a false positive on pylint and being tracke