diff options
author | montezdesousa <79287829+montezdesousa@users.noreply.github.com> | 2024-03-20 13:16:50 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-20 13:16:50 +0000 |
commit | 7f4007a7514410a1cbd9734f8d06cbaa0e1bc554 (patch) | |
tree | db045b881ac76a1eabbecd40d87ef92c6dfd2876 | |
parent | bab42a0178507f3b259065804f2f03f41d5110fb (diff) |
[Feature] - Warn if inexistent kwargs (#6236)
* feat: warn if wrong kwargs
* fix: remove query test
* fix: add tests and fix bug
---------
Co-authored-by: Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com>
6 files changed, 190 insertions, 161 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py index b6d06ed0f8e..c8d7f43e377 100644 --- a/openbb_platform/core/openbb_core/app/command_runner.py +++ b/openbb_platform/core/openbb_core/app/command_runner.py @@ -1,18 +1,20 @@ """Command runner module.""" -import warnings from copy import deepcopy +from dataclasses import asdict, is_dataclass from datetime import datetime from inspect import Parameter, signature from sys import exc_info from time import perf_counter_ns from typing import Any, Callable, Dict, List, Optional, Tuple +from warnings import catch_warnings, showwarning, warn -from pydantic import ConfigDict, create_model +from fastapi.params import Query +from pydantic import BaseModel, ConfigDict, create_model from openbb_core.app.logs.logging_service import LoggingService from openbb_core.app.model.abstract.error import OpenBBError -from openbb_core.app.model.abstract.warning import cast_warning +from openbb_core.app.model.abstract.warning import OpenBBWarning, cast_warning from openbb_core.app.model.command_context import CommandContext from openbb_core.app.model.metadata import Metadata from openbb_core.app.model.obbject import OBBject @@ -176,13 +178,52 @@ class ParametersBuilder: return kwargs @staticmethod + def _warn_kwargs( + provider_choices: Dict[str, Any], + extra_params: Dict[str, Any], + model: BaseModel, + ) -> None: + """Warn if kwargs received and ignored by the validation model.""" + # We only check the extra_params annotation because ignored fields + # will always be kwargs + annotation = getattr( + model.model_fields.get("extra_params", None), "annotation", None + ) + if annotation: + # When there is no annotation there is nothing to warn + valid = asdict(annotation()) if is_dataclass(annotation) else {} # type: ignore + provider = provider_choices.get("provider", None) + for p in extra_params: + if field := valid.get(p): + if provider: + providers = ( + field.title + if isinstance(field, Query) and isinstance(field.title, str) + else "" + ).split(",") + if provider not in providers: + warn( + message=f"Parameter '{p}' is not supported by '{provider}'." + f" Available for: {', '.join(providers)}.", + category=OpenBBWarning, + ) + else: + warn( + message=f"Parameter '{p}' not found.", + category=OpenBBWarning, + ) + + @staticmethod + def _as_dict(obj: Any) -> Dict[str, Any]: + """Safely convert an object to a dict.""" + return asdict(obj) if is_dataclass(obj) else dict(obj) + + @staticmethod def validate_kwargs( func: Callable, kwargs: Dict[str, Any], ) -> Dict[str, Any]: """Validate kwargs and if possible coerce to the correct type.""" - config = ConfigDict(extra="allow", arbitrary_types_allowed=True) - sig = signature(func) fields = { n: ( @@ -191,11 +232,17 @@ class ParametersBuilder: ) for n, p in sig.parameters.items() } + # We allow extra fields to return with model with 'cc: CommandContext' + config = ConfigDict(extra="allow", arbitrary_types_allowed=True) ValidationModel = create_model(func.__name__, __config__=config, **fields) # type: ignore + # Validate and coerce model = ValidationModel(**kwargs) - result = dict(model) - - return result + ParametersBuilder._warn_kwargs( + ParametersBuilder._as_dict(kwargs.get("provider_choices", {})), + ParametersBuilder._as_dict(kwargs.get("extra_params", {})), + ValidationModel, + ) + return dict(model) @classmethod def build( @@ -230,7 +277,10 @@ class ParametersBuilder: kwargs=kwargs, route_default=user_settings.defaults.routes.get(route, None), ) - kwargs = cls.validate_kwargs(func=func, kwargs=kwargs) + kwargs = cls.validate_kwargs( + func=func, + kwargs=kwargs, + ) return kwargs @@ -242,30 +292,12 @@ class StaticCommandRunner: cls, func: Callable, kwargs: Dict[str, Any], - show_warnings: bool = True, ) -> OBBject: """Run a command and return the output.""" - - with warnings.catch_warnings(record=True) as warning_list: - obbject = await maybe_coroutine(func, **kwargs) - obbject.provider = getattr( - kwargs.get("provider_choices", None), "provider", None - ) - - if warning_list: - obbject.warnings = [] - for w in warning_list: - obbject.warnings.append(cast_warning(w)) - if show_warnings: - warnings.showwarning( - message=w.message, - category=w.category, - filename=w.filename, - lineno=w.lineno, - file=w.file, - line=w.line, - ) - + obbject = await maybe_coroutine(func, **kwargs) + obbject.provider = getattr( + kwargs.get("provider_choices", None), "provider", None + ) return obbject @classmethod @@ -294,68 +326,70 @@ class StaticCommandRunner: user_settings = execution_context.user_settings system_settings = execution_context.system_settings - # If we're on Jupyter we need to pop here because we will lose "chart" after - # ParametersBuilder.build. This needs to be fixed in a way that chart is - # added to the function signature and shared for jupyter and api - # We can check in the router decorator if the given function has a chart - # in the charting extension then we add it there. This way we can remove - # the chart parameter from the commands.py and package_builder, it will be - # added to the function signature in the router decorator - chart = kwargs.pop("chart", False) + with catch_warnings(record=True) as warning_list: + # If we're on Jupyter we need to pop here because we will lose "chart" after + # ParametersBuilder.build. This needs to be fixed in a way that chart is + # added to the function signature and shared for jupyter and api + # We can check in the router decorator if the given function has a chart + # in the charting extension then we add it there. This way we can remove + # the chart parameter from the commands.py and package_builder, it will be + # added to the function signature in the router decorator + chart = kwargs.pop("chart", False) + + kwargs = ParametersBuilder.build( + args=args, + execution_context=execution_context, + func=func, + route=route, + kwargs=kwargs, + ) - kwargs = ParametersBuilder.build( - args=args, - execution_context=execution_context, - func=func, - route=route, - kwargs=kwargs, - ) + # If we're on the api we need to remove "chart" here because the parameter is added on + # commands.py and the function signature does not expect "chart" + kwargs.pop("chart", None) + # We also pop custom headers + model_headers = system_settings.api_settings.custom_headers or {} + custom_headers = { + name: kwargs.pop(name.replace("-", "_"), default) + for name, default in model_headers.items() or {} + } or None - # If we're on the api we need to remove "chart" here because the parameter is added on - # commands.py and the function signature does not expect "chart" - kwargs.pop("chart", None) - # We also pop custom headers + try: + obbject = await cls._command(func, kwargs) + # pylint: disable=protected-access + obbject._route = route + obbject._standard_params = kwargs.get("standard_params", None) - model_headers = ( - SystemService().system_settings.api_settings.custom_headers or {} - ) - custom_headers = { - name: kwargs.pop(name.replace("-", "_"), default) - for name, default in model_headers.items() or {} - } or None + if chart and obbject.results: + cls._chart(obbject, **kwargs) - try: - obbject = await cls._command( - func=func, - kwargs=kwargs, - show_warnings=user_settings.preferences.show_warnings, - ) - # pylint: disable=protected-access - obbject._route = route - obbject._standard_params = kwargs.get("standard_params", None) - - if chart and obbject.results: - cls._chart( - obbject=obbject, - **kwargs, + except Exception as e: + raise OpenBBError(e) from e + finally: + ls = LoggingService(system_settings, user_settings) + ls.log( + user_settings=user_settings, + system_settings=system_settings, + route=route, + func=func, + kwargs=kwargs, + exec_info=exc_info(), + custom_headers=custom_headers, ) - except Exception as e: - raise OpenBBError(e) from e - finally: - ls = LoggingService( - user_settings=user_settings, system_settings=system_settings - ) - ls.log( - user_settings=user_settings, - system_settings=system_settings, - route=route, - func=func, - kwargs=kwargs, - exec_info=exc_info(), - custom_headers=custom_headers, - ) - + if warning_list: + obbject.warnings = [] + for w in warning_list: + obbject.warnings.append(cast_warning(w)) + if user_settings.preferences.show_warnings: + showwarning( + message=w.message, + category=w.category, + filename=w.filename, + lineno=w.lineno, + file=w.file, + line=w.line, + ) return obbject @classmethod diff --git a/openbb_platform/core/openbb_core/app/model/user_settings.py b/openbb_platform/core/openbb_core/app/model/user_settings.py index 1df4c286c58..ea3ada12541 100644 --- a/openbb_platform/core/openbb_core/app/model/user_settings.py +++ b/openbb_platform/core/openbb_core/app/model/user_settings.py @@ -19,8 +19,6 @@ class UserSettings(Tagged): def __repr__(self) -> str: """Human readable representation of the object.""" - # We use the __dict__ because Credentials.model_dump() will use the serializer - # and unmask the credentials return f"{self.__class__.__name__}\n\n" + "\n".join( f"{k}: {v}" for k, v in self.model_dump().items() ) diff --git a/openbb_platform/core/openbb_core/app/query.py b/openbb_platform/core/openbb_core/app/query.py index 3f2cfc878d0..1a23ea572fb 100644 --- a/openbb_platform/core/openbb_core/app/query.py +++ b/openbb_platform/core/openbb_core/app/query.py @@ -1,10 +1,8 @@ """Query class.""" -import warnings from dataclasses import asdict -from typing import Any, Dict +from typing import Any -from openbb_core.app.model.abstract.warning import OpenBBWarning from openbb_core.app.model.command_context import CommandContext from openbb_core.app.provider_interface import ( ExtraParams, @@ -30,49 +28,14 @@ class Query: self.standard_params = standard_params self.extra_params = extra_params self.name = self.standard_params.__class__.__name__ - self.provider_interface = ProviderInterface() - - def filter_extra_params( - self, - extra_params: ExtraParams, - provider_name: str, - ) -> Dict[str, Any]: - """Filter extra params based on the provider and warn if not supported.""" - original = asdict(extra_params) - filtered = {} - - query = extra_params.__class__.__name__ - fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore - - for k, v in original.items(): - f = fields[k] - providers = f.title.split(",") if hasattr(f, "title") else [] - if v != f.default: - if provider_name in providers: - filtered[k] = v - else: - available = ", ".join(providers) - warnings.warn( - message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.", - category=OpenBBWarning, - ) - - return filtered + self.query_executor = ProviderInterface().create_executor() async def execute(self) -> Any: """Execute the query.""" - standard_dict = asdict(self.standard_params) - extra_dict = ( - self.filter_extra_params(self.extra_params, self.provider) - if self.extra_params - else {} - ) - query_executor = self.provider_interface.create_executor() - - return await query_executor.execute( + return await self.query_executor.execute( provider_name=self.provider, model_name=self.name, - params={**standard_dict, **extra_dict}, + params={**asdict(self.standard_params), **asdict(self.extra_params)}, credentials=self.cc.user_settings.credentials.model_dump(), preferences=self.cc.user_settings.preferences.model_dump(), ) diff --git a/openbb_platform/core/openbb_core/provider/query_executor.py b/openbb_platform/core/openbb_core/provider/query_executor.py index 5d2458b9f54..7a742a0c3bd 100644 --- a/openbb_platform/core/openbb_core/provider/query_executor.py +++ b/openbb_platform/core/openbb_core/provider/query_executor.py @@ -93,8 +93,4 @@ class QueryExecutor: filtered_credentials = self.filter_credentials( credentials, provider, fetcher.require_credentials ) - - try: - return await fetcher.fetch_data(params, filtered_credentials, **kwargs) - except Exception as e: - raise OpenBBError(e) from e + return await fetcher.fetch_data(params, filtered_credentials, **kwargs) diff --git a/openbb_platform/core/tests/app/test_command_runner.py b/openbb_platform/core/tests/app/test_command_runner.py index 00659a6e0c5..97bfc6df04d 100644 --- a/openbb_platform/core/tests/app/test_command_runner.py +++ b/openbb_platform/core/tests/app/test_command_runner.py @@ -1,18 +1,23 @@ +from dataclasses import dataclass from inspect import Parameter from typing import Dict, List from unittest.mock import Mock, patch import pytest +from fastapi import Query +from fastapi.params import Query as QueryParam from openbb_core.app.command_runner import ( CommandRunner, ExecutionContext, ParametersBuilder, StaticCommandRunner, ) +from openbb_core.app.model.abstract.warning import OpenBBWarning from openbb_core.app.model.command_context import CommandContext from openbb_core.app.model.system_settings import SystemSettings from openbb_core.app.model.user_settings import UserSettings from openbb_core.app.router import CommandMap +from pydantic import BaseModel, ConfigDict @pytest.fixture() @@ -224,6 +229,65 @@ def test_parameters_builder_validate_kwargs(mock_func): assert result == {"a": 1, "b": 2, "c": 3.0, "d": 4, "provider_choices": {}} +@pytest.mark.parametrize( + "provider_choices, extra_params, expect", + [ + ( + {"provider": "provider1"}, + {"exists_in_2": ...}, + OpenBBWarning, + ), + ( + {"provider": "inexistent_provider"}, + {"exists_in_both": ...}, + OpenBBWarning, + ), + ( + {}, + {"inexistent_field": ...}, + OpenBBWarning, + ), + ( + {"provider": "provider2"}, + {"exists_in_2": ...}, + None, + ), + ( + {"provider": "provider2"}, + {"exists_in_both": ...}, + None, + ), + ( + {}, + {"exists_in_both": ...}, + None, + ), + ], +) +def test_parameters_builder__warn_kwargs(provider_choices, extra_params, expect): + """Test _warn_kwargs.""" + + @dataclass + class SomeModel: + """SomeModel""" + + exists_in_2: QueryParam = Query(..., title="provider2") + exists_in_both: QueryParam = Query(..., title="provider1,provider2") + + class Model(BaseModel): + """Model""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + extra_params: SomeModel + + with pytest.warns(expect) as warning_info: + # pylint: disable=protected-access + ParametersBuilder._warn_kwargs(provider_choices, extra_params, Model) + + if not expect: + assert len(warning_info) == 0 + + def test_parameters_builder_build(mock_func, execution_context): """Test build.""" diff --git a/openbb_platform/core/tests/app/test_query.py b/openbb_platform/core/tests/app/test_query.py index 42cee25f79b..95155415db8 100644 --- a/openbb_platform/core/tests/app/test_query.py +++ b/openbb_platform/core/tests/app/test_query.py @@ -63,32 +63,6 @@ def test_init(query): assert query -def test_filter_extra_params(query): - """Test filter_extra_params.""" - extra_params = create_mock_extra_params() - extra_params = query.filter_extra_params(extra_params, "fmp") - - assert isinstance(extra_params, dict) - assert len(extra_params) == 0 - - -def test_filter_extra_params_wrong_param(query): - """Test filter_extra_params.""" - - @dataclass - class EquityHistorical: - """Mock ExtraParams dataclass.""" - - sort: str = "desc" - limit: int = 4 - - extra_params = EquityHistorical() - - extra = query.filter_extra_params(extra_params, "fmp") - assert isinstance(extra, dict) - assert len(extra) == 0 - - @pytest.fixture def mock_registry(): """Mock registry.""" |