summaryrefslogtreecommitdiffstats
path: root/cli/openbb_cli/argparse_translator/argparse_class_processor.py
blob: 6fff39aa18e33b03b62469ce6d9469dc79eac023 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import inspect
from typing import Any, Dict, Optional, Type

# TODO: this needs to be done differently
from openbb_core.app.static.container import Container

from openbb_cli.argparse_translator.argparse_translator import (
    ArgparseTranslator,
    ReferenceToCustomArgumentsProcessor,
)


class ArgparseClassProcessor:
    """
    Process a target class to create ArgparseTranslators for its methods.
    """

    # reference variable used to create custom groups for the ArgpaseTranslators
    _reference: Dict[str, Any] = {}

    def __init__(
        self,
        target_class: Type,
        add_help: bool = False,
        reference: Optional[Dict[str, Any]] = None,
    ):
        """
        Initialize the ArgparseClassProcessor.

        Parameters
        ----------
        target_class : Type
            The target class whose methods will be processed.
        add_help : Optional[bool]
            Whether to add help to the ArgparseTranslators.
        """
        self._target_class: Type = target_class
        self._add_help: bool = add_help
        self._translators: Dict[str, ArgparseTranslator] = {}
        self._paths: Dict[str, str] = {}

        ArgparseClassProcessor._reference = reference or {}

        self._translators = self._process_class(
            target=self._target_class, add_help=self._add_help
        )
        self._paths[self._get_class_name(self._target_class)] = "path"
        self._build_paths(target=self._target_class)

    @property
    def translators(self) -> Dict[str, ArgparseTranslator]:
        """
        Get the ArgparseTranslators associated with the target class.

        Returns
        -------
        Dict[str, ArgparseTranslator]
            The ArgparseTranslators associated with the target class.
        """
        return self._translators

    @property
    def paths(self) -> Dict[str, str]:
        """
        Get the paths associated with the target class.

        Returns
        -------
        Dict[str, str]
            The paths associated with the target class.
        """
        return self._paths

    @classmethod
    def _custom_groups_from_reference(cls, class_name: str, function_name: str) -> Dict:
        route = f"/{class_name.replace('_', '/')}/{function_name}"
        reference = {route: cls._reference[route]} if route in cls._reference else {}
        if not reference:
            return {}
        rp = ReferenceToCustomArgumentsProcessor(reference)
        return rp.custom_groups.get(route, {})  # type: ignore

    @classmethod
    def _process_class(
        cls,
        target: type,
        add_help: bool = False,
    ) -> Dict[str, ArgparseTranslator]:
        methods = {}

        for name, member in inspect.getmembers(target):
            if name.startswith("__") or name.startswith("_"):
                continue
            if inspect.ismethod(member):

                class_name = cls._get_class_name(target)
                methods[f"{class_name}_{name}"] = ArgparseTranslator(
                    func=member,
                    add_help=add_help,
                    custom_argument_groups=cls._custom_groups_from_reference(  # type: ignore
                        class_name=class_name, function_name=name
                    ),
                )
            elif isinstance(member, Container):
                methods = {
                    **methods,
                    **cls._process_class(
                        target=getattr(target, name), add_help=add_help
                    ),
                }

        return methods

    @staticmethod
    def _get_class_name(target: type) -> str:
        return (
            str(type(target))
            .rsplit(".", maxsplit=1)[-1]
            .replace("'>", "")
            .replace("ROUTER_", "")
            .lower()
        )

    def get_translator(self, command: str) -> ArgparseTranslator:
        """
        Retrieve the ArgparseTranslator object associated with a specific menu and command.

        Parameters
        ----------
        command : str
            The command associated with the ArgparseTranslator.

        Returns
        -------
        ArgparseTranslator
            The ArgparseTranslator associated with the specified menu and command.
        """
        return self._translators[command]

    def _build_paths(self, target: type, depth: int = 1):
        for name, member in inspect.getmembers(target):
            if name.startswith("__") or name.startswith("_"):
                continue
            if inspect.ismethod(member):
                pass
            elif isinstance(member, Container):
                self._build_paths(target=getattr(target, name), depth=depth + 1)
                self._paths[f"{name}"] = "sub" * depth + "path"