summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/app/model/credentials.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/app/model/credentials.py')
-rw-r--r--openbb_platform/core/openbb_core/app/model/credentials.py54
1 files changed, 28 insertions, 26 deletions
diff --git a/openbb_platform/core/openbb_core/app/model/credentials.py b/openbb_platform/core/openbb_core/app/model/credentials.py
index 8979db5957a..f54198e9495 100644
--- a/openbb_platform/core/openbb_core/app/model/credentials.py
+++ b/openbb_platform/core/openbb_core/app/model/credentials.py
@@ -2,7 +2,7 @@
import traceback
import warnings
-from typing import Dict, Optional, Set, Tuple
+from typing import Dict, List, Optional, Tuple
from pydantic import (
BaseModel,
@@ -36,37 +36,39 @@ OBBSecretStr = Annotated[
class CredentialsLoader:
"""Here we create the Credentials model."""
- credentials: Dict[str, Set[str]] = {}
+ credentials: Dict[str, List[str]] = {}
- @staticmethod
- def prepare(
- credentials: Dict[str, Set[str]],
- ) -> Dict[str, Tuple[object, None]]:
+ def format_credentials(self) -> Dict[str, Tuple[object, None]]:
"""Prepare credentials map to be used in the Credentials model."""
formatted: Dict[str, Tuple[object, None]] = {}
- for origin, creds in credentials.items():
- for c in creds:
- # Not sure we should do this, if you require the same credential it breaks
- # if c in formatted:
- # raise ValueError(f"Credential '{c}' already in use.")
- formatted[c] = (
+ for c_origin, c_list in self.credentials.items():
+ for c_name in c_list:
+ if c_name in formatted:
+ warnings.warn(
+ message=f"Skipping '{c_name}', credential already in use.",
+ category=OpenBBWarning,
+ )
+ continue
+ formatted[c_name] = (
Optional[OBBSecretStr],
- Field(
- default=None, description=origin, alias=c.upper()
- ), # register the credential origin (obbject, providers)
+ Field(default=None, description=c_origin, alias=c_name.upper()),
)
- return formatted
+ return dict(sorted(formatted.items()))
def from_obbject(self) -> None:
"""Load credentials from OBBject extensions."""
- self.credentials["obbject"] = set()
- for name, entry in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined]
+ for ext_name, ext in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined]
try:
- for c in entry.credentials:
- self.credentials["obbject"].add(c)
+ if ext_name in self.credentials:
+ warnings.warn(
+ message=f"Skipping '{ext_name}', name already in user.",
+ category=OpenBBWarning,
+ )
+ continue
+ self.credentials[ext_name] = ext.credentials
except Exception as e:
- msg = f"Error loading extension: {name}\n"
+ msg = f"Error loading extension: {ext_name}\n"
if Env().DEBUG_MODE:
traceback.print_exception(type(e), e, e.__traceback__)
raise LoadingError(msg + f"\033[91m{e}\033[0m") from e
@@ -77,20 +79,20 @@ class CredentialsLoader:
def from_providers(self) -> None:
"""Load credentials from providers."""
- self.credentials["providers"] = set()
- for c in ProviderInterface().credentials:
- self.credentials["providers"].add(c)
+ self.credentials = ProviderInterface().credentials
def load(self) -> BaseModel:
"""Load credentials from providers."""
# We load providers first to give them priority choosing credential names
self.from_providers()
self.from_obbject()
- return create_model( # type: ignore
+ model = create_model( # type: ignore
"Credentials",
__config__=ConfigDict(validate_assignment=True, populate_by_name=True),
- **self.prepare(self.credentials),
+ **self.format_credentials(),
)
+ model.origins = self.credentials
+ return model
_Credentials = CredentialsLoader().load()