summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/app/model/credentials.py
blob: 9e1de94b8600a4ec7f0e07e8ee51d656b99988c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import traceback
import warnings
from typing import Dict, Optional, Set, Tuple

from importlib_metadata import entry_points
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    SecretStr,
    create_model,
)
from pydantic.functional_serializers import PlainSerializer
from typing_extensions import Annotated

from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.extension import Extension
from openbb_core.app.provider_interface import ProviderInterface
from openbb_core.env import Env


class LoadingError(Exception):
    """Error loading extension."""


# @model_serializer blocks model_dump with pydantic parameters (include, exclude)
OBBSecretStr = Annotated[
    SecretStr,
    PlainSerializer(
        lambda x: x.get_secret_value(), return_type=str, when_used="json-unless-none"
    ),
]


class CredentialsLoader:
    """Here we create the Credentials model"""

    credentials: Dict[str, Set[str]] = {}

    @staticmethod
    def prepare(
        credentials: Dict[str, Set[str]],
    ) -> Dict[str, Tuple[object, None]]:
        """Prepare credentials map to be used in the Credentials model"""
        formatted: Dict[str, Tuple[object, None]] = {}
        for origin, creds in credentials.items():
            for c in creds:
                # Not sure we should do this, if you require the same credential it breaks
                # if c in formatted:
                #     raise ValueError(f"Credential '{c}' already in use.")
                formatted[c] = (
                    Optional[OBBSecretStr],
                    Field(
                        default=None, description=origin
                    ),  # register the credential origin (obbject, providers)
                )

        return formatted

    def from_obbject(self) -> None:
        """Load credentials from OBBject extensions"""
        self.credentials["obbject"] = set()
        for entry_point in sorted(entry_points(group="openbb_obbject_extension")):
            try:
                entry = entry_point.load()
                if isinstance(entry, Extension):
                    for c in entry.credentials:
                        self.credentials["obbject"].add(c)
            except Exception as e:
                msg = f"Error loading extension: {entry_point.name}\n"
                if Env().DEBUG_MODE:
                    traceback.print_exception(type(e), e, e.__traceback__)
                    raise LoadingError(msg + f"\033[91m{e}\033[0m") from e
                warnings.warn(
                    message=msg,
                    category=OpenBBWarning,
                )

    def from_providers(self) -> None:
        """Load credentials from providers"""
        self.credentials["providers"] = set()
        for c in ProviderInterface().credentials:
            self.credentials["providers"].add(c)

    def load(self) -> BaseModel:
        """Load credentials from providers"""
        # We load providers first to give them priority choosing credential names
        self.from_providers()
        self.from_obbject()
        return create_model(  # type: ignore
            "Credentials",
            __config__=ConfigDict(validate_assignment=True),
            **self.prepare(self.credentials),
        )


_Credentials = CredentialsLoader().load()


class Credentials(_Credentials):  # type: ignore
    """Credentials model used to store provider credentials"""

    def __repr__(self) -> str:
        """String representation of the credentials"""
        return (
            self.__class__.__name__
            + "\n\n"
            + "\n".join([f"{k}: {v}" for k, v in sorted(self.__dict__.items())])
        )

    def show(self):
        """Unmask credentials and print them"""
        print(  # noqa: T201
            self.__class__.__name__
            + "\n\n"
            + "\n".join(
                [f"{k}: {v}" for k, v in sorted(self.model_dump(mode="json").items())]
            )
        )