summaryrefslogtreecommitdiffstats
path: root/cli/openbb_cli/argparse_translator/argparse_translator.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/openbb_cli/argparse_translator/argparse_translator.py')
-rw-r--r--cli/openbb_cli/argparse_translator/argparse_translator.py530
1 files changed, 530 insertions, 0 deletions
diff --git a/cli/openbb_cli/argparse_translator/argparse_translator.py b/cli/openbb_cli/argparse_translator/argparse_translator.py
new file mode 100644
index 00000000000..45fa4ce5d6b
--- /dev/null
+++ b/cli/openbb_cli/argparse_translator/argparse_translator.py
@@ -0,0 +1,530 @@
+import argparse
+import inspect
+from copy import deepcopy
+from enum import Enum
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+ get_args,
+ get_origin,
+ get_type_hints,
+)
+
+from openbb_core.app.model.field import OpenBBField
+from pydantic import BaseModel, model_validator
+from typing_extensions import Annotated
+
+SEP = "__"
+
+
+class ArgparseActionType(Enum):
+ store = "store"
+ store_true = "store_true"
+
+
+class CustomArgument(BaseModel):
+
+ name: str
+ type: Optional[Any]
+ dest: str
+ default: Any
+ required: bool
+ action: Literal["store_true", "store"]
+ help: str
+ nargs: Optional[Literal["+"]]
+ choices: Optional[Any]
+
+ @model_validator(mode="after") # type: ignore
+ @classmethod
+ def validate_action(cls, values: "CustomArgument"):
+ if values.type is bool and values.action != "store_true":
+ raise ValueError('If type is bool, action must be "store_true"')
+ return values
+
+ @model_validator(mode="after") # type: ignore
+ @classmethod
+ def remove_props_on_store_true(cls, values: "CustomArgument"):
+ if values.action == "store_true":
+ values.type = None
+ values.nargs = None
+ values.choices = None
+ return values
+
+ # override
+ def model_dump(self, **kwargs):
+
+ res = super().model_dump(**kwargs)
+
+ # Check if choices is present and if it's an empty tuple remove it
+ if "choices" in res and not res["choices"]:
+ del res["choices"]
+
+ return res
+
+
+class CustomArgumentGroup(BaseModel):
+ name: str
+ arguments: List[CustomArgument]
+
+
+class ReferenceToCustomArgumentsProcessor:
+ def __init__(self, reference: Dict[str, Dict]):
+ """Initializes the ReferenceToCustomArgumentsProcessor."""
+ self.reference = reference
+ self.custom_groups: Dict[str, List[CustomArgumentGroup]] = {}
+
+ self.build_custom_groups()
+
+ @staticmethod
+ def _make_type_parsable(type_: str) -> type:
+ """Make the type parsable by removing the annotations."""
+ if "Union" in type_ and "str" in type_:
+ return str
+ if "Union" in type_ and "int" in type_:
+ return int
+ if type_ in ["date", "datetime.time", "time"]:
+ return str
+
+ if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]):
+ if "Annotated" in type_:
+ type_ = type_.replace("Annotated[", "").replace("]", "")
+ type_ = type_.split(",")[0]
+
+ return eval(type_) # noqa: S307, E501 pylint: disable=eval-used
+
+ def _parse_type(self, type_: str) -> type:
+ """Parse the type from the string representation."""
+ type_ = self._make_type_parsable(type_) # type: ignore
+
+ if get_origin(type_) is Literal:
+ type_ = type(get_args(type_)[0]) # type: ignore
+
+ return type_ # type: ignore
+
+ def _get_nargs(self, type_: type) -> Optional[Union[int, str]]:
+ """Get the nargs for the given type."""
+ if get_origin(type_) is list:
+ return "+"
+ return None
+
+ def _get_choices(self, type_: str) -> Tuple:
+ """Get the choices for the given type."""
+ type_ = self._make_type_parsable(type_) # type: ignore
+ type_origin = get_origin(type_)
+
+ choices = ()
+
+ if type_origin is Literal:
+ choices = get_args(type_)
+ # param_type = type(choices[0])
+
+ if type_origin is list:
+ type_ = get_args(type_)[0]
+
+ if get_origin(type_) is Literal:
+ choices = get_args(type_)
+ # param_type = type(choices[0])
+
+ if type_origin is Union and type(None) in get_args(type_):
+ # remove NoneType from the args
+ args = [arg for arg in get_args(type_) if arg != type(None)]
+ # if there is only one arg left, use it
+ if len(args) > 1:
+ raise ValueError("Union with NoneType should have only one type left")
+ type_ = args[0]
+
+ if get_origin(type_) is Literal:
+ choices = get_args(type_)
+ # param_type = type(choices[0])
+
+ return choices
+
+ def build_custom_groups(self):
+ """Build the custom groups from the reference."""
+ for route, v in self.reference.items():
+
+ for provider, args in v["parameters"].items():
+ if provider == "standard":
+ continue
+
+ custom_arguments = []
+ for arg in args:
+ if arg.get("standard"):
+ continue
+
+ type_ = self._parse_type(arg["type"])
+
+ custom_arguments.append(
+ CustomArgument(
+ name=arg["name"],
+ type=type_,
+ dest=arg["name"],
+ default=arg["default"],
+ required=not (arg["optional"]),
+ action="store" if type_ != bool else "store_true",
+ help=arg["description"],
+ nargs=self._get_nargs(type_), # type: ignore
+ choices=self._get_choices(arg["type"]),
+ )
+ )
+
+ group = CustomArgumentGroup(name=provider, arguments=custom_arguments)
+
+ if route not in self.custom_groups:
+ self.custom_groups[route] = []
+
+ self.custom_groups[route].append(group)
+
+
+class ArgparseTranslator:
+ def __init__(
+ self,
+ func: Callable,
+ custom_argument_groups: Optional[List[CustomArgumentGroup]] = None,
+ add_help: Optional[bool] = True,
+ ):
+ """
+ Initializes the ArgparseTranslator.
+
+ Args:
+ func (Callable): The function to translate into an argparse program.
+ add_help (Optional[bool], optional): Whether to add the help argument. Defaults to False.
+ """
+ self.func = func
+ self.signature = inspect.signature(func)
+ self.type_hints = get_type_hints(func)
+ self.provider_parameters = []
+
+ self._parser = argparse.ArgumentParser(
+ prog=func.__name__,
+ description=self._build_description(func.__doc__), # type: ignore
+ formatter_class=argparse.RawTextHelpFormatter,
+ add_help=add_help if add_help else False,
+ )
+ self._required = self._parser.add_argument_group("required arguments")
+
+ if any(param in self.type_hints for param in self.signature.parameters):
+ self._generate_argparse_arguments(self.signature.parameters)
+
+ if custom_argument_groups:
+ for group in custom_argument_groups:
+ argparse_group = self._parser.add_argument_group(group.name)
+ for argument in group.arguments:
+ kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
+
+ # If the argument is already in use, we can't repeat it
+ if f"--{argument.name}" not in self._parser_arguments():
+ argparse_group.add_argument(f"--{argument.name}", **kwargs)
+ self.provider_parameters.append(argument.name)
+
+ def _parser_arguments(self) -> List[str]:
+ """Get all the arguments from all groups currently defined on the parser."""
+ arguments_in_use: List[str] = []
+
+ # pylint: disable=protected-access
+ for action_group in self._parser._action_groups:
+ for action in action_group._group_actions:
+ arguments_in_use.extend(action.option_strings)
+
+ return arguments_in_use
+
+ @property
+ def parser(self) -> argparse.ArgumentParser:
+ return deepcopy(self._parser)
+
+ @staticmethod
+ def _build_description(func_doc: str) -> str:
+ """Builds the description of the argparse program from the function docstring."""
+
+ patterns = ["openbb\n ======", "Parameters\n ----------"]
+
+ if func_doc:
+ for pattern in patterns:
+ if pattern in func_doc:
+ func_doc = func_doc[: func_doc.index(pattern)].strip()
+ break
+
+ return func_doc
+
+ @staticmethod
+ def _param_is_default(param: inspect.Parameter) -> bool:
+ """Returns True if the parameter has a default value."""
+ return param.default != inspect.Parameter.empty
+
+ def _get_action_type(self, param: inspect.Parameter) -> str:
+ """Returns the argparse action type for the given parameter."""
+ param_type = self.type_hints[param.name]
+
+ if param_type == bool:
+ return ArgparseActionType.store_true.value
+ return ArgparseActionType.store.value
+
+ def _get_type_and_choices(
+ self, param: inspect.Parameter
+ ) -> Tuple[Type[Any], Tuple[Any, ...]]:
+ """Returns the type and choices for the given parameter."""
+ param_type = self.type_hints[param.name]
+ type_origin = get_origin(param_type)
+
+ choices = ()
+
+ if type_origin is Literal:
+ choices = get_args(param_type)
+ param_type = type(choices[0]) # type: ignore
+
+ if type_origin is list: # TODO: dict should also go here
+ param_type = get_args(param_type)[0]
+
+ if get_origin(param_type) is Literal:
+ choices = get_args(param_type)
+ param_type = type(choices[0]) # type: ignore
+
+ if type_origin is Union:
+ union_args = get_args(param_type)
+ if str in union_args:
+ param_type = str
+
+ # check if it's an Optional, which would be a Union with NoneType
+ if type(None) in get_args(param_type):
+ # remove NoneType from the args
+ args = [arg for arg in get_args(param_type) if arg != type(None)]
+ # if there is only one arg left, use it
+ if len(args) > 1:
+ raise ValueError(
+ "Union with NoneType should have only one type left"
+ )
+ param_type = args[0]
+
+ if get_origin(param_type) is Literal:
+ choices = get_args(param_type)
+ param_type = type(choices[0]) # type: ignore
+
+ # if there are custom choices, override
+ choices = self._get_argument_custom_choices(param) or choices # type: ignore
+
+ return param_type, choices
+
+ @staticmethod
+ def _split_annotation(
+ base_annotation: Type[Any], custom_annotation_type: Type
+ ) -> Tuple[Type[Any], List[Any]]:
+ """Find the base annotation and the custom annotations, namely the OpenBBField."""
+ if get_origin(base_annotation) is not Annotated:
+ return base_annotation, []
+ base_annotation, *maybe_custom_annotations = get_args(base_annotation)
+ return base_annotation, [
+ annotation
+ for annotation in maybe_custom_annotations
+ if isinstance(annotation, custom_annotation_type)
+ ]
+
+ @classmethod
+ def _get_argument_custom_help(cls, param: inspect.Parameter) -> Optional[str]:
+ """Returns the help annotation for the given parameter."""
+ base_annotation = param.annotation
+ _, custom_annotations = cls._split_annotation(base_annotation, OpenBBField)
+ help_annotation = (
+ custom_annotations[0].description if custom_annotations else None
+ )
+ if not help_annotation:
+ # try to get it from the docstring
+ pass
+ return help_annotation
+
+ @classmethod
+ def _get_argument_custom_choices(cls, param: inspect.Parameter) -> Optional[str]:
+ """Returns the help annotation for the given parameter."""
+ base_annotation = param.annotation
+ _, custom_annotations = cls._split_annotation(base_annotation, OpenBBField)
+ choices_annotation = (
+ custom_annotations[0].choices if custom_annotations else None
+ )
+ if not choices_annotation:
+ # try to get it from the docstring
+ pass
+ return choices_annotation
+
+ def _get_nargs(self, param: inspect.Parameter) -> Optional[str]:
+ """Returns the nargs annotation for the given parameter."""
+ param_type = self.type_hints[param.name]
+ origin = get_origin(param_type)
+
+ if origin is list:
+ return "+"
+
+ if origin is Union and any(
+ get_origin(arg) is list for arg in get_args(param_type)
+ ):
+ return "+"
+
+ return None
+
+ def _generate_argparse_arguments(self, parameters) -> None:
+ """Generates the argparse arguments from the function parameters."""
+ for param in parameters.values():
+ # TODO : how to handle kwargs?
+ # it's possible to add unknown arguments when parsing as follows:
+ # args, unknown_args = parser.parse_known_args()
+ if param.name == "kwargs":
+ continue
+
+ param_type, choices = self._get_type_and_choices(param)
+
+ # if the param is a custom type, we need to flatten it
+ if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
+ # update type hints with the custom type fields
+ type_hints = get_type_hints(param_type)
+ # prefix the type hints keys with the param name
+ type_hints = {
+ f"{param.name}{SEP}{key}": value
+ for key, value in type_hints.items()
+ }
+ self.type_hints.update(type_hints)
+ # create a signature from the custom type
+ sig = inspect.signature(param_type)
+
+ # add help to the annotation
+ annotated_parameters: List[inspect.Parameter] = []
+ for child_param in sig.parameters.values():
+ new_child_param = child_param.replace(
+ name=f"{param.name}{SEP}{child_param.name}",
+ annotation=Annotated[
+ child_param.annotation,
+ OpenBBField(
+ description=param_type.model_json_schema()[
+ "properties"
+ ][child_param.name].get("description", None)
+ ),
+ ],
+ kind=inspect.Parameter.KEYWORD_ONLY,
+ )
+ annotated_parameters.append(new_child_param)
+
+ # replacing with the annotated parameters
+ new_signature = inspect.Signature(
+ parameters=annotated_parameters,
+ return_annotation=sig.return_annotation,
+ )
+ self._generate_argparse_arguments(new_signature.parameters)
+
+ # the custom type itself should not be added as an argument
+ continue
+
+ required = not self._param_is_default(param)
+
+ kwargs = {
+ "type": param_type,
+ "dest": param.name,
+ "default": param.default,
+ "required": required,
+ "action": self._get_action_type(param),
+ "help": self._get_argument_custom_help(param),
+ "nargs": self._get_nargs(param),
+ }
+
+ if choices:
+ kwargs["choices"] = choices
+
+ if param_type == bool:
+ # store_true action does not accept the below kwargs
+ kwargs.pop("type")
+ kwargs.pop("nargs")
+
+ if required:
+ self._required.add_argument(
+ f"--{param.name}",
+ **kwargs,
+ )
+ else:
+ self._parser.add_argument(
+ f"--{param.name}",
+ **kwargs,
+ )
+
+ @staticmethod
+ def _unflatten_args(args: dict) -> Dict[str, Any]:
+ """Unflatten the args that were flattened by the custom types."""
+ result: Dict[str, Any] = {}
+ for key, value in args.items():
+ if SEP in key:
+ parts = key.split(SEP)
+ nested_dict = result
+ for part in parts[:-1]:
+ if part not in nested_dict:
+ nested_dict[part] = {}
+ nested_dict = nested_dict[part]
+ nested_dict[parts[-1]] = value
+ else:
+ result[key] = value
+ return result
+
+ def _update_with_custom_types(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the kwargs with the custom types."""
+ # for each argument in the signature that is a custom type, we need to
+ # update the kwargs with the custom type kwargs
+ for param in self.signature.parameters.values():
+ # TODO : how to handle kwargs?
+ if param.name == "kwargs":
+ continue
+ param_type, _ = self._get_type_and_choices(param)
+ if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
+ custom_type_kwargs = kwargs[param.name]
+ kwargs[param.name] = param_type(**custom_type_kwargs)
+
+ return kwargs
+
+ def execute_func(
+ self,
+ parsed_args: Optional[argparse.Namespace] = None,
+ ) -> Any:
+ """
+ Executes the original function with the parsed arguments.
+
+ Args:
+ parsed_args (Optional[argparse.Namespace], optional): The parsed arguments. Defaults to None.
+
+ Returns:
+ Any: The return value of the original function.
+
+ """
+ kwargs = self._unflatten_args(vars(parsed_args))
+ kwargs = self._update_with_custom_types(kwargs)
+
+ # remove kwargs that doesn't match the signature or provider parameters
+ kwargs = {
+ key: value
+ for key, value in kwargs.items()
+ if key in self.signature.parameters or key in self.provider_parameters
+ }
+
+ return self.func(**kwargs)
+
+ def parse_args_and_execute(self) -> Any:
+ """
+ Parses the arguments and executes the original function.
+
+ Returns:
+ Any: The return value of the original function.
+ """
+ parsed_args = self._parser.parse_args()
+ return self.execute_func(parsed_args)
+
+ def translate(self) -> Callable:
+ """
+ Wraps the original function with an argparse program.
+
+ Returns:
+ Callable: The original function wrapped with an argparse program.
+ """
+
+ def wrapper_func():
+ return self.parse_args_and_execute()
+
+ return wrapper_func