diff options
-rw-r--r-- | cli/openbb_cli/argparse_translator/argparse_argument.py | 63 | ||||
-rw-r--r-- | cli/openbb_cli/argparse_translator/argparse_class_processor.py | 14 | ||||
-rw-r--r-- | cli/openbb_cli/argparse_translator/argparse_translator.py | 322 | ||||
-rw-r--r-- | cli/openbb_cli/argparse_translator/reference_processor.py | 137 | ||||
-rw-r--r-- | cli/openbb_cli/argparse_translator/utils.py | 71 | ||||
-rw-r--r-- | cli/openbb_cli/controllers/base_controller.py | 11 | ||||
-rw-r--r-- | cli/tests/test_argparse_translator.py | 14 |
7 files changed, 366 insertions, 266 deletions
diff --git a/cli/openbb_cli/argparse_translator/argparse_argument.py b/cli/openbb_cli/argparse_translator/argparse_argument.py new file mode 100644 index 00000000000..2a40d9f1e4c --- /dev/null +++ b/cli/openbb_cli/argparse_translator/argparse_argument.py @@ -0,0 +1,63 @@ +"""Pydantic models for argparse arguments and argument groups.""" + +from typing import ( + Any, + List, + Literal, + Optional, + Tuple, +) + +from pydantic import BaseModel, model_validator + +SEP = "__" + + +class ArgparseArgumentModel(BaseModel): + """Pydantic model for an argparse argument.""" + + name: str + type: Any + dest: str + default: Any + required: bool + action: Literal["store_true", "store"] + help: Optional[str] + nargs: Optional[Literal["+"]] + choices: Optional[Tuple] + + @model_validator(mode="after") # type: ignore + @classmethod + def validate_action(cls, values: "ArgparseArgumentModel"): + """Validate the action based on the type.""" + if values.type is bool and values.action != "store_true": + raise ValueError('If type is bool, action must be "store_true"') + return values + + @model_validator(mode="after") # type: ignore + @classmethod + def remove_props_on_store_true(cls, values: "ArgparseArgumentModel"): + """Remove type, nargs, and choices if action is store_true.""" + if values.action == "store_true": + values.type = None + values.nargs = None + values.choices = None + return values + + # override + def model_dump(self, **kwargs): + """Override the model_dump method to remove empty choices.""" + res = super().model_dump(**kwargs) + + # Check if choices is present and if it's an empty tuple remove it + if "choices" in res and not res["choices"]: + del res["choices"] + + return res + + +class ArgparseArgumentGroupModel(BaseModel): + """Pydantic model for a custom argument group.""" + + name: str + arguments: List[ArgparseArgumentModel] diff --git a/cli/openbb_cli/argparse_translator/argparse_class_processor.py b/cli/openbb_cli/argparse_translator/argparse_class_processor.py index d7ddcdafcd2..af7b9e56ab4 100644 --- a/cli/openbb_cli/argparse_translator/argparse_class_processor.py +++ b/cli/openbb_cli/argparse_translator/argparse_class_processor.py @@ -1,19 +1,19 @@ +"""Module for the ArgparseClassProcessor class.""" + import inspect from typing import Any, Dict, Optional, Type # TODO: this needs to be done differently from openbb_core.app.static.container import Container -from openbb_cli.argparse_translator.argparse_translator import ( - ArgparseTranslator, - ReferenceToCustomArgumentsProcessor, +from openbb_cli.argparse_translator.argparse_translator import ArgparseTranslator +from openbb_cli.argparse_translator.reference_processor import ( + ReferenceToArgumentsProcessor, ) class ArgparseClassProcessor: - """ - Process a target class to create ArgparseTranslators for its methods. - """ + """Process a target class to create ArgparseTranslators for its methods.""" # reference variable used to create custom groups for the ArgpaseTranslators _reference: Dict[str, Any] = {} @@ -77,7 +77,7 @@ class ArgparseClassProcessor: reference = {route: cls._reference[route]} if route in cls._reference else {} if not reference: return {} - rp = ReferenceToCustomArgumentsProcessor(reference) + rp = ReferenceToArgumentsProcessor(reference) return rp.custom_groups.get(route, {}) # type: ignore @classmethod diff --git a/cli/openbb_cli/argparse_translator/argparse_translator.py b/cli/openbb_cli/argparse_translator/argparse_translator.py index 605fbe6b41f..2715f2b3901 100644 --- a/cli/openbb_cli/argparse_translator/argparse_translator.py +++ b/cli/openbb_cli/argparse_translator/argparse_translator.py @@ -1,8 +1,9 @@ +"""Module for translating a function into an argparse program.""" + import argparse import inspect import re from copy import deepcopy -from enum import Enum from typing import ( Any, Callable, @@ -19,181 +20,37 @@ from typing import ( ) from openbb_core.app.model.field import OpenBBField -from pydantic import BaseModel, model_validator +from pydantic import BaseModel from typing_extensions import Annotated +from openbb_cli.argparse_translator.argparse_argument import ( + ArgparseArgumentGroupModel, + ArgparseArgumentModel, +) +from openbb_cli.argparse_translator.utils import ( + get_argument_choices, + get_argument_optional_choices, + in_group, + remove_argument, + set_optional_choices, +) + # pylint: disable=protected-access SEP = "__" -class ArgparseActionType(Enum): - store = "store" - store_true = "store_true" - - -class CustomArgument(BaseModel): - name: str - type: Optional[Any] - dest: str - default: Any - required: bool - action: Literal["store_true", "store"] - help: str - nargs: Optional[Literal["+"]] - choices: Optional[Tuple] - - @model_validator(mode="after") # type: ignore - @classmethod - def validate_action(cls, values: "CustomArgument"): - if values.type is bool and values.action != "store_true": - raise ValueError('If type is bool, action must be "store_true"') - return values - - @model_validator(mode="after") # type: ignore - @classmethod - def remove_props_on_store_true(cls, values: "CustomArgument"): - if values.action == "store_true": - values.type = None - values.nargs = None - values.choices = None - return values - - # override - def model_dump(self, **kwargs): - res = super().model_dump(**kwargs) - - # Check if choices is present and if it's an empty tuple remove it - if "choices" in res and not res["choices"]: - del res["choices"] - - return res - - -class CustomArgumentGroup(BaseModel): - name: str - arguments: List[CustomArgument] - - -class ReferenceToCustomArgumentsProcessor: - def __init__(self, reference: Dict[str, Dict]): - """Initializes the ReferenceToCustomArgumentsProcessor.""" - self.reference = reference - self.custom_groups: Dict[str, List[CustomArgumentGroup]] = {} - - self.build_custom_groups() - - @staticmethod - def _make_type_parsable(type_: str) -> type: - """Make the type parsable by removing the annotations.""" - if "Union" in type_ and "str" in type_: - return str - if "Union" in type_ and "int" in type_: - return int - if type_ in ["date", "datetime.time", "time"]: - return str - - if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]): - if "Annotated" in type_: - type_ = type_.replace("Annotated[", "").replace("]", "") - type_ = type_.split(",")[0] - - return eval(type_) # noqa: S307, E501 pylint: disable=eval-used - - def _parse_type(self, type_: str) -> type: - """Parse the type from the string representation.""" - type_ = self._make_type_parsable(type_) # type: ignore - - if get_origin(type_) is Literal: - type_ = type(get_args(type_)[0]) # type: ignore - - return type_ # type: ignore - - def _get_nargs(self, type_: type) -> Optional[Union[int, str]]: - """Get the nargs for the given type.""" - if get_origin(type_) is list: - return "+" - return None - - def _get_choices(self, type_: str, custom_choices: Any) -> Tuple: - """Get the choices for the given type.""" - type_ = self._make_type_parsable(type_) # type: ignore - type_origin = get_origin(type_) - - choices = () - - if type_origin is Literal: - choices = get_args(type_) - - if type_origin is list: - type_ = get_args(type_)[0] - - if get_origin(type_) is Literal: - choices = get_args(type_) - - if type_origin is Union and type(None) in get_args(type_): - # remove NoneType from the args - args = [arg for arg in get_args(type_) if arg != type(None)] - # if there is only one arg left, use it - if len(args) > 1: - raise ValueError("Union with NoneType should have only one type left") - type_ = args[0] - - if get_origin(type_) is Literal: - choices = get_args(type_) - - if custom_choices: - return tuple(custom_choices) - - return choices - - def build_custom_groups(self): - """Build the custom groups from the reference.""" - for route, v in self.reference.items(): - for provider, args in v["parameters"].items(): - if provider == "standard": - continue - - custom_arguments = [] - for arg in args: - if arg.get("standard"): - continue - - type_ = self._parse_type(arg["type"]) - - custom_arguments.append( - CustomArgument( - name=arg["name"], - type=type_, - dest=arg["name"], - default=arg["default"], - required=not (arg["optional"]), - action="store" if type_ != bool else "store_true", - help=arg["description"], - nargs=self._get_nargs(type_), # type: ignore - choices=self._get_choices( - arg["type"], custom_choices=arg["choices"] - ), - ) - ) - - group = CustomArgumentGroup(name=provider, arguments=custom_arguments) - - if route not in self.custom_groups: - self.custom_groups[route] = [] - - self.custom_groups[route].append(group) - - class ArgparseTranslator: + """Class to translate a function into an argparse program.""" + def __init__( self, func: Callable, - custom_argument_groups: Optional[List[CustomArgumentGroup]] = None, + custom_argument_groups: Optional[List[ArgparseArgumentGroupModel]] = None, add_help: Optional[bool] = True, ): """ - Initializes the ArgparseTranslator. + Initialize the ArgparseTranslator. Args: func (Callable): The function to translate into an argparse program. @@ -225,45 +82,6 @@ class ArgparseTranslator: def _handle_argument_in_groups(self, argument, group): """Handle the argument and add it to the parser.""" - def _in_group(arg, group_title): - for action_group in self._parser._action_groups: - if action_group.title == group_title: - for action in action_group._group_actions: - opts = action.option_strings - if (opts and opts[0] == arg) or action.dest == arg: - return True - return False - - def _remove_argument(arg) -> List[Optional[str]]: - groups_w_arg = [] - - # remove the argument from the parser - for action in self._parser._actions: - opts = action.option_strings - if (opts and opts[0] == arg) or action.dest == arg: - self._parser._remove_action(action) - break - - # remove from all groups - for action_group in self._parser._action_groups: - for action in action_group._group_actions: - opts = action.option_strings - if (opts and opts[0] == arg) or action.dest == arg: - action_group._group_actions.remove(action) - groups_w_arg.append(action_group.title) - - # remove from _action_groups dict - self._parser._option_string_actions.pop(f"--{arg}", None) - - return groups_w_arg - - def _get_arg_choices(arg) -> Tuple: - for action in self._parser._actions: - opts = action.option_strings - if (opts and opts[0] == arg) or action.dest == arg: - return tuple(action.choices or ()) - return () - def _update_providers( input_string: str, new_provider: List[Optional[str]] ) -> str: @@ -285,23 +103,27 @@ class ArgparseTranslator: kwargs = argument.model_dump(exclude={"name"}, exclude_none=True) model_choices = kwargs.get("choices", ()) or () # extend choices - choices = tuple(set(_get_arg_choices(argument.name) + model_choices)) + existing_choices = get_argument_choices(self._parser, argument.name) + choices = tuple(set(existing_choices + model_choices)) + optional_choices = bool(existing_choices and not model_choices) # check if the argument is in the required arguments - if _in_group(argument.name, group_title="required arguments"): + if in_group(self._parser, argument.name, group_title="required arguments"): for action in self._required._group_actions: if action.dest == argument.name and choices: # update choices action.choices = choices + set_optional_choices(action, optional_choices) return # check if the argument is in the optional arguments - if _in_group(argument.name, group_title="optional arguments"): + if in_group(self._parser, argument.name, group_title="optional arguments"): for action in self._parser._actions: if action.dest == argument.name: # update choices if choices: action.choices = choices + set_optional_choices(action, optional_choices) if argument.name not in self.signature.parameters: # update help action.help = _update_providers( @@ -309,9 +131,16 @@ class ArgparseTranslator: ) return + # we need to check if the optional choices were set in other group + # before we remove the argument from the group, otherwise we will lose info + if not optional_choices: + optional_choices = get_argument_optional_choices( + self._parser, argument.name + ) + # if the argument is in use, remove it from all groups # and return the groups that had the argument - groups_w_arg = _remove_argument(argument.name) + groups_w_arg = remove_argument(self._parser, argument.name) groups_w_arg.append(group.title) # add current group # add it to the optional arguments group instead @@ -319,16 +148,17 @@ class ArgparseTranslator: kwargs["choices"] = choices # update choices # add provider info to the help kwargs["help"] = _update_providers(argument.help or "", groups_w_arg) - self._parser.add_argument(f"--{argument.name}", **kwargs) + action = self._parser.add_argument(f"--{argument.name}", **kwargs) + set_optional_choices(action, optional_choices) @property def parser(self) -> argparse.ArgumentParser: + """Get the argparse parser.""" return deepcopy(self._parser) @staticmethod def _build_description(func_doc: str) -> str: - """Builds the description of the argparse program from the function docstring.""" - + """Build the description of the argparse program from the function docstring.""" patterns = ["openbb\n ======", "Parameters\n ----------"] if func_doc: @@ -341,31 +171,34 @@ class ArgparseTranslator: @staticmethod def _param_is_default(param: inspect.Parameter) -> bool: - """Returns True if the parameter has a default value.""" + """Return True if the parameter has a default value.""" return param.default != inspect.Parameter.empty def _get_action_type(self, param: inspect.Parameter) -> str: - """Returns the argparse action type for the given parameter.""" + """Return the argparse action type for the given parameter.""" param_type = self.type_hints[param.name] + type_origin = get_origin(param_type) - if param_type == bool: - return ArgparseActionType.store_true.value - return ArgparseActionType.store.value + if param_type == bool or ( + type_origin is Union and bool in get_args(param_type) + ): + return "store_true" + return "store" def _get_type_and_choices( self, param: inspect.Parameter ) -> Tuple[Type[Any], Tuple[Any, ...]]: - """Returns the type and choices for the given parameter.""" + """Return the type and choices for the given parameter.""" param_type = self.type_hints[param.name] type_origin = get_origin(param_type) - choices = () + choices: tuple[Any, ...] = () if type_origin is Literal: choices = get_args(param_type) param_type = type(choices[0]) # type: ignore - if type_origin is list: # TODO: dict should also go here + if type_origin is list: param_type = get_args(param_type)[0] if get_origin(param_type) is Literal: @@ -413,32 +246,26 @@ class ArgparseTranslator: @classmethod def _get_argument_custom_help(cls, param: inspect.Parameter) -> Optional[str]: - """Returns the help annotation for the given parameter.""" + """Return the help annotation for the given parameter.""" base_annotation = param.annotation _, custom_annotations = cls._split_annotation(base_annotation, OpenBBField) help_annotation = ( custom_annotations[0].description if custom_annotations else None ) - if not help_annotation: - # try to get it from the docstring - pass return help_annotation @classmethod def _get_argument_custom_choices(cls, param: inspect.Parameter) -> Optional[str]: - """Returns the help annotation for the given parameter.""" + """Return the help annotation for the given parameter.""" base_annotation = param.annotation _, custom_annotations = cls._split_annotation(base_annotation, OpenBBField) choices_annotation = ( custom_annotations[0].choices if custom_annotations else None ) - if not choices_annotation: - # try to get it from the docstring - pass return choices_annotation def _get_nargs(self, param: inspect.Parameter) -> Optional[str]: - """Returns the nargs annotation for the given parameter.""" + """Return the nargs annotation for the given parameter.""" param_type = self.type_hints[param.name] origin = get_origin(param_type) @@ -453,11 +280,8 @@ class ArgparseTranslator: return None def _generate_argparse_arguments(self, parameters) -> None: - """Generates the argparse arguments from the function parameters.""" + """Generate the argparse arguments from the function parameters.""" for param in parameters.values(): - # TODO : how to handle kwargs? - # it's possible to add unknown arguments when parsing as follows: - # args, unknown_args = parser.parse_known_args() if param.name == "kwargs": continue @@ -505,32 +329,27 @@ class ArgparseTranslator: required = not self._param_is_default(param) - kwargs = { - "type": param_type, - "dest": param.name, - "default": param.default, - "required": required, - "action": self._get_action_type(param), - "help": self._get_argument_custom_help(param), - "nargs": self._get_nargs(param), - } - - if choices: - kwargs["choices"] = choices - - if param_type == bool: - # store_true action does not accept the below kwargs - kwargs.pop("type") - kwargs.pop("nargs") + argument = ArgparseArgumentModel( + name=param.name, + type=param_type, + dest=param.name, + default=param.default, + required=required, + action=self._get_action_type(param), + help=self._get_argument_custom_help(param), + nargs=self._get_nargs(param), + choices=choices, + ) + kwargs = argument.model_dump(exclude={"name"}, exclude_none=True) if required: self._required.add_argument( - f"--{param.name}", + f"--{argument.name}", **kwargs, ) else: self._parser.add_argument( - f"--{param.name}", + f"--{argument.name}", **kwargs, ) @@ -556,7 +375,6 @@ class ArgparseTranslator: # for each argument in the signature that is a custom type, we need to # update the kwargs with the custom type kwargs for param in self.signature.parameters.values(): - # TODO : how to handle kwargs? if param.name == "kwargs": continue param_type, _ = self._get_type_and_choices(param) @@ -571,7 +389,7 @@ class ArgparseTranslator: parsed_args: Optional[argparse.Namespace] = None, ) -> Any: """ - Executes the original function with the parsed arguments. + Execute the original function with the parsed arguments. Args: parsed_args (Optional[argparse.Namespace], optional): The parsed arguments. Defaults to None. @@ -602,7 +420,7 @@ class ArgparseTranslator: def parse_args_and_execute(self) -> Any: """ - Parses the arguments and executes the original function. + Parse the arguments and executes the original function. Returns: Any: The return value of the original function. @@ -612,7 +430,7 @@ class ArgparseTranslator: def translate(self) -> Callable: """ - Wraps the original function with an argparse program. + Wrap the original function with an argparse program. Returns: Callable: The original function wrapped with an argparse program. diff --git a/cli/openbb_cli/argparse_translator/reference_processor.py b/cli/openbb_cli/argparse_translator/reference_processor.py new file mode 100644 index 00000000000..53cba266cf5 --- /dev/null +++ b/cli/openbb_cli/argparse_translator/reference_processor.py @@ -0,0 +1,137 @@ +"""Module for the ReferenceToArgumentsProcessor class.""" + +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Union, + get_args, + get_origin, +) + +from openbb_cli.argparse_translator.argparse_argument import ( + ArgparseArgumentGroupModel, + ArgparseArgumentModel, +) + + +class ReferenceToArgumentsProcessor: + """Class to process the reference and build custom argument groups.""" + + def __init__(self, reference: Dict[str, Dict]): + """Initialize the ReferenceToArgumentsProcessor.""" + self._reference = reference + self._custom_groups: Dict[str, List[ArgparseArgumentGroupModel]] = {} + + self._build_custom_groups() + + @property + def custom_groups(self) -> Dict[str, List[ArgparseArgumentGroupModel]]: + """Get the custom groups.""" + return self._custom_groups + + @staticmethod + def _make_type_parsable(type_: str) -> type: + """Make the type parsable by removing the annotations.""" + if "Union" in type_ and "str" in type_: + return str + if "Union" in type_ and "int" in type_: + return int + if type_ in ["date", "datetime.time", "time"]: + return str + + if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]): + if "Annotated" in type_: + type_ = type_.replace("Annotated[", "").replace("]", "") + type_ = type_.split(",")[0] + + return eval(type_) # noqa: S307, E501 pylint: disable=eval-used + + def _parse_type(self, type_: str) -> type: + """Parse the type from the string representation.""" + type_ = self._make_type_parsable(type_) # type: ignore + + if get_origin(type_) is Literal: + type_ = type(get_args(type_)[0]) # type: ignore + + return type_ # type: ignore + + def _get_nargs(self, type_: type) -> Optional[Union[int, str]]: + """Get the nargs for the given type.""" + if get_origin(type_) is list: + return "+" + return None + + def _get_choices(self, type_: str, custom_choices: Any) -> Tuple: + """Get the choices for the given type.""" + type_ = self._make_type_parsable(type_) # type: ignore + type_origin = get_origin(type_) + + choices: tuple[Any, ...] = () + + if type_origin is Literal: + choices = get_args(type_) + + if type_origin is list: + type_ = get_args(type_)[0] + + if get_origin(type_) is Literal: + choices = get_args(type_) + + if type_origin is Union and type(None) in get_args(type_): + # remove NoneType from the args + args = [arg for arg in get_args(type_) if arg != type(None)] + # if there is only one arg left, use it + if len(args) > 1: + raise ValueError("Union with NoneType should have only one type left") + type_ = args[0] + + if get_origin(type_) is Literal: + choices = get_args(type_) + + if custom_choices: + return tuple(custom_choices) + + return choices + + def _build_custom_groups(self): + """Build the custom groups from the reference.""" + for route, v in self._reference.items(): + for provider, args in v["parameters"].items(): + if provider == "standard": + continue + + custom_arguments = [] + for arg in args: + if arg.get("standard"): + continue + + type_ = self._parse_type(arg["type"]) + + custom_arguments.append( + ArgparseArgumentModel( + name=arg["name"], + type=type_, + dest=arg["name"], + default=arg["default"], + required=not (arg["optional"]), + action="store" if type_ != bool else "store_true", + help=arg["description"], + nargs=self._get_nargs(type_), # type: ignore + choices=self._get_choices( + arg["type"], custom_choices=arg["choices"] + ), + ) + ) + + group = ArgparseArgumentGroupModel( + name=provider, arguments=custom_arguments + ) |