summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/app/query.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/app/query.py')
-rw-r--r--openbb_platform/core/openbb_core/app/query.py48
1 files changed, 44 insertions, 4 deletions
diff --git a/openbb_platform/core/openbb_core/app/query.py b/openbb_platform/core/openbb_core/app/query.py
index 1a23ea572fb..4867e2005ed 100644
--- a/openbb_platform/core/openbb_core/app/query.py
+++ b/openbb_platform/core/openbb_core/app/query.py
@@ -1,8 +1,10 @@
"""Query class."""
+import warnings
from dataclasses import asdict
-from typing import Any
+from typing import Any, Dict
+from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.provider_interface import (
ExtraParams,
@@ -28,14 +30,52 @@ class Query:
self.standard_params = standard_params
self.extra_params = extra_params
self.name = self.standard_params.__class__.__name__
- self.query_executor = ProviderInterface().create_executor()
+ self.provider_interface = ProviderInterface()
+
+ def filter_extra_params(
+ self,
+ extra_params: ExtraParams,
+ provider_name: str,
+ ) -> Dict[str, Any]:
+ """Filter extra params based on the provider and warn if not supported."""
+ original = asdict(extra_params)
+ filtered = {}
+
+ query = extra_params.__class__.__name__
+ fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore
+
+ for k, v in original.items():
+ f = fields[k]
+ providers = f.title.split(",") if hasattr(f, "title") else []
+
+ # We only filter/warn if the value is not the default, because fastapi
+ # Depends always sends the default value, even if it's not in the request.
+ if v != f.default:
+ if provider_name in providers:
+ filtered[k] = v
+ else:
+ available = ", ".join(providers)
+ warnings.warn(
+ message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.",
+ category=OpenBBWarning,
+ )
+
+ return filtered
async def execute(self) -> Any:
"""Execute the query."""
- return await self.query_executor.execute(
+ standard_dict = asdict(self.standard_params)
+ extra_dict = (
+ self.filter_extra_params(self.extra_params, self.provider)
+ if self.extra_params
+ else {}
+ )
+ query_executor = self.provider_interface.create_executor()
+
+ return await query_executor.execute(
provider_name=self.provider,
model_name=self.name,
- params={**asdict(self.standard_params), **asdict(self.extra_params)},
+ params={**standard_dict, **extra_dict},
credentials=self.cc.user_settings.credentials.model_dump(),
preferences=self.cc.user_settings.preferences.model_dump(),
)