diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/app/query.py')
-rw-r--r-- | openbb_platform/core/openbb_core/app/query.py | 48 |
1 files changed, 44 insertions, 4 deletions
diff --git a/openbb_platform/core/openbb_core/app/query.py b/openbb_platform/core/openbb_core/app/query.py index 1a23ea572fb..4867e2005ed 100644 --- a/openbb_platform/core/openbb_core/app/query.py +++ b/openbb_platform/core/openbb_core/app/query.py @@ -1,8 +1,10 @@ """Query class.""" +import warnings from dataclasses import asdict -from typing import Any +from typing import Any, Dict +from openbb_core.app.model.abstract.warning import OpenBBWarning from openbb_core.app.model.command_context import CommandContext from openbb_core.app.provider_interface import ( ExtraParams, @@ -28,14 +30,52 @@ class Query: self.standard_params = standard_params self.extra_params = extra_params self.name = self.standard_params.__class__.__name__ - self.query_executor = ProviderInterface().create_executor() + self.provider_interface = ProviderInterface() + + def filter_extra_params( + self, + extra_params: ExtraParams, + provider_name: str, + ) -> Dict[str, Any]: + """Filter extra params based on the provider and warn if not supported.""" + original = asdict(extra_params) + filtered = {} + + query = extra_params.__class__.__name__ + fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore + + for k, v in original.items(): + f = fields[k] + providers = f.title.split(",") if hasattr(f, "title") else [] + + # We only filter/warn if the value is not the default, because fastapi + # Depends always sends the default value, even if it's not in the request. + if v != f.default: + if provider_name in providers: + filtered[k] = v + else: + available = ", ".join(providers) + warnings.warn( + message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.", + category=OpenBBWarning, + ) + + return filtered async def execute(self) -> Any: """Execute the query.""" - return await self.query_executor.execute( + standard_dict = asdict(self.standard_params) + extra_dict = ( + self.filter_extra_params(self.extra_params, self.provider) + if self.extra_params + else {} + ) + query_executor = self.provider_interface.create_executor() + + return await query_executor.execute( provider_name=self.provider, model_name=self.name, - params={**asdict(self.standard_params), **asdict(self.extra_params)}, + params={**standard_dict, **extra_dict}, credentials=self.cc.user_settings.credentials.model_dump(), preferences=self.cc.user_settings.preferences.model_dump(), ) |