summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDiogo Sousa <montezdesousa@gmail.com>2024-05-21 14:06:55 +0100
committerDiogo Sousa <montezdesousa@gmail.com>2024-05-21 14:06:55 +0100
commit52148a8a0ade800b6df7182b76491bed93667f4a (patch)
tree8f4bfb3b72787b45f9cfdc23e29670b844dff2b1
parent1994b34866eb621d92d6a88f75bcf9845d3dbce4 (diff)
feat: allow provider fallback based on credentials
-rw-r--r--openbb_platform/core/openbb_core/app/command_runner.py74
-rw-r--r--openbb_platform/core/openbb_core/app/model/credentials.py54
-rw-r--r--openbb_platform/core/openbb_core/app/model/defaults.py6
-rw-r--r--openbb_platform/core/openbb_core/app/provider_interface.py4
-rw-r--r--openbb_platform/core/openbb_core/app/static/container.py26
-rw-r--r--openbb_platform/core/openbb_core/app/static/utils/decorators.py7
-rw-r--r--openbb_platform/core/openbb_core/provider/registry_map.py16
7 files changed, 66 insertions, 121 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py
index 89433e45156..e3dce19f465 100644
--- a/openbb_platform/core/openbb_core/app/command_runner.py
+++ b/openbb_platform/core/openbb_core/app/command_runner.py
@@ -21,7 +21,7 @@ from openbb_core.app.model.metadata import Metadata
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
-from openbb_core.app.provider_interface import ExtraParams, ProviderInterface
+from openbb_core.app.provider_interface import ExtraParams
from openbb_core.app.router import CommandMap
from openbb_core.app.service.system_service import SystemService
from openbb_core.app.service.user_service import UserService
@@ -118,68 +118,6 @@ class ParametersBuilder:
return kwargs
@staticmethod
- def update_provider_choices(
- func: Callable,
- command_coverage: Dict[str, List[str]],
- route: str,
- kwargs: Dict[str, Any],
- route_default: Optional[Dict[str, Optional[str]]],
- ) -> Dict[str, Any]:
- """Update the provider choices with the available providers and set default provider."""
-
- def _needs_provider(func: Callable) -> bool:
- """Check if the function needs a provider."""
- parameters = signature(func).parameters.keys()
- return "provider_choices" in parameters
-
- def _has_provider(kwargs: Dict[str, Any]) -> bool:
- """Check if the kwargs already have a provider."""
- provider_choices = kwargs.get("provider_choices")
-
- if isinstance(provider_choices, dict): # when in python
- return provider_choices.get("provider", None) is not None
- if isinstance(provider_choices, object): # when running as fastapi
- return getattr(provider_choices, "provider", None) is not None
- return False
-
- def _get_first_provider() -> Optional[str]:
- """Get the first available provider."""
- available_providers = ProviderInterface().available_providers
- return available_providers[0] if available_providers else None
-
- def _get_default_provider(
- command_coverage: Dict[str, List[str]],
- route_default: Optional[Dict[str, Optional[str]]],
- ) -> Optional[str]:
- """
- Get the default provider for the given route.
-
- Either pick it from the user defaults or from the command coverage.
- """
- cmd_cov_given_route = command_coverage.get(route)
- command_cov_provider = (
- cmd_cov_given_route[0] if cmd_cov_given_route else None
- )
-
- if route_default:
- return route_default.get("provider", None) or command_cov_provider # type: ignore
-
- return command_cov_provider
-
- if not _has_provider(kwargs) and _needs_provider(func):
- provider = (
- _get_default_provider(
- command_coverage,
- route_default,
- )
- if route in command_coverage
- else _get_first_provider()
- )
- kwargs["provider_choices"] = {"provider": provider}
-
- return kwargs
-
- @staticmethod
def _warn_kwargs(
extra_params: Dict[str, Any],
model: Type[BaseModel],
@@ -246,14 +184,12 @@ class ParametersBuilder:
args: Tuple[Any, ...],
execution_context: ExecutionContext,
func: Callable,
- route: str,
kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Build the parameters for a function."""
func = cls.get_polished_func(func=func)
system_settings = execution_context.system_settings
user_settings = execution_context.user_settings
- command_map = execution_context.command_map
kwargs = cls.merge_args_and_kwargs(
func=func,
@@ -266,13 +202,6 @@ class ParametersBuilder:
system_settings=system_settings,
user_settings=user_settings,
)
- kwargs = cls.update_provider_choices(
- func=func,
- command_coverage=command_map.command_coverage,
- route=route,
- kwargs=kwargs,
- route_default=user_settings.defaults.routes.get(route, None),
- )
kwargs = cls.validate_kwargs(
func=func,
kwargs=kwargs,
@@ -364,7 +293,6 @@ class StaticCommandRunner:
args=args,
execution_context=execution_context,
func=func,
- route=route,
kwargs=kwargs,
)
diff --git a/openbb_platform/core/openbb_core/app/model/credentials.py b/openbb_platform/core/openbb_core/app/model/credentials.py
index 8979db5957a..7e72cad1441 100644
--- a/openbb_platform/core/openbb_core/app/model/credentials.py
+++ b/openbb_platform/core/openbb_core/app/model/credentials.py
@@ -2,7 +2,7 @@
import traceback
import warnings
-from typing import Dict, Optional, Set, Tuple
+from typing import Dict, List, Optional, Tuple
from pydantic import (
BaseModel,
@@ -36,37 +36,39 @@ OBBSecretStr = Annotated[
class CredentialsLoader:
"""Here we create the Credentials model."""
- credentials: Dict[str, Set[str]] = {}
+ credentials: Dict[str, List[str]] = {}
- @staticmethod
- def prepare(
- credentials: Dict[str, Set[str]],
- ) -> Dict[str, Tuple[object, None]]:
+ def format_credentials(self) -> Dict[str, Tuple[object, None]]:
"""Prepare credentials map to be used in the Credentials model."""
formatted: Dict[str, Tuple[object, None]] = {}
- for origin, creds in credentials.items():
- for c in creds:
- # Not sure we should do this, if you require the same credential it breaks
- # if c in formatted:
- # raise ValueError(f"Credential '{c}' already in use.")
- formatted[c] = (
+ for c_origin, c_list in self.credentials.items():
+ for c_name in c_list:
+ if c_name in formatted:
+ warnings.warn(
+ message=f"Skipping '{c_name}', credential already in use.",
+ category=OpenBBWarning,
+ )
+ continue
+ formatted[c_name] = (
Optional[OBBSecretStr],
- Field(
- default=None, description=origin, alias=c.upper()
- ), # register the credential origin (obbject, providers)
+ Field(default=None, description=c_origin, alias=c_name.upper()),
)
- return formatted
+ return dict(sorted(formatted.items()))
def from_obbject(self) -> None:
"""Load credentials from OBBject extensions."""
- self.credentials["obbject"] = set()
- for name, entry in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined]
+ for ext_name, ext in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined]
try:
- for c in entry.credentials:
- self.credentials["obbject"].add(c)
+ if ext_name in self.credentials:
+ warnings.warn(
+ message=f"Skipping '{ext_name}', name already in user.",
+ category=OpenBBWarning,
+ )
+ continue
+ self.credentials[ext_name] = ext.credentials
except Exception as e:
- msg = f"Error loading extension: {name}\n"
+ msg = f"Error loading extension: {ext_name}\n"
if Env().DEBUG_MODE:
traceback.print_exception(type(e), e, e.__traceback__)
raise LoadingError(msg + f"\033[91m{e}\033[0m") from e
@@ -77,20 +79,20 @@ class CredentialsLoader:
def from_providers(self) -> None:
"""Load credentials from providers."""
- self.credentials["providers"] = set()
- for c in ProviderInterface().credentials:
- self.credentials["providers"].add(c)
+ self.credentials = ProviderInterface().credentials
def load(self) -> BaseModel:
"""Load credentials from providers."""
# We load providers first to give them priority choosing credential names
self.from_providers()
self.from_obbject()
- return create_model( # type: ignore
+ model = create_model( # type: ignore
"Credentials",
__config__=ConfigDict(validate_assignment=True, populate_by_name=True),
- **self.prepare(self.credentials),
+ **self.format_credentials(),
)
+ model.providers = self.credentials
+ return model
_Credentials = CredentialsLoader().load()
diff --git a/openbb_platform/core/openbb_core/app/model/defaults.py b/openbb_platform/core/openbb_core/app/model/defaults.py
index 52007edf5de..89cea5cbfb4 100644
--- a/openbb_platform/core/openbb_core/app/model/defaults.py
+++ b/openbb_platform/core/openbb_core/app/model/defaults.py
@@ -1,6 +1,6 @@
"""Defaults model."""
-from typing import Dict, Optional
+from typing import Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
@@ -10,7 +10,9 @@ class Defaults(BaseModel):
model_config = ConfigDict(validate_assignment=True)
- routes: Dict[str, Dict[str, Optional[str]]] = Field(default_factory=dict)
+ routes: Dict[str, Dict[str, Optional[Union[str, List[str]]]]] = Field(
+ default_factory=dict
+ )
def __repr__(self) -> str:
"""Return string representation."""
diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py
index 11bb21d5f77..05c5db13ba6 100644
--- a/openbb_platform/core/openbb_core/app/provider_interface.py
+++ b/openbb_platform/core/openbb_core/app/provider_interface.py
@@ -126,8 +126,8 @@ class ProviderInterface(metaclass=SingletonMeta):
return self._map
@property
- def credentials(self) -> List[str]:
- """Dictionary of required credentials by provider."""
+ def credentials(self) -> Dict[str, List[str]]:
+ """Map credentials to providers."""
return self._registry_map.credentials
@property
diff --git a/openbb_platform/core/openbb_core/app/static/container.py b/openbb_platform/core/openbb_core/app/static/container.py
index 60fc1def9bb..ed673b3c5fc 100644
--- a/openbb_platform/core/openbb_core/app/static/container.py
+++ b/openbb_platform/core/openbb_core/app/static/container.py
@@ -3,7 +3,6 @@
from typing import Any, Optional, Tuple
from openbb_core.app.command_runner import CommandRunner
-from openbb_core.app.model.abstract.error import OpenBBError
from openbb_core.app.model.obbject import OBBject
@@ -24,18 +23,29 @@ class Container:
return obbject
return getattr(obbject, "to_" + output_type)()
+ def _check_credentials(self, provider: str) -> bool:
+ """Check required credentials are populated."""
+ credentials = self._command_runner.user_settings.credentials
+ required = credentials.providers.get(provider, [])
+ current = credentials.model_dump(exclude_none=True)
+ return all(item in current for item in required)
+
def _get_provider(
self, choice: Optional[str], cmd: str, available: Tuple[str, ...]
) -> str:
"""Get the provider to use in execution."""
if choice is None:
- if config_default := self._command_runner.user_settings.defaults.routes.get(
- cmd, {}
- ).get("provider"):
- if config_default in available:
- return config_default
- raise OpenBBError(
- f"provider '{config_default}' is not available. Choose from: {', '.join(available)}."
+ routes = self._command_runner.user_settings.defaults.routes
+ if provider := (routes.get(cmd, {}).get("provider") or available):
+ provider_iterable = (
+ [provider] if isinstance(provider, str) else provider
)
+ for p in provider_iterable:
+ if self._check_credentials(p):
+ return p
+ continue
+ # Warn that that no provider with keys was found
+ # We fallback to the first provider that does not need keys
+ # We fallback to the first provider
return available[0]
return choice
diff --git a/openbb_platform/core/openbb_core/app/static/utils/decorators.py b/openbb_platform/core/openbb_core/app/static/utils/decorators.py
index 8daefae0575..6155e932147 100644
--- a/openbb_platform/core/openbb_core/app/static/utils/decorators.py
+++ b/openbb_platform/core/openbb_core/app/static/utils/decorators.py
@@ -65,7 +65,12 @@ def exception_handler(func: Callable[P, R]) -> Callable[P, R]:
[
str(i)
for i in err.get("loc", ())
- if i not in ("standard_params", "extra_params")
+ if i
+ not in (
+ "standard_params",
+ "extra_params",
+ "provider_choices",
+ )
]
)
_input = err.get("input", "")
diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py
index a3a7fcc32d4..c3ccdb1b6f6 100644
--- a/openbb_platform/core/openbb_core/provider/registry_map.py
+++ b/openbb_platform/core/openbb_core/provider/registry_map.py
@@ -40,8 +40,8 @@ class RegistryMap:
return self._available_providers
@property
- def credentials(self) -> List[str]:
- """Get list of required credentials."""
+ def credentials(self) -> Dict[str, List[str]]:
+ """Get map of providers to credentials."""
return self._credentials
@property
@@ -59,13 +59,11 @@ class RegistryMap:
"""Get available models."""
return self._models
- def _get_credentials(self, registry: Registry) -> List[str]:
- """Get list of required credentials."""
- cred_list = []
- for provider in registry.providers.values():
- for c in provider.credentials:
- cred_list.append(c)
- return cred_list
+ def _get_credentials(self, registry: Registry) -> Dict[str, List[str]]:
+ """Get map of providers to credentials."""
+ return {
+ name: provider.credentials for name, provider in registry.providers.items()
+ }
def _get_available_providers(self, registry: Registry) -> List[str]:
"""Get list of available providers."""