1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
import traceback
import warnings
from typing import Dict, Optional, Set, Tuple
from importlib_metadata import entry_points
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
create_model,
)
from pydantic.functional_serializers import PlainSerializer
from typing_extensions import Annotated
from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.extension import Extension
from openbb_core.app.provider_interface import ProviderInterface
from openbb_core.env import Env
class LoadingError(Exception):
"""Error loading extension."""
# @model_serializer blocks model_dump with pydantic parameters (include, exclude)
OBBSecretStr = Annotated[
SecretStr,
PlainSerializer(
lambda x: x.get_secret_value(), return_type=str, when_used="json-unless-none"
),
]
class CredentialsLoader:
"""Here we create the Credentials model"""
credentials: Dict[str, Set[str]] = {}
@staticmethod
def prepare(
credentials: Dict[str, Set[str]],
) -> Dict[str, Tuple[object, None]]:
"""Prepare credentials map to be used in the Credentials model"""
formatted: Dict[str, Tuple[object, None]] = {}
for origin, creds in credentials.items():
for c in creds:
# Not sure we should do this, if you require the same credential it breaks
# if c in formatted:
# raise ValueError(f"Credential '{c}' already in use.")
formatted[c] = (
Optional[OBBSecretStr],
Field(
default=None, description=origin
), # register the credential origin (obbject, providers)
)
return formatted
def from_obbject(self) -> None:
"""Load credentials from OBBject extensions"""
self.credentials["obbject"] = set()
for entry_point in sorted(entry_points(group="openbb_obbject_extension")):
try:
entry = entry_point.load()
if isinstance(entry, Extension):
for c in entry.credentials:
self.credentials["obbject"].add(c)
except Exception as e:
msg = f"Error loading extension: {entry_point.name}\n"
if Env().DEBUG_MODE:
traceback.print_exception(type(e), e, e.__traceback__)
raise LoadingError(msg + f"\033[91m{e}\033[0m") from e
warnings.warn(
message=msg,
category=OpenBBWarning,
)
def from_providers(self) -> None:
"""Load credentials from providers"""
self.credentials["providers"] = set()
for c in ProviderInterface().credentials:
self.credentials["providers"].add(c)
def load(self) -> BaseModel:
"""Load credentials from providers"""
# We load providers first to give them priority choosing credential names
self.from_providers()
self.from_obbject()
return create_model( # type: ignore
"Credentials",
__config__=ConfigDict(validate_assignment=True),
**self.prepare(self.credentials),
)
_Credentials = CredentialsLoader().load()
class Credentials(_Credentials): # type: ignore
"""Credentials model used to store provider credentials"""
def __repr__(self) -> str:
"""String representation of the credentials"""
return (
self.__class__.__name__
+ "\n\n"
+ "\n".join([f"{k}: {v}" for k, v in sorted(self.__dict__.items())])
)
def show(self):
"""Unmask credentials and print them"""
print( # noqa: T201
self.__class__.__name__
+ "\n\n"
+ "\n".join(
[f"{k}: {v}" for k, v in sorted(self.model_dump(mode="json").items())]
)
)
|