summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormontezdesousa <79287829+montezdesousa@users.noreply.github.com>2024-03-25 12:14:27 +0000
committerGitHub <noreply@github.com>2024-03-25 12:14:27 +0000
commit1794d6973f35c6e85ebb8a5be06c536f1c4248a7 (patch)
treee5d48df85576d928111ab16859c9a07b410bc250
parent0cec1180b08876eebc796081df7900479fb6bc95 (diff)
[BugFix] - Untyped variadic keyword arguments break during execution (#6250)
* fix: ensure var_kw args come last in signatures * remove package builder change * Update package_builder.py * fix: warn only when ExtraParams * track VAR_KEYWORD * minor fix * type Any if no type provided * minor fix * fix test * fix: _as_dict * ruff * update reorder_params unit test * update func default and tests * typing * update comment * update comment * rename var * Update command_runner.py --------- Co-authored-by: Danglewood <85772166+deeleeramone@users.noreply.github.com>
-rw-r--r--openbb_platform/core/openbb_core/api/router/commands.py27
-rw-r--r--openbb_platform/core/openbb_core/app/command_runner.py22
-rw-r--r--openbb_platform/core/openbb_core/app/static/package_builder.py21
-rw-r--r--openbb_platform/core/tests/app/static/test_package_builder.py60
-rw-r--r--openbb_platform/core/tests/app/test_command_runner.py19
5 files changed, 111 insertions, 38 deletions
diff --git a/openbb_platform/core/openbb_core/api/router/commands.py b/openbb_platform/core/openbb_core/api/router/commands.py
index 7b644cca372..d524156ddbd 100644
--- a/openbb_platform/core/openbb_core/api/router/commands.py
+++ b/openbb_platform/core/openbb_core/api/router/commands.py
@@ -51,11 +51,16 @@ def build_new_signature(path: str, func: Callable) -> Signature:
parameter_list = sig.parameters.values()
return_annotation = sig.return_annotation
new_parameter_list = []
-
- for parameter in parameter_list:
+ var_kw_pos = len(parameter_list)
+ for pos, parameter in enumerate(parameter_list):
if parameter.name == "cc" and parameter.annotation == CommandContext:
continue
+ if parameter.kind == Parameter.VAR_KEYWORD:
+ # We track VAR_KEYWORD parameter to insert the any additional
+ # parameters we need to add before it and avoid a SyntaxError
+ var_kw_pos = pos
+
new_parameter_list.append(
Parameter(
parameter.name,
@@ -66,18 +71,21 @@ def build_new_signature(path: str, func: Callable) -> Signature:
)
if CHARTING_INSTALLED and path.replace("/", "_")[1:] in Charting.functions():
- new_parameter_list.append(
+ new_parameter_list.insert(
+ var_kw_pos,
Parameter(
"chart",
kind=Parameter.POSITIONAL_OR_KEYWORD,
default=False,
annotation=bool,
- )
+ ),
)
+ var_kw_pos += 1
if custom_headers := SystemService().system_settings.api_settings.custom_headers:
for name, default in custom_headers.items():
- new_parameter_list.append(
+ new_parameter_list.insert(
+ var_kw_pos,
Parameter(
name.replace("-", "_"),
kind=Parameter.POSITIONAL_OR_KEYWORD,
@@ -85,11 +93,13 @@ def build_new_signature(path: str, func: Callable) -> Signature:
annotation=Annotated[
Optional[str], Header(include_in_schema=False)
],
- )
+ ),
)
+ var_kw_pos += 1
if Env().API_AUTH:
- new_parameter_list.append(
+ new_parameter_list.insert(
+ var_kw_pos,
Parameter(
"__authenticated_user_settings",
kind=Parameter.POSITIONAL_OR_KEYWORD,
@@ -97,8 +107,9 @@ def build_new_signature(path: str, func: Callable) -> Signature:
annotation=Annotated[
UserSettings, Depends(AuthService().user_settings_hook)
],
- )
+ ),
)
+ var_kw_pos += 1
return Signature(
parameters=new_parameter_list,
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py
index 33fb4def99f..9f8b52c8714 100644
--- a/openbb_platform/core/openbb_core/app/command_runner.py
+++ b/openbb_platform/core/openbb_core/app/command_runner.py
@@ -20,7 +20,7 @@ from openbb_core.app.model.metadata import Metadata
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
-from openbb_core.app.provider_interface import ProviderInterface
+from openbb_core.app.provider_interface import ExtraParams, ProviderInterface
from openbb_core.app.router import CommandMap
from openbb_core.app.service.system_service import SystemService
from openbb_core.app.service.user_service import UserService
@@ -185,13 +185,16 @@ class ParametersBuilder:
) -> None:
"""Warn if kwargs received and ignored by the validation model."""
# We only check the extra_params annotation because ignored fields
- # will always be kwargs
+ # will always be there
annotation = getattr(
model.model_fields.get("extra_params", None), "annotation", None
)
- if annotation:
- # When there is no annotation there is nothing to warn
- valid = asdict(annotation()) if is_dataclass(annotation) else {} # type: ignore
+ if is_dataclass(annotation) and any(
+ t is ExtraParams for t in getattr(annotation, "__bases__", [])
+ ):
+ # We only warn when endpoint defines ExtraParams, so we need
+ # to check if the annotation is a dataclass and child of ExtraParams
+ valid = asdict(annotation()) # type: ignore
provider = provider_choices.get("provider", None)
for p in extra_params:
if field := valid.get(p):
@@ -216,7 +219,12 @@ class ParametersBuilder:
@staticmethod
def _as_dict(obj: Any) -> Dict[str, Any]:
"""Safely convert an object to a dict."""
- return asdict(obj) if is_dataclass(obj) else dict(obj)
+ try:
+ if isinstance(obj, dict):
+ return obj
+ return asdict(obj) if is_dataclass(obj) else dict(obj)
+ except Exception:
+ return {}
@staticmethod
def validate_kwargs(
@@ -227,7 +235,7 @@ class ParametersBuilder:
sig = signature(func)
fields = {
n: (
- p.annotation,
+ Any if p.annotation is Parameter.empty else p.annotation,
... if p.default is Parameter.empty else p.default,
)
for n, p in sig.parameters.items()
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py
index 793f82731c6..0faab062605 100644
--- a/openbb_platform/core/openbb_core/app/static/package_builder.py
+++ b/openbb_platform/core/openbb_core/app/static/package_builder.py
@@ -528,10 +528,12 @@ class MethodDefinition:
return getattr(PathHandler.build_route_map()[path], "summary", "")
@staticmethod
- def reorder_params(params: Dict[str, Parameter]) -> "OrderedDict[str, Parameter]":
- """Reorder the params."""
+ def reorder_params(
+ params: Dict[str, Parameter], var_kw: Optional[List[str]] = None
+ ) -> "OrderedDict[str, Parameter]":
+ """Reorder the params and make sure VAR_KEYWORD come after 'provider."""
formatted_keys = list(params.keys())
- for k in ["provider", "extra_params"]:
+ for k in ["provider"] + (var_kw or []):
if k in formatted_keys:
formatted_keys.remove(k)
formatted_keys.append(k)
@@ -563,14 +565,11 @@ class MethodDefinition:
)
formatted: Dict[str, Parameter] = {}
-
+ var_kw = []
for name, param in parameter_map.items():
if name == "extra_params":
formatted[name] = Parameter(name="kwargs", kind=Parameter.VAR_KEYWORD)
- elif name == "kwargs":
- formatted["**" + name] = Parameter(
- name="kwargs", kind=Parameter.VAR_KEYWORD, annotation=Any
- )
+ var_kw.append(name)
elif name == "provider_choices":
fields = param.annotation.__args__[0].__dataclass_fields__
field = fields["provider"]
@@ -624,12 +623,14 @@ class MethodDefinition:
formatted[name] = Parameter(
name=name,
- kind=Parameter.POSITIONAL_OR_KEYWORD,
+ kind=param.kind,
annotation=updated_type,
default=param.default,
)
+ if param.kind == Parameter.VAR_KEYWORD:
+ var_kw.append(name)
- return MethodDefinition.reorder_params(params=formatted)
+ return MethodDefinition.reorder_params(params=formatted, var_kw=var_kw)
@staticmethod
def add_field_custom_annotations(
diff --git a/openbb_platform/core/tests/app/static/test_package_builder.py b/openbb_platform/core/tests/app/static/test_package_builder.py
index aa7c0029475..029c200fa45 100644
--- a/openbb_platform/core/tests/app/static/test_package_builder.py
+++ b/openbb_platform/core/tests/app/static/test_package_builder.py
@@ -206,17 +206,57 @@ def test_is_annotated_dc_annotated(method_definition):
assert result
-def test_reorder_params(method_definition):
- """Test reorder params."""
- params = {
- "provider": Parameter.empty,
- "extra_params": Parameter.empty,
- "param1": Parameter.empty,
- "param2": Parameter.empty,
- }
- result = method_definition.reorder_params(params=params)
+@pytest.mark.parametrize(
+ "params, var_kw, expected",
+ [
+ (
+ {
+ "provider": Parameter.empty,
+ "extra_params": Parameter.empty,
+ "param1": Parameter.empty,
+ "param2": Parameter.empty,
+ },
+ None,
+ ["extra_params", "param1", "param2", "provider"],
+ ),
+ (
+ {
+ "param1": Parameter.empty,
+ "provider": Parameter.empty,
+ "extra_params": Parameter.empty,
+ "param2": Parameter.empty,
+ },
+ ["extra_params"],
+ ["param1", "param2", "provider", "extra_params"],
+ ),
+ (
+ {
+ "param2": Parameter.empty,
+ "any_kwargs": Parameter.empty,
+ "provider": Parameter.empty,
+ "param1": Parameter.empty,
+ },
+ ["any_kwargs"],
+ ["param2", "param1", "provider", "any_kwargs"],
+ ),
+ (
+ {
+ "any_kwargs": Parameter.empty,
+ "extra_params": Parameter.empty,
+ "provider": Parameter.empty,
+ "param1": Parameter.empty,
+ "param2": Parameter.empty,
+ },
+ ["any_kwargs", "extra_params"],
+ ["param1", "param2", "provider", "any_kwargs", "extra_params"],
+ ),
+ ],
+)
+def test_reorder_params(method_definition, params, var_kw, expected):
+ """Test reorder params, ensure var_kw are last after 'provider'."""
+ result = method_definition.reorder_params(params, var_kw)
assert result
- assert list(result.keys()) == ["param1", "param2", "provider", "extra_params"]
+ assert list(result.keys()) == expected
def test_build_func_params(method_definition):
diff --git a/openbb_platform/core/tests/app/test_command_runner.py b/openbb_platform/core/tests/app/test_command_runner.py
index 3447d66407b..cd751e9c592 100644
--- a/openbb_platform/core/tests/app/test_command_runner.py
+++ b/openbb_platform/core/tests/app/test_command_runner.py
@@ -16,6 +16,7 @@ from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
+from openbb_core.app.provider_interface import ExtraParams
from openbb_core.app.router import CommandMap
from pydantic import BaseModel, ConfigDict
@@ -228,45 +229,57 @@ def test_parameters_builder_validate_kwargs(mock_func):
@pytest.mark.parametrize(
- "provider_choices, extra_params, expect",
+ "provider_choices, extra_params, base, expect",
[
(
{"provider": "provider1"},
{"exists_in_2": ...},
+ ExtraParams,
OpenBBWarning,
),
(
{"provider": "inexistent_provider"},
{"exists_in_both": ...},
+ ExtraParams,
OpenBBWarning,
),
(
{},
{"inexistent_field": ...},
+ ExtraParams,
OpenBBWarning,
),
(
+ {},
+ {"inexistent_field": ...},
+ object,
+ None,
+ ),
+ (
{"provider": "provider2"},
{"exists_in_2": ...},
+ ExtraParams,
None,
),
(
{"provider": "provider2"},
{"exists_in_both": ...},
+ ExtraParams,
None,
),
(
{},
{"exists_in_both": ...},
+ ExtraParams,
None,
),
],
)
-def test_parameters_builder__warn_kwargs(provider_choices, extra_params, expect):
+def test_parameters_builder__warn_kwargs(provider_choices, extra_params, base, expect):
"""Test _warn_kwargs."""
@dataclass
- class SomeModel:
+ class SomeModel(base):
"""SomeModel"""
exists_in_2: QueryParam = Query(..., title="provider2")