summaryrefslogtreecommitdiffstats
path: root/cli/openbb_cli/argparse_translator/reference_processor.py
blob: 53cba266cf5bdd8a8f1f34a2284885a6b1f226be (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Module for the ReferenceToArgumentsProcessor class."""

from typing import (
    Any,
    Dict,
    List,
    Literal,
    Optional,
    Tuple,
    Union,
    get_args,
    get_origin,
)

from openbb_cli.argparse_translator.argparse_argument import (
    ArgparseArgumentGroupModel,
    ArgparseArgumentModel,
)


class ReferenceToArgumentsProcessor:
    """Class to process the reference and build custom argument groups."""

    def __init__(self, reference: Dict[str, Dict]):
        """Initialize the ReferenceToArgumentsProcessor."""
        self._reference = reference
        self._custom_groups: Dict[str, List[ArgparseArgumentGroupModel]] = {}

        self._build_custom_groups()

    @property
    def custom_groups(self) -> Dict[str, List[ArgparseArgumentGroupModel]]:
        """Get the custom groups."""
        return self._custom_groups

    @staticmethod
    def _make_type_parsable(type_: str) -> type:
        """Make the type parsable by removing the annotations."""
        if "Union" in type_ and "str" in type_:
            return str
        if "Union" in type_ and "int" in type_:
            return int
        if type_ in ["date", "datetime.time", "time"]:
            return str

        if any(x in type_ for x in ["gt=", "ge=", "lt=", "le="]):
            if "Annotated" in type_:
                type_ = type_.replace("Annotated[", "").replace("]", "")
            type_ = type_.split(",")[0]

        return eval(type_)  # noqa: S307, E501 pylint: disable=eval-used

    def _parse_type(self, type_: str) -> type:
        """Parse the type from the string representation."""
        type_ = self._make_type_parsable(type_)  # type: ignore

        if get_origin(type_) is Literal:
            type_ = type(get_args(type_)[0])  # type: ignore

        return type_  # type: ignore

    def _get_nargs(self, type_: type) -> Optional[Union[int, str]]:
        """Get the nargs for the given type."""
        if get_origin(type_) is list:
            return "+"
        return None

    def _get_choices(self, type_: str, custom_choices: Any) -> Tuple:
        """Get the choices for the given type."""
        type_ = self._make_type_parsable(type_)  # type: ignore
        type_origin = get_origin(type_)

        choices: tuple[Any, ...] = ()

        if type_origin is Literal:
            choices = get_args(type_)

        if type_origin is list:
            type_ = get_args(type_)[0]

            if get_origin(type_) is Literal:
                choices = get_args(type_)

        if type_origin is Union and type(None) in get_args(type_):
            # remove NoneType from the args
            args = [arg for arg in get_args(type_) if arg != type(None)]
            # if there is only one arg left, use it
            if len(args) > 1:
                raise ValueError("Union with NoneType should have only one type left")
            type_ = args[0]

            if get_origin(type_) is Literal:
                choices = get_args(type_)

        if custom_choices:
            return tuple(custom_choices)

        return choices

    def _build_custom_groups(self):
        """Build the custom groups from the reference."""
        for route, v in self._reference.items():
            for provider, args in v["parameters"].items():
                if provider == "standard":
                    continue

                custom_arguments = []
                for arg in args:
                    if arg.get("standard"):
                        continue

                    type_ = self._parse_type(arg["type"])

                    custom_arguments.append(
                        ArgparseArgumentModel(
                            name=arg["name"],
                            type=type_,
                            dest=arg["name"],
                            default=arg["default"],
                            required=not (arg["optional"]),
                            action="store" if type_ != bool else "store_true",
                            help=arg["description"],
                            nargs=self._get_nargs(type_),  # type: ignore
                            choices=self._get_choices(
                                arg["type"], custom_choices=arg["choices"]
                            ),
                        )
                    )

                group = ArgparseArgumentGroupModel(
                    name=provider, arguments=custom_arguments
                )

                if route not in self._custom_groups:
                    self._custom_groups[route] = []

                self._custom_groups[route].append(group)