summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/api/router/commands.py
blob: d524156ddbd2fac3b4d6bd843014be09572e2696 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Commands: generates the command map."""

import inspect
from functools import partial, wraps
from inspect import Parameter, Signature, signature
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar

from fastapi import APIRouter, Depends, Header
from fastapi.routing import APIRoute
from openbb_core.app.command_runner import CommandRunner
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.router import RouterLoader
from openbb_core.app.service.auth_service import AuthService
from openbb_core.app.service.system_service import SystemService
from openbb_core.app.service.user_service import UserService
from openbb_core.env import Env
from pydantic import BaseModel
from typing_extensions import Annotated, ParamSpec

try:
    from openbb_charting import Charting

    CHARTING_INSTALLED = True
except ImportError:
    CHARTING_INSTALLED = False

T = TypeVar("T")
P = ParamSpec("P")

router = APIRouter(prefix="")


def build_new_annotation_map(sig: Signature) -> Dict[str, Any]:
    """Build new annotation map."""
    annotation_map = {}
    parameter_list = sig.parameters.values()

    for parameter in parameter_list:
        annotation_map[parameter.name] = parameter.annotation

    annotation_map["return"] = sig.return_annotation

    return annotation_map


def build_new_signature(path: str, func: Callable) -> Signature:
    """Build new function signature."""
    sig = signature(func)
    parameter_list = sig.parameters.values()
    return_annotation = sig.return_annotation
    new_parameter_list = []
    var_kw_pos = len(parameter_list)
    for pos, parameter in enumerate(parameter_list):
        if parameter.name == "cc" and parameter.annotation == CommandContext:
            continue

        if parameter.kind == Parameter.VAR_KEYWORD:
            # We track VAR_KEYWORD parameter to insert the any additional
            # parameters we need to add before it and avoid a SyntaxError
            var_kw_pos = pos

        new_parameter_list.append(
            Parameter(
                parameter.name,
                kind=parameter.kind,
                default=parameter.default,
                annotation=parameter.annotation,
            )
        )

    if CHARTING_INSTALLED and path.replace("/", "_")[1:] in Charting.functions():
        new_parameter_list.insert(
            var_kw_pos,
            Parameter(
                "chart",
                kind=Parameter.POSITIONAL_OR_KEYWORD,
                default=False,
                annotation=bool,
            ),
        )
        var_kw_pos += 1

    if custom_headers := SystemService().system_settings.api_settings.custom_headers:
        for name, default in custom_headers.items():
            new_parameter_list.insert(
                var_kw_pos,
                Parameter(
                    name.replace("-", "_"),
                    kind=Parameter.POSITIONAL_OR_KEYWORD,
                    default=default,
                    annotation=Annotated[
                        Optional[str], Header(include_in_schema=False)
                    ],
                ),
            )
            var_kw_pos += 1

    if Env().API_AUTH:
        new_parameter_list.insert(
            var_kw_pos,
            Parameter(
                "__authenticated_user_settings",
                kind=Parameter.POSITIONAL_OR_KEYWORD,
                default=UserSettings(),
                annotation=Annotated[
                    UserSettings, Depends(AuthService().user_settings_hook)
                ],
            ),
        )
        var_kw_pos += 1

    return Signature(
        parameters=new_parameter_list,
        return_annotation=return_annotation,
    )


def validate_output(c_out: OBBject) -> OBBject:
    """
    Validate OBBject object.

    Checks against the OBBject schema and removes fields that contain the
    `exclude_from_api` extra `pydantic.Field` kwarg.
    Note that the modification to the `OBBject` object is done in-place.

    Parameters
    ----------
    c_out : OBBject
        OBBject object to validate.

    Returns
    -------
    OBBject
        Validated OBBject object.
    """

    def is_model(type_):
        return inspect.isclass(type_) and issubclass(type_, BaseModel)

    def exclude_fields_from_api(key: str, value: Any):
        type_ = type(value)
        field = c_out.model_fields.get(key, None)
        json_schema_extra = field.json_schema_extra if field else None

        # case where 1st layer field needs to be excluded
        if (
            json_schema_extra
            and isinstance(json_schema_extra, dict)
            and json_schema_extra.get("exclude_from_api", None)
        ):
            delattr(c_out, key)

        # if it's a model with nested fields
        elif is_model(type_):
            for field_name, field in type_.__fields__.items():
                extra = getattr(field, "json_schema_extra", None)
                if (
                    extra
                    and isinstance(extra, dict)
                    and extra.get("exclude_from_api", None)
                ):
                    delattr(value, field_name)

                # if it's a yet a nested model we need to go deeper in the recursion
                elif is_model(getattr(field, "annotation", None)):
                    exclude_fields_from_api(field_name, getattr(value, field_name))

    for k, v in c_out.model_copy():
        exclude_fields_from_api(k, v)

    return c_out


def build_api_wrapper(
    command_runner: CommandRunner,
    route: APIRoute,
) -> Callable:
    """Build API wrapper for a command."""
    func: Callable = route.endpoint  # type: ignore
    path: str = route.path  # type: ignore

    new_signature = build_new_signature(path=path, func=func)
    new_annotations_map = build_new_annotation_map(sig=new_signature)

    func.__signature__ = new_signature  # type: ignore
    func.__annotations__ = new_annotations_map

    @wraps(wrapped=func)
    async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]):
        user_settings: UserSettings = UserSettings.model_validate(
            kwargs.pop(
                "__authenticated_user_settings",
                UserService.read_default_user_settings(),
            )
        )
        execute = partial(command_runner.run, path, user_settings)
        output: OBBject = await execute(*args, **kwargs)

        output = validate_output(output)
        return output

    return wrapper


def add_command_map(command_runner: CommandRunner, api_router: APIRouter) -> None:
    """Add command map to the API router."""
    plugins_router = RouterLoader.from_extensions()

    for route in plugins_router.api_router.routes:
        route.endpoint = build_api_wrapper(command_runner=command_runner, route=route)  # type: ignore # noqa
    api_router.include_router(router=plugins_router.api_router)


system_settings = SystemService(logging_sub_app="api").system_settings
command_runner_instance = CommandRunner(system_settings=system_settings)
add_command_map(command_runner=command_runner_instance, api_router=router)