summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/app/model/obbject.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/app/model/obbject.py')
-rw-r--r--openbb_platform/core/openbb_core/app/model/obbject.py81
1 files changed, 43 insertions, 38 deletions
diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py
index 9f2a2da8746..8d4ba26839c 100644
--- a/openbb_platform/core/openbb_core/app/model/obbject.py
+++ b/openbb_platform/core/openbb_core/app/model/obbject.py
@@ -8,6 +8,7 @@ from typing import (
ClassVar,
Dict,
Generic,
+ Hashable,
List,
Literal,
Optional,
@@ -25,6 +26,7 @@ from openbb_core.app.model.abstract.tagged import Tagged
from openbb_core.app.model.abstract.warning import Warning_
from openbb_core.app.model.charts.chart import Chart
from openbb_core.app.utils import basemodel_to_df
+from openbb_core.provider.abstract.annotated_result import AnnotatedResult
from openbb_core.provider.abstract.data import Data
if TYPE_CHECKING:
@@ -82,30 +84,32 @@ class OBBject(Tagged, Generic[T]):
@classmethod
def results_type_repr(cls, params: Optional[Any] = None) -> str:
- """Return the results type name."""
+ """Return the results type representation."""
results_field = cls.model_fields.get("results")
- type_ = params[0] if params else results_field.annotation
- name = type_.__name__ if hasattr(type_, "__name__") else str(type_)
+ type_repr = "Any"
+ if results_field:
+ type_ = params[0] if params else results_field.annotation
+ type_repr = getattr(type_, "__name__", str(type_))
- if (json_schema_extra := results_field.json_schema_extra) is not None:
- model = json_schema_extra.get("model")
+ if json_schema_extra := getattr(results_field, "json_schema_extra", {}):
+ model = json_schema_extra.get("model", "Any")
- if json_schema_extra.get("is_union"):
- return f"Union[List[{model}], {model}]"
- if json_schema_extra.get("has_list"):
- return f"List[{model}]"
+ if json_schema_extra.get("is_union"):
+ return f"Union[List[{model}], {model}]"
+ if json_schema_extra.get("has_list"):
+ return f"List[{model}]"
- return model
+ return model
- if "typing." in str(type_):
- unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_))
- name = sub(
- r"(\w+\.)*(\w+)?(\, NoneType)?",
- r"\2",
- unpack_optional,
- )
+ if "typing." in str(type_):
+ unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_))
+ type_repr = sub(
+ r"(\w+\.)*(\w+)?(\, NoneType)?",
+ r"\2",
+ unpack_optional,
+ )
- return name
+ return type_repr
@classmethod
def model_parametrized_name(cls, params: Any) -> str:
@@ -189,9 +193,9 @@ class OBBject(Tagged, Generic[T]):
# List[List | str | int | float] | Dict[str, Dict | List | BaseModel]
else:
try:
- df = pd.DataFrame(res)
+ df = pd.DataFrame(res) # type: ignore[call-overload]
# Set index, if any
- if index is not None and index in df.columns:
+ if df is not None and index is not None and index in df.columns:
df.set_index(index, inplace=True)
except ValueError:
@@ -245,7 +249,7 @@ class OBBject(Tagged, Generic[T]):
orient: Literal[
"dict", "list", "series", "split", "tight", "records", "index"
] = "list",
- ) -> Dict[str, List]:
+ ) -> Union[Dict[Hashable, Any], List[Dict[Hashable, Any]]]:
"""Convert results field to a dictionary using any of pandas to_dict options.
Parameters
@@ -256,25 +260,21 @@ class OBBject(Tagged, Generic[T]):
Returns
-------
- Dict[str, List]
- Dictionary of lists.
+ Union[Dict[Hashable, Any], List[Dict[Hashable, Any]]]
+ Dictionary of lists or list of dictionaries if orient is "records".
"""
- df = self.to_dataframe(index=None) # type: ignore
- transpose = False
- if orient == "list":
- transpose = True
- if not isinstance(self.results, dict):
- transpose = False
- else: # Only enter the loop if self.results is a dictionary
- self.results: Dict[str, Any] = self.results # type: ignore
- for _, value in self.results.items():
- if not isinstance(value, dict):
- transpose = False
- break
- if transpose:
+ df = self.to_dataframe(index=None)
+ if (
+ orient == "list"
+ and isinstance(self.results, dict)
+ and all(
+ isinstance(value, dict)
+ for value in self.results.values() # pylint: disable=no-member
+ )
+ ):
df = df.T
results = df.to_dict(orient=orient)
- if orient == "list" and "index" in results:
+ if isinstance(results, dict) and orient == "list" and "index" in results:
del results["index"]
return results
@@ -300,4 +300,9 @@ class OBBject(Tagged, Generic[T]):
OBBject[ResultsType]
OBBject with results.
"""
- return cls(results=await query.execute())
+ results = await query.execute()
+ if isinstance(results, AnnotatedResult):
+ return cls(
+ results=results.result, extra={"results_metadata": results.metadata}
+ )
+ return cls(results=results)