summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormontezdesousa <79287829+montezdesousa@users.noreply.github.com>2024-03-20 13:16:50 +0000
committerGitHub <noreply@github.com>2024-03-20 13:16:50 +0000
commit7f4007a7514410a1cbd9734f8d06cbaa0e1bc554 (patch)
treedb045b881ac76a1eabbecd40d87ef92c6dfd2876
parentbab42a0178507f3b259065804f2f03f41d5110fb (diff)
[Feature] - Warn if inexistent kwargs (#6236)
* feat: warn if wrong kwargs * fix: remove query test * fix: add tests and fix bug --------- Co-authored-by: Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com>
-rw-r--r--openbb_platform/core/openbb_core/app/command_runner.py208
-rw-r--r--openbb_platform/core/openbb_core/app/model/user_settings.py2
-rw-r--r--openbb_platform/core/openbb_core/app/query.py45
-rw-r--r--openbb_platform/core/openbb_core/provider/query_executor.py6
-rw-r--r--openbb_platform/core/tests/app/test_command_runner.py64
-rw-r--r--openbb_platform/core/tests/app/test_query.py26
6 files changed, 190 insertions, 161 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py
index b6d06ed0f8e..c8d7f43e377 100644
--- a/openbb_platform/core/openbb_core/app/command_runner.py
+++ b/openbb_platform/core/openbb_core/app/command_runner.py
@@ -1,18 +1,20 @@
"""Command runner module."""
-import warnings
from copy import deepcopy
+from dataclasses import asdict, is_dataclass
from datetime import datetime
from inspect import Parameter, signature
from sys import exc_info
from time import perf_counter_ns
from typing import Any, Callable, Dict, List, Optional, Tuple
+from warnings import catch_warnings, showwarning, warn
-from pydantic import ConfigDict, create_model
+from fastapi.params import Query
+from pydantic import BaseModel, ConfigDict, create_model
from openbb_core.app.logs.logging_service import LoggingService
from openbb_core.app.model.abstract.error import OpenBBError
-from openbb_core.app.model.abstract.warning import cast_warning
+from openbb_core.app.model.abstract.warning import OpenBBWarning, cast_warning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.metadata import Metadata
from openbb_core.app.model.obbject import OBBject
@@ -176,13 +178,52 @@ class ParametersBuilder:
return kwargs
@staticmethod
+ def _warn_kwargs(
+ provider_choices: Dict[str, Any],
+ extra_params: Dict[str, Any],
+ model: BaseModel,
+ ) -> None:
+ """Warn if kwargs received and ignored by the validation model."""
+ # We only check the extra_params annotation because ignored fields
+ # will always be kwargs
+ annotation = getattr(
+ model.model_fields.get("extra_params", None), "annotation", None
+ )
+ if annotation:
+ # When there is no annotation there is nothing to warn
+ valid = asdict(annotation()) if is_dataclass(annotation) else {} # type: ignore
+ provider = provider_choices.get("provider", None)
+ for p in extra_params:
+ if field := valid.get(p):
+ if provider:
+ providers = (
+ field.title
+ if isinstance(field, Query) and isinstance(field.title, str)
+ else ""
+ ).split(",")
+ if provider not in providers:
+ warn(
+ message=f"Parameter '{p}' is not supported by '{provider}'."
+ f" Available for: {', '.join(providers)}.",
+ category=OpenBBWarning,
+ )
+ else:
+ warn(
+ message=f"Parameter '{p}' not found.",
+ category=OpenBBWarning,
+ )
+
+ @staticmethod
+ def _as_dict(obj: Any) -> Dict[str, Any]:
+ """Safely convert an object to a dict."""
+ return asdict(obj) if is_dataclass(obj) else dict(obj)
+
+ @staticmethod
def validate_kwargs(
func: Callable,
kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Validate kwargs and if possible coerce to the correct type."""
- config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
-
sig = signature(func)
fields = {
n: (
@@ -191,11 +232,17 @@ class ParametersBuilder:
)
for n, p in sig.parameters.items()
}
+ # We allow extra fields to return with model with 'cc: CommandContext'
+ config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
ValidationModel = create_model(func.__name__, __config__=config, **fields) # type: ignore
+ # Validate and coerce
model = ValidationModel(**kwargs)
- result = dict(model)
-
- return result
+ ParametersBuilder._warn_kwargs(
+ ParametersBuilder._as_dict(kwargs.get("provider_choices", {})),
+ ParametersBuilder._as_dict(kwargs.get("extra_params", {})),
+ ValidationModel,
+ )
+ return dict(model)
@classmethod
def build(
@@ -230,7 +277,10 @@ class ParametersBuilder:
kwargs=kwargs,
route_default=user_settings.defaults.routes.get(route, None),
)
- kwargs = cls.validate_kwargs(func=func, kwargs=kwargs)
+ kwargs = cls.validate_kwargs(
+ func=func,
+ kwargs=kwargs,
+ )
return kwargs
@@ -242,30 +292,12 @@ class StaticCommandRunner:
cls,
func: Callable,
kwargs: Dict[str, Any],
- show_warnings: bool = True,
) -> OBBject:
"""Run a command and return the output."""
-
- with warnings.catch_warnings(record=True) as warning_list:
- obbject = await maybe_coroutine(func, **kwargs)
- obbject.provider = getattr(
- kwargs.get("provider_choices", None), "provider", None
- )
-
- if warning_list:
- obbject.warnings = []
- for w in warning_list:
- obbject.warnings.append(cast_warning(w))
- if show_warnings:
- warnings.showwarning(
- message=w.message,
- category=w.category,
- filename=w.filename,
- lineno=w.lineno,
- file=w.file,
- line=w.line,
- )
-
+ obbject = await maybe_coroutine(func, **kwargs)
+ obbject.provider = getattr(
+ kwargs.get("provider_choices", None), "provider", None
+ )
return obbject
@classmethod
@@ -294,68 +326,70 @@ class StaticCommandRunner:
user_settings = execution_context.user_settings
system_settings = execution_context.system_settings
- # If we're on Jupyter we need to pop here because we will lose "chart" after
- # ParametersBuilder.build. This needs to be fixed in a way that chart is
- # added to the function signature and shared for jupyter and api
- # We can check in the router decorator if the given function has a chart
- # in the charting extension then we add it there. This way we can remove
- # the chart parameter from the commands.py and package_builder, it will be
- # added to the function signature in the router decorator
- chart = kwargs.pop("chart", False)
+ with catch_warnings(record=True) as warning_list:
+ # If we're on Jupyter we need to pop here because we will lose "chart" after
+ # ParametersBuilder.build. This needs to be fixed in a way that chart is
+ # added to the function signature and shared for jupyter and api
+ # We can check in the router decorator if the given function has a chart
+ # in the charting extension then we add it there. This way we can remove
+ # the chart parameter from the commands.py and package_builder, it will be
+ # added to the function signature in the router decorator
+ chart = kwargs.pop("chart", False)
+
+ kwargs = ParametersBuilder.build(
+ args=args,
+ execution_context=execution_context,
+ func=func,
+ route=route,
+ kwargs=kwargs,
+ )
- kwargs = ParametersBuilder.build(
- args=args,
- execution_context=execution_context,
- func=func,
- route=route,
- kwargs=kwargs,
- )
+ # If we're on the api we need to remove "chart" here because the parameter is added on
+ # commands.py and the function signature does not expect "chart"
+ kwargs.pop("chart", None)
+ # We also pop custom headers
+ model_headers = system_settings.api_settings.custom_headers or {}
+ custom_headers = {
+ name: kwargs.pop(name.replace("-", "_"), default)
+ for name, default in model_headers.items() or {}
+ } or None
- # If we're on the api we need to remove "chart" here because the parameter is added on
- # commands.py and the function signature does not expect "chart"
- kwargs.pop("chart", None)
- # We also pop custom headers
+ try:
+ obbject = await cls._command(func, kwargs)
+ # pylint: disable=protected-access
+ obbject._route = route
+ obbject._standard_params = kwargs.get("standard_params", None)
- model_headers = (
- SystemService().system_settings.api_settings.custom_headers or {}
- )
- custom_headers = {
- name: kwargs.pop(name.replace("-", "_"), default)
- for name, default in model_headers.items() or {}
- } or None
+ if chart and obbject.results:
+ cls._chart(obbject, **kwargs)
- try:
- obbject = await cls._command(
- func=func,
- kwargs=kwargs,
- show_warnings=user_settings.preferences.show_warnings,
- )
- # pylint: disable=protected-access
- obbject._route = route
- obbject._standard_params = kwargs.get("standard_params", None)
-
- if chart and obbject.results:
- cls._chart(
- obbject=obbject,
- **kwargs,
+ except Exception as e:
+ raise OpenBBError(e) from e
+ finally:
+ ls = LoggingService(system_settings, user_settings)
+ ls.log(
+ user_settings=user_settings,
+ system_settings=system_settings,
+ route=route,
+ func=func,
+ kwargs=kwargs,
+ exec_info=exc_info(),
+ custom_headers=custom_headers,
)
- except Exception as e:
- raise OpenBBError(e) from e
- finally:
- ls = LoggingService(
- user_settings=user_settings, system_settings=system_settings
- )
- ls.log(
- user_settings=user_settings,
- system_settings=system_settings,
- route=route,
- func=func,
- kwargs=kwargs,
- exec_info=exc_info(),
- custom_headers=custom_headers,
- )
-
+ if warning_list:
+ obbject.warnings = []
+ for w in warning_list:
+ obbject.warnings.append(cast_warning(w))
+ if user_settings.preferences.show_warnings:
+ showwarning(
+ message=w.message,
+ category=w.category,
+ filename=w.filename,
+ lineno=w.lineno,
+ file=w.file,
+ line=w.line,
+ )
return obbject
@classmethod
diff --git a/openbb_platform/core/openbb_core/app/model/user_settings.py b/openbb_platform/core/openbb_core/app/model/user_settings.py
index 1df4c286c58..ea3ada12541 100644
--- a/openbb_platform/core/openbb_core/app/model/user_settings.py
+++ b/openbb_platform/core/openbb_core/app/model/user_settings.py
@@ -19,8 +19,6 @@ class UserSettings(Tagged):
def __repr__(self) -> str:
"""Human readable representation of the object."""
- # We use the __dict__ because Credentials.model_dump() will use the serializer
- # and unmask the credentials
return f"{self.__class__.__name__}\n\n" + "\n".join(
f"{k}: {v}" for k, v in self.model_dump().items()
)
diff --git a/openbb_platform/core/openbb_core/app/query.py b/openbb_platform/core/openbb_core/app/query.py
index 3f2cfc878d0..1a23ea572fb 100644
--- a/openbb_platform/core/openbb_core/app/query.py
+++ b/openbb_platform/core/openbb_core/app/query.py
@@ -1,10 +1,8 @@
"""Query class."""
-import warnings
from dataclasses import asdict
-from typing import Any, Dict
+from typing import Any
-from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.provider_interface import (
ExtraParams,
@@ -30,49 +28,14 @@ class Query:
self.standard_params = standard_params
self.extra_params = extra_params
self.name = self.standard_params.__class__.__name__
- self.provider_interface = ProviderInterface()
-
- def filter_extra_params(
- self,
- extra_params: ExtraParams,
- provider_name: str,
- ) -> Dict[str, Any]:
- """Filter extra params based on the provider and warn if not supported."""
- original = asdict(extra_params)
- filtered = {}
-
- query = extra_params.__class__.__name__
- fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore
-
- for k, v in original.items():
- f = fields[k]
- providers = f.title.split(",") if hasattr(f, "title") else []
- if v != f.default:
- if provider_name in providers:
- filtered[k] = v
- else:
- available = ", ".join(providers)
- warnings.warn(
- message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.",
- category=OpenBBWarning,
- )
-
- return filtered
+ self.query_executor = ProviderInterface().create_executor()
async def execute(self) -> Any:
"""Execute the query."""
- standard_dict = asdict(self.standard_params)
- extra_dict = (
- self.filter_extra_params(self.extra_params, self.provider)
- if self.extra_params
- else {}
- )
- query_executor = self.provider_interface.create_executor()
-
- return await query_executor.execute(
+ return await self.query_executor.execute(
provider_name=self.provider,
model_name=self.name,
- params={**standard_dict, **extra_dict},
+ params={**asdict(self.standard_params), **asdict(self.extra_params)},
credentials=self.cc.user_settings.credentials.model_dump(),
preferences=self.cc.user_settings.preferences.model_dump(),
)
diff --git a/openbb_platform/core/openbb_core/provider/query_executor.py b/openbb_platform/core/openbb_core/provider/query_executor.py
index 5d2458b9f54..7a742a0c3bd 100644
--- a/openbb_platform/core/openbb_core/provider/query_executor.py
+++ b/openbb_platform/core/openbb_core/provider/query_executor.py
@@ -93,8 +93,4 @@ class QueryExecutor:
filtered_credentials = self.filter_credentials(
credentials, provider, fetcher.require_credentials
)
-
- try:
- return await fetcher.fetch_data(params, filtered_credentials, **kwargs)
- except Exception as e:
- raise OpenBBError(e) from e
+ return await fetcher.fetch_data(params, filtered_credentials, **kwargs)
diff --git a/openbb_platform/core/tests/app/test_command_runner.py b/openbb_platform/core/tests/app/test_command_runner.py
index 00659a6e0c5..97bfc6df04d 100644
--- a/openbb_platform/core/tests/app/test_command_runner.py
+++ b/openbb_platform/core/tests/app/test_command_runner.py
@@ -1,18 +1,23 @@
+from dataclasses import dataclass
from inspect import Parameter
from typing import Dict, List
from unittest.mock import Mock, patch
import pytest
+from fastapi import Query
+from fastapi.params import Query as QueryParam
from openbb_core.app.command_runner import (
CommandRunner,
ExecutionContext,
ParametersBuilder,
StaticCommandRunner,
)
+from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.router import CommandMap
+from pydantic import BaseModel, ConfigDict
@pytest.fixture()
@@ -224,6 +229,65 @@ def test_parameters_builder_validate_kwargs(mock_func):
assert result == {"a": 1, "b": 2, "c": 3.0, "d": 4, "provider_choices": {}}
+@pytest.mark.parametrize(
+ "provider_choices, extra_params, expect",
+ [
+ (
+ {"provider": "provider1"},
+ {"exists_in_2": ...},
+ OpenBBWarning,
+ ),
+ (
+ {"provider": "inexistent_provider"},
+ {"exists_in_both": ...},
+ OpenBBWarning,
+ ),
+ (
+ {},
+ {"inexistent_field": ...},
+ OpenBBWarning,
+ ),
+ (
+ {"provider": "provider2"},
+ {"exists_in_2": ...},
+ None,
+ ),
+ (
+ {"provider": "provider2"},
+ {"exists_in_both": ...},
+ None,
+ ),
+ (
+ {},
+ {"exists_in_both": ...},
+ None,
+ ),
+ ],
+)
+def test_parameters_builder__warn_kwargs(provider_choices, extra_params, expect):
+ """Test _warn_kwargs."""
+
+ @dataclass
+ class SomeModel:
+ """SomeModel"""
+
+ exists_in_2: QueryParam = Query(..., title="provider2")
+ exists_in_both: QueryParam = Query(..., title="provider1,provider2")
+
+ class Model(BaseModel):
+ """Model"""
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+ extra_params: SomeModel
+
+ with pytest.warns(expect) as warning_info:
+ # pylint: disable=protected-access
+ ParametersBuilder._warn_kwargs(provider_choices, extra_params, Model)
+
+ if not expect:
+ assert len(warning_info) == 0
+
+
def test_parameters_builder_build(mock_func, execution_context):
"""Test build."""
diff --git a/openbb_platform/core/tests/app/test_query.py b/openbb_platform/core/tests/app/test_query.py
index 42cee25f79b..95155415db8 100644
--- a/openbb_platform/core/tests/app/test_query.py
+++ b/openbb_platform/core/tests/app/test_query.py
@@ -63,32 +63,6 @@ def test_init(query):
assert query
-def test_filter_extra_params(query):
- """Test filter_extra_params."""
- extra_params = create_mock_extra_params()
- extra_params = query.filter_extra_params(extra_params, "fmp")
-
- assert isinstance(extra_params, dict)
- assert len(extra_params) == 0
-
-
-def test_filter_extra_params_wrong_param(query):
- """Test filter_extra_params."""
-
- @dataclass
- class EquityHistorical:
- """Mock ExtraParams dataclass."""
-
- sort: str = "desc"
- limit: int = 4
-
- extra_params = EquityHistorical()
-
- extra = query.filter_extra_params(extra_params, "fmp")
- assert isinstance(extra, dict)
- assert len(extra) == 0
-
-
@pytest.fixture
def mock_registry():
"""Mock registry."""