summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDanglewood <85772166+deeleeramone@users.noreply.github.com>2024-06-13 01:12:55 -0700
committerGitHub <noreply@github.com>2024-06-13 08:12:55 +0000
commit8b9f461dcaffabdd4609d942795ba9de8e261a2d (patch)
tree846aa89248e2979ac24f7ba1844c0ffb0e5e6a13
parent99d2256e0feb812bd343e9a6c1899b268f8424f0 (diff)
[Enhancement] Convert Params Models To Dictionary Before Assigning As Private Attribute In OBBject. (#6492)
* convert params models to dict * update cli test * fix test_static_command_runner_chart * mock_obbject results * minor linting adjustments * review changes * charting test mock obbject --------- Co-authored-by: Henrique Joaquim <henriquecjoaquim@gmail.com>
-rw-r--r--cli/openbb_cli/argparse_translator/obbject_registry.py8
-rw-r--r--cli/tests/test_argparse_translator_obbject_registry.py2
-rw-r--r--openbb_platform/core/openbb_core/app/command_runner.py39
-rw-r--r--openbb_platform/core/openbb_core/app/model/obbject.py3
-rw-r--r--openbb_platform/core/tests/app/test_command_runner.py12
-rw-r--r--openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py53
-rw-r--r--openbb_platform/obbject_extensions/charting/tests/test_charting.py11
7 files changed, 73 insertions, 55 deletions
diff --git a/cli/openbb_cli/argparse_translator/obbject_registry.py b/cli/openbb_cli/argparse_translator/obbject_registry.py
index 372254b4b54..aa0f876942b 100644
--- a/cli/openbb_cli/argparse_translator/obbject_registry.py
+++ b/cli/openbb_cli/argparse_translator/obbject_registry.py
@@ -51,10 +51,12 @@ class Registry:
def _handle_standard_params(obbject: OBBject) -> str:
"""Handle standard params for obbjects"""
standard_params_json = ""
- std_params = obbject._standard_params # pylint: disable=protected-access
- if hasattr(std_params, "__dict__"):
+ std_params = getattr(
+ obbject, "_standard_params", {}
+ ) # pylint: disable=protected-access
+ if std_params:
standard_params = {
- k: str(v)[:30] for k, v in std_params.__dict__.items() if v
+ k: str(v)[:30] for k, v in std_params.items() if v and k != "data"
}
standard_params_json = json.dumps(standard_params)
diff --git a/cli/tests/test_argparse_translator_obbject_registry.py b/cli/tests/test_argparse_translator_obbject_registry.py
index a37e5a33541..53a4a1cea80 100644
--- a/cli/tests/test_argparse_translator_obbject_registry.py
+++ b/cli/tests/test_argparse_translator_obbject_registry.py
@@ -35,7 +35,7 @@ def mock_obbject():
obb.extra = {"command": "test_command"}
obb._route = "/test/route"
obb._standard_params = Mock()
- obb._standard_params.__dict__ = {}
+ obb._standard_params = {}
obb.results = [MockModel(1), MockModel(2)]
return obb
diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py
index e3dce19f465..edfe22436cc 100644
--- a/openbb_platform/core/openbb_core/app/command_runner.py
+++ b/openbb_platform/core/openbb_core/app/command_runner.py
@@ -237,23 +237,20 @@ class StaticCommandRunner:
raise OpenBBError(
"Charting is not installed. Please install `openbb-charting`."
)
+ # Here we will pop the chart_params kwargs and flatten them into the kwargs.
chart_params = {}
- extra_params = kwargs.get("extra_params", {})
+ extra_params = getattr(obbject, "_extra_params", {})
- if hasattr(extra_params, "__dict__") and hasattr(
- extra_params, "chart_params"
- ):
- chart_params = kwargs["extra_params"].__dict__.get("chart_params", {})
- elif isinstance(extra_params, dict) and "chart_params" in extra_params:
- chart_params = kwargs["extra_params"].get("chart_params", {})
+ if extra_params and "chart_params" in extra_params:
+ chart_params = extra_params.get("chart_params", {})
- if "chart_params" in kwargs and kwargs["chart_params"] is not None:
+ if kwargs.get("chart_params"):
chart_params.update(kwargs.pop("chart_params", {}))
-
+ # Verify that kwargs is not nested as kwargs so we don't miss any chart params.
if (
"kwargs" in kwargs
and "chart_params" in kwargs["kwargs"]
- and kwargs["kwargs"].get("chart_params") is not None
+ and kwargs["kwargs"].get("chart_params")
):
chart_params.update(kwargs.pop("kwargs", {}).get("chart_params", {}))
@@ -265,6 +262,14 @@ class StaticCommandRunner:
raise OpenBBError(e) from e
warn(str(e), OpenBBWarning)
+ @classmethod
+ def _extract_params(cls, kwargs, key) -> Dict:
+ """Extract params models from kwargs and convert to a dictionary."""
+ params = kwargs.get(key, {})
+ if hasattr(params, "__dict__"):
+ return params.__dict__
+ return params
+
# pylint: disable=R0913, R0914
@classmethod
async def _execute_func(
@@ -308,9 +313,17 @@ class StaticCommandRunner:
try:
obbject = await cls._command(func, kwargs)
- # pylint: disable=protected-access
- obbject._route = route
- obbject._standard_params = kwargs.get("standard_params", None)
+
+ # This section prepares the obbject to pass to the charting service.
+ obbject._route = route # pylint: disable=protected-access
+ std_params = cls._extract_params(kwargs, "standard_params") or (
+ kwargs if "data" in kwargs else {}
+ )
+ extra_params = cls._extract_params(kwargs, "extra_params")
+ obbject._standard_params = ( # pylint: disable=protected-access
+ std_params
+ )
+ obbject._extra_params = extra_params # pylint: disable=protected-access
if chart and obbject.results:
cls._chart(obbject, **kwargs)
finally:
diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py
index 67f41e9d15d..75078ff2919 100644
--- a/openbb_platform/core/openbb_core/app/model/obbject.py
+++ b/openbb_platform/core/openbb_core/app/model/obbject.py
@@ -72,6 +72,9 @@ class OBBject(Tagged, Generic[T]):
_standard_params: Optional[Dict[str, Any]] = PrivateAttr(
default_factory=dict,
)
+ _standard_params: Optional[Dict[str, Any]] = PrivateAttr(
+ default_factory=dict,
+ )
def __repr__(self) -> str:
"""Human readable representation of the object."""
diff --git a/openbb_platform/core/tests/app/test_command_runner.py b/openbb_platform/core/tests/app/test_command_runner.py
index 41205ca75e5..7c20059c100 100644
--- a/openbb_platform/core/tests/app/test_command_runner.py
+++ b/openbb_platform/core/tests/app/test_command_runner.py
@@ -16,6 +16,7 @@ from openbb_core.app.command_runner import (
)
from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
+from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.provider_interface import ExtraParams
@@ -364,8 +365,15 @@ async def test_static_command_runner_execute_func(
def test_static_command_runner_chart():
"""Test _chart method when charting is in obbject.accessors."""
- mock_obbject = Mock()
- mock_obbject.accessors = ["charting"]
+
+ mock_obbject = OBBject(
+ results=[
+ {"date": "1990", "value": 100},
+ {"date": "1991", "value": 200},
+ {"date": "1992", "value": 300},
+ ],
+ accessors={"charting": Mock()},
+ )
mock_obbject.charting.show = Mock()
StaticCommandRunner._chart(mock_obbject) # pylint: disable=protected-access
diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py b/openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py
index 7560b6fb369..d20908557d4 100644
--- a/openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py
+++ b/openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py
@@ -325,23 +325,20 @@ class Charting:
charting_function = self._get_chart_function(
self._obbject._route # pylint: disable=protected-access
)
- kwargs["obbject_item"] = self._obbject.results
- kwargs["charting_settings"] = self._charting_settings
- if (
- hasattr(self._obbject, "_standard_params")
- and self._obbject._standard_params # pylint: disable=protected-access
- ):
- kwargs["standard_params"] = (
- self._obbject._standard_params.__dict__ # pylint: disable=protected-access
- )
+ kwargs["obbject_item"] = self._obbject # pylint: disable=protected-access
+ kwargs["charting_settings"] = (
+ self._charting_settings
+ ) # pylint: disable=protected-access
+ kwargs["standard_params"] = (
+ self._obbject._standard_params
+ ) # pylint: disable=protected-access
+ kwargs["extra_params"] = (
+ self._obbject._extra_params
+ ) # pylint: disable=protected-access
kwargs["provider"] = (
self._obbject.provider
) # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access
-
- if "kwargs" in kwargs:
- _kwargs = kwargs.pop("kwargs")
- kwargs.update(_kwargs.get("chart_params", {}))
fig, content = charting_function(**kwargs)
fig = self._set_chart_style(fig)
content = fig.show(external=True, **kwargs).to_plotly_json()
@@ -448,24 +445,18 @@ class Charting:
kwargs["symbol"] = symbol
kwargs["target"] = target
kwargs["index"] = index
- kwargs["obbject_item"] = self._obbject.results
- kwargs["charting_settings"] = self._charting_settings
- if (
- hasattr(self._obbject, "_standard_params")
- and self._obbject._standard_params # pylint: disable=protected-access
- ):
- kwargs["standard_params"] = (
- self._obbject._standard_params.__dict__ # pylint: disable=protected-access
- )
+ kwargs["obbject_item"] = self._obbject # pylint: disable=protected-access
+ kwargs["charting_settings"] = (
+ self._charting_settings
+ ) # pylint: disable=protected-access
+ kwargs["standard_params"] = (
+ self._obbject._standard_params
+ ) # pylint: disable=protected-access
+ kwargs["extra_params"] = (
+ self._obbject._extra_params
+ ) # pylint: disable=protected-access
kwargs["provider"] = self._obbject.provider # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access
- metadata = kwargs["extra"].get("metadata")
- kwargs["extra_params"] = (
- metadata.arguments.get("extra_params") if metadata else None
- )
- if "kwargs" in kwargs:
- _kwargs = kwargs.pop("kwargs")
- kwargs.update(_kwargs.get("chart_params", {}))
try:
if has_data:
self.show(data=data_as_df, render=render, **kwargs)
@@ -488,7 +479,7 @@ class Charting:
def _set_chart_style(self, figure: Figure):
"""Set the user preference for light or dark mode."""
- style = self._charting_settings.chart_style # pylint: disable=protected-access
+ style = self._charting_settings.chart_style
font_color = "black" if style == "light" else "white"
paper_bgcolor = "white" if style == "light" else "black"
figure = figure.update_layout(
@@ -498,7 +489,7 @@ class Charting:
)
return figure
- def toggle_chart_style(self): # pylint: disable=protected-access
+ def toggle_chart_style(self):
"""Toggle the chart style between light and dark mode."""
if not hasattr(self._obbject.chart, "fig"):
raise ValueError(
diff --git a/openbb_platform/obbject_extensions/charting/tests/test_charting.py b/openbb_platform/obbject_extensions/charting/tests/test_charting.py
index 5849386987f..9e148c72808 100644
--- a/openbb_platform/obbject_extensions/charting/tests/test_charting.py
+++ b/openbb_platform/obbject_extensions/charting/tests/test_charting.py
@@ -24,7 +24,7 @@ mock_dataframe = MockDataframe()
@pytest.fixture()
def obbject():
- """Mock OOBject."""
+ """Mock OBBject."""
class MockStdParams(BaseModel):
"""Mock Standard Parameters."""
@@ -32,17 +32,18 @@ def obbject():
param1: str
param2: str
- class MockOOBject:
- """Mock OOBject."""
+ class MockOBBject:
+ """Mock OBBject."""
def __init__(self):
- """Mock OOBject."""
+ """Mock OBBject."""
self._user_settings = UserSettings()
self._system_settings = SystemSettings()
self._route = "mock/route"
self._standard_params = MockStdParams(
param1="mock_param1", param2="mock_param2"
)
+ self._extra_params = {}
self.results = "mock_results"
self.provider = "mock_provider"
@@ -54,7 +55,7 @@ def obbject():
"""Mock to_dataframe."""
return mock_dataframe
- return MockOOBject()
+ return MockOBBject()
def test_charting_settings(obbject):