summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenrique Joaquim <henriquecjoaquim@gmail.com>2024-06-04 08:57:36 +0100
committerGitHub <noreply@github.com>2024-06-04 07:57:36 +0000
commitf860c6d7979ba9944d75ea32bab259566a6cc8c5 (patch)
treecaa0178197da9490743a4d824a61433ed9d966c5
parent8a63c2df3e8831d52af4f1fca05c65e859e6f31b (diff)
[Feature] Optional choices (#6463)
* using the custom argument to leverage validation * restructure * better handling of variables and props * optional choices * utilities module * proper handling of optional choices attr * setattr --------- Co-authored-by: Danglewood <85772166+deeleeramone@users.noreply.github.com> Co-authored-by: Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com>
-rw-r--r--cli/openbb_cli/argparse_translator/argparse_argument.py63
-rw-r--r--cli/openbb_cli/argparse_translator/argparse_class_processor.py14
-rw-r--r--cli/openbb_cli/argparse_translator/argparse_translator.py322
-rw-r--r--cli/openbb_cli/argparse_translator/reference_processor.py137
-rw-r--r--cli/openbb_cli/argparse_translator/utils.py71
-rw-r--r--cli/openbb_cli/controllers/base_controller.py11
-rw-r--r--cli/tests/test_argparse_translator.py14
7 files changed, 366 insertions, 266 deletions
diff --git a/cli/openbb_cli/argparse_translator/argparse_argument.py b/cli/openbb_cli/argparse_translator/argparse_argument.py
new file mode 100644
index 00000000000..2a40d9f1e4c
--- /dev/null
+++ b/cli/openbb_cli/argparse_translator/argparse_argument.py
@@ -0,0 +1,63 @@
+"""Pydantic models for argparse arguments and argument groups."""
+
+from typing import (
+ Any,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+)
+
+from pydantic import BaseModel, model_validator
+
+SEP = "__"
+
+
+class ArgparseArgumentModel(BaseModel):
+ """Pydantic model for an argparse argument."""
+
+ name: str
+ type: Any
+ dest: str
+ default: Any
+ required: bool
+ action: Literal["store_true", "store"]
+ help: Optional[str]
+ nargs: Optional[Literal["+"]]
+ choices: Optional[Tuple]
+
+ @model_validator(mode="after") # type: ignore
+ @classmethod
+ def validate_action(cls, values: "ArgparseArgumentModel"):
+ """Validate the action based on the type."""
+ if values.type is bool and values.action != "store_true":
+ raise ValueError('If type is bool, action must be "store_true"')
+ return values
+
+ @model_validator(mode="after") # type: ignore
+ @classmethod
+ def remove_props_on_store_true(cls, values: "ArgparseArgumentModel"):
+ """Remove type, nargs, and choices if action is store_true."""
+ if values.action == "store_true":
+ values.type = None
+ values.nargs = None
+ values.choices = None
+ return values
+
+ # override
+ def model_dump(self, **kwargs):
+ """Override the model_dump method to remove empty choices."""
+ res = super().model_dump(**kwargs)
+
+ # Check if choices is present and if it's an empty tuple remove it
+ if "choices" in res and not res["choices"]:
+ del res["choices"]
+
+ return res
+
+
+class ArgparseArgumentGroupModel(BaseModel):
+ """Pydantic model for a custom argument group."""
+
+ name: str
+ arguments: List[ArgparseArgumentModel]
diff --git a/cli/openbb_cli/argparse_translator/argparse_class_processor.py b/cli/openbb_cli/argparse_translator/argparse_class_processor.py
index d7ddcdafcd2..af7b9e56ab4 100644
--- a/cli/openbb_cli/argparse_translator/argparse_class_processor.py
+++ b/cli/openbb_cli/argparse_translator/argparse_class_processor.py
@@ -1,19 +1,19 @@
+"""Module for the ArgparseClassProcessor class."""
+
import inspect
from typing import Any, Dict, Optional, Type
# TODO: this needs to be done differently
from openbb_core.app.static.container import Container
-from openbb_cli.argparse_translator.argparse_translator import (
- ArgparseTranslator,
- ReferenceToCustomArgumentsProcessor,
+from openbb_cli.argparse_translator.argparse_translator import ArgparseTranslator
+from openbb_cli.argparse_translator.reference_processor import (
+ ReferenceToArgumentsProcessor,
)
class ArgparseClassProcessor:
- """
- Process a target class to create ArgparseTranslators for its methods.
- """
+ """Process a target class to create ArgparseTranslators for its methods."""
# reference variable used to create custom groups for the ArgpaseTranslators
_reference: Dict[str, Any] = {}
@@ -77,7 +77,7 @@ class ArgparseClassProcessor:
reference = {route: cls._reference[route]} if route in cls._reference else {}
if not reference:
return {}
- rp = ReferenceToCustomArgumentsProcessor(reference)
+ rp = ReferenceToArgumentsProcessor(reference)
return rp.custom_groups.get(route, {}) # type: ignore
@classmethod
diff --git a/cli/openbb_cli/argparse_translator/argparse_translator.py b/cli/openbb_cli/argparse_translator/argparse_translator.py
index 605fbe6b41f..2715f2b3901 100644
--- a/cli/openbb_cli/argparse_translator/argparse_translator.py
+++ b/cli/openbb_cli/argparse_translator/argparse_translator.py
@@ -1,8 +1,9 @@
+"""Module for translating a function into an argparse program."""
+
import argparse
import inspect
import re
from copy import deepcopy
-from enum import Enum
from typing import (
Any,
Callable,
@@ -19,181 +20,37 @@ from typing import (
)
from openbb_core.app.model.field import OpenBBField
-from pydantic import BaseModel, model_validator
+from pydantic import BaseModel
from typing_extensions import Annotated
+from openbb_cli.argparse_translator.argparse_argument import (
+ ArgparseArgumentGroupModel,
+ ArgparseArgumentModel,
+)
+from openbb_cli.argparse_translator.utils import (
+ get_argument_choices,
+ get_argument_optional_choices,
+ in_group,
+ remove_argument,
+ set_optional_choices,
+)
+
# pylint: disable=protected-access
SEP = "__"
-class ArgparseActionType(Enum):
- store = "store"
- store_true = "store_true"
-
-
-class CustomArgument(BaseModel):
- name: str
- type: Optional[Any]
- dest: str
- default: Any
- required: bool
- action: Literal["store_true", "store"]
- help: str
- nargs: Optional[Literal["+"]]
- choices: Optional[Tuple]
-
- @model_validator(mode="after") # type: ignore
- @classmethod
- def validate_action(cls, values: "CustomArgument"):
- if values.type is bool and values.action != "store_true":
- raise ValueError('If type is bool, action must be "store_true"')
- return values
-
- @model_validator(mode="after") # type: ignore
- @classmethod
- def remove_props_on_store_true(cls, values: "CustomArgument"):
- if values.action == "store_true":
- values.type = None
- values.nargs = None
- values.choices = None
- return values
-
- # override
- def model_dump(self, **kwargs):
- res = super().model_dump(**kwargs)
-
- # Check if choices is present and if it's an empty tuple remove it
- if "choices" in res and not res["choices"]:
- del res["choices"]
-
- return res
-
-
-class CustomArgumentGroup(BaseModel):
- name: str
- arguments: List[CustomArgument]
-
-
-class ReferenceToCustomArgumentsProcessor:
- def __init__(self, reference: Dict[str, Dict]):
- """Initializes the ReferenceToCustomArgumentsProcessor."""
- self.reference = reference
- self.custom_groups: Dict[str, List[CustomArgumentGroup]] = {}
-
- self.build_custom_groups()
-
- @staticmethod
- def _make_type_parsable(type_: str) -> type:
- """Make the type parsable by removing the annotations."""
- if "Union" in type_ and "str" in type_:
- return str
- if "Union" in type_ and "int" in type_:
- return int
- if type_ in ["date", "datetime.time", "time"]:
- return str
-
- if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]):
- if "Annotated" in type_:
- type_ = type_.replace("Annotated[", "").replace("]", "")
- type_ = type_.split(",")[0]
-
- return eval(type_) # noqa: S307, E501 pylint: disable=eval-used
-
- def _parse_type(self, type_: str) -> type:
- """Parse the type from the string representation."""
- type_ = self._make_type_parsable(type_) # type: ignore
-
- if get_origin(type_) is Literal:
- type_ = type(get_args(type_)[0]) # type: ignore
-
- return type_ # type: ignore
-
- def _get_nargs(self, type_: type) -> Optional[Union[int, str]]:
- """Get the nargs for the given type."""
- if get_origin(type_) is list:
- return "+"
- return None
-
- def _get_choices(self, type_: str, custom_choices: Any) -> Tuple:
- """Get the choices for the given type."""
- type_ = self._make_type_parsable(type_) # type: ignore
- type_origin = get_origin(type_)
-
- choices = ()
-
- if type_origin is Literal:
- choices = get_args(type_)
-
- if type_origin is list:
- type_ = get_args(type_)[0]
-
- if get_origin(type_) is Literal:
- choices = get_args(type_)
-
- if type_origin is Union and type(None) in get_args(type_):
- # remove NoneType from the args
- args = [arg for arg in get_args(type_) if arg != type(None)]
- # if there is only one arg left, use it
- if len(args) > 1:
- raise ValueError("Union with NoneType should have only one type left")
- type_ = args[0]
-
- if get_origin(type_) is Literal:
- choices = get_args(type_)
-
- if custom_choices:
- return tuple(custom_choices)
-
- return choices
-
- def build_custom_groups(self):
- """Build the custom groups from the reference."""
- for route, v in self.reference.items():
- for provider, args in v["parameters"].items():
- if provider == "standard":
- continue
-
- custom_arguments = []
- for arg in args:
- if arg.get("standard"):
- continue
-
- type_ = self._parse_type(arg["type"])
-
- custom_arguments.append(
- CustomArgument(
- name=arg["name"],
- type=type_,
- dest=arg["name"],
- default=arg["default"],
- required=not (arg["optional"]),
- action="store" if type_ != bool else "store_true",
- help=arg["description"],
- nargs=self._get_nargs(type_), # type: ignore
- choices=self._get_choices(
- arg["type"], custom_choices=arg["choices"]
- ),
- )
- )
-
- group = CustomArgumentGroup(name=provider, arguments=custom_arguments)
-
- if route not in self.custom_groups:
- self.custom_groups[route] = []
-
- self.custom_groups[route].append(group)
-
-
class ArgparseTranslator:
+ """Class to translate a function into an argparse program."""
+
def __init__(
self,
func: Callable,
- custom_argument_groups: Optional[List[CustomArgumentGroup]] = None,
+ custom_argument_groups: Optional[List[ArgparseArgumentGroupModel]] = None,
add_help: Optional[bool] = True,
):
"""
- Initializes the ArgparseTranslator.
+ Initialize the ArgparseTranslator.
Args:
func (Callable): The function to translate into an argparse program.
@@ -225,45 +82,6 @@ class ArgparseTranslator:
def _handle_argument_in_groups(self, argument, group):
"""Handle the argument and add it to the parser."""
- def _in_group(arg, group_title):
- for action_group in self._parser._action_groups:
- if action_group.title == group_title:
- for action in action_group._group_actions:
- opts = action.option_strings
- if (opts and opts[0] == arg) or action.dest == arg:
- return True
- return False
-
- def _remove_argument(arg) -> List[Optional[str]]:
- groups_w_arg = []
-
- # remove the argument from the parser
- for action in self._parser._actions:
- opts = action.option_strings
- if (opts and opts[0] == arg) or action.dest == arg:
- self._parser._remove_action(action)
- break
-
- # remove from all groups
- for action_group in self._parser._action_groups:
- for action in action_group._group_actions:
- opts = action.option_strings
- if (opts and opts[0] == arg) or action.dest == arg:
- action_group._group_actions.remove(action)
- groups_w_arg.append(action_group.title)
-
- # remove from _action_groups dict
- self._parser._option_string_actions.pop(f"--{arg}", None)
-
- return groups_w_arg
-
- def _get_arg_choices(arg) -> Tuple:
- for action in self._parser._actions:
- opts = action.option_strings
- if (opts and opts[0] == arg) or action.dest == arg:
- return tuple(action.choices or ())
- return ()
-
def _update_providers(
input_string: str, new_provider: List[Optional[str]]
) -> str:
@@ -285,23 +103,27 @@ class ArgparseTranslator:
kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
model_choices = kwargs.get("choices", ()) or ()
# extend choices
- choices = tuple(set(_get_arg_choices(argument.name) + model_choices))
+ existing_choices = get_argument_choices(self._parser, argument.name)
+ choices = tuple(set(existing_choices + model_choices))
+ optional_choices = bool(existing_choices and not model_choices)
# check if the argument is in the required arguments
- if _in_group(argument.name, group_title="required arguments"):
+ if in_group(self._parser, argument.name, group_title="required arguments"):
for action in self._required._group_actions:
if action.dest == argument.name and choices:
# update choices
action.choices = choices
+ set_optional_choices(action, optional_choices)
return
# check if the argument is in the optional arguments
- if _in_group(argument.name, group_title="optional arguments"):
+ if in_group(self._parser, argument.name, group_title="optional arguments"):
for action in self._parser._actions:
if action.dest == argument.name:
# update choices
if choices:
action.choices = choices
+ set_optional_choices(action, optional_choices)
if argument.name not in self.signature.parameters:
# update help
action.help = _update_providers(
@@ -309,9 +131,16 @@ class ArgparseTranslator:
)
return
+ # we need to check if the optional choices were set in other group
+ # before we remove the argument from the group, otherwise we will lose info
+ if not optional_choices:
+ optional_choices = get_argument_optional_choices(
+ self._parser, argument.name
+ )
+
# if the argument is in use, remove it from all groups
# and return the groups that had the argument
- groups_w_arg = _remove_argument(argument.name)
+ groups_w_arg = remove_argument(self._parser, argument.name)
groups_w_arg.append(group.title) # add current group
# add it to the optional arguments group instead
@@ -319,16 +148,17 @@ class ArgparseTranslator:
kwargs["choices"] = choices # update choices
# add provider info to the help
kwargs["help"] = _update_providers(argument.help or "", groups_w_arg)
- self._parser.add_argument(f"--{argument.name}", **kwargs)
+ action = self._parser.add_argument(f"--{argument.name}", **kwargs)
+ set_optional_choices(action, optional_choices)
@property
def parser(self) -> argparse.ArgumentParser:
+ """Get the argparse parser."""
return deepcopy(self._parser)
@staticmethod
def _build_description(func_doc: str) -> str:
- """Builds the description of the argparse program from the function docstring."""
-
+ """Build the description of the argparse program from the function docstring."""
patterns = ["openbb\n ======", "Parameters\n ----------"]
if func_doc:
@@ -341,31 +171,34 @@ class ArgparseTranslator:
@staticmethod
def _param_is_default(param: inspect.Parameter) -> bool:
- """Returns True if the parameter has a default value."""
+ """Return True if the parameter has a default value."""
return param.default != inspect.Parameter.empty
def _get_action_type(self, param: inspect.Parameter) -> str:
- """Returns the argparse action type for the given parameter."""
+ """Return the argparse action type for the given parameter."""
param_type = self.type_hints[param.name]
+ type_origin = get_origin(param_type)
- if param_type == bool:
- return ArgparseActionType.store_true.value
- return ArgparseActionType.store.value
+ if param_type == bool or (
+ type_origin is Union and bool in get_args(param_type)
+ ):
+ return "store_true"
+ return "store"
def _get_type_and_choices(
self, param: inspect.Parameter
) -> Tuple[Type[Any], Tuple[Any, ...]]:
- """Returns the type and choices for the given parameter."""
+ """Return the type and choices for the given parameter."""
param_type = self.type_hints[param.name]
type_origin = get_origin(param_type)
- choices = ()
+ choices: tuple[Any, ...] = ()
if type_origin is Literal:
choices = get_args(param_type)
param_type = type(choices[0]) # type: ignore
- if type_origin is list: # TODO: dict should also go here
+ if type_origin is list:
param_type = get_args(param_type)[0]
if get_origin(param_type) is Literal:
@@ -413,32 +246,26 @@ class ArgparseTranslator:
@classmethod
def _get_argument_custom_help(cls, param: inspect.Parameter) -> Optional[str]:
- """Returns the help annotation for the given parameter."""
+ """Return the help annotation for the given parameter."""
base_annotation = param.annotation
_, custom_annotations = cls._split_annotation(base_annotation, OpenBBField)
help_annotation = (
custom_annotations[0].description if custom_annotations else None
)
- if not help_annotation:
- # try to get it from the docstring
- pass
return help_annotation
@classmethod
def _get_argument_custom_choices(cls, param: inspect.Parameter) -> Optional[str]:
- """Returns the help annotation for the given parameter."""
+ """Return the help annotation for the given parameter."""
base_annotation = param.annotation
_, custom_annotations = cls._split_annotation(base_annotation, OpenBBField)
choices_annotation = (
custom_annotations[0].choices if custom_annotations else None
)
- if not choices_annotation:
- # try to get it from the docstring
- pass
return choices_annotation
def _get_nargs(self, param: inspect.Parameter) -> Optional[str]:
- """Returns the nargs annotation for the given parameter."""
+ """Return the nargs annotation for the given parameter."""
param_type = self.type_hints[param.name]
origin = get_origin(param_type)
@@ -453,11 +280,8 @@ class ArgparseTranslator:
return None
def _generate_argparse_arguments(self, parameters) -> None:
- """Generates the argparse arguments from the function parameters."""
+ """Generate the argparse arguments from the function parameters."""
for param in parameters.values():
- # TODO : how to handle kwargs?
- # it's possible to add unknown arguments when parsing as follows:
- # args, unknown_args = parser.parse_known_args()
if param.name == "kwargs":
continue
@@ -505,32 +329,27 @@ class ArgparseTranslator:
required = not self._param_is_default(param)
- kwargs = {
- "type": param_type,
- "dest": param.name,
- "default": param.default,
- "required": required,
- "action": self._get_action_type(param),
- "help": self._get_argument_custom_help(param),
- "nargs": self._get_nargs(param),
- }
-
- if choices:
- kwargs["choices"] = choices
-
- if param_type == bool:
- # store_true action does not accept the below kwargs
- kwargs.pop("type")
- kwargs.pop("nargs")
+ argument = ArgparseArgumentModel(
+ name=param.name,
+ type=param_type,
+ dest=param.name,
+ default=param.default,
+ required=required,
+ action=self._get_action_type(param),
+ help=self._get_argument_custom_help(param),
+ nargs=self._get_nargs(param),
+ choices=choices,
+ )
+ kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
if required:
self._required.add_argument(
- f"--{param.name}",
+ f"--{argument.name}",
**kwargs,
)
else:
self._parser.add_argument(
- f"--{param.name}",
+ f"--{argument.name}",
**kwargs,
)
@@ -556,7 +375,6 @@ class ArgparseTranslator:
# for each argument in the signature that is a custom type, we need to
# update the kwargs with the custom type kwargs
for param in self.signature.parameters.values():
- # TODO : how to handle kwargs?
if param.name == "kwargs":
continue
param_type, _ = self._get_type_and_choices(param)
@@ -571,7 +389,7 @@ class ArgparseTranslator:
parsed_args: Optional[argparse.Namespace] = None,
) -> Any:
"""
- Executes the original function with the parsed arguments.
+ Execute the original function with the parsed arguments.
Args:
parsed_args (Optional[argparse.Namespace], optional): The parsed arguments. Defaults to None.
@@ -602,7 +420,7 @@ class ArgparseTranslator:
def parse_args_and_execute(self) -> Any:
"""
- Parses the arguments and executes the original function.
+ Parse the arguments and executes the original function.
Returns:
Any: The return value of the original function.
@@ -612,7 +430,7 @@ class ArgparseTranslator:
def translate(self) -> Callable:
"""
- Wraps the original function with an argparse program.
+ Wrap the original function with an argparse program.
Returns:
Callable: The original function wrapped with an argparse program.
diff --git a/cli/openbb_cli/argparse_translator/reference_processor.py b/cli/openbb_cli/argparse_translator/reference_processor.py
new file mode 100644
index 00000000000..53cba266cf5
--- /dev/null
+++ b/cli/openbb_cli/argparse_translator/reference_processor.py
@@ -0,0 +1,137 @@
+"""Module for the ReferenceToArgumentsProcessor class."""
+
+from typing import (
+ Any,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Union,
+ get_args,
+ get_origin,
+)
+
+from openbb_cli.argparse_translator.argparse_argument import (
+ ArgparseArgumentGroupModel,
+ ArgparseArgumentModel,
+)
+
+
+class ReferenceToArgumentsProcessor:
+ """Class to process the reference and build custom argument groups."""
+
+ def __init__(self, reference: Dict[str, Dict]):
+ """Initialize the ReferenceToArgumentsProcessor."""
+ self._reference = reference
+ self._custom_groups: Dict[str, List[ArgparseArgumentGroupModel]] = {}
+
+ self._build_custom_groups()
+
+ @property
+ def custom_groups(self) -> Dict[str, List[ArgparseArgumentGroupModel]]:
+ """Get the custom groups."""
+ return self._custom_groups
+
+ @staticmethod
+ def _make_type_parsable(type_: str) -> type:
+ """Make the type parsable by removing the annotations."""
+ if "Union" in type_ and "str" in type_:
+ return str
+ if "Union" in type_ and "int" in type_:
+ return int
+ if type_ in ["date", "datetime.time", "time"]:
+ return str
+
+ if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]):
+ if "Annotated" in type_:
+ type_ = type_.replace("Annotated[", "").replace("]", "")
+ type_ = type_.split(",")[0]
+
+ return eval(type_) # noqa: S307, E501 pylint: disable=eval-used
+
+ def _parse_type(self, type_: str) -> type:
+ """Parse the type from the string representation."""
+ type_ = self._make_type_parsable(type_) # type: ignore
+
+ if get_origin(type_) is Literal:
+ type_ = type(get_args(type_)[0]) # type: ignore
+
+ return type_ # type: ignore
+
+ def _get_nargs(self, type_: type) -> Optional[Union[int, str]]:
+ """Get the nargs for the given type."""
+ if get_origin(type_) is list:
+ return "+"
+ return None
+
+ def _get_choices(self, type_: str, custom_choices: Any) -> Tuple:
+ """Get the choices for the given type."""
+ type_ = self._make_type_parsable(type_) # type: ignore
+ type_origin = get_origin(type_)
+
+ choices: tuple[Any, ...] = ()
+
+ if type_origin is Literal:
+ choices = get_args(type_)
+
+ if type_origin is list:
+ type_ = get_args(type_)[0]
+
+ if get_origin(type_) is Literal:
+ choices = get_args(type_)
+
+ if type_origin is Union and type(None) in get_args(type_):
+ # remove NoneType from the args
+ args = [arg for arg in get_args(type_) if arg != type(None)]
+ # if there is only one arg left, use it
+ if len(args) > 1:
+ raise ValueError("Union with NoneType should have only one type left")
+ type_ = args[0]
+
+ if get_origin(type_) is Literal:
+ choices = get_args(type_)
+
+ if custom_choices:
+ return tuple(custom_choices)
+
+ return choices
+
+ def _build_custom_groups(self):
+ """Build the custom groups from the reference."""
+ for route, v in self._reference.items():
+ for provider, args in v["parameters"].items():
+ if provider == "standard":
+ continue
+
+ custom_arguments = []
+ for arg in args:
+ if arg.get("standard"):
+ continue
+
+ type_ = self._parse_type(arg["type"])
+
+ custom_arguments.append(
+ ArgparseArgumentModel(
+ name=arg["name"],
+ type=type_,
+ dest=arg["name"],
+ default=arg["default"],
+ required=not (arg["optional"]),
+ action="store" if type_ != boo