diff options
Diffstat (limited to 'cli/openbb_cli/argparse_translator/argparse_translator.py')
-rw-r--r-- | cli/openbb_cli/argparse_translator/argparse_translator.py | 530 |
1 files changed, 530 insertions, 0 deletions
diff --git a/cli/openbb_cli/argparse_translator/argparse_translator.py b/cli/openbb_cli/argparse_translator/argparse_translator.py new file mode 100644 index 00000000000..45fa4ce5d6b --- /dev/null +++ b/cli/openbb_cli/argparse_translator/argparse_translator.py @@ -0,0 +1,530 @@ +import argparse +import inspect +from copy import deepcopy +from enum import Enum +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) + +from openbb_core.app.model.field import OpenBBField +from pydantic import BaseModel, model_validator +from typing_extensions import Annotated + +SEP = "__" + + +class ArgparseActionType(Enum): + store = "store" + store_true = "store_true" + + +class CustomArgument(BaseModel): + + name: str + type: Optional[Any] + dest: str + default: Any + required: bool + action: Literal["store_true", "store"] + help: str + nargs: Optional[Literal["+"]] + choices: Optional[Any] + + @model_validator(mode="after") # type: ignore + @classmethod + def validate_action(cls, values: "CustomArgument"): + if values.type is bool and values.action != "store_true": + raise ValueError('If type is bool, action must be "store_true"') + return values + + @model_validator(mode="after") # type: ignore + @classmethod + def remove_props_on_store_true(cls, values: "CustomArgument"): + if values.action == "store_true": + values.type = None + values.nargs = None + values.choices = None + return values + + # override + def model_dump(self, **kwargs): + + res = super().model_dump(**kwargs) + + # Check if choices is present and if it's an empty tuple remove it + if "choices" in res and not res["choices"]: + del res["choices"] + + return res + + +class CustomArgumentGroup(BaseModel): + name: str + arguments: List[CustomArgument] + + +class ReferenceToCustomArgumentsProcessor: + def __init__(self, reference: Dict[str, Dict]): + """Initializes the ReferenceToCustomArgumentsProcessor.""" + self.reference = reference + self.custom_groups: Dict[str, List[CustomArgumentGroup]] = {} + + self.build_custom_groups() + + @staticmethod + def _make_type_parsable(type_: str) -> type: + """Make the type parsable by removing the annotations.""" + if "Union" in type_ and "str" in type_: + return str + if "Union" in type_ and "int" in type_: + return int + if type_ in ["date", "datetime.time", "time"]: + return str + + if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]): + if "Annotated" in type_: + type_ = type_.replace("Annotated[", "").replace("]", "") + type_ = type_.split(",")[0] + + return eval(type_) # noqa: S307, E501 pylint: disable=eval-used + + def _parse_type(self, type_: str) -> type: + """Parse the type from the string representation.""" + type_ = self._make_type_parsable(type_) # type: ignore + + if get_origin(type_) is Literal: + type_ = type(get_args(type_)[0]) # type: ignore + + return type_ # type: ignore + + def _get_nargs(self, type_: type) -> Optional[Union[int, str]]: + """Get the nargs for the given type.""" + if get_origin(type_) is list: + return "+" + return None + + def _get_choices(self, type_: str) -> Tuple: + """Get the choices for the given type.""" + type_ = self._make_type_parsable(type_) # type: ignore + type_origin = get_origin(type_) + + choices = () + + if type_origin is Literal: + choices = get_args(type_) + # param_type = type(choices[0]) + + if type_origin is list: + type_ = get_args(type_)[0] + + if get_origin(type_) is Literal: + choices = get_args(type_) + # param_type = type(choices[0]) + + if type_origin is Union and type(None) in get_args(type_): + # remove NoneType from the args + args = [arg for arg in get_args(type_) if arg != type(None)] + # if there is only one arg left, use it + if len(args) > 1: + raise ValueError("Union with NoneType should have only one type left") + type_ = args[0] + + if get_origin(type_) is Literal: + choices = get_args(type_) + # param_type = type(choices[0]) + + return choices + + def build_custom_groups(self): + """Build the custom groups from the reference.""" + for route, v in self.reference.items(): + + for provider, args in v["parameters"].items(): + if provider == "standard": + continue + + custom_arguments = [] + for arg in args: + if arg.get("standard"): + continue + + type_ = self._parse_type(arg["type"]) + + custom_arguments.append( + CustomArgument( + name=arg["name"], + type=type_, + dest=arg["name"], + default=arg["default"], + required=not (arg["optional"]), + action="store" if type_ != bool else "store_true", + help=arg["description"], + nargs=self._get_nargs(type_), # type: ignore + choices=self._get_choices(arg["type"]), + ) + ) + + group = CustomArgumentGroup(name=provider, arguments=custom_arguments) + + if route not in self.custom_groups: + self.custom_groups[route] = [] + + self.custom_groups[route].append(group) + + +class ArgparseTranslator: + def __init__( + self, + func: Callable, + custom_argument_groups: Optional[List[CustomArgumentGroup]] = None, + add_help: Optional[bool] = True, + ): + """ + Initializes the ArgparseTranslator. + + Args: + func (Callable): The function to translate into an argparse program. + add_help (Optional[bool], optional): Whether to add the help argument. Defaults to False. + """ + self.func = func + self.signature = inspect.signature(func) + self.type_hints = get_type_hints(func) + self.provider_parameters = [] + + self._parser = argparse.ArgumentParser( + prog=func.__name__, + description=self._build_description(func.__doc__), # type: ignore + formatter_class=argparse.RawTextHelpFormatter, + add_help=add_help if add_help else False, + ) + self._required = self._parser.add_argument_group("required arguments") + + if any(param in self.type_hints for param in self.signature.parameters): + self._generate_argparse_arguments(self.signature.parameters) + + if custom_argument_groups: + for group in custom_argument_groups: + argparse_group = self._parser.add_argument_group(group.name) + for argument in group.arguments: + kwargs = argument.model_dump(exclude={"name"}, exclude_none=True) + + # If the argument is already in use, we can't repeat it + if f"--{argument.name}" not in self._parser_arguments(): + argparse_group.add_argument(f"--{argument.name}", **kwargs) + self.provider_parameters.append(argument.name) + + def _parser_arguments(self) -> List[str]: + """Get all the arguments from all groups currently defined on the parser.""" + arguments_in_use: List[str] = [] + + # pylint: disable=protected-access + for action_group in self._parser._action_groups: + for action in action_group._group_actions: + arguments_in_use.extend(action.option_strings) + + return arguments_in_use + + @property + def parser(self) -> argparse.ArgumentParser: + return deepcopy(self._parser) + + @staticmethod + def _build_description(func_doc: str) -> str: + """Builds the description of the argparse program from the function docstring.""" + + patterns = ["openbb\n ======", "Parameters\n ----------"] + + if func_doc: + for pattern in patterns: + if pattern in func_doc: + func_doc = func_doc[: func_doc.index(pattern)].strip() + break + + return func_doc + + @staticmethod + def _param_is_default(param: inspect.Parameter) -> bool: + """Returns True if the parameter has a default value.""" + return param.default != inspect.Parameter.empty + + def _get_action_type(self, param: inspect.Parameter) -> str: + """Returns the argparse action type for the given parameter.""" + param_type = self.type_hints[param.name] + + if param_type == bool: + return ArgparseActionType.store_true.value + return ArgparseActionType.store.value + + def _get_type_and_choices( + self, param: inspect.Parameter + ) -> Tuple[Type[Any], Tuple[Any, ...]]: + """Returns the type and choices for the given parameter.""" + param_type = self.type_hints[param.name] + type_origin = get_origin(param_type) + + choices = () + + if type_origin is Literal: + choices = get_args(param_type) + param_type = type(choices[0]) # type: ignore + + if type_origin is list: # TODO: dict should also go here + param_type = get_args(param_type)[0] + + if get_origin(param_type) is Literal: + choices = get_args(param_type) + param_type = type(choices[0]) # type: ignore + + if type_origin is Union: + union_args = get_args(param_type) + if str in union_args: + param_type = str + + # check if it's an Optional, which would be a Union with NoneType + if type(None) in get_args(param_type): + # remove NoneType from the args + args = [arg for arg in get_args(param_type) if arg != type(None)] + # if there is only one arg left, use it + if len(args) > 1: + raise ValueError( + "Union with NoneType should have only one type left" + ) + param_type = args[0] + + if get_origin(param_type) is Literal: + choices = get_args(param_type) + param_type = type(choices[0]) # type: ignore + + # if there are custom choices, override + choices = self._get_argument_custom_choices(param) or choices # type: ignore + + return param_type, choices + + @staticmethod + def _split_annotation( + base_annotation: Type[Any], custom_annotation_type: Type + ) -> Tuple[Type[Any], List[Any]]: + """Find the base annotation and the custom annotations, namely the OpenBBField.""" + if get_origin(base_annotation) is not Annotated: + return base_annotation, [] + base_annotation, *maybe_custom_annotations = get_args(base_annotation) + return base_annotation, [ + annotation + for annotation in maybe_custom_annotations + if isinstance(annotation, custom_annotation_type) + ] + + @classmethod + def _get_argument_custom_help(cls, param: inspect.Parameter) -> Optional[str]: + """Returns the help annotation for the given parameter.""" + base_annotation = param.annotation + _, custom_annotations = cls._split_annotation(base_annotation, OpenBBField) + help_annotation = ( + custom_annotations[0].description if custom_annotations else None + ) + if not help_annotation: + # try to get it from the docstring + pass + return help_annotation + + @classmethod + def _get_argument_custom_choices(cls, param: inspect.Parameter) -> Optional[str]: + """Returns the help annotation for the given parameter.""" + base_annotation = param.annotation + _, custom_annotations = cls._split_annotation(base_annotation, OpenBBField) + choices_annotation = ( + custom_annotations[0].choices if custom_annotations else None + ) + if not choices_annotation: + # try to get it from the docstring + pass + return choices_annotation + + def _get_nargs(self, param: inspect.Parameter) -> Optional[str]: + """Returns the nargs annotation for the given parameter.""" + param_type = self.type_hints[param.name] + origin = get_origin(param_type) + + if origin is list: + return "+" + + if origin is Union and any( + get_origin(arg) is list for arg in get_args(param_type) + ): + return "+" + + return None + + def _generate_argparse_arguments(self, parameters) -> None: + """Generates the argparse arguments from the function parameters.""" + for param in parameters.values(): + # TODO : how to handle kwargs? + # it's possible to add unknown arguments when parsing as follows: + # args, unknown_args = parser.parse_known_args() + if param.name == "kwargs": + continue + + param_type, choices = self._get_type_and_choices(param) + + # if the param is a custom type, we need to flatten it + if inspect.isclass(param_type) and issubclass(param_type, BaseModel): + # update type hints with the custom type fields + type_hints = get_type_hints(param_type) + # prefix the type hints keys with the param name + type_hints = { + f"{param.name}{SEP}{key}": value + for key, value in type_hints.items() + } + self.type_hints.update(type_hints) + # create a signature from the custom type + sig = inspect.signature(param_type) + + # add help to the annotation + annotated_parameters: List[inspect.Parameter] = [] + for child_param in sig.parameters.values(): + new_child_param = child_param.replace( + name=f"{param.name}{SEP}{child_param.name}", + annotation=Annotated[ + child_param.annotation, + OpenBBField( + description=param_type.model_json_schema()[ + "properties" + ][child_param.name].get("description", None) + ), + ], + kind=inspect.Parameter.KEYWORD_ONLY, + ) + annotated_parameters.append(new_child_param) + + # replacing with the annotated parameters + new_signature = inspect.Signature( + parameters=annotated_parameters, + return_annotation=sig.return_annotation, + ) + self._generate_argparse_arguments(new_signature.parameters) + + # the custom type itself should not be added as an argument + continue + + required = not self._param_is_default(param) + + kwargs = { + "type": param_type, + "dest": param.name, + "default": param.default, + "required": required, + "action": self._get_action_type(param), + "help": self._get_argument_custom_help(param), + "nargs": self._get_nargs(param), + } + + if choices: + kwargs["choices"] = choices + + if param_type == bool: + # store_true action does not accept the below kwargs + kwargs.pop("type") + kwargs.pop("nargs") + + if required: + self._required.add_argument( + f"--{param.name}", + **kwargs, + ) + else: + self._parser.add_argument( + f"--{param.name}", + **kwargs, + ) + + @staticmethod + def _unflatten_args(args: dict) -> Dict[str, Any]: + """Unflatten the args that were flattened by the custom types.""" + result: Dict[str, Any] = {} + for key, value in args.items(): + if SEP in key: + parts = key.split(SEP) + nested_dict = result + for part in parts[:-1]: + if part not in nested_dict: + nested_dict[part] = {} + nested_dict = nested_dict[part] + nested_dict[parts[-1]] = value + else: + result[key] = value + return result + + def _update_with_custom_types(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Update the kwargs with the custom types.""" + # for each argument in the signature that is a custom type, we need to + # update the kwargs with the custom type kwargs + for param in self.signature.parameters.values(): + # TODO : how to handle kwargs? + if param.name == "kwargs": + continue + param_type, _ = self._get_type_and_choices(param) + if inspect.isclass(param_type) and issubclass(param_type, BaseModel): + custom_type_kwargs = kwargs[param.name] + kwargs[param.name] = param_type(**custom_type_kwargs) + + return kwargs + + def execute_func( + self, + parsed_args: Optional[argparse.Namespace] = None, + ) -> Any: + """ + Executes the original function with the parsed arguments. + + Args: + parsed_args (Optional[argparse.Namespace], optional): The parsed arguments. Defaults to None. + + Returns: + Any: The return value of the original function. + + """ + kwargs = self._unflatten_args(vars(parsed_args)) + kwargs = self._update_with_custom_types(kwargs) + + # remove kwargs that doesn't match the signature or provider parameters + kwargs = { + key: value + for key, value in kwargs.items() + if key in self.signature.parameters or key in self.provider_parameters + } + + return self.func(**kwargs) + + def parse_args_and_execute(self) -> Any: + """ + Parses the arguments and executes the original function. + + Returns: + Any: The return value of the original function. + """ + parsed_args = self._parser.parse_args() + return self.execute_func(parsed_args) + + def translate(self) -> Callable: + """ + Wraps the original function with an argparse program. + + Returns: + Callable: The original function wrapped with an argparse program. + """ + + def wrapper_func(): + return self.parse_args_and_execute() + + return wrapper_func |