summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormontezdesousa <79287829+montezdesousa@users.noreply.github.com>2024-03-27 15:34:53 +0000
committerGitHub <noreply@github.com>2024-03-27 15:34:53 +0000
commit4b5787b11db395375ddbee0acaccdce79df6e740 (patch)
tree41039e81529c7e96d5037117b88e149fe1c3970a
parent657fd1fbcca1ccf908e7c9cc31b34a985c2d4e0e (diff)
fix: move extra_params warning to query.py (#6259)
-rw-r--r--openbb_platform/core/openbb_core/app/command_runner.py21
-rw-r--r--openbb_platform/core/openbb_core/app/query.py48
-rw-r--r--openbb_platform/core/tests/app/test_command_runner.py45
-rw-r--r--openbb_platform/core/tests/app/test_query.py26
4 files changed, 77 insertions, 63 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py
index 9f8b52c8714..3a43aba3e8a 100644
--- a/openbb_platform/core/openbb_core/app/command_runner.py
+++ b/openbb_platform/core/openbb_core/app/command_runner.py
@@ -9,7 +9,6 @@ from time import perf_counter_ns
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from warnings import catch_warnings, showwarning, warn
-from fastapi.params import Query
from pydantic import BaseModel, ConfigDict, create_model
from openbb_core.app.logs.logging_service import LoggingService
@@ -179,7 +178,6 @@ class ParametersBuilder:
@staticmethod
def _warn_kwargs(
- provider_choices: Dict[str, Any],
extra_params: Dict[str, Any],
model: Type[BaseModel],
) -> None:
@@ -192,25 +190,9 @@ class ParametersBuilder:
if is_dataclass(annotation) and any(
t is ExtraParams for t in getattr(annotation, "__bases__", [])
):
- # We only warn when endpoint defines ExtraParams, so we need
- # to check if the annotation is a dataclass and child of ExtraParams
valid = asdict(annotation()) # type: ignore
- provider = provider_choices.get("provider", None)
for p in extra_params:
- if field := valid.get(p):
- if provider:
- providers = (
- field.title
- if isinstance(field, Query) and isinstance(field.title, str)
- else ""
- ).split(",")
- if provider not in providers:
- warn(
- message=f"Parameter '{p}' is not supported by '{provider}'."
- f" Available for: {', '.join(providers)}.",
- category=OpenBBWarning,
- )
- else:
+ if p not in valid:
warn(
message=f"Parameter '{p}' not found.",
category=OpenBBWarning,
@@ -246,7 +228,6 @@ class ParametersBuilder:
# Validate and coerce
model = ValidationModel(**kwargs)
ParametersBuilder._warn_kwargs(
- ParametersBuilder._as_dict(kwargs.get("provider_choices", {})),
ParametersBuilder._as_dict(kwargs.get("extra_params", {})),
ValidationModel,
)
diff --git a/openbb_platform/core/openbb_core/app/query.py b/openbb_platform/core/openbb_core/app/query.py
index 1a23ea572fb..4867e2005ed 100644
--- a/openbb_platform/core/openbb_core/app/query.py
+++ b/openbb_platform/core/openbb_core/app/query.py
@@ -1,8 +1,10 @@
"""Query class."""
+import warnings
from dataclasses import asdict
-from typing import Any
+from typing import Any, Dict
+from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.provider_interface import (
ExtraParams,
@@ -28,14 +30,52 @@ class Query:
self.standard_params = standard_params
self.extra_params = extra_params
self.name = self.standard_params.__class__.__name__
- self.query_executor = ProviderInterface().create_executor()
+ self.provider_interface = ProviderInterface()
+
+ def filter_extra_params(
+ self,
+ extra_params: ExtraParams,
+ provider_name: str,
+ ) -> Dict[str, Any]:
+ """Filter extra params based on the provider and warn if not supported."""
+ original = asdict(extra_params)
+ filtered = {}
+
+ query = extra_params.__class__.__name__
+ fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore
+
+ for k, v in original.items():
+ f = fields[k]
+ providers = f.title.split(",") if hasattr(f, "title") else []
+
+ # We only filter/warn if the value is not the default, because fastapi
+ # Depends always sends the default value, even if it's not in the request.
+ if v != f.default:
+ if provider_name in providers:
+ filtered[k] = v
+ else:
+ available = ", ".join(providers)
+ warnings.warn(
+ message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.",
+ category=OpenBBWarning,
+ )
+
+ return filtered
async def execute(self) -> Any:
"""Execute the query."""
- return await self.query_executor.execute(
+ standard_dict = asdict(self.standard_params)
+ extra_dict = (
+ self.filter_extra_params(self.extra_params, self.provider)
+ if self.extra_params
+ else {}
+ )
+ query_executor = self.provider_interface.create_executor()
+
+ return await query_executor.execute(
provider_name=self.provider,
model_name=self.name,
- params={**asdict(self.standard_params), **asdict(self.extra_params)},
+ params={**standard_dict, **extra_dict},
credentials=self.cc.user_settings.credentials.model_dump(),
preferences=self.cc.user_settings.preferences.model_dump(),
)
diff --git a/openbb_platform/core/tests/app/test_command_runner.py b/openbb_platform/core/tests/app/test_command_runner.py
index cd751e9c592..4b1b35da92d 100644
--- a/openbb_platform/core/tests/app/test_command_runner.py
+++ b/openbb_platform/core/tests/app/test_command_runner.py
@@ -229,61 +229,28 @@ def test_parameters_builder_validate_kwargs(mock_func):
@pytest.mark.parametrize(
- "provider_choices, extra_params, base, expect",
+ "extra_params, base, expect",
[
(
- {"provider": "provider1"},
- {"exists_in_2": ...},
+ {"exists": ...},
ExtraParams,
- OpenBBWarning,
- ),
- (
- {"provider": "inexistent_provider"},
- {"exists_in_both": ...},
- ExtraParams,
- OpenBBWarning,
+ None,
),
(
- {},
{"inexistent_field": ...},
ExtraParams,
OpenBBWarning,
),
- (
- {},
- {"inexistent_field": ...},
- object,
- None,
- ),
- (
- {"provider": "provider2"},
- {"exists_in_2": ...},
- ExtraParams,
- None,
- ),
- (
- {"provider": "provider2"},
- {"exists_in_both": ...},
- ExtraParams,
- None,
- ),
- (
- {},
- {"exists_in_both": ...},
- ExtraParams,
- None,
- ),
],
)
-def test_parameters_builder__warn_kwargs(provider_choices, extra_params, base, expect):
+def test_parameters_builder__warn_kwargs(extra_params, base, expect):
"""Test _warn_kwargs."""
@dataclass
class SomeModel(base):
"""SomeModel"""
- exists_in_2: QueryParam = Query(..., title="provider2")
- exists_in_both: QueryParam = Query(..., title="provider1,provider2")
+ exists: QueryParam = Query(...)
class Model(BaseModel):
"""Model"""
@@ -293,7 +260,7 @@ def test_parameters_builder__warn_kwargs(provider_choices, extra_params, base, e
with pytest.warns(expect) as warning_info:
# pylint: disable=protected-access
- ParametersBuilder._warn_kwargs(provider_choices, extra_params, Model)
+ ParametersBuilder._warn_kwargs(extra_params, Model)
if not expect:
assert len(warning_info) == 0
diff --git a/openbb_platform/core/tests/app/test_query.py b/openbb_platform/core/tests/app/test_query.py
index 95155415db8..082ab43db4e 100644
--- a/openbb_platform/core/tests/app/test_query.py
+++ b/openbb_platform/core/tests/app/test_query.py
@@ -95,6 +95,32 @@ def query_instance():
)
+def test_filter_extra_params(query):
+ """Test filter_extra_params."""
+ extra_params = create_mock_extra_params()
+ extra_params = query.filter_extra_params(extra_params, "fmp")
+
+ assert isinstance(extra_params, dict)
+ assert len(extra_params) == 0
+
+
+def test_filter_extra_params_wrong_param(query):
+ """Test filter_extra_params."""
+
+ @dataclass
+ class EquityHistorical:
+ """Mock ExtraParams dataclass."""
+
+ sort: str = "desc"
+ limit: int = 4
+
+ extra_params = EquityHistorical()
+
+ extra = query.filter_extra_params(extra_params, "fmp")
+ assert isinstance(extra, dict)
+ assert len(extra) == 0
+
+
@pytest.mark.asyncio
async def test_execute_method_fake_credentials(query_instance: Query, mock_registry):
"""Test execute method without setting credentials."""