summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormontezdesousa <79287829+montezdesousa@users.noreply.github.com>2024-03-28 15:37:28 +0000
committerGitHub <noreply@github.com>2024-03-28 15:37:28 +0000
commitce52ef33ec25dd97a6623b1fe22b5df9f1fe224e (patch)
tree7b8977a0ec212de4ddc0715147d06877949e8f92
parent27a6fbbf1694d3f737b0fa040817c9e68ce03343 (diff)
[Bugfix] - Test parametrize skips charting tests (#6264)
* don't skipt charting tests * multiples does not have a charting implementation * fix list index out of range * mypy * mypy * mypy ta class + base * indicators mypy * sync precommit with ci * sync ci with precommit * revert * revert * mypy indicators * fix mypy * fix hastype * type ignore * mypy * mypy * mypy * does this work? * remove post inits * this * this * and this --------- Co-authored-by: hjoaquim <h.joaquim@campus.fct.unl.pt>
-rw-r--r--.github/workflows/linting.yml2
-rw-r--r--.pre-commit-config.yaml2
-rw-r--r--openbb_platform/extensions/tests/conftest.py8
-rw-r--r--openbb_platform/obbject_extensions/charting/integration/test_charting_api.py23
-rw-r--r--openbb_platform/obbject_extensions/charting/openbb_charting/charting_router.py24
-rw-r--r--openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py2
-rw-r--r--openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/base.py24
-rw-r--r--openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/data_classes.py56
-rw-r--r--openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/ta_class.py12
-rw-r--r--openbb_terminal/core/plots/plotly_ta/ta_class.py2
10 files changed, 75 insertions, 80 deletions
diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml
index ccabe3a349d..ff98d8e8f6e 100644
--- a/.github/workflows/linting.yml
+++ b/.github/workflows/linting.yml
@@ -69,7 +69,7 @@ jobs:
# Run linters for openbb_platform
if [ -n "${{ env.platform_files }}" ]; then
pylint ${{ env.platform_files }}
- mypy ${{ env.platform_files }} --ignore-missing-imports --check-untyped-defs
+ mypy ${{ env.platform_files }} --ignore-missing-imports --scripts-are-modules --check-untyped-defs
else
echo "No Python files changed in openbb_platform"
fi
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 60e62785b57..da3cfd4d3e9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -50,7 +50,7 @@ repos:
entry: mypy
language: python
"types_or": [python, pyi]
- args: ["--ignore-missing-imports", "--scripts-are-modules"]
+ args: ["--ignore-missing-imports", "--scripts-are-modules", "--check-untyped-defs"]
require_serial: true
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
diff --git a/openbb_platform/extensions/tests/conftest.py b/openbb_platform/extensions/tests/conftest.py
index 7e3fbce9dc2..b8849169816 100644
--- a/openbb_platform/extensions/tests/conftest.py
+++ b/openbb_platform/extensions/tests/conftest.py
@@ -1,6 +1,6 @@
"""Custom pytest configuration for the extensions."""
-from typing import Any, Dict, List
+from typing import Dict, List
import pytest
from openbb_core.app.router import CommandMap
@@ -13,7 +13,7 @@ commands = list(cm.map.keys())
# ruff: noqa: SIM114
-def parametrize(argnames: str, argvalues: List[Dict[str, Any]], **kwargs):
+def parametrize(argnames: str, argvalues: List, **kwargs):
"""Custom parametrize decorator that filters test cases based on the environment."""
routers, providers, obbject_ext = list_openbb_extensions()
@@ -49,6 +49,10 @@ def parametrize(argnames: str, argvalues: List[Dict[str, Any]], **kwargs):
elif "provider" not in args and function_name_v3 in commands:
# Handle edge case
filtered_argvalues.append(args)
+ elif extension_name in obbject_ext:
+ filtered_argvalues.append(args)
+
+ # If filtered_argvalues is empty, pytest will skip the test!
return pytest.mark.parametrize(argnames, filtered_argvalues, **kwargs)(
function
)
diff --git a/openbb_platform/obbject_extensions/charting/integration/test_charting_api.py b/openbb_platform/obbject_extensions/charting/integration/test_charting_api.py
index cacfac05daa..f89c1b318bf 100644
--- a/openbb_platform/obbject_extensions/charting/integration/test_charting_api.py
+++ b/openbb_platform/obbject_extensions/charting/integration/test_charting_api.py
@@ -77,29 +77,6 @@ def test_charting_equity_price_historical(params, headers):
@parametrize(
"params",
- [({"symbol": "AAPL", "limit": 100, "chart": True})],
-)
-@pytest.mark.integration
-def test_charting_equity_fundamental_multiples(params, headers):
- """Test chart equity multiples."""
- params = {p: v for p, v in params.items() if v}
-
- query_str = get_querystring(params, [])
- url = f"http://0.0.0.0:8000/api/v1/equity/fundamental/multiples?{query_str}"
- result = requests.get(url, headers=headers, timeout=10)
- assert isinstance(result, requests.Response)
- assert result.status_code == 200
-
- chart = result.json()["chart"]
- fig = chart.pop("fig", {})
-
- assert chart
- assert not fig
- assert list(chart.keys()) == ["content", "format"]
-
-
-@parametrize(
- "params",
[
(
{
diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/charting_router.py b/openbb_platform/obbject_extensions/charting/openbb_charting/charting_router.py
index 44b92a2c6c5..08c27d8c12c 100644
--- a/openbb_platform/obbject_extensions/charting/openbb_charting/charting_router.py
+++ b/openbb_platform/obbject_extensions/charting/openbb_charting/charting_router.py
@@ -1,7 +1,7 @@
"""Charting router."""
import json
-from typing import Any, Dict, Tuple
+from typing import Any, Dict, Optional, Tuple, Union
import pandas as pd
from openbb_core.app.model.charts.chart import ChartFormat
@@ -188,7 +188,6 @@ def technical_cones(
**kwargs: TechnicalConesChartQueryParams,
) -> Tuple["OpenBBFigure", Dict[str, Any]]:
"""Volatility Cones Chart."""
-
data = kwargs.get("data")
if isinstance(data, pd.DataFrame) and not data.empty and "window" in data.columns:
@@ -286,10 +285,9 @@ def technical_cones(
def economy_fred_series(
- **kwargs: FredSeriesChartQueryParams,
+ **kwargs: Union[Any, FredSeriesChartQueryParams],
) -> Tuple["OpenBBFigure", Dict[str, Any]]:
"""FRED Series Chart."""
-
ytitle_dict = {
"chg": "Change",
"ch1": "Change From Year Ago",
@@ -385,12 +383,9 @@ def economy_fred_series(
+ " Override this error by setting `allow_unsafe = True`."
)
- y1_units = y_units[0]
-
+ y1_units = y_units[0] if y_units else None
y1title = y1_units
-
y2title = y_units[1] if len(y_units) > 1 else None
-
xtitle = ""
# If the request was transformed, the y-axis will be shared under these conditions.
@@ -401,8 +396,9 @@ def economy_fred_series(
y2title = None
# Set the title for the chart.
- if kwargs.get("title"):
- title = kwargs.get("title")
+ title: str = ""
+ if isinstance(kwargs, dict) and title in kwargs:
+ title = kwargs["title"]
else:
if metadata.get(columns[0]):
title = metadata.get(columns[0]).get("title") if len(columns) == 1 else "FRED Series" # type: ignore
@@ -412,7 +408,7 @@ def economy_fred_series(
title = f"{title} - {transform_title}" if transform_title else title
# Define this to use as a check.
- y3title = ""
+ y3title: Optional[str] = ""
# Create the figure object with subplots.
fig = OpenBBFigure().create_subplots(
@@ -456,14 +452,14 @@ def economy_fred_series(
if kwargs.get("y2title") and y2title is not None:
y2title = kwargs.get("y2title")
# Set the x-axis title, if suppiled.
- if kwargs.get("xtitle"):
- xtitle = kwargs.get("xtitle")
+ if isinstance(kwargs, dict) and "xtitle" in kwargs:
+ xtitle = kwargs["xtitle"]
# If the data was normalized, set the title to reflect this.
if normalize:
y1title = None
y2title = None
y3title = None
- title = f"{title} - Normalized"
+ title = f"{title} - Normalized" if title else "Normalized"
# Now update the layout of the complete figure.
fig.update_layout(
diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py b/openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py
index 457df7d5981..2307f3993fc 100644
--- a/openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py
+++ b/openbb_platform/obbject_extensions/charting/openbb_charting/core/backend.py
@@ -221,7 +221,7 @@ class Backend(PyWry):
self.send_outgoing(outgoing)
if export_image and isinstance(export_image, Path):
- if self.loop.is_closed():
+ if self.loop.is_closed(): # type: ignore[has-type]
# Create a new event loop
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/base.py b/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/base.py
index 126f3b195fc..1d0ec43eb51 100644
--- a/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/base.py
+++ b/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/base.py
@@ -1,4 +1,6 @@
-from typing import Any, Callable, Dict, Iterator, List, Optional, Type
+"""Base class for charting plugins."""
+
+from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union
import pandas as pd
@@ -6,7 +8,7 @@ from .data_classes import ChartIndicators, TAIndicator
def columns_regex(df_ta: pd.DataFrame, name: str) -> List[str]:
- """Return columns that match regex name"""
+ """Return columns that match regex name."""
column_name = df_ta.filter(regex=rf"{name}(?=[^\d]|$)").columns.tolist()
return column_name
@@ -26,6 +28,7 @@ class Indicator:
self.attrs = attrs
def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ """Call the indicator function."""
return self.func(*args, **kwargs)
@@ -39,6 +42,7 @@ class PluginMeta(type):
__subplots__: List[str] = []
def __new__(mcs: Type["PluginMeta"], *args: Any, **kwargs: Any) -> "PluginMeta":
+ """Create a new instance of the class."""
name, bases, attrs = args
indicators: Dict[str, Indicator] = {}
cls_attrs: Dict[str, list] = {
@@ -76,6 +80,7 @@ class PluginMeta(type):
return new_cls
def __iter__(cls: Type["PluginMeta"]) -> Iterator[Indicator]: # type: ignore
+ """Iterate over the indicators."""
return iter(cls.__indicators__)
# pylint: disable=unused-argument
@@ -88,11 +93,11 @@ class PltTA(metaclass=PluginMeta):
indicators: ChartIndicators
intraday: bool = False
- df_stock: pd.DataFrame
- df_ta: pd.DataFrame
+ df_stock: Union[pd.DataFrame, pd.Series]
+ df_ta: Optional[pd.DataFrame] = None
df_fib: pd.DataFrame
close_column: Optional[str] = "close"
- params: Dict[str, TAIndicator] = {}
+ params: Optional[Dict[str, TAIndicator]] = {}
inchart_colors: List[str] = []
show_volume: bool = True
@@ -104,6 +109,7 @@ class PltTA(metaclass=PluginMeta):
# pylint: disable=unused-argument
def __new__(cls, *args: Any, **kwargs: Any) -> "PltTA":
+ """Create a new instance of the class."""
if cls is PltTA:
raise TypeError("Can't instantiate abstract class Plugin directly")
self = super().__new__(cls)
@@ -132,6 +138,7 @@ class PltTA(metaclass=PluginMeta):
@property
def ma_mode(self) -> List[str]:
+ """Moving average mode."""
return list(set(self.__ma_mode__))
@ma_mode.setter
@@ -139,7 +146,7 @@ class PltTA(metaclass=PluginMeta):
self.__ma_mode__ = value
def add_plugins(self, plugins: List["PltTA"]) -> None:
- """Add plugins to current instance"""
+ """Add plugins to current instance."""
for plugin in plugins:
for item in plugin.__indicators__:
# pylint: disable=unnecessary-dunder-call
@@ -161,7 +168,7 @@ class PltTA(metaclass=PluginMeta):
getattr(self, attr).extend(value)
def remove_plugins(self, plugins: List["PltTA"]) -> None:
- """Remove plugins from current instance"""
+ """Remove plugins from current instance."""
for plugin in plugins:
for item in plugin.__indicators__:
delattr(self, item.name)
@@ -171,10 +178,11 @@ class PltTA(metaclass=PluginMeta):
delattr(self, static_method)
def __iter__(self) -> Iterator[Indicator]:
+ """Iterate over the indicators."""
return iter(self.__indicators__)
def get_float_precision(self) -> str:
- """Returns f-string precision format"""
+ """Returns f-string precision format."""
price = self.df_stock[self.close_column].tail(1).values[0]
float_precision = (
",.2f" if price > 1.10 else "" if len(str(price)) < 8 else ".6f"
diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/data_classes.py b/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/data_classes.py
index aa0223bfdc9..43836fc6c46 100644
--- a/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/data_classes.py
+++ b/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/data_classes.py
@@ -67,10 +67,6 @@ class TAIndicator:
]
args: List[Arguments]
- def __post_init__(self):
- """Post init."""
- self.args = [Arguments(**arg) for arg in self.args]
-
def __iter__(self):
"""Return iterator."""
return iter(self.args)
@@ -98,14 +94,6 @@ class ChartIndicators:
indicators: Optional[List[TAIndicator]] = None
- def __post_init__(self):
- """Post init."""
- self.indicators = (
- [TAIndicator(**indicator) for indicator in self.indicators]
- if self.indicators
- else []
- )
-
def get_indicator(self, name: str) -> Union[TAIndicator, None]:
"""Return indicator with given name."""
output = None
@@ -165,21 +153,43 @@ class ChartIndicators:
@staticmethod
def get_available_indicators() -> Tuple[str, ...]:
"""Return tuple of available indicators."""
- return list(
+ return tuple(
TAIndicator.__annotations__["name"].__args__ # pylint: disable=E1101
)
@classmethod
- def from_dict(cls, indicators: Dict[str, Dict[str, Any]]) -> "ChartIndicators":
- """Return ChartIndicators from dictionary."""
- data = []
- for indicator in indicators:
- args = []
- for arg in indicators[indicator]:
- args.append({"label": arg, "values": indicators[indicator][arg]})
- data.append({"name": indicator, "args": args})
-
- return cls(indicators=data) # type: ignore
+ def from_dict(
+ cls, indicators: Dict[str, Dict[str, List[Dict[str, Any]]]]
+ ) -> "ChartIndicators":
+ """Return ChartIndicators from dictionary.
+
+ Example
+ -------
+ ChartIndicators.from_dict(
+ {
+ "ad": {
+ "args": [
+ {
+ "label": "AD_LABEL",
+ "values": [1, 2, 3],
+ }
+ ]
+ }
+ }
+ )
+ """
+ return cls(
+ indicators=[
+ TAIndicator(
+ name=name, # type: ignore[arg-type]
+ args=[
+ Arguments(label=label, values=values)
+ for label, values in args.items()
+ ],
+ )
+ for name, args in indicators.items()
+ ]
+ )
def to_dataframe(
self, df_ta: pd.DataFrame, ma_mode: Optional[List[str]] = None
diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/ta_class.py b/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/ta_class.py
index ebe12307ed3..46b2af6cd6b 100644
--- a/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/ta_class.py
+++ b/openbb_platform/obbject_extensions/charting/openbb_charting/core/plotly_ta/ta_class.py
@@ -88,7 +88,7 @@ class PlotlyTA(PltTA):
inchart_colors: List[str] = []
plugins: List[Type[PltTA]] = []
- df_ta: pd.DataFrame = None
+ df_ta: Optional[pd.DataFrame] = None
close_column: Optional[str] = "close"
has_volume: bool = True
show_volume: bool = True
@@ -112,11 +112,11 @@ class PlotlyTA(PltTA):
# Creates the instance of the class and loads the plugins
# We set the global variable to the instance of the class so that
# the plugins are only loaded once
- PLOTLY_TA = super().__new__(cls)
- PLOTLY_TA._locate_plugins(
+ PLOTLY_TA = super().__new__(cls) # type: ignore[attr-defined, assignment]
+ PLOTLY_TA._locate_plugins( # type: ignore[attr-defined]
getattr(cls.charting_settings, "debug_mode", False)
)
- PLOTLY_TA.add_plugins(PLOTLY_TA.plugins)
+ PLOTLY_TA.add_plugins(PLOTLY_TA.plugins) # type: ignore[attr-defined, assignment]
return PLOTLY_TA
@@ -180,7 +180,7 @@ class PlotlyTA(PltTA):
df_stock = df_stock.to_frame()
if not isinstance(indicators, ChartIndicators):
- indicators = ChartIndicators.from_dict(indicators or dict(dict()))
+ indicators = ChartIndicators.from_dict(indicators or {})
# Apply to_datetime to the index in a way that handles daylight savings.
df_stock.loc[:, "date"] = df_stock.index # type: ignore
@@ -289,7 +289,7 @@ class PlotlyTA(PltTA):
def _clear_data(self):
"""Clear and reset all data to default values."""
self.df_stock = None
- self.indicators = {}
+ self.indicators = ChartIndicators.from_dict({})
self.params = None
self.intraday = False
self.show_volume = True
diff --git a/openbb_terminal/core/plots/plotly_ta/ta_class.py b/openbb_terminal/core/plots/plotly_ta/ta_class.py
index 98dba3a8046..0456cec5fa4 100644
--- a/openbb_terminal/core/plots/plotly_ta/ta_class.py
+++ b/openbb_terminal/core/plots/plotly_ta/ta_class.py
@@ -155,7 +155,7 @@ class PlotlyTA(PltTA):
df_stock = df_stock.to_frame()
if not isinstance(indicators, ChartIndicators):
- indicators = ChartIndicators.from_dict(indicators or dict(dict()))
+ indicators = ChartIndicators.from_dict(indicators or {})
self.indicators = indicators
self.intraday = df_stock.index[-2].time() != df_stock.index[-1].time()