diff options
author | montezdesousa <79287829+montezdesousa@users.noreply.github.com> | 2024-03-27 15:34:53 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-27 15:34:53 +0000 |
commit | 4b5787b11db395375ddbee0acaccdce79df6e740 (patch) | |
tree | 41039e81529c7e96d5037117b88e149fe1c3970a | |
parent | 657fd1fbcca1ccf908e7c9cc31b34a985c2d4e0e (diff) |
fix: move extra_params warning to query.py (#6259)
-rw-r--r-- | openbb_platform/core/openbb_core/app/command_runner.py | 21 | ||||
-rw-r--r-- | openbb_platform/core/openbb_core/app/query.py | 48 | ||||
-rw-r--r-- | openbb_platform/core/tests/app/test_command_runner.py | 45 | ||||
-rw-r--r-- | openbb_platform/core/tests/app/test_query.py | 26 |
4 files changed, 77 insertions, 63 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py index 9f8b52c8714..3a43aba3e8a 100644 --- a/openbb_platform/core/openbb_core/app/command_runner.py +++ b/openbb_platform/core/openbb_core/app/command_runner.py @@ -9,7 +9,6 @@ from time import perf_counter_ns from typing import Any, Callable, Dict, List, Optional, Tuple, Type from warnings import catch_warnings, showwarning, warn -from fastapi.params import Query from pydantic import BaseModel, ConfigDict, create_model from openbb_core.app.logs.logging_service import LoggingService @@ -179,7 +178,6 @@ class ParametersBuilder: @staticmethod def _warn_kwargs( - provider_choices: Dict[str, Any], extra_params: Dict[str, Any], model: Type[BaseModel], ) -> None: @@ -192,25 +190,9 @@ class ParametersBuilder: if is_dataclass(annotation) and any( t is ExtraParams for t in getattr(annotation, "__bases__", []) ): - # We only warn when endpoint defines ExtraParams, so we need - # to check if the annotation is a dataclass and child of ExtraParams valid = asdict(annotation()) # type: ignore - provider = provider_choices.get("provider", None) for p in extra_params: - if field := valid.get(p): - if provider: - providers = ( - field.title - if isinstance(field, Query) and isinstance(field.title, str) - else "" - ).split(",") - if provider not in providers: - warn( - message=f"Parameter '{p}' is not supported by '{provider}'." - f" Available for: {', '.join(providers)}.", - category=OpenBBWarning, - ) - else: + if p not in valid: warn( message=f"Parameter '{p}' not found.", category=OpenBBWarning, @@ -246,7 +228,6 @@ class ParametersBuilder: # Validate and coerce model = ValidationModel(**kwargs) ParametersBuilder._warn_kwargs( - ParametersBuilder._as_dict(kwargs.get("provider_choices", {})), ParametersBuilder._as_dict(kwargs.get("extra_params", {})), ValidationModel, ) diff --git a/openbb_platform/core/openbb_core/app/query.py b/openbb_platform/core/openbb_core/app/query.py index 1a23ea572fb..4867e2005ed 100644 --- a/openbb_platform/core/openbb_core/app/query.py +++ b/openbb_platform/core/openbb_core/app/query.py @@ -1,8 +1,10 @@ """Query class.""" +import warnings from dataclasses import asdict -from typing import Any +from typing import Any, Dict +from openbb_core.app.model.abstract.warning import OpenBBWarning from openbb_core.app.model.command_context import CommandContext from openbb_core.app.provider_interface import ( ExtraParams, @@ -28,14 +30,52 @@ class Query: self.standard_params = standard_params self.extra_params = extra_params self.name = self.standard_params.__class__.__name__ - self.query_executor = ProviderInterface().create_executor() + self.provider_interface = ProviderInterface() + + def filter_extra_params( + self, + extra_params: ExtraParams, + provider_name: str, + ) -> Dict[str, Any]: + """Filter extra params based on the provider and warn if not supported.""" + original = asdict(extra_params) + filtered = {} + + query = extra_params.__class__.__name__ + fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore + + for k, v in original.items(): + f = fields[k] + providers = f.title.split(",") if hasattr(f, "title") else [] + + # We only filter/warn if the value is not the default, because fastapi + # Depends always sends the default value, even if it's not in the request. + if v != f.default: + if provider_name in providers: + filtered[k] = v + else: + available = ", ".join(providers) + warnings.warn( + message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.", + category=OpenBBWarning, + ) + + return filtered async def execute(self) -> Any: """Execute the query.""" - return await self.query_executor.execute( + standard_dict = asdict(self.standard_params) + extra_dict = ( + self.filter_extra_params(self.extra_params, self.provider) + if self.extra_params + else {} + ) + query_executor = self.provider_interface.create_executor() + + return await query_executor.execute( provider_name=self.provider, model_name=self.name, - params={**asdict(self.standard_params), **asdict(self.extra_params)}, + params={**standard_dict, **extra_dict}, credentials=self.cc.user_settings.credentials.model_dump(), preferences=self.cc.user_settings.preferences.model_dump(), ) diff --git a/openbb_platform/core/tests/app/test_command_runner.py b/openbb_platform/core/tests/app/test_command_runner.py index cd751e9c592..4b1b35da92d 100644 --- a/openbb_platform/core/tests/app/test_command_runner.py +++ b/openbb_platform/core/tests/app/test_command_runner.py @@ -229,61 +229,28 @@ def test_parameters_builder_validate_kwargs(mock_func): @pytest.mark.parametrize( - "provider_choices, extra_params, base, expect", + "extra_params, base, expect", [ ( - {"provider": "provider1"}, - {"exists_in_2": ...}, + {"exists": ...}, ExtraParams, - OpenBBWarning, - ), - ( - {"provider": "inexistent_provider"}, - {"exists_in_both": ...}, - ExtraParams, - OpenBBWarning, + None, ), ( - {}, {"inexistent_field": ...}, ExtraParams, OpenBBWarning, ), - ( - {}, - {"inexistent_field": ...}, - object, - None, - ), - ( - {"provider": "provider2"}, - {"exists_in_2": ...}, - ExtraParams, - None, - ), - ( - {"provider": "provider2"}, - {"exists_in_both": ...}, - ExtraParams, - None, - ), - ( - {}, - {"exists_in_both": ...}, - ExtraParams, - None, - ), ], ) -def test_parameters_builder__warn_kwargs(provider_choices, extra_params, base, expect): +def test_parameters_builder__warn_kwargs(extra_params, base, expect): """Test _warn_kwargs.""" @dataclass class SomeModel(base): """SomeModel""" - exists_in_2: QueryParam = Query(..., title="provider2") - exists_in_both: QueryParam = Query(..., title="provider1,provider2") + exists: QueryParam = Query(...) class Model(BaseModel): """Model""" @@ -293,7 +260,7 @@ def test_parameters_builder__warn_kwargs(provider_choices, extra_params, base, e with pytest.warns(expect) as warning_info: # pylint: disable=protected-access - ParametersBuilder._warn_kwargs(provider_choices, extra_params, Model) + ParametersBuilder._warn_kwargs(extra_params, Model) if not expect: assert len(warning_info) == 0 diff --git a/openbb_platform/core/tests/app/test_query.py b/openbb_platform/core/tests/app/test_query.py index 95155415db8..082ab43db4e 100644 --- a/openbb_platform/core/tests/app/test_query.py +++ b/openbb_platform/core/tests/app/test_query.py @@ -95,6 +95,32 @@ def query_instance(): ) +def test_filter_extra_params(query): + """Test filter_extra_params.""" + extra_params = create_mock_extra_params() + extra_params = query.filter_extra_params(extra_params, "fmp") + + assert isinstance(extra_params, dict) + assert len(extra_params) == 0 + + +def test_filter_extra_params_wrong_param(query): + """Test filter_extra_params.""" + + @dataclass + class EquityHistorical: + """Mock ExtraParams dataclass.""" + + sort: str = "desc" + limit: int = 4 + + extra_params = EquityHistorical() + + extra = query.filter_extra_params(extra_params, "fmp") + assert isinstance(extra, dict) + assert len(extra) == 0 + + @pytest.mark.asyncio async def test_execute_method_fake_credentials(query_instance: Query, mock_registry): """Test execute method without setting credentials.""" |