diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/app/command_runner.py')
-rw-r--r-- | openbb_platform/core/openbb_core/app/command_runner.py | 78 |
1 files changed, 47 insertions, 31 deletions
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py index c8d7f43e377..a8c368839f8 100644 --- a/openbb_platform/core/openbb_core/app/command_runner.py +++ b/openbb_platform/core/openbb_core/app/command_runner.py @@ -1,15 +1,16 @@ """Command runner module.""" +# pylint: disable=R0903 + from copy import deepcopy from dataclasses import asdict, is_dataclass from datetime import datetime from inspect import Parameter, signature from sys import exc_info from time import perf_counter_ns -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from warnings import catch_warnings, showwarning, warn -from fastapi.params import Query from pydantic import BaseModel, ConfigDict, create_model from openbb_core.app.logs.logging_service import LoggingService @@ -20,7 +21,7 @@ from openbb_core.app.model.metadata import Metadata from openbb_core.app.model.obbject import OBBject from openbb_core.app.model.system_settings import SystemSettings from openbb_core.app.model.user_settings import UserSettings -from openbb_core.app.provider_interface import ProviderInterface +from openbb_core.app.provider_interface import ExtraParams, ProviderInterface from openbb_core.app.router import CommandMap from openbb_core.app.service.system_service import SystemService from openbb_core.app.service.user_service import UserService @@ -38,6 +39,7 @@ class ExecutionContext: system_settings: SystemSettings, user_settings: UserSettings, ) -> None: + """Initialize the execution context.""" self.command_map = command_map self.route = route self.system_settings = system_settings @@ -179,35 +181,23 @@ class ParametersBuilder: @staticmethod def _warn_kwargs( - provider_choices: Dict[str, Any], extra_params: Dict[str, Any], - model: BaseModel, + model: Type[BaseModel], ) -> None: """Warn if kwargs received and ignored by the validation model.""" # We only check the extra_params annotation because ignored fields - # will always be kwargs + # will always be there annotation = getattr( model.model_fields.get("extra_params", None), "annotation", None ) - if annotation: - # When there is no annotation there is nothing to warn - valid = asdict(annotation()) if is_dataclass(annotation) else {} # type: ignore - provider = provider_choices.get("provider", None) + if is_dataclass(annotation) and any( + t is ExtraParams for t in getattr(annotation, "__bases__", []) + ): + valid = asdict(annotation()) # type: ignore for p in extra_params: - if field := valid.get(p): - if provider: - providers = ( - field.title - if isinstance(field, Query) and isinstance(field.title, str) - else "" - ).split(",") - if provider not in providers: - warn( - message=f"Parameter '{p}' is not supported by '{provider}'." - f" Available for: {', '.join(providers)}.", - category=OpenBBWarning, - ) - else: + if "chart_params" in p: + continue + if p not in valid: warn( message=f"Parameter '{p}' not found.", category=OpenBBWarning, @@ -216,7 +206,12 @@ class ParametersBuilder: @staticmethod def _as_dict(obj: Any) -> Dict[str, Any]: """Safely convert an object to a dict.""" - return asdict(obj) if is_dataclass(obj) else dict(obj) + try: + if isinstance(obj, dict): + return obj + return asdict(obj) if is_dataclass(obj) else dict(obj) + except Exception: + return {} @staticmethod def validate_kwargs( @@ -227,27 +222,27 @@ class ParametersBuilder: sig = signature(func) fields = { n: ( - p.annotation, + Any if p.annotation is Parameter.empty else p.annotation, ... if p.default is Parameter.empty else p.default, ) for n, p in sig.parameters.items() } # We allow extra fields to return with model with 'cc: CommandContext' config = ConfigDict(extra="allow", arbitrary_types_allowed=True) - ValidationModel = create_model(func.__name__, __config__=config, **fields) # type: ignore + ValidationModel = create_model(func.__name__, __config__=config, **fields) # type: ignore # pylint: disable=C0103 # Validate and coerce model = ValidationModel(**kwargs) ParametersBuilder._warn_kwargs( - ParametersBuilder._as_dict(kwargs.get("provider_choices", {})), ParametersBuilder._as_dict(kwargs.get("extra_params", {})), ValidationModel, ) return dict(model) + # pylint: disable=R0913 @classmethod def build( cls, - args: Tuple[Any], + args: Tuple[Any, ...], execution_context: ExecutionContext, func: Callable, route: str, @@ -284,6 +279,7 @@ class ParametersBuilder: return kwargs +# pylint: disable=too-few-public-methods class StaticCommandRunner: """Static Command Runner.""" @@ -292,6 +288,7 @@ class StaticCommandRunner: cls, func: Callable, kwargs: Dict[str, Any], + show_warnings: bool = True, # pylint: disable=unused-argument # type: ignore ) -> OBBject: """Run a command and return the output.""" obbject = await maybe_coroutine(func, **kwargs) @@ -311,13 +308,31 @@ class StaticCommandRunner: raise OpenBBError( "Charting is not installed. Please install `openbb-charting`." ) - obbject.charting.show(render=False, **kwargs) # type: ignore + chart_params = {} + extra_params = kwargs.get("extra_params", {}) + + if hasattr(extra_params, "__dict__") and hasattr(extra_params, "chart_params"): + chart_params = kwargs["extra_params"].__dict__.get("chart_params", {}) + elif isinstance(extra_params, dict) and "chart_params" in extra_params: + chart_params = kwargs["extra_params"].get("chart_params", {}) + + if "chart_params" in kwargs: + chart_params.update(kwargs.pop("chart_params", {})) + if "kwargs" in kwargs: + chart_params.update(kwargs.pop("kwargs", {}).get("chart_params", {})) + + if chart_params: + kwargs.update(chart_params) + + obbject.charting.show(render=False, **kwargs) + + # pylint: disable=R0913, R0914 @classmethod async def _execute_func( cls, route: str, - args: Tuple[Any], + args: Tuple[Any, ...], execution_context: ExecutionContext, func: Callable, kwargs: Dict[str, Any], @@ -392,6 +407,7 @@ class StaticCommandRunner: ) return obbject + # pylint: disable=W0718 @classmethod async def run( cls, |