diff options
author | Diogo Sousa <montezdesousa@gmail.com> | 2024-05-21 14:06:55 +0100 |
---|---|---|
committer | Diogo Sousa <montezdesousa@gmail.com> | 2024-05-21 14:06:55 +0100 |
commit | 52148a8a0ade800b6df7182b76491bed93667f4a (patch) | |
tree | 8f4bfb3b72787b45f9cfdc23e29670b844dff2b1 | |
parent | 1994b34866eb621d92d6a88f75bcf9845d3dbce4 (diff) |
feat: allow provider fallback based on credentials
7 files changed, 66 insertions, 121 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py index 89433e45156..e3dce19f465 100644 --- a/openbb_platform/core/openbb_core/app/command_runner.py +++ b/openbb_platform/core/openbb_core/app/command_runner.py @@ -21,7 +21,7 @@ from openbb_core.app.model.metadata import Metadata from openbb_core.app.model.obbject import OBBject from openbb_core.app.model.system_settings import SystemSettings from openbb_core.app.model.user_settings import UserSettings -from openbb_core.app.provider_interface import ExtraParams, ProviderInterface +from openbb_core.app.provider_interface import ExtraParams from openbb_core.app.router import CommandMap from openbb_core.app.service.system_service import SystemService from openbb_core.app.service.user_service import UserService @@ -118,68 +118,6 @@ class ParametersBuilder: return kwargs @staticmethod - def update_provider_choices( - func: Callable, - command_coverage: Dict[str, List[str]], - route: str, - kwargs: Dict[str, Any], - route_default: Optional[Dict[str, Optional[str]]], - ) -> Dict[str, Any]: - """Update the provider choices with the available providers and set default provider.""" - - def _needs_provider(func: Callable) -> bool: - """Check if the function needs a provider.""" - parameters = signature(func).parameters.keys() - return "provider_choices" in parameters - - def _has_provider(kwargs: Dict[str, Any]) -> bool: - """Check if the kwargs already have a provider.""" - provider_choices = kwargs.get("provider_choices") - - if isinstance(provider_choices, dict): # when in python - return provider_choices.get("provider", None) is not None - if isinstance(provider_choices, object): # when running as fastapi - return getattr(provider_choices, "provider", None) is not None - return False - - def _get_first_provider() -> Optional[str]: - """Get the first available provider.""" - available_providers = ProviderInterface().available_providers - return available_providers[0] if available_providers else None - - def _get_default_provider( - command_coverage: Dict[str, List[str]], - route_default: Optional[Dict[str, Optional[str]]], - ) -> Optional[str]: - """ - Get the default provider for the given route. - - Either pick it from the user defaults or from the command coverage. - """ - cmd_cov_given_route = command_coverage.get(route) - command_cov_provider = ( - cmd_cov_given_route[0] if cmd_cov_given_route else None - ) - - if route_default: - return route_default.get("provider", None) or command_cov_provider # type: ignore - - return command_cov_provider - - if not _has_provider(kwargs) and _needs_provider(func): - provider = ( - _get_default_provider( - command_coverage, - route_default, - ) - if route in command_coverage - else _get_first_provider() - ) - kwargs["provider_choices"] = {"provider": provider} - - return kwargs - - @staticmethod def _warn_kwargs( extra_params: Dict[str, Any], model: Type[BaseModel], @@ -246,14 +184,12 @@ class ParametersBuilder: args: Tuple[Any, ...], execution_context: ExecutionContext, func: Callable, - route: str, kwargs: Dict[str, Any], ) -> Dict[str, Any]: """Build the parameters for a function.""" func = cls.get_polished_func(func=func) system_settings = execution_context.system_settings user_settings = execution_context.user_settings - command_map = execution_context.command_map kwargs = cls.merge_args_and_kwargs( func=func, @@ -266,13 +202,6 @@ class ParametersBuilder: system_settings=system_settings, user_settings=user_settings, ) - kwargs = cls.update_provider_choices( - func=func, - command_coverage=command_map.command_coverage, - route=route, - kwargs=kwargs, - route_default=user_settings.defaults.routes.get(route, None), - ) kwargs = cls.validate_kwargs( func=func, kwargs=kwargs, @@ -364,7 +293,6 @@ class StaticCommandRunner: args=args, execution_context=execution_context, func=func, - route=route, kwargs=kwargs, ) diff --git a/openbb_platform/core/openbb_core/app/model/credentials.py b/openbb_platform/core/openbb_core/app/model/credentials.py index 8979db5957a..7e72cad1441 100644 --- a/openbb_platform/core/openbb_core/app/model/credentials.py +++ b/openbb_platform/core/openbb_core/app/model/credentials.py @@ -2,7 +2,7 @@ import traceback import warnings -from typing import Dict, Optional, Set, Tuple +from typing import Dict, List, Optional, Tuple from pydantic import ( BaseModel, @@ -36,37 +36,39 @@ OBBSecretStr = Annotated[ class CredentialsLoader: """Here we create the Credentials model.""" - credentials: Dict[str, Set[str]] = {} + credentials: Dict[str, List[str]] = {} - @staticmethod - def prepare( - credentials: Dict[str, Set[str]], - ) -> Dict[str, Tuple[object, None]]: + def format_credentials(self) -> Dict[str, Tuple[object, None]]: """Prepare credentials map to be used in the Credentials model.""" formatted: Dict[str, Tuple[object, None]] = {} - for origin, creds in credentials.items(): - for c in creds: - # Not sure we should do this, if you require the same credential it breaks - # if c in formatted: - # raise ValueError(f"Credential '{c}' already in use.") - formatted[c] = ( + for c_origin, c_list in self.credentials.items(): + for c_name in c_list: + if c_name in formatted: + warnings.warn( + message=f"Skipping '{c_name}', credential already in use.", + category=OpenBBWarning, + ) + continue + formatted[c_name] = ( Optional[OBBSecretStr], - Field( - default=None, description=origin, alias=c.upper() - ), # register the credential origin (obbject, providers) + Field(default=None, description=c_origin, alias=c_name.upper()), ) - return formatted + return dict(sorted(formatted.items())) def from_obbject(self) -> None: """Load credentials from OBBject extensions.""" - self.credentials["obbject"] = set() - for name, entry in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined] + for ext_name, ext in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined] try: - for c in entry.credentials: - self.credentials["obbject"].add(c) + if ext_name in self.credentials: + warnings.warn( + message=f"Skipping '{ext_name}', name already in user.", + category=OpenBBWarning, + ) + continue + self.credentials[ext_name] = ext.credentials except Exception as e: - msg = f"Error loading extension: {name}\n" + msg = f"Error loading extension: {ext_name}\n" if Env().DEBUG_MODE: traceback.print_exception(type(e), e, e.__traceback__) raise LoadingError(msg + f"\033[91m{e}\033[0m") from e @@ -77,20 +79,20 @@ class CredentialsLoader: def from_providers(self) -> None: """Load credentials from providers.""" - self.credentials["providers"] = set() - for c in ProviderInterface().credentials: - self.credentials["providers"].add(c) + self.credentials = ProviderInterface().credentials def load(self) -> BaseModel: """Load credentials from providers.""" # We load providers first to give them priority choosing credential names self.from_providers() self.from_obbject() - return create_model( # type: ignore + model = create_model( # type: ignore "Credentials", __config__=ConfigDict(validate_assignment=True, populate_by_name=True), - **self.prepare(self.credentials), + **self.format_credentials(), ) + model.providers = self.credentials + return model _Credentials = CredentialsLoader().load() diff --git a/openbb_platform/core/openbb_core/app/model/defaults.py b/openbb_platform/core/openbb_core/app/model/defaults.py index 52007edf5de..89cea5cbfb4 100644 --- a/openbb_platform/core/openbb_core/app/model/defaults.py +++ b/openbb_platform/core/openbb_core/app/model/defaults.py @@ -1,6 +1,6 @@ """Defaults model.""" -from typing import Dict, Optional +from typing import Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field @@ -10,7 +10,9 @@ class Defaults(BaseModel): model_config = ConfigDict(validate_assignment=True) - routes: Dict[str, Dict[str, Optional[str]]] = Field(default_factory=dict) + routes: Dict[str, Dict[str, Optional[Union[str, List[str]]]]] = Field( + default_factory=dict + ) def __repr__(self) -> str: """Return string representation.""" diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py index 11bb21d5f77..05c5db13ba6 100644 --- a/openbb_platform/core/openbb_core/app/provider_interface.py +++ b/openbb_platform/core/openbb_core/app/provider_interface.py @@ -126,8 +126,8 @@ class ProviderInterface(metaclass=SingletonMeta): return self._map @property - def credentials(self) -> List[str]: - """Dictionary of required credentials by provider.""" + def credentials(self) -> Dict[str, List[str]]: + """Map credentials to providers.""" return self._registry_map.credentials @property diff --git a/openbb_platform/core/openbb_core/app/static/container.py b/openbb_platform/core/openbb_core/app/static/container.py index 60fc1def9bb..ed673b3c5fc 100644 --- a/openbb_platform/core/openbb_core/app/static/container.py +++ b/openbb_platform/core/openbb_core/app/static/container.py @@ -3,7 +3,6 @@ from typing import Any, Optional, Tuple from openbb_core.app.command_runner import CommandRunner -from openbb_core.app.model.abstract.error import OpenBBError from openbb_core.app.model.obbject import OBBject @@ -24,18 +23,29 @@ class Container: return obbject return getattr(obbject, "to_" + output_type)() + def _check_credentials(self, provider: str) -> bool: + """Check required credentials are populated.""" + credentials = self._command_runner.user_settings.credentials + required = credentials.providers.get(provider, []) + current = credentials.model_dump(exclude_none=True) + return all(item in current for item in required) + def _get_provider( self, choice: Optional[str], cmd: str, available: Tuple[str, ...] ) -> str: """Get the provider to use in execution.""" if choice is None: - if config_default := self._command_runner.user_settings.defaults.routes.get( - cmd, {} - ).get("provider"): - if config_default in available: - return config_default - raise OpenBBError( - f"provider '{config_default}' is not available. Choose from: {', '.join(available)}." + routes = self._command_runner.user_settings.defaults.routes + if provider := (routes.get(cmd, {}).get("provider") or available): + provider_iterable = ( + [provider] if isinstance(provider, str) else provider ) + for p in provider_iterable: + if self._check_credentials(p): + return p + continue + # Warn that that no provider with keys was found + # We fallback to the first provider that does not need keys + # We fallback to the first provider return available[0] return choice diff --git a/openbb_platform/core/openbb_core/app/static/utils/decorators.py b/openbb_platform/core/openbb_core/app/static/utils/decorators.py index 8daefae0575..6155e932147 100644 --- a/openbb_platform/core/openbb_core/app/static/utils/decorators.py +++ b/openbb_platform/core/openbb_core/app/static/utils/decorators.py @@ -65,7 +65,12 @@ def exception_handler(func: Callable[P, R]) -> Callable[P, R]: [ str(i) for i in err.get("loc", ()) - if i not in ("standard_params", "extra_params") + if i + not in ( + "standard_params", + "extra_params", + "provider_choices", + ) ] ) _input = err.get("input", "") diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index a3a7fcc32d4..c3ccdb1b6f6 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -40,8 +40,8 @@ class RegistryMap: return self._available_providers @property - def credentials(self) -> List[str]: - """Get list of required credentials.""" + def credentials(self) -> Dict[str, List[str]]: + """Get map of providers to credentials.""" return self._credentials @property @@ -59,13 +59,11 @@ class RegistryMap: """Get available models.""" return self._models - def _get_credentials(self, registry: Registry) -> List[str]: - """Get list of required credentials.""" - cred_list = [] - for provider in registry.providers.values(): - for c in provider.credentials: - cred_list.append(c) - return cred_list + def _get_credentials(self, registry: Registry) -> Dict[str, List[str]]: + """Get map of providers to credentials.""" + return { + name: provider.credentials for name, provider in registry.providers.items() + } def _get_available_providers(self, registry: Registry) -> List[str]: """Get list of available providers.""" |