diff options
author | Diogo Sousa <montezdesousa@gmail.com> | 2024-06-06 21:11:49 +0100 |
---|---|---|
committer | Diogo Sousa <montezdesousa@gmail.com> | 2024-06-06 21:11:49 +0100 |
commit | a839c228cc133ee6a744b27e5df5c4e2a5f33159 (patch) | |
tree | 76c6d9f359fdc58ab5ab5fd3a7a9de7d642a4618 | |
parent | a52b24c16b3c59048ad94574cc2ff8d0b0784fce (diff) |
Generate settings menu from pydantic model
-rw-r--r-- | cli/openbb_cli/argparse_translator/utils.py | 2 | ||||
-rw-r--r-- | cli/openbb_cli/controllers/base_platform_controller.py | 3 | ||||
-rw-r--r-- | cli/openbb_cli/controllers/cli_controller.py | 3 | ||||
-rw-r--r-- | cli/openbb_cli/controllers/settings_controller.py | 442 | ||||
-rw-r--r-- | cli/openbb_cli/controllers/utils.py | 90 | ||||
-rw-r--r-- | cli/openbb_cli/models/settings.py | 147 |
6 files changed, 266 insertions, 421 deletions
diff --git a/cli/openbb_cli/argparse_translator/utils.py b/cli/openbb_cli/argparse_translator/utils.py index d9ab81c5327..2e09f166667 100644 --- a/cli/openbb_cli/argparse_translator/utils.py +++ b/cli/openbb_cli/argparse_translator/utils.py @@ -64,7 +64,7 @@ def get_argument_optional_choices(parser: ArgumentParser, argument_name: str) -> or action.dest == argument_name and hasattr(action, "optional_choices") ): - return action.optional_choices + return action.optional_choices # type: ignore[attr-defined] return False diff --git a/cli/openbb_cli/controllers/base_platform_controller.py b/cli/openbb_cli/controllers/base_platform_controller.py index acbec34d165..33cc4700849 100644 --- a/cli/openbb_cli/controllers/base_platform_controller.py +++ b/cli/openbb_cli/controllers/base_platform_controller.py @@ -6,6 +6,7 @@ from types import MethodType from typing import Dict, List, Optional import pandas as pd +from openbb import obb from openbb_charting.core.openbb_figure import OpenBBFigure from openbb_cli.argparse_translator.argparse_class_processor import ( ArgparseClassProcessor, @@ -16,8 +17,6 @@ from openbb_cli.controllers.utils import export_data, print_rich_table from openbb_cli.session import Session from openbb_core.app.model.obbject import OBBject -from openbb import obb - session = Session() diff --git a/cli/openbb_cli/controllers/cli_controller.py b/cli/openbb_cli/controllers/cli_controller.py index 31d32499659..9af3c74aa12 100644 --- a/cli/openbb_cli/controllers/cli_controller.py +++ b/cli/openbb_cli/controllers/cli_controller.py @@ -18,6 +18,7 @@ from typing import Any, Dict, List, Optional import pandas as pd import requests +from openbb import obb from openbb_cli.config import constants from openbb_cli.config.constants import ( ASSETS_DIRECTORY, @@ -47,8 +48,6 @@ from prompt_toolkit.formatted_text import HTML from prompt_toolkit.styles import Style from pydantic import BaseModel -from openbb import obb - PLATFORM_ROUTERS = { d: "menu" if not isinstance(getattr(obb, d), BaseModel) else "command" for d in dir(obb) diff --git a/cli/openbb_cli/controllers/settings_controller.py b/cli/openbb_cli/controllers/settings_controller.py index f93501a5063..4dacfb76d1e 100644 --- a/cli/openbb_cli/controllers/settings_controller.py +++ b/cli/openbb_cli/controllers/settings_controller.py @@ -1,15 +1,12 @@ """Settings Controller Module.""" import argparse -from typing import List, Optional +from functools import partial, update_wrapper +from types import MethodType +from typing import List, Literal, Optional, get_origin -from openbb_cli.config.constants import AVAILABLE_FLAIRS from openbb_cli.config.menu_text import MenuText - -# pylint: disable=too-many-lines,no-member,too-many-public-methods,C0302 -# pylint:disable=import-outside-toplevel from openbb_cli.controllers.base_controller import BaseController -from openbb_cli.controllers.utils import all_timezones, is_timezone_valid from openbb_cli.session import Session session = Session() @@ -18,356 +15,119 @@ session = Session() class SettingsController(BaseController): """Settings Controller class.""" - CHOICES_COMMANDS: List[str] = [ - "interactive", - "cls", - "promptkit", - "exithelp", - "rcontext", - "richpanel", - "tbhint", - "overwrite", - "version", - "console_style", - "flair", - "timezone", - "n_rows", - "n_cols", - "obbject_msg", - "obbject_res", - "obbject_display", - ] + _COMMANDS = { + v.json_schema_extra.get("command"): { + "command": (v.json_schema_extra or {}).get("command"), + "group": (v.json_schema_extra or {}).get("group"), + "description": v.description, + "annotation": v.annotation, + "name": k, + } + for k, v in sorted( + session.settings.model_fields.items(), + key=lambda item: (item[1].json_schema_extra or {}).get("command", ""), + ) + if v.json_schema_extra + } + CHOICES_COMMANDS: List[str] = list(_COMMANDS.keys()) PATH = "/settings/" CHOICES_GENERATION = True def __init__(self, queue: Optional[List[str]] = None): """Initialize the Constructor.""" super().__init__(queue) - + for cmd, field in self._COMMANDS.items(): + group = field.get("group") + if group == "feature-flags": + self._generate_command(cmd, field, "toggle") + elif group == "preferences": + self._generate_command(cmd, field, "set") self.update_completer(self.choices_default) def print_help(self): """Print help.""" - settings = session.settings - mt = MenuText("settings/") - mt.add_info("Feature flags") - mt.add_setting( - "interactive", - settings.USE_INTERACTIVE_DF, - description="open dataframes in interactive window", - ) - mt.add_setting( - "cls", - settings.USE_CLEAR_AFTER_CMD, - description="clear console after each command", - ) - mt.add_setting( - "promptkit", - settings.USE_PROMPT_TOOLKIT, - description="enable prompt toolkit (autocomplete and history)", - ) - mt.add_setting( - "exithelp", - settings.ENABLE_EXIT_AUTO_HELP, - description="automatically print help when quitting menu", - ) - mt.add_setting( - "rcontext", - settings.REMEMBER_CONTEXTS, - description="remember contexts between menus", - ) - mt.add_setting( - "richpanel", - settings.ENABLE_RICH_PANEL, - description="colorful rich CLI panel", - ) - mt.add_setting( - "tbhint", - settings.TOOLBAR_HINT, - description="displays usage hints in the bottom toolbar", - ) - mt.add_setting( - "overwrite", - settings.FILE_OVERWRITE, - description="whether to overwrite Excel files if they already exists", - ) - mt.add_setting( - "version", - settings.SHOW_VERSION, - description="whether to show the version in the bottom right corner", - ) - mt.add_setting( - "obbject_msg", - settings.SHOW_MSG_OBBJECT_REGISTRY, - description="show obbject registry message after a new result is added", - ) + mt.add_info("Feature Flags") + for k, f in self._COMMANDS.items(): + if f.get("group") == "feature-flags": + mt.add_setting( + name=k, + status=getattr(session.settings, f["name"]), + description=f["description"], + ) mt.add_raw("\n") mt.add_info("Preferences") - mt.add_cmd("console_style", description="apply a custom rich style to the CLI") - mt.add_cmd("flair", description="choose flair icon") - mt.add_cmd("timezone", description="pick timezone") - mt.add_cmd( - "n_rows", description="number of rows to show on non interactive tables" - ) - mt.add_cmd( - "n_cols", description="number of columns to show on non interactive tables" - ) - mt.add_cmd( - "obbject_res", - description="define the maximum number of obbjects allowed in the registry", - ) - mt.add_cmd( - "obbject_display", - description="define the maximum number of cached results to display on the help menu", - ) - - session.console.print(text=mt.menu_text, menu="Settings") - - def call_overwrite(self, _): - """Process overwrite command.""" - session.settings.set_item("FILE_OVERWRITE", not session.settings.FILE_OVERWRITE) - - def call_version(self, _): - """Process version command.""" - session.settings.SHOW_VERSION = not session.settings.SHOW_VERSION - - def call_interactive(self, _): - """Process interactive command.""" - session.settings.set_item( - "USE_INTERACTIVE_DF", not session.settings.USE_INTERACTIVE_DF - ) - - def call_cls(self, _): - """Process cls command.""" - session.settings.set_item( - "USE_CLEAR_AFTER_CMD", not session.settings.USE_CLEAR_AFTER_CMD - ) - - def call_promptkit(self, _): - """Process promptkit command.""" - session.settings.set_item( - "USE_PROMPT_TOOLKIT", not session.settings.USE_PROMPT_TOOLKIT - ) - - def call_exithelp(self, _): - """Process exithelp command.""" - session.settings.set_item( - "ENABLE_EXIT_AUTO_HELP", not session.settings.ENABLE_EXIT_AUTO_HELP - ) - - def call_rcontext(self, _): - """Process rcontext command.""" - session.settings.set_item( - "REMEMBER_CONTEXTS", not session.settings.REMEMBER_CONTEXTS - ) - - def call_dt(self, _): - """Process dt command.""" - session.settings.set_item("USE_DATETIME", not session.settings.USE_DATETIME) - - def call_richpanel(self, _): - """Process richpanel command.""" - session.settings.set_item( - "ENABLE_RICH_PANEL", not session.settings.ENABLE_RICH_PANEL - ) - - def call_tbhint(self, _): - """Process tbhint command.""" - if session.settings.TOOLBAR_HINT: - session.console.print("Will take effect when running CLI again.") - session.settings.set_item("TOOLBAR_HINT", not session.settings.TOOLBAR_HINT) - - def call_obbject_msg(self, _): - """Process obbject_msg command.""" - session.settings.set_item( - "SHOW_MSG_OBBJECT_REGISTRY", - not session.settings.SHOW_MSG_OBBJECT_REGISTRY, - ) - - def call_console_style(self, other_args: List[str]) -> None: - """Process cosole_style command.""" - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - prog="console_style", - description="Change your custom console style.", - add_help=False, - ) - parser.add_argument( - "-s", - "--style", - dest="style", - action="store", - required=False, - choices=session.style.available_styles, - ) - ns_parser = self.parse_simple_args(parser, other_args) - - if ns_parser and ns_parser.style: - session.style.apply(ns_parser.style) - session.settings.set_item("RICH_STYLE", ns_parser.style) - elif not other_args: - session.console.print( - f"Current console style: {session.settings.RICH_STYLE}" - ) - - def call_flair(self, other_args: List[str]) -> None: - """Process flair command.""" - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - prog="flair", - description="Change your custom flair.", - add_help=False, - ) - parser.add_argument( - "-f", - "--flair", - dest="flair", - action="store", - required=False, - choices=list(AVAILABLE_FLAIRS.keys()), - ) - ns_parser = self.parse_simple_args(parser, other_args) - - if ns_parser and ns_parser.flair: - session.settings.set_item("FLAIR", ns_parser.flair) - elif not other_args: - session.console.print(f"Current flair: {session.settings.FLAIR}") - - def call_timezone(self, other_args: List[str]) -> None: - """Process timezone command.""" - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - prog="timezone", - description="Change your custom timezone.", - add_help=False, - ) - parser.add_argument( - "-t", - "--timezone", - dest="timezone", - action="store", - required=False, - type=str, - choices=all_timezones, - ) - ns_parser = self.parse_simple_args(parser, other_args) - - if ns_parser and ns_parser.timezone: - if is_timezone_valid(ns_parser.timezone): - session.settings.set_item("TIMEZONE", ns_parser.timezone) - else: - session.console.print( - "Invalid timezone. Please enter a valid timezone." + for k, f in self._COMMANDS.items(): + if f.get("group") == "preferences": + mt.add_cmd( + name=k, + description=f["description"], ) - session.console.print( - f"Available timezones are: {', '.join(all_timezones)}" - ) - elif not other_args: - session.console.print(f"Current timezone: {session.settings.TIMEZONE}") - - def call_n_rows(self, other_args: List[str]) -> None: - """Process n_rows command.""" - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - prog="n_rows", - description="Number of rows to show (when not using interactive tables).", - add_help=False, - ) - parser.add_argument( - "-r", - "--rows", - dest="rows", - action="store", - required=False, - type=int, - ) - ns_parser = self.parse_simple_args(parser, other_args) - - if ns_parser and ns_parser.rows: - session.settings.set_item("ALLOWED_NUMBER_OF_ROWS", ns_parser.rows) - - elif not other_args: - session.console.print( - f"Current number of rows: {session.settings.ALLOWED_NUMBER_OF_ROWS}" - ) - - def call_n_cols(self, other_args: List[str]) -> None: - """Process n_cols command.""" - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - prog="n_cols", - description="Number of columns to show (when not using interactive tables).", - add_help=False, - ) - parser.add_argument( - "-c", - "--columns", - dest="columns", - action="store", - required=False, - type=int, - ) - ns_parser = self.parse_simple_args(parser, other_args) - - if ns_parser and ns_parser.columns: - session.settings.set_item("ALLOWED_NUMBER_OF_COLUMNS", ns_parser.columns) + session.console.print(text=mt.menu_text, menu="Settings") - elif not other_args: - session.console.print( - f"Current number of columns: {session.settings.ALLOWED_NUMBER_OF_COLUMNS}" + def _generate_command( + self, name: str, field: dict, action_type: Literal["toggle", "set"] + ): + """Generate command call.""" + + def _toggle(self, other_args: List[str], field=field) -> None: + """Toggle setting value.""" + name = field["name"] + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + prog=field["command"], + description=field["description"], + add_help=False, ) - - def call_obbject_res(self, other_args: List[str]): - """Process obbject_res command.""" - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - prog="obbject_res", - description="Maximum allowed number of results to keep in the OBBject Registry.", - add_help=False, - ) - parser.add_argument( - "-n", - "--number", - dest="number", - action="store", - required=False, - type=int, - ) - ns_parser = self.parse_simple_args(parser, other_args) - - if ns_parser and ns_parser.number: - session.settings.set_item("N_TO_KEEP_OBBJECT_REGISTRY", ns_parser.number) - - elif not other_args: - session.console.print( - f"Current maximum allowed number of results to keep in the OBBject registry:" - f" {session.settings.N_TO_KEEP_OBBJECT_REGISTRY}" + ns_parser = self.parse_simple_args(parser, other_args) + if ns_parser: + session.settings.set_item(name, not getattr(session.settings, name)) + + def _set(self, other_args: List[str], field=field) -> None: + """Set preference value.""" + name = field["name"] + annotation = field["annotation"] + command = field["command"] + choices = None + if get_origin(annotation) is Literal: + choices = annotation.__args__ + elif command == "console_style": + choices = session.style.available_styles + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + prog=command, + description=field["description"], + add_help=False, ) - - def call_obbject_display(self, other_args: List[str]): - """Process obbject_display command.""" - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - prog="obbject_display", - description="Number of results to display from the OBBject Registry.", - add_help=False, - ) - parser.add_argument( - "-n", - "--number", - dest="number", - action="store", - required=False, - type=int, - ) - ns_parser = self.parse_simple_args(parser, other_args) - - if ns_parser and ns_parser.number: - session.settings.set_item("N_TO_DISPLAY_OBBJECT_REGISTRY", ns_parser.number) - - elif not other_args: - session.console.print( - f"Current number of results to display from the OBBject registry:" - f" {session.settings.N_TO_DISPLAY_OBBJECT_REGISTRY}" + parser.add_argument( + "-v", + "--value", + dest="value", + action="store", + required=False, + type=None if get_origin(annotation) is Literal else annotation, # type: ignore[arg-type] + choices=choices, ) + ns_parser = self.parse_simple_args(parser, other_args) + if ns_parser: + if ns_parser.value: + if command == "console_style": + session.style.apply(ns_parser.value) + session.settings.set_item(name, ns_parser.value) + elif not other_args: + session.console.print( + f"Current value: {getattr(session.settings, name)}" + ) + + action = None + if action_type == "toggle": + action = _toggle + elif action_type == "set": + action = _set + else: + raise ValueError(f"Action type '{action_type}' not allowed.") + + bound_method = update_wrapper( + partial(MethodType(action, self), field=field), action + ) + setattr(self, f"call_{name}", bound_method) diff --git a/cli/openbb_cli/controllers/utils.py b/cli/openbb_cli/controllers/utils.py index a77c93eb683..b0f7b220f0d 100644 --- a/cli/openbb_cli/controllers/utils.py +++ b/cli/openbb_cli/controllers/utils.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd import requests +from openbb import obb from openbb_charting.core.backend import create_backend, get_backend from openbb_cli.config.constants import AVAILABLE_FLAIRS, ENV_FILE_SETTINGS from openbb_cli.session import Session @@ -23,8 +24,6 @@ from openbb_core.app.model.charts.charting_settings import ChartingSettings from pytz import all_timezones, timezone from rich.table import Table -from openbb import obb - if TYPE_CHECKING: from openbb_charting.core.openbb_figure import OpenBBFigure @@ -33,6 +32,8 @@ if TYPE_CHECKING: # pylint: disable=too-many-statements,no-member,too-many-branches,C0302 +session = Session() + def remove_file(path: Path) -> bool: """Remove path. @@ -55,7 +56,7 @@ def remove_file(path: Path) -> bool: shutil.rmtree(path) return True except Exception: - Session().console.print( + session.console.print( f"\n[bold red]Failed to remove {path}" "\nPlease delete this manually![/bold red]" ) @@ -88,17 +89,17 @@ Please feel free to check out our other products: [bold]OpenBB Platform:[/] [cmds]https://openbb.co/products/platform[/cmds] [bold]OpenBB Bot[/]: [cmds]https://openbb.co/products/bot[/cmds] """ - Session().console.print(text) + session.console.print(text) def print_guest_block_msg(): """Block guest users from using the cli.""" - if Session().is_local(): - Session().console.print( + if session.is_local(): + session.console.print( "[info]You are currently logged as a guest.[/info]\n" "[info]Login to use this feature.[/info]\n\n" "[info]If you don't have an account, you can create one here: [/info]" - f"[cmds]{Session().settings.HUB_URL + '/register'}\n[/cmds]" + f"[cmds]{session.settings.HUB_URL + '/register'}\n[/cmds]" ) @@ -115,7 +116,7 @@ def bootup(): # pylint: disable=E1101 sys.stdout.reconfigure(encoding="utf-8") # type: ignore except Exception as e: - Session().console.print(e, "\n") + session.console.print(e, "\n") def welcome_message(): @@ -123,8 +124,8 @@ def welcome_message(): Prints first welcome message, help and a notification if updates are available. """ - Session().console.print( - f"\nWelcome to OpenBB Platform CLI v{Session().settings.VERSION}" + session.console.print( + f"\nWelcome to OpenBB Platform CLI v{session.settings.VERSION}" ) @@ -133,16 +134,15 @@ def reset(queue: Optional[List[str]] = None): Allows for checking code without quitting. """ - Session().console.print("resetting...") - Session().reset() - debug = Session().settings.DEBUG_MODE - dev = Session().settings.DEV_BACKEND + session.console.print("resetting...") + debug = session.settings.DEBUG_MODE + dev = session.settings.DEV_BACKEND try: # remove the hub routines - if not Session().is_local(): + if not session.is_local(): remove_file( - Path(Session().user.preferences.export_directory, "routines", "hub") + Path(session.user.preferences.export_directory, "routines", "hub") ) # if not get_current_user().profile.remember: @@ -157,7 +157,7 @@ def reset(queue: Optional[List[str]] = None): queue_list = ["/".join(queue) if len(queue) > 0 else ""] # type: ignore # pylint: disable=import-outside-toplevel # we run the cli again - if Session().is_local(): + if session.is_local(): from openbb_cli.controllers.cli_controller import main main(debug, dev, queue_list, module="") # type: ignore @@ -167,7 +167,7 @@ def reset(queue: Optional[List[str]] = None): launch(queue=queue_list) except Exception as e: - Session().console.print(f"Unfortunately, resetting wasn't possible: {e}\n") + session.console.print(f"Unfortunately, resetting wasn't possible: {e}\n") print_goodbye() @@ -198,7 +198,7 @@ def first_time_user() -> bool: Whether or not the user is a first time user """ if ENV_FILE_SETTINGS.stat().st_size == 0: - Session().settings.set_item("PREVIOUS_USE", True) + session.settings.set_item("PREVIOUS_USE", True) return True return False @@ -367,8 +367,8 @@ def print_rich_table( # noqa: PLR0912 if export: return - MAX_COLS = Session().settings.ALLOWED_NUMBER_OF_COLUMNS - MAX_ROWS = Session().settings.ALLOWED_NUMBER_OF_ROWS + MAX_COLS = session.settings.ALLOWED_NUMBER_OF_COLUMNS + MAX_ROWS = session.settings.ALLOWED_NUMBER_OF_ROWS # Make a copy of the dataframe to avoid SettingWithCopyWarning df = df.copy() @@ -397,7 +397,7 @@ def print_rich_table( # noqa: PLR0912 raise ValueError("Length of headers does not match length of DataFrame.") return output - if Session().settings.USE_INTERACTIVE_DF: + if session.settings.USE_INTERACTIVE_DF: df_outgoing = df.copy() # If headers are provided, use them if headers is not None: @@ -419,7 +419,7 @@ def print_rich_table( # noqa: PLR0912 _get_backend().send_table( df_table=df_outgoing, title=title, - theme=Session().user.preferences.table_style, + theme=session.user.preferences.table_style, ) return @@ -501,24 +501,24 @@ def print_rich_table( # noqa: PLR0912 for idx, x in enumerate(values) ] table.add_row(*row_idx) - Session().console.print(table) + session.console.print(table) else: - Session().console.print(df.to_string(col_space=0)) + session.console.print(df.to_string(col_space=0)) if exceeds_allowed_columns: - Session().console.print( - f"[yellow]\nAllowed number of columns exceeded ({Session().settings.ALLOWED_NUMBER_OF_COLUMNS}).\n" + session.console.print( + f"[yellow]\nAllowed number of columns exceeded ({session.settings.ALLOWED_NUMBER_OF_COLUMNS}).\n" f"The following columns were removed from the output: {', '.join(trimmed_columns)}.\n[/yellow]" ) if exceeds_allowed_rows: - Session().console.print( - f"[yellow]\nAllowed number of rows exceeded ({Session().settings.ALLOWED_NUMBER_OF_ROWS}).\n" + session.console.print( + f"[yellow]\nAllowed number of rows exceeded ({session.settings.ALLOWED_NUMBER_OF_ROWS}).\n" f"{trimmed_rows_count} rows were removed from the output.\n[/yellow]" ) if exceeds_allowed_columns or exceeds_allowed_rows: - Session().console.print( + session.console.print( "Use the `--export` flag to analyse the full output on a file." ) @@ -556,7 +556,7 @@ def get_user_agent() -> str: def get_flair() -> str: """Get a flair icon.""" - current_flair = str(Session().settings.FLAIR) + current_flair = str(session.settings.FLAIR) flair = AVAILABLE_FLAIRS.get(current_flair, current_flair) return flair @@ -564,7 +564,7 @@ def get_flair() -> str: def get_dtime() -> str: """Get a datetime string.""" dtime = "" - if Session().settings.USE_DATETIME and get_user_timezone_or_invalid() != "INVALID": + if session.settings.USE_DATETIME and get_user_timezone_or_invalid() != "INVALID": dtime = datetime.now(timezone(get_user_timezone())).strftime("%Y %b %d, %H:%M") return dtime @@ -572,12 +572,10 @@ def get_dtime() -> str: def get_flair_and_username() -> str: """Get a flair icon and username.""" |