diff options
40 files changed, 950 insertions, 1086 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py index 89433e45156..e3dce19f465 100644 --- a/openbb_platform/core/openbb_core/app/command_runner.py +++ b/openbb_platform/core/openbb_core/app/command_runner.py @@ -21,7 +21,7 @@ from openbb_core.app.model.metadata import Metadata from openbb_core.app.model.obbject import OBBject from openbb_core.app.model.system_settings import SystemSettings from openbb_core.app.model.user_settings import UserSettings -from openbb_core.app.provider_interface import ExtraParams, ProviderInterface +from openbb_core.app.provider_interface import ExtraParams from openbb_core.app.router import CommandMap from openbb_core.app.service.system_service import SystemService from openbb_core.app.service.user_service import UserService @@ -118,68 +118,6 @@ class ParametersBuilder: return kwargs @staticmethod - def update_provider_choices( - func: Callable, - command_coverage: Dict[str, List[str]], - route: str, - kwargs: Dict[str, Any], - route_default: Optional[Dict[str, Optional[str]]], - ) -> Dict[str, Any]: - """Update the provider choices with the available providers and set default provider.""" - - def _needs_provider(func: Callable) -> bool: - """Check if the function needs a provider.""" - parameters = signature(func).parameters.keys() - return "provider_choices" in parameters - - def _has_provider(kwargs: Dict[str, Any]) -> bool: - """Check if the kwargs already have a provider.""" - provider_choices = kwargs.get("provider_choices") - - if isinstance(provider_choices, dict): # when in python - return provider_choices.get("provider", None) is not None - if isinstance(provider_choices, object): # when running as fastapi - return getattr(provider_choices, "provider", None) is not None - return False - - def _get_first_provider() -> Optional[str]: - """Get the first available provider.""" - available_providers = ProviderInterface().available_providers - return available_providers[0] if available_providers else None - - def _get_default_provider( - command_coverage: Dict[str, List[str]], - route_default: Optional[Dict[str, Optional[str]]], - ) -> Optional[str]: - """ - Get the default provider for the given route. - - Either pick it from the user defaults or from the command coverage. - """ - cmd_cov_given_route = command_coverage.get(route) - command_cov_provider = ( - cmd_cov_given_route[0] if cmd_cov_given_route else None - ) - - if route_default: - return route_default.get("provider", None) or command_cov_provider # type: ignore - - return command_cov_provider - - if not _has_provider(kwargs) and _needs_provider(func): - provider = ( - _get_default_provider( - command_coverage, - route_default, - ) - if route in command_coverage - else _get_first_provider() - ) - kwargs["provider_choices"] = {"provider": provider} - - return kwargs - - @staticmethod def _warn_kwargs( extra_params: Dict[str, Any], model: Type[BaseModel], @@ -246,14 +184,12 @@ class ParametersBuilder: args: Tuple[Any, ...], execution_context: ExecutionContext, func: Callable, - route: str, kwargs: Dict[str, Any], ) -> Dict[str, Any]: """Build the parameters for a function.""" func = cls.get_polished_func(func=func) system_settings = execution_context.system_settings user_settings = execution_context.user_settings - command_map = execution_context.command_map kwargs = cls.merge_args_and_kwargs( func=func, @@ -266,13 +202,6 @@ class ParametersBuilder: system_settings=system_settings, user_settings=user_settings, ) - kwargs = cls.update_provider_choices( - func=func, - command_coverage=command_map.command_coverage, - route=route, - kwargs=kwargs, - route_default=user_settings.defaults.routes.get(route, None), - ) kwargs = cls.validate_kwargs( func=func, kwargs=kwargs, @@ -364,7 +293,6 @@ class StaticCommandRunner: args=args, execution_context=execution_context, func=func, - route=route, kwargs=kwargs, ) diff --git a/openbb_platform/core/openbb_core/app/model/credentials.py b/openbb_platform/core/openbb_core/app/model/credentials.py index 8979db5957a..f54198e9495 100644 --- a/openbb_platform/core/openbb_core/app/model/credentials.py +++ b/openbb_platform/core/openbb_core/app/model/credentials.py @@ -2,7 +2,7 @@ import traceback import warnings -from typing import Dict, Optional, Set, Tuple +from typing import Dict, List, Optional, Tuple from pydantic import ( BaseModel, @@ -36,37 +36,39 @@ OBBSecretStr = Annotated[ class CredentialsLoader: """Here we create the Credentials model.""" - credentials: Dict[str, Set[str]] = {} + credentials: Dict[str, List[str]] = {} - @staticmethod - def prepare( - credentials: Dict[str, Set[str]], - ) -> Dict[str, Tuple[object, None]]: + def format_credentials(self) -> Dict[str, Tuple[object, None]]: """Prepare credentials map to be used in the Credentials model.""" formatted: Dict[str, Tuple[object, None]] = {} - for origin, creds in credentials.items(): - for c in creds: - # Not sure we should do this, if you require the same credential it breaks - # if c in formatted: - # raise ValueError(f"Credential '{c}' already in use.") - formatted[c] = ( + for c_origin, c_list in self.credentials.items(): + for c_name in c_list: + if c_name in formatted: + warnings.warn( + message=f"Skipping '{c_name}', credential already in use.", + category=OpenBBWarning, + ) + continue + formatted[c_name] = ( Optional[OBBSecretStr], - Field( - default=None, description=origin, alias=c.upper() - ), # register the credential origin (obbject, providers) + Field(default=None, description=c_origin, alias=c_name.upper()), ) - return formatted + return dict(sorted(formatted.items())) def from_obbject(self) -> None: """Load credentials from OBBject extensions.""" - self.credentials["obbject"] = set() - for name, entry in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined] + for ext_name, ext in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined] try: - for c in entry.credentials: - self.credentials["obbject"].add(c) + if ext_name in self.credentials: + warnings.warn( + message=f"Skipping '{ext_name}', name already in user.", + category=OpenBBWarning, + ) + continue + self.credentials[ext_name] = ext.credentials except Exception as e: - msg = f"Error loading extension: {name}\n" + msg = f"Error loading extension: {ext_name}\n" if Env().DEBUG_MODE: traceback.print_exception(type(e), e, e.__traceback__) raise LoadingError(msg + f"\033[91m{e}\033[0m") from e @@ -77,20 +79,20 @@ class CredentialsLoader: def from_providers(self) -> None: """Load credentials from providers.""" - self.credentials["providers"] = set() - for c in ProviderInterface().credentials: - self.credentials["providers"].add(c) + self.credentials = ProviderInterface().credentials def load(self) -> BaseModel: """Load credentials from providers.""" # We load providers first to give them priority choosing credential names self.from_providers() self.from_obbject() - return create_model( # type: ignore + model = create_model( # type: ignore "Credentials", __config__=ConfigDict(validate_assignment=True, populate_by_name=True), - **self.prepare(self.credentials), + **self.format_credentials(), ) + model.origins = self.credentials + return model _Credentials = CredentialsLoader().load() diff --git a/openbb_platform/core/openbb_core/app/model/defaults.py b/openbb_platform/core/openbb_core/app/model/defaults.py index 52007edf5de..1f4a176674f 100644 --- a/openbb_platform/core/openbb_core/app/model/defaults.py +++ b/openbb_platform/core/openbb_core/app/model/defaults.py @@ -1,19 +1,46 @@ """Defaults model.""" -from typing import Dict, Optional +from typing import Dict, List, Optional +from warnings import warn -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from openbb_core.app.model.abstract.warning import OpenBBWarning class Defaults(BaseModel): """Defaults.""" - model_config = ConfigDict(validate_assignment=True) + model_config = ConfigDict(validate_assignment=True, populate_by_name=True) - routes: Dict[str, Dict[str, Optional[str]]] = Field(default_factory=dict) + commands: Dict[str, Dict[str, Optional[List[str]]]] = Field( + default_factory=dict, + alias="routes", + ) def __repr__(self) -> str: """Return string representation.""" return f"{self.__class__.__name__}\n\n" + "\n".join( f"{k}: {v}" for k, v in self.model_dump().items() ) + + @model_validator(mode="before") + @classmethod + def validate_before(cls, values: dict) -> dict: + """Validate model (before).""" + key = "commands" + if "routes" in values: + warn( + message="'routes' is deprecated. Use 'commands' instead.", + category=OpenBBWarning, + ) + key = "routes" + + new_values: Dict[str, Dict[str, Optional[List[str]]]] = {"commands": {}} + for k, v in values.get(key, {}).items(): + clean_k = k.strip("/").replace("/", ".") + provider = v.get("provider") if v else None + if isinstance(provider, str): + v["provider"] = [provider] + new_values["commands"][clean_k] = v + return new_values diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py index 11bb21d5f77..014b7c255d7 100644 --- a/openbb_platform/core/openbb_core/app/provider_interface.py +++ b/openbb_platform/core/openbb_core/app/provider_interface.py @@ -126,8 +126,8 @@ class ProviderInterface(metaclass=SingletonMeta): return self._map @property - def credentials(self) -> List[str]: - """Dictionary of required credentials by provider.""" + def credentials(self) -> Dict[str, List[str]]: + """Map providers to credentials.""" return self._registry_map.credentials @property diff --git a/openbb_platform/core/openbb_core/app/static/container.py b/openbb_platform/core/openbb_core/app/static/container.py index 60fc1def9bb..bcabe6507a2 100644 --- a/openbb_platform/core/openbb_core/app/static/container.py +++ b/openbb_platform/core/openbb_core/app/static/container.py @@ -24,18 +24,59 @@ class Container: return obbject return getattr(obbject, "to_" + output_type)() + def _check_credentials(self, provider: str) -> Optional[bool]: + """Check required credentials are populated.""" + credentials = self._command_runner.user_settings.credentials + if provider not in credentials.origins: + return None + required = credentials.origins.get(provider) + return all(getattr(credentials, r, None) for r in required) + def _get_provider( - self, choice: Optional[str], cmd: str, available: Tuple[str, ...] + self, choice: Optional[str], command: str, default_priority: Tuple[str, ...] ) -> str: - """Get the provider to use in execution.""" + """Get the provider to use in execution. + + If no choice is specified, the configured priority list is used. A provider is used + when all of its required credentials are populated. + + Parameters + ---------- + choice: Optional[str] + The provider choice, for example 'fmp'. + command: str + The command to get the provider for, for example 'equity.price.historical' + default_priority: Tuple[str, ...] + A tuple of available providers for the given command to use as default priority list. + + Returns + ------- + str + The provider to use in the command. + + Raises + ------ + OpenBBError + Raises error when all the providers in the priority list failed. + """ if choice is None: - if config_default := self._command_runner.user_settings.defaults.routes.get( - cmd, {} - ).get("provider"): - if config_default in available: - return config_default - raise OpenBBError( - f"provider '{config_default}' is not available. Choose from: {', '.join(available)}." - ) - return available[0] + commands = self._command_runner.user_settings.defaults.commands + providers = ( + commands.get(command, {}).get("provider", []) or default_priority + ) + tries = [] + for p in providers: + result = self._check_credentials(p) + if result: + return p + elif result is False: + tries.append((p, "missing credentials")) + else: + tries.append((p, "not found")) + + msg = "\n ".join([f"* '{pair[0]}' -> {pair[1]}" for pair in tries]) + raise OpenBBError( + f"Provider fallback failed, please specify the provider or update credentials.\n" + f"[Providers]\n {msg}" + ) return choice diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index 758f2be6f9a..84901b0d4d6 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -599,8 +599,7 @@ class MethodDefinition: fields = param.annotation.__args__[0].__dataclass_fields__ field = fields["provider"] type_ = getattr(field, "type") - args = getattr(type_, "__args__") - first = args[0] if args else None + default_priority = getattr(type_, "__args__") formatted["provider"] = Parameter( name="provider", kind=Parameter.POSITIONAL_OR_KEYWORD, @@ -608,10 +607,9 @@ class MethodDefinition: Optional[MethodDefinition.get_type(field)], OpenBBField( description=( - "The provider to use for the query, by default None.\n" - f" If None, the provider specified in defaults is selected or '{first}' if there is\n" - " no default." - "" + "The provider to use, by default None. " + "If None, the priority list configured in the settings is used. " + f"Default priority: {', '.join(default_priority)}." ), ), ], @@ -828,10 +826,11 @@ class MethodDefinition: elif name == "provider_choices": field = param.annotation.__args__[0].__dataclass_fields__["provider"] available = field.type.__args__ + cmd = path.strip("/").replace("/", ".") code += " provider_choices={\n" code += ' "provider": self._get_provider(\n' code += " provider,\n" - code += f' "{path}",\n' + code += f' "{cmd}",\n' code += f" {available},\n" code += " )\n" code += " },\n" @@ -1397,7 +1396,7 @@ class ReferenceGenerator: ) @classmethod - def _get_provider_parameter_info(cls, model: str) -> Dict[str, str]: + def _get_provider_parameter_info(cls, model: str) -> Dict[str, Any]: """Get the name, type, description, default value and optionality information for the provider parameter. Parameters @@ -1407,7 +1406,7 @@ class ReferenceGenerator: Returns ------- - Dict[str, str] + Dict[str, Any] Dictionary of the provider parameter information |