summaryrefslogtreecommitdiffstats
path: root/cli/openbb_cli/argparse_translator/argparse_translator.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/openbb_cli/argparse_translator/argparse_translator.py')
-rw-r--r--cli/openbb_cli/argparse_translator/argparse_translator.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/cli/openbb_cli/argparse_translator/argparse_translator.py b/cli/openbb_cli/argparse_translator/argparse_translator.py
index 001e7afce04..724c2bea306 100644
--- a/cli/openbb_cli/argparse_translator/argparse_translator.py
+++ b/cli/openbb_cli/argparse_translator/argparse_translator.py
@@ -203,7 +203,7 @@ class ArgparseTranslator:
self.func = func
self.signature = inspect.signature(func)
self.type_hints = get_type_hints(func)
- self.provider_parameters: List[str] = []
+ self.provider_parameters: Dict[str, List[str]] = {}
self._parser = argparse.ArgumentParser(
prog=func.__name__,
@@ -218,6 +218,7 @@ class ArgparseTranslator:
if custom_argument_groups:
for group in custom_argument_groups:
+ self.provider_parameters[group.name] = []
argparse_group = self._parser.add_argument_group(group.name)
for argument in group.arguments:
self._handle_argument_in_groups(argument, argparse_group)
@@ -278,7 +279,8 @@ class ArgparseTranslator:
if f"--{argument.name}" not in self._parser._option_string_actions:
kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
group.add_argument(f"--{argument.name}", **kwargs)
- self.provider_parameters.append(argument.name)
+ if group.title in self.provider_parameters:
+ self.provider_parameters[group.title].append(argument.name)
else:
kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
@@ -582,11 +584,19 @@ class ArgparseTranslator:
kwargs = self._unflatten_args(vars(parsed_args))
kwargs = self._update_with_custom_types(kwargs)
+ provider = kwargs.get("provider")
+ provider_args = []
+ if provider and provider in self.provider_parameters:
+ provider_args = self.provider_parameters[provider]
+ else:
+ for args in self.provider_parameters.values():
+ provider_args.extend(args)
+
# remove kwargs that doesn't match the signature or provider parameters
kwargs = {
key: value
for key, value in kwargs.items()
- if key in self.signature.parameters or key in self.provider_parameters
+ if key in self.signature.parameters or key in provider_args
}
return self.func(**kwargs)