diff options
author | Henrique Joaquim <h.joaquim@campus.fct.unl.pt> | 2023-09-20 15:46:31 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-20 15:46:31 +0100 |
commit | 826d0136f22de1492b6d8db1dbe7628147f20233 (patch) | |
tree | 61cdafdc6c2b85e98ff1b972297ef1d5b0249b44 | |
parent | 3aac0baa8bdc41c0f0b3db45f59e398b004f6bb7 (diff) |
Metadata input scaling (#5448)
* adding scalling of arguments in metadata
* adjustments + tests for metada
* tests
* minor fix so we don't get repeated columns
* serie to series
* ruff and one more test
* docstring formatting
-rw-r--r-- | openbb_sdk/sdk/core/openbb_core/app/command_runner.py | 13 | ||||
-rw-r--r-- | openbb_sdk/sdk/core/openbb_core/app/model/metadata.py | 98 | ||||
-rw-r--r-- | openbb_sdk/sdk/core/tests/app/model/test_metadata.py | 117 |
3 files changed, 221 insertions, 7 deletions
diff --git a/openbb_sdk/sdk/core/openbb_core/app/command_runner.py b/openbb_sdk/sdk/core/openbb_core/app/command_runner.py index 96ab424404c..17cd48b69ec 100644 --- a/openbb_sdk/sdk/core/openbb_core/app/command_runner.py +++ b/openbb_sdk/sdk/core/openbb_core/app/command_runner.py @@ -14,6 +14,7 @@ from openbb_core.app.logs.logging_service import LoggingService from openbb_core.app.model.abstract.error import OpenBBError from openbb_core.app.model.abstract.warning import cast_warning from openbb_core.app.model.command_context import CommandContext +from openbb_core.app.model.metadata import Metadata from openbb_core.app.model.obbject import OBBject from openbb_core.app.model.system_settings import SystemSettings from openbb_core.app.model.user_settings import UserSettings @@ -366,12 +367,12 @@ class StaticCommandRunner: duration = perf_counter_ns() - start_ns if execution_context.user_settings.preferences.metadata: - obbject.metadata = { # type: ignore - "arguments": kwargs, - "duration": duration, - "route": route, - "timestamp": timestamp, - } + obbject.metadata = Metadata( + arguments=kwargs, + duration=duration, + route=route, + timestamp=timestamp, + ) return obbject diff --git a/openbb_sdk/sdk/core/openbb_core/app/model/metadata.py b/openbb_sdk/sdk/core/openbb_core/app/model/metadata.py index f667fab3f4c..bc263ea9c9c 100644 --- a/openbb_sdk/sdk/core/openbb_core/app/model/metadata.py +++ b/openbb_sdk/sdk/core/openbb_core/app/model/metadata.py @@ -1,7 +1,11 @@ from datetime import datetime +from inspect import isclass from typing import Any, Dict -from pydantic import BaseModel, Field +import numpy as np +import pandas as pd +from openbb_provider.abstract.data import Data +from pydantic import BaseModel, Field, validator class Metadata(BaseModel): @@ -19,3 +23,95 @@ class Metadata(BaseModel): return f"{self.__class__.__name__}\n\n" + "\n".join( f"{k}: {v}" for k, v in self.dict().items() ) + + @validator("arguments") + @classmethod + def scale_arguments(cls, v): + """Scale arguments. + This function is meant to limit the size of the input arguments of a command. + If the type is one of the following: `Data`, `List[Data]`, `DataFrame`, `List[DataFrame]`, + `Series`, `List[Series]` or `ndarray`, the value of the argument is swapped by a dictionary + containing the type and the columns. If the type is not one of the previous, the + value is kept or trimmed to 80 characters. + """ + for arg, arg_val in v.items(): + new_arg_val = None + + # Data + if isclass(type(arg_val)) and issubclass(type(arg_val), Data): + new_arg_val = { + "type": f"{type(arg_val).__name__}", + "columns": list(arg_val.dict().keys()), + } + + # List[Data] + if isinstance(arg_val, list) and issubclass(type(arg_val[0]), Data): + columns = [list(d.dict().keys()) for d in arg_val] + columns = (item for sublist in columns for item in sublist) # flatten + new_arg_val = { + "type": f"List[{type(arg_val[0]).__name__}]", + "columns": list(set(columns)), + } + + # DataFrame + elif isinstance(arg_val, pd.DataFrame): + columns = ( + list(arg_val.index.names) + arg_val.columns.tolist() + if any(index is not None for index in list(arg_val.index.names)) + else arg_val.columns.tolist() + ) + new_arg_val = { + "type": f"{type(arg_val).__name__}", + "columns": columns, + } + + # List[DataFrame] + elif isinstance(arg_val, list) and issubclass( + type(arg_val[0]), pd.DataFrame + ): + columns = [ + list(df.index.names) + df.columns.tolist() + if any(index is not None for index in list(df.index.names)) + else df.columns.tolist() + for df in arg_val + ] + new_arg_val = { + "type": f"List[{type(arg_val[0]).__name__}]", + "columns": columns, + } + + # Series + elif isinstance(arg_val, pd.Series): + new_arg_val = { + "type": f"{type(arg_val).__name__}", + "columns": list(arg_val.index.names) + [arg_val.name], + } + + # List[Series] + elif isinstance(arg_val, list) and isinstance(arg_val[0], pd.Series): + columns = [ + list(series.index.names) + [series.name] + if any(index is not None for index in list(series.index.names)) + else series.name + for series in arg_val + ] + new_arg_val = { + "type": f"List[{type(arg_val[0]).__name__}]", + "columns": columns, + } + + # ndarray + elif isinstance(arg_val, np.ndarray): + new_arg_val = { + "type": f"{type(arg_val).__name__}", + "columns": list(arg_val.dtype.names or []), + } + + else: + str_repr_arg_val = str(arg_val) + if len(str_repr_arg_val) > 80: + new_arg_val = str_repr_arg_val[:80] + + v[arg] = new_arg_val or arg_val + + return v diff --git a/openbb_sdk/sdk/core/tests/app/model/test_metadata.py b/openbb_sdk/sdk/core/tests/app/model/test_metadata.py new file mode 100644 index 00000000000..9f6c9487a3e --- /dev/null +++ b/openbb_sdk/sdk/core/tests/app/model/test_metadata.py @@ -0,0 +1,117 @@ +from datetime import datetime + +import numpy as np +import pandas as pd +import pytest +from openbb_core.app.model.metadata import Metadata +from openbb_provider.abstract.data import Data + + +def test_Metadata(): + "Smoke test" + m = Metadata( + arguments={"test": "test"}, + route="test", + timestamp=datetime.now(), + duration=0, + ) + assert m + assert isinstance(m, Metadata) + + +def test_fields(): + "Smoke test" + fields = Metadata.__fields__.keys() + assert "arguments" in fields + assert "duration" in fields + assert "route" in fields + assert "timestamp" in fields + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + # Test cases for various input types + ({"data": Data()}, {"data": {"type": "Data", "columns": []}}), + ( + {"data": Data(open=123, close=456)}, + {"data": {"type": "Data", "columns": ["open", "close"]}}, + ), + ( + {"data_list": [Data(open=123, close=456), Data(volume=789)]}, + { + "data_list": { + "type": "List[Data]", + "columns": ["open", "close", "volume"], + } + }, + ), + ( + {"data_list": [Data(open=123, close=456), Data(open=321, volume=789)]}, + { + "data_list": { + "type": "List[Data]", + "columns": ["open", "close", "volume"], + } + }, + ), + ( + {"data_frame": pd.DataFrame({"A": [1, 2], "B": [3, 4]})}, + {"data_frame": {"type": "DataFrame", "columns": ["A", "B"]}}, + ), + ( + { + "data_frame_list": [ + pd.DataFrame({"A": [1, 2], "B": [3, 4]}), + pd.DataFrame({"C": [5, 6]}), + ], + "data_series_list": [ + pd.Series([1, 2], name="X"), + pd.Series([3, 4], name="Y"), + ], + }, + { + "data_frame_list": { + "type": "List[DataFrame]", + "columns": [["A", "B"], ["C"]], + }, + "data_series_list": {"type": "List[Series]", "columns": ["X", "Y"]}, + }, + ), + ( + { + "numpy_array": np.array( + [(1, "Alice"), (2, "Bob")], dtype=[("id", int), ("name", "U10")] + ) + }, + {"numpy_array": {"type": "ndarray", "columns": ["id", "name"]}}, + ), + # Test case for long string input + ( + { + "long_string": "This is a very long string that exceeds 80 characters in length and should be trimmed." + }, + { + "long_string": "This is a very long string that exceeds 80 characters in length and should be tr" + }, + ), + ], +) +def test_scale_arguments(input_data, expected_output): + m = Metadata( + arguments=input_data, + route="test", + timestamp=datetime.now(), + duration=0, + ) + arguments = m.arguments + + for arg in arguments: + if "columns" in arguments[arg]: + # compare the column names disregarding the order with the expected output + assert sorted(arguments[arg]["columns"]) == sorted( + expected_output[arg]["columns"] + ) + assert arguments[arg]["type"] == expected_output[arg]["type"] + else: + assert m.arguments == expected_output |