summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/app/command_runner.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/app/command_runner.py')
-rw-r--r--openbb_platform/core/openbb_core/app/command_runner.py78
1 files changed, 47 insertions, 31 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py
index c8d7f43e377..a8c368839f8 100644
--- a/openbb_platform/core/openbb_core/app/command_runner.py
+++ b/openbb_platform/core/openbb_core/app/command_runner.py
@@ -1,15 +1,16 @@
"""Command runner module."""
+# pylint: disable=R0903
+
from copy import deepcopy
from dataclasses import asdict, is_dataclass
from datetime import datetime
from inspect import Parameter, signature
from sys import exc_info
from time import perf_counter_ns
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from warnings import catch_warnings, showwarning, warn
-from fastapi.params import Query
from pydantic import BaseModel, ConfigDict, create_model
from openbb_core.app.logs.logging_service import LoggingService
@@ -20,7 +21,7 @@ from openbb_core.app.model.metadata import Metadata
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
-from openbb_core.app.provider_interface import ProviderInterface
+from openbb_core.app.provider_interface import ExtraParams, ProviderInterface
from openbb_core.app.router import CommandMap
from openbb_core.app.service.system_service import SystemService
from openbb_core.app.service.user_service import UserService
@@ -38,6 +39,7 @@ class ExecutionContext:
system_settings: SystemSettings,
user_settings: UserSettings,
) -> None:
+ """Initialize the execution context."""
self.command_map = command_map
self.route = route
self.system_settings = system_settings
@@ -179,35 +181,23 @@ class ParametersBuilder:
@staticmethod
def _warn_kwargs(
- provider_choices: Dict[str, Any],
extra_params: Dict[str, Any],
- model: BaseModel,
+ model: Type[BaseModel],
) -> None:
"""Warn if kwargs received and ignored by the validation model."""
# We only check the extra_params annotation because ignored fields
- # will always be kwargs
+ # will always be there
annotation = getattr(
model.model_fields.get("extra_params", None), "annotation", None
)
- if annotation:
- # When there is no annotation there is nothing to warn
- valid = asdict(annotation()) if is_dataclass(annotation) else {} # type: ignore
- provider = provider_choices.get("provider", None)
+ if is_dataclass(annotation) and any(
+ t is ExtraParams for t in getattr(annotation, "__bases__", [])
+ ):
+ valid = asdict(annotation()) # type: ignore
for p in extra_params:
- if field := valid.get(p):
- if provider:
- providers = (
- field.title
- if isinstance(field, Query) and isinstance(field.title, str)
- else ""
- ).split(",")
- if provider not in providers:
- warn(
- message=f"Parameter '{p}' is not supported by '{provider}'."
- f" Available for: {', '.join(providers)}.",
- category=OpenBBWarning,
- )
- else:
+ if "chart_params" in p:
+ continue
+ if p not in valid:
warn(
message=f"Parameter '{p}' not found.",
category=OpenBBWarning,
@@ -216,7 +206,12 @@ class ParametersBuilder:
@staticmethod
def _as_dict(obj: Any) -> Dict[str, Any]:
"""Safely convert an object to a dict."""
- return asdict(obj) if is_dataclass(obj) else dict(obj)
+ try:
+ if isinstance(obj, dict):
+ return obj
+ return asdict(obj) if is_dataclass(obj) else dict(obj)
+ except Exception:
+ return {}
@staticmethod
def validate_kwargs(
@@ -227,27 +222,27 @@ class ParametersBuilder:
sig = signature(func)
fields = {
n: (
- p.annotation,
+ Any if p.annotation is Parameter.empty else p.annotation,
... if p.default is Parameter.empty else p.default,
)
for n, p in sig.parameters.items()
}
# We allow extra fields to return with model with 'cc: CommandContext'
config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
- ValidationModel = create_model(func.__name__, __config__=config, **fields) # type: ignore
+ ValidationModel = create_model(func.__name__, __config__=config, **fields) # type: ignore # pylint: disable=C0103
# Validate and coerce
model = ValidationModel(**kwargs)
ParametersBuilder._warn_kwargs(
- ParametersBuilder._as_dict(kwargs.get("provider_choices", {})),
ParametersBuilder._as_dict(kwargs.get("extra_params", {})),
ValidationModel,
)
return dict(model)
+ # pylint: disable=R0913
@classmethod
def build(
cls,
- args: Tuple[Any],
+ args: Tuple[Any, ...],
execution_context: ExecutionContext,
func: Callable,
route: str,
@@ -284,6 +279,7 @@ class ParametersBuilder:
return kwargs
+# pylint: disable=too-few-public-methods
class StaticCommandRunner:
"""Static Command Runner."""
@@ -292,6 +288,7 @@ class StaticCommandRunner:
cls,
func: Callable,
kwargs: Dict[str, Any],
+ show_warnings: bool = True, # pylint: disable=unused-argument # type: ignore
) -> OBBject:
"""Run a command and return the output."""
obbject = await maybe_coroutine(func, **kwargs)
@@ -311,13 +308,31 @@ class StaticCommandRunner:
raise OpenBBError(
"Charting is not installed. Please install `openbb-charting`."
)
- obbject.charting.show(render=False, **kwargs) # type: ignore
+ chart_params = {}
+ extra_params = kwargs.get("extra_params", {})
+
+ if hasattr(extra_params, "__dict__") and hasattr(extra_params, "chart_params"):
+ chart_params = kwargs["extra_params"].__dict__.get("chart_params", {})
+ elif isinstance(extra_params, dict) and "chart_params" in extra_params:
+ chart_params = kwargs["extra_params"].get("chart_params", {})
+
+ if "chart_params" in kwargs:
+ chart_params.update(kwargs.pop("chart_params", {}))
+ if "kwargs" in kwargs:
+ chart_params.update(kwargs.pop("kwargs", {}).get("chart_params", {}))
+
+ if chart_params:
+ kwargs.update(chart_params)
+
+ obbject.charting.show(render=False, **kwargs)
+
+ # pylint: disable=R0913, R0914
@classmethod
async def _execute_func(
cls,
route: str,
- args: Tuple[Any],
+ args: Tuple[Any, ...],
execution_context: ExecutionContext,
func: Callable,
kwargs: Dict[str, Any],
@@ -392,6 +407,7 @@ class StaticCommandRunner:
)
return obbject
+ # pylint: disable=W0718
@classmethod
async def run(
cls,