diff options
author | Chavithra PARANA <chavithra@gmail.com> | 2022-11-29 19:27:11 +0100 |
---|---|---|
committer | Chavithra PARANA <chavithra@gmail.com> | 2022-11-29 19:27:11 +0100 |
commit | 971c5eba075ab2faf9cb666bfaa6349e33264024 (patch) | |
tree | 7ebfaccc42dd90a3a14db68b9e1b8691d1a0a42b | |
parent | 1d9ce6fb02bc1f5e34fdb8072d0e70262091c054 (diff) | |
parent | a12b788ea9cf382cccb54cf525cc7fc19b6f03d8 (diff) |
Merge branch 'main' of github.com:OpenBB-finance/OpenBBTerminal
34 files changed, 1022 insertions, 944 deletions
diff --git a/openbb_terminal/forecast/expo_view.py b/openbb_terminal/forecast/expo_view.py index 30bf87dc215..8c1f890d561 100644 --- a/openbb_terminal/forecast/expo_view.py +++ b/openbb_terminal/forecast/expo_view.py @@ -11,6 +11,7 @@ import matplotlib.pyplot as plt from openbb_terminal.forecast import expo_model from openbb_terminal.decorators import log_start_end from openbb_terminal.forecast import helpers +from openbb_terminal.rich_config import console logger = logging.getLogger(__name__) # pylint: disable=too-many-arguments @@ -126,6 +127,7 @@ def display_expo_forecast( external_axes=external_axes, ) if residuals: - helpers.plot_residuals( - _model, None, ticker_series, forecast_horizon=forecast_horizon - ) + console.print("[red]Expo model does not support residuals at this time[/red]\n") + # helpers.plot_residuals( + # _model, None, ticker_series, forecast_horizon=forecast_horizon + # ) diff --git a/openbb_terminal/helper_funcs.py b/openbb_terminal/helper_funcs.py index 4a7ee62dd64..f50245cae09 100644 --- a/openbb_terminal/helper_funcs.py +++ b/openbb_terminal/helper_funcs.py @@ -1243,7 +1243,8 @@ def check_file_type_saved(valid_types: List[str] = None): valid_filenames.append(filename) else: console.print( - f"[red]Filename '{filename}' provided is not valid![/red]" + f"[red]Filename '{filename}' provided is not valid!\nPlease use one of the following file types:" + f"{','.join(valid_types)}[/red]\n" ) return ",".join(valid_filenames) diff --git a/openbb_terminal/parent_classes.py b/openbb_terminal/parent_classes.py index 9355f29535f..d85d7d2d4b6 100644 --- a/openbb_terminal/parent_classes.py +++ b/openbb_terminal/parent_classes.py @@ -747,7 +747,6 @@ class BaseController(metaclass=ABCMeta): type=check_file_type_saved(choices_export), dest="export", help=help_export, - choices=choices_export, ) if raw: diff --git a/openbb_terminal/portfolio/portfolio_optimization/optimizer_helper.py b/openbb_terminal/portfolio/portfolio_optimization/optimizer_helper.py index 9f5a502ed41..c97832ce3fb 100644 --- a/openbb_terminal/portfolio/portfolio_optimization/optimizer_helper.py +++ b/openbb_terminal/portfolio/portfolio_optimization/optimizer_helper.py @@ -4,6 +4,9 @@ __docformat__ = "numpy" import argparse import pandas as pd +from openbb_terminal.portfolio.portfolio_optimization import statics +from openbb_terminal.rich_config import console + # These are all the possible yfinance properties valid_property_infos = [ "previousClose", @@ -128,3 +131,23 @@ def dict_to_df(d: dict) -> pd.DataFrame: df = pd.DataFrame.from_dict(data=d, orient="index", columns=["value"]) return df + + +def validate_risk_measure(risk_measure: str, warning: bool = True) -> str: + """Check that the risk measure selected is valid + + Parameters + ---------- + risk_measure : str + Risk measure to check + + Returns + ------- + str + Validated risk measure + """ + if risk_measure.lower() in statics.RISK_CHOICES: + return statics.RISK_CHOICES[risk_measure.lower()] + if warning: + console.print("[yellow]Risk measure not found. Using 'MV'.[/yellow]") + return "MV" diff --git a/openbb_terminal/portfolio/portfolio_optimization/optimizer_model.py b/openbb_terminal/portfolio/portfolio_optimization/optimizer_model.py index 282d267238d..f4a1ec0675d 100644 --- a/openbb_terminal/portfolio/portfolio_optimization/optimizer_model.py +++ b/openbb_terminal/portfolio/portfolio_optimization/optimizer_model.py @@ -449,7 +449,7 @@ def get_mean_risk_portfolio( covariance = kwargs.get("covariance", "hist") d_ewma = kwargs.get("d_ewma", 0.94) - risk_measure = risk_measure.upper() + risk_measure = optimizer_helper.validate_risk_measure(risk_measure) stock_prices = yahoo_finance_model.process_stocks( symbols, interval, start_date, end_date @@ -1496,6 +1496,7 @@ def get_ef( seed = kwargs.get("seed", 123) risk_free_rate = risk_free_rate / time_factor[freq.upper()] + risk_measure = optimizer_helper.validate_risk_measure(risk_measure) stock_prices = yahoo_finance_model.process_stocks( symbols, interval, start_date, end_date @@ -1715,6 +1716,7 @@ def get_risk_parity_portfolio( ) risk_free_rate = risk_free_rate / time_factor[freq.upper()] + risk_measure = optimizer_helper.validate_risk_measure(risk_measure) # Building the portfolio object port = rp.Portfolio(returns=stock_returns, alpha=alpha) @@ -2113,7 +2115,7 @@ def get_hcp_portfolio( alpha_tail = kwargs.get("alpha_tail", 0.05) leaf_order = kwargs.get("leaf_order", True) - risk_measure = risk_measure.upper() + risk_measure = optimizer_helper.validate_risk_measure(risk_measure) stock_prices = yahoo_finance_model.process_stocks( symbols, interval, start_date, end_date diff --git a/openbb_terminal/portfolio/portfolio_optimization/po_model.py b/openbb_terminal/portfolio/portfolio_optimization/po_model.py index 531ce966706..9b545f1944d 100644 --- a/openbb_terminal/portfolio/portfolio_optimization/po_model.py +++ b/openbb_terminal/portfolio/portfolio_optimization/po_model.py @@ -15,6 +15,7 @@ from riskfolio import rp from openbb_terminal.decorators import log_start_end from openbb_terminal.portfolio.portfolio_optimization import ( + optimizer_helper, optimizer_model, ) from openbb_terminal.portfolio.portfolio_optimization.statics import ( @@ -250,6 +251,8 @@ def get_portfolio_performance(weights: Dict, data: pd.DataFrame, **kwargs) -> Di "Sharpe ratio": sharpe, } + risk_measure = optimizer_helper.validate_risk_measure(risk_measure, warning=False) + if risk_measure != "MV": risk = rp.Sharpe_Risk( weights, diff --git a/tests/openbb_terminal/forecast/test_autoarima_view.py b/tests/openbb_terminal/forecast/test_autoarima_view.py index e59fbe76140..33ba0a7a697 100644 --- a/tests/openbb_terminal/forecast/test_autoarima_view.py +++ b/tests/openbb_terminal/forecast/test_autoarima_view.py @@ -7,12 +7,10 @@ except ImportError: def test_display_autoarima_forecast(tsla_csv): - with pytest.raises(AttributeError): - autoarima_view.display_autoarima_forecast( - tsla_csv, - target_column="close", - seasonal_periods=3, - n_predict=1, - start_window=0.5, - forecast_horizon=1, - ) + autoarima_view.display_autoarima_forecast( + tsla_csv, + target_column="close", + seasonal_periods=3, + n_predict=1, + start_window=0.5, + ) diff --git a/tests/openbb_terminal/forecast/test_autoces_view.py b/tests/openbb_terminal/forecast/test_autoces_view.py index 8cccd6a5d93..544d699c45f 100644 --- a/tests/openbb_terminal/forecast/test_autoces_view.py +++ b/tests/openbb_terminal/forecast/test_autoces_view.py @@ -7,12 +7,11 @@ except ImportError: def test_display_ces_forecast(tsla_csv): - with pytest.raises(AttributeError): - autoces_view.display_autoces_forecast( - tsla_csv, - target_column="close", - seasonal_periods=3, - n_predict=1, - start_window=0.5, - forecast_horizon=1, - ) + autoces_view.display_autoces_forecast( + tsla_csv, + target_column="close", + seasonal_periods=3, + n_predict=1, + start_window=0.5, + forecast_horizon=1, + ) diff --git a/tests/openbb_terminal/forecast/test_autoets_view.py b/tests/openbb_terminal/forecast/test_autoets_view.py index cb88356167a..fb43bc033b9 100644 --- a/tests/openbb_terminal/forecast/test_autoets_view.py +++ b/tests/openbb_terminal/forecast/test_autoets_view.py @@ -7,12 +7,11 @@ except ImportError: def test_display_tft_forecast(tsla_csv): - with pytest.raises(AttributeError): - autoets_view.display_autoets_forecast( - tsla_csv, - target_column="close", - seasonal_periods=3, - n_predict=1, - start_window=0.5, - forecast_horizon=1, - ) + autoets_view.display_autoets_forecast( + tsla_csv, + target_column="close", + seasonal_periods=3, + n_predict=1, + start_window=0.5, + forecast_horizon=1, + ) diff --git a/tests/openbb_terminal/forecast/test_autoselect_view.py b/tests/openbb_terminal/forecast/test_autoselect_view.py index 8045a7ef356..a05c52910ef 100644 --- a/tests/openbb_terminal/forecast/test_autoselect_view.py +++ b/tests/openbb_terminal/forecast/test_autoselect_view.py @@ -7,12 +7,11 @@ except ImportError: def test_display_autoselect_forecast(tsla_csv): - with pytest.raises(AttributeError): - autoselect_view.display_autoselect_forecast( - tsla_csv, - target_column="close", - seasonal_periods=3, - n_predict=1, - start_window=0.5, - forecast_horizon=1, - ) + autoselect_view.display_autoselect_forecast( + tsla_csv, + target_column="close", + seasonal_periods=3, + n_predict=1, + start_window=0.5, + forecast_horizon=1, + ) diff --git a/tests/openbb_terminal/forecast/test_expo_view.py b/tests/openbb_terminal/forecast/test_expo_view.py index 1931d9e17dd..007024c8e1b 100644 --- a/tests/openbb_terminal/forecast/test_expo_view.py +++ b/tests/openbb_terminal/forecast/test_expo_view.py @@ -6,17 +6,15 @@ except ImportError: pytest.skip(allow_module_level=True) -def test_display_tft_forecast(tsla_csv): - with pytest.raises(AttributeError): - expo_view.display_expo_forecast( - tsla_csv, - target_column="close", - trend="N", - seasonal="N", - seasonal_periods=3, - dampen="F", - n_predict=1, - start_window=0.5, - forecast_horizon=1, - residuals=True, - ) +def test_display_expo_forecast(tsla_csv): + expo_view.display_expo_forecast( + tsla_csv, + target_column="close", + trend="N", + seasonal="N", + seasonal_periods=3, + dampen="F", + n_predict=1, + start_window=0.5, + forecast_horizon=1, + ) diff --git a/tests/openbb_terminal/forecast/test_forecast_controller.py b/tests/openbb_terminal/forecast/test_forecast_controller.py index dc59650f162..59a0bed76d5 100644 --- a/tests/openbb_terminal/forecast/test_forecast_controller.py +++ b/tests/openbb_terminal/forecast/test_forecast_controller.py @@ -2,7 +2,7 @@ from typing import List import argparse import pandas as pd import pytest -from prompt_toolkit.completion import NestedCompleter +from openbb_terminal.custom_prompt_toolkit import NestedCompleter try: from openbb_terminal.forecast import forecast_controller as fc @@ -11,7 +11,8 @@ except ImportError: # pylint: disable=E1121 base = "openbb_terminal.forecast.forecast_controller." -df = pd.DataFrame([[1, 2, 3], [2, 3, 4], [3, 4, 5]], columns=["first", "date", "close"]) +the_list = [[x + 1, x + 2, x + 3] for x in range(50)] +df = pd.DataFrame(the_list, columns=["first", "date", "close"]) class Thing: @@ -47,6 +48,7 @@ def test_fc_update_runtime_choices(mocker): cont = fc.ForecastController() cont.datasets = {"stonks": df} cont.update_runtime_choices() + print(type(cont.completer)) assert isinstance(cont.completer, NestedCompleter) @@ -62,7 +64,7 @@ def test_fc_print_help(capsys): cont = fc.ForecastController() cont.print_help() captured = capsys.readouterr() - assert "forecast" in captured.out + assert "ll models are for educational purposes " in captured.out def test_fc_custom_reset(): @@ -78,17 +80,8 @@ def test_fc_custom_reset_with_files(): assert val == ["forecast", "'load file1 -a file1'"] -def test_fc_parse_known_args_and_warn(mocker): - mock = mocker.Mock() - mock2 = mocker.Mock() - mock.parse_known_args = mock_func - mock.format_help = Thing().format_help - mock.add_argument = mock2 - assert mock2.call_count == 28 - - def test_fc_call_load(mocker): - mocker.patch(base + "forecast_model.load", return_value=df) + mocker.patch("openbb_terminal.common.common_model.load", return_value=df) cont = fc.ForecastController() cont.call_load(["data.csv"]) @@ -111,6 +104,7 @@ def test_fc_call_show(mocker, name, loc_df): mock = mocker.MagicMock() mock.name = name mock.limit = 4 + mock.limit_col = 9 mocker.patch( base + "ForecastController.parse_known_args_and_warn", return_value=mock ) @@ -133,7 +127,7 @@ def test_call_desc(mocker, dataset): def test_call_plot(mocker): mock = mocker.MagicMock() - mock.values = ["data.first"] + mock.values = "data.first" mocker.patch( base + "ForecastController.parse_known_args_and_warn", return_value=mock ) @@ -229,10 +223,11 @@ def test_call_comb_not_in(mocker): cont.call_combine(["data"]) +@pytest.mark.skip def test_call_comb(mocker): mock = mocker.MagicMock() mock.dataset = "data" - mock.columns = ["data.first", "data.second"] + mock.columns = "data.close" mocker.patch( base + "ForecastController.parse_known_args_and_warn", return_value=mock ) @@ -242,7 +237,8 @@ def test_call_comb(mocker): "combine": {"data.first": 1, "data.second": 2, "data.third": 3}, "delete": {"data.first": 1, "data.second": 2, "data.third": 3}, } - cont.call_combine(["data"]) + + cont.call_combine(["--dataset", "data", "--columns", "data.first"]) def test_call_clean(): @@ -273,16 +269,16 @@ def test_call_feat_eng_invalid(feature): "combine": {"data.first": 1, "data.second": 2, "data.third": 3}, "delete": {"data.first": 1, "data.second": 2, "data.third": 3}, } - the_list = ["data"] + a_list = ["data"] if feature == "rsi": - the_list.append("--period") - the_list.append("2") - getattr(cont, f"call_{feature}")(the_list) + a_list.append("--period") + a_list.append("2") + getattr(cont, f"call_{feature}")(a_list) @pytest.mark.parametrize( "feature", - ["ema", "sto", "rsi", "roc", "mom", "delta", "atr", "signal", "delete", "export"], + ["ema", "sto", "rsi", "roc", "mom", "delta", "atr", "signal", "export"], ) def test_call_feat_eng_invalid_parser(feature, mocker): mocker.patch(base + "helpers.check_parser_input", return_value=None) @@ -298,11 +294,11 @@ def test_call_feat_eng_invalid_parser(feature, mocker): "combine": {"data.first": 1, "data.second": 2, "data.third": 3}, "delete": {"data.first": 1, "data.second": 2, "data.third": 3}, } - the_list = ["data"] + a_list = ["data"] if feature == "rsi": - the_list.append("--period") - the_list.append("2") - getattr(cont, f"call_{feature}")(the_list) + a_list.append("--period") + a_list.append("2") + getattr(cont, f"call_{feature}")(a_list) def test_call_ema(mocker): @@ -317,6 +313,8 @@ def test_call_ema(mocker): cont.call_ema(["data"]) +# TODO: for now we are not allowing multiple items in the split +@pytest.mark.skip @pytest.mark.parametrize("datasets", [[], ["data"], ["bad"]]) def test_call_delete(mocker, datasets): cont = fc.ForecastController() diff --git a/tests/openbb_terminal/forecast/test_forecast_model.py b/tests/openbb_terminal/forecast/test_forecast_model.py index ac000ff9aa3..821523748ed 100644 --- a/tests/openbb_terminal/forecast/test_forecast_model.py +++ b/tests/openbb_terminal/forecast/test_forecast_model.py @@ -9,11 +9,7 @@ from tests.openbb_terminal.forecast import conftest @pytest.mark.parametrize("file_type", ["csv", "xlsx", "swp", "juan"]) def test_load(file_type): path = conftest.create_path("forecast", "data", f"TSLA.{file_type}") - if not file_type == "xlsx": - common_model.load("TSLA", file_type, {"TSLA": path}) - else: - with pytest.raises(ValueError): - common_model.load("TSLA", file_type, {"TSLA": path}) + common_model.load(file_type, {"TSLA": path}) @pytest.mark.parametrize("name", ["data", None]) diff --git a/tests/openbb_terminal/forecast/test_forecast_view.py b/tests/openbb_terminal/forecast/test_forecast_view.py index 9720928e221..ea06bfa0d2e 100644 --- a/tests/openbb_terminal/forecast/test_forecast_view.py +++ b/tests/openbb_terminal/forecast/test_forecast_view.py @@ -1,3 +1,4 @@ +import matplotlib.pyplot as plt import pandas as pd import pytest @@ -22,32 +23,30 @@ def test_show_options(tsla_csv, capsys): def test_display_plot(tsla_csv, mocker): mock = mocker.patch(base + "theme.visualize_output") - fv.display_plot(tsla_csv, "close") + fv.display_plot(tsla_csv, ["close"]) mock.assert_called_once() -def test_display_plot_multiindex(tsla_csv, mocker, capsys): +def test_display_plot_multiindex(tsla_csv, mocker): mock = mocker.patch(base + "theme.visualize_output") tuples = [("1", x) for x in tsla_csv.index] index = pd.MultiIndex.from_tuples(tuples, names=["first", "second"]) tsla_csv.index = index - fv.display_plot(tsla_csv, "close") - captured = capsys.readouterr() - assert "multi-index" in captured.out - mock.assert_not_called() + fv.display_plot(tsla_csv, ["close"]) + mock.assert_called_once() def test_display_plot_external_axes(tsla_csv, mocker): mock1 = mocker.Mock() mock = mocker.patch(base + "theme.visualize_output") - fv.display_plot(tsla_csv, "close", external_axes=[mock1]) + fv.display_plot(tsla_csv, ["close"], external_axes=[mock1]) mock.assert_not_called() def test_display_plot_series(tsla_csv, mocker): mock1 = mocker.Mock() mock = mocker.patch(base + "theme.v |