diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/app/model/credentials.py')
-rw-r--r-- | openbb_platform/core/openbb_core/app/model/credentials.py | 54 |
1 files changed, 28 insertions, 26 deletions
diff --git a/openbb_platform/core/openbb_core/app/model/credentials.py b/openbb_platform/core/openbb_core/app/model/credentials.py index 8979db5957a..f54198e9495 100644 --- a/openbb_platform/core/openbb_core/app/model/credentials.py +++ b/openbb_platform/core/openbb_core/app/model/credentials.py @@ -2,7 +2,7 @@ import traceback import warnings -from typing import Dict, Optional, Set, Tuple +from typing import Dict, List, Optional, Tuple from pydantic import ( BaseModel, @@ -36,37 +36,39 @@ OBBSecretStr = Annotated[ class CredentialsLoader: """Here we create the Credentials model.""" - credentials: Dict[str, Set[str]] = {} + credentials: Dict[str, List[str]] = {} - @staticmethod - def prepare( - credentials: Dict[str, Set[str]], - ) -> Dict[str, Tuple[object, None]]: + def format_credentials(self) -> Dict[str, Tuple[object, None]]: """Prepare credentials map to be used in the Credentials model.""" formatted: Dict[str, Tuple[object, None]] = {} - for origin, creds in credentials.items(): - for c in creds: - # Not sure we should do this, if you require the same credential it breaks - # if c in formatted: - # raise ValueError(f"Credential '{c}' already in use.") - formatted[c] = ( + for c_origin, c_list in self.credentials.items(): + for c_name in c_list: + if c_name in formatted: + warnings.warn( + message=f"Skipping '{c_name}', credential already in use.", + category=OpenBBWarning, + ) + continue + formatted[c_name] = ( Optional[OBBSecretStr], - Field( - default=None, description=origin, alias=c.upper() - ), # register the credential origin (obbject, providers) + Field(default=None, description=c_origin, alias=c_name.upper()), ) - return formatted + return dict(sorted(formatted.items())) def from_obbject(self) -> None: """Load credentials from OBBject extensions.""" - self.credentials["obbject"] = set() - for name, entry in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined] + for ext_name, ext in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined] try: - for c in entry.credentials: - self.credentials["obbject"].add(c) + if ext_name in self.credentials: + warnings.warn( + message=f"Skipping '{ext_name}', name already in user.", + category=OpenBBWarning, + ) + continue + self.credentials[ext_name] = ext.credentials except Exception as e: - msg = f"Error loading extension: {name}\n" + msg = f"Error loading extension: {ext_name}\n" if Env().DEBUG_MODE: traceback.print_exception(type(e), e, e.__traceback__) raise LoadingError(msg + f"\033[91m{e}\033[0m") from e @@ -77,20 +79,20 @@ class CredentialsLoader: def from_providers(self) -> None: """Load credentials from providers.""" - self.credentials["providers"] = set() - for c in ProviderInterface().credentials: - self.credentials["providers"].add(c) + self.credentials = ProviderInterface().credentials def load(self) -> BaseModel: """Load credentials from providers.""" # We load providers first to give them priority choosing credential names self.from_providers() self.from_obbject() - return create_model( # type: ignore + model = create_model( # type: ignore "Credentials", __config__=ConfigDict(validate_assignment=True, populate_by_name=True), - **self.prepare(self.credentials), + **self.format_credentials(), ) + model.origins = self.credentials + return model _Credentials = CredentialsLoader().load() |