diff options
Diffstat (limited to 'cli/openbb_cli/argparse_translator/argparse_translator.py')
-rw-r--r-- | cli/openbb_cli/argparse_translator/argparse_translator.py | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/cli/openbb_cli/argparse_translator/argparse_translator.py b/cli/openbb_cli/argparse_translator/argparse_translator.py index 001e7afce04..724c2bea306 100644 --- a/cli/openbb_cli/argparse_translator/argparse_translator.py +++ b/cli/openbb_cli/argparse_translator/argparse_translator.py @@ -203,7 +203,7 @@ class ArgparseTranslator: self.func = func self.signature = inspect.signature(func) self.type_hints = get_type_hints(func) - self.provider_parameters: List[str] = [] + self.provider_parameters: Dict[str, List[str]] = {} self._parser = argparse.ArgumentParser( prog=func.__name__, @@ -218,6 +218,7 @@ class ArgparseTranslator: if custom_argument_groups: for group in custom_argument_groups: + self.provider_parameters[group.name] = [] argparse_group = self._parser.add_argument_group(group.name) for argument in group.arguments: self._handle_argument_in_groups(argument, argparse_group) @@ -278,7 +279,8 @@ class ArgparseTranslator: if f"--{argument.name}" not in self._parser._option_string_actions: kwargs = argument.model_dump(exclude={"name"}, exclude_none=True) group.add_argument(f"--{argument.name}", **kwargs) - self.provider_parameters.append(argument.name) + if group.title in self.provider_parameters: + self.provider_parameters[group.title].append(argument.name) else: kwargs = argument.model_dump(exclude={"name"}, exclude_none=True) @@ -582,11 +584,19 @@ class ArgparseTranslator: kwargs = self._unflatten_args(vars(parsed_args)) kwargs = self._update_with_custom_types(kwargs) + provider = kwargs.get("provider") + provider_args = [] + if provider and provider in self.provider_parameters: + provider_args = self.provider_parameters[provider] + else: + for args in self.provider_parameters.values(): + provider_args.extend(args) + # remove kwargs that doesn't match the signature or provider parameters kwargs = { key: value for key, value in kwargs.items() - if key in self.signature.parameters or key in self.provider_parameters + if key in self.signature.parameters or key in provider_args } return self.func(**kwargs) |