diff options
Diffstat (limited to 'cli/openbb_cli/controllers/base_platform_controller.py')
-rw-r--r-- | cli/openbb_cli/controllers/base_platform_controller.py | 78 |
1 files changed, 59 insertions, 19 deletions
diff --git a/cli/openbb_cli/controllers/base_platform_controller.py b/cli/openbb_cli/controllers/base_platform_controller.py index 658e73b8a9f..a3cd0d1527c 100644 --- a/cli/openbb_cli/controllers/base_platform_controller.py +++ b/cli/openbb_cli/controllers/base_platform_controller.py @@ -75,20 +75,31 @@ class PlatformController(BaseController): for _, trl in self.translators.items(): for action in trl._parser._actions: # pylint: disable=protected-access if action.dest == "data": + # Generate choices by combining indexed and key-based choices action.choices = [ "OBB" + str(i) for i in range(len(session.obbject_registry.obbjects)) + ] + [ + obbject.extra["register_key"] + for obbject in session.obbject_registry.obbjects + if "register_key" in obbject.extra ] + action.type = str action.nargs = None def _intersect_data_processing_commands(self, ns_parser): """Intersect data processing commands and change the obbject id into an actual obbject.""" if hasattr(ns_parser, "data"): - ns_parser.data = int(ns_parser.data.replace("OBB", "")) - if ns_parser.data in range(len(session.obbject_registry.obbjects)): + if "OBB" in ns_parser.data: + ns_parser.data = int(ns_parser.data.replace("OBB", "")) + + if (ns_parser.data in range(len(session.obbject_registry.obbjects))) or ( + ns_parser.data in session.obbject_registry.obbject_keys + ): obbject = session.obbject_registry.get(ns_parser.data) - setattr(ns_parser, "data", obbject.results) + if obbject and isinstance(obbject, OBBject): + setattr(ns_parser, "data", obbject.results) return ns_parser @@ -152,6 +163,11 @@ class PlatformController(BaseController): try: ns_parser = self._intersect_data_processing_commands(ns_parser) + store_obbject = ( + hasattr(ns_parser, "register_obbject") + and ns_parser.register_obbject + ) + obbject = translator.execute_func(parsed_args=ns_parser) df: pd.DataFrame = pd.DataFrame() fig: Optional[OpenBBFigure] = None @@ -159,7 +175,11 @@ class PlatformController(BaseController): if obbject: if isinstance(obbject, OBBject): - if session.max_obbjects_exceeded() and obbject.results: + if ( + session.max_obbjects_exceeded() + and obbject.results + and store_obbject + ): session.obbject_registry.remove() session.console.print( "[yellow]Maximum number of OBBjects reached. The oldest entry was removed.[yellow]" @@ -167,25 +187,45 @@ class PlatformController(BaseController): # use the obbject to store the command so we can display it later on results obbject.extra["command"] = f"{title} {' '.join(other_args)}" - - register_result = session.obbject_registry.register(obbject) - - # we need to force to re-link so that the new obbject - # is immediately available for data processing commands - self._link_obbject_to_data_processing_commands() - # also update the completer - self.update_completer(self.choices_default) - + # if there is a registry key in the parser, store to the obbject if ( - session.settings.SHOW_MSG_OBBJECT_REGISTRY - and register_result + hasattr(ns_parser, "register_key") + and ns_parser.register_key ): - session.console.print( - "Added `OBBject` to cached results." + if ( + ns_parser.register_key + not in session.obbject_registry.obbject_keys + ): + obbject.extra["register_key"] = str( + ns_parser.register_key + ) + else: + session.console.print( + f"[yellow]Key `{ns_parser.register_key}` already exists in the registry." + "The `OBBject` was kept without the key.[/yellow]" + ) + + if store_obbject: + # store the obbject in the registry + register_result = session.obbject_registry.register( + obbject ) - # making the dataframe available - # either for printing or exporting + # we need to force to re-link so that the new obbject + # is immediately available for data processing commands + self._link_obbject_to_data_processing_commands() + # also update the completer + self.update_completer(self.choices_default) + + if ( + session.settings.SHOW_MSG_OBBJECT_REGISTRY + and register_result + ): + session.console.print( + "Added `OBBject` to cached results." + ) + + # making the dataframe available either for printing or exporting df = obbject.to_dataframe() export = hasattr(ns_parser, "export") and ns_parser.export |