diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/app/model/obbject.py')
-rw-r--r-- | openbb_platform/core/openbb_core/app/model/obbject.py | 81 |
1 files changed, 43 insertions, 38 deletions
diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py index 9f2a2da8746..8d4ba26839c 100644 --- a/openbb_platform/core/openbb_core/app/model/obbject.py +++ b/openbb_platform/core/openbb_core/app/model/obbject.py @@ -8,6 +8,7 @@ from typing import ( ClassVar, Dict, Generic, + Hashable, List, Literal, Optional, @@ -25,6 +26,7 @@ from openbb_core.app.model.abstract.tagged import Tagged from openbb_core.app.model.abstract.warning import Warning_ from openbb_core.app.model.charts.chart import Chart from openbb_core.app.utils import basemodel_to_df +from openbb_core.provider.abstract.annotated_result import AnnotatedResult from openbb_core.provider.abstract.data import Data if TYPE_CHECKING: @@ -82,30 +84,32 @@ class OBBject(Tagged, Generic[T]): @classmethod def results_type_repr(cls, params: Optional[Any] = None) -> str: - """Return the results type name.""" + """Return the results type representation.""" results_field = cls.model_fields.get("results") - type_ = params[0] if params else results_field.annotation - name = type_.__name__ if hasattr(type_, "__name__") else str(type_) + type_repr = "Any" + if results_field: + type_ = params[0] if params else results_field.annotation + type_repr = getattr(type_, "__name__", str(type_)) - if (json_schema_extra := results_field.json_schema_extra) is not None: - model = json_schema_extra.get("model") + if json_schema_extra := getattr(results_field, "json_schema_extra", {}): + model = json_schema_extra.get("model", "Any") - if json_schema_extra.get("is_union"): - return f"Union[List[{model}], {model}]" - if json_schema_extra.get("has_list"): - return f"List[{model}]" + if json_schema_extra.get("is_union"): + return f"Union[List[{model}], {model}]" + if json_schema_extra.get("has_list"): + return f"List[{model}]" - return model + return model - if "typing." in str(type_): - unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_)) - name = sub( - r"(\w+\.)*(\w+)?(\, NoneType)?", - r"\2", - unpack_optional, - ) + if "typing." in str(type_): + unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_)) + type_repr = sub( + r"(\w+\.)*(\w+)?(\, NoneType)?", + r"\2", + unpack_optional, + ) - return name + return type_repr @classmethod def model_parametrized_name(cls, params: Any) -> str: @@ -189,9 +193,9 @@ class OBBject(Tagged, Generic[T]): # List[List | str | int | float] | Dict[str, Dict | List | BaseModel] else: try: - df = pd.DataFrame(res) + df = pd.DataFrame(res) # type: ignore[call-overload] # Set index, if any - if index is not None and index in df.columns: + if df is not None and index is not None and index in df.columns: df.set_index(index, inplace=True) except ValueError: @@ -245,7 +249,7 @@ class OBBject(Tagged, Generic[T]): orient: Literal[ "dict", "list", "series", "split", "tight", "records", "index" ] = "list", - ) -> Dict[str, List]: + ) -> Union[Dict[Hashable, Any], List[Dict[Hashable, Any]]]: """Convert results field to a dictionary using any of pandas to_dict options. Parameters @@ -256,25 +260,21 @@ class OBBject(Tagged, Generic[T]): Returns ------- - Dict[str, List] - Dictionary of lists. + Union[Dict[Hashable, Any], List[Dict[Hashable, Any]]] + Dictionary of lists or list of dictionaries if orient is "records". """ - df = self.to_dataframe(index=None) # type: ignore - transpose = False - if orient == "list": - transpose = True - if not isinstance(self.results, dict): - transpose = False - else: # Only enter the loop if self.results is a dictionary - self.results: Dict[str, Any] = self.results # type: ignore - for _, value in self.results.items(): - if not isinstance(value, dict): - transpose = False - break - if transpose: + df = self.to_dataframe(index=None) + if ( + orient == "list" + and isinstance(self.results, dict) + and all( + isinstance(value, dict) + for value in self.results.values() # pylint: disable=no-member + ) + ): df = df.T results = df.to_dict(orient=orient) - if orient == "list" and "index" in results: + if isinstance(results, dict) and orient == "list" and "index" in results: del results["index"] return results @@ -300,4 +300,9 @@ class OBBject(Tagged, Generic[T]): OBBject[ResultsType] OBBject with results. """ - return cls(results=await query.execute()) + results = await query.execute() + if isinstance(results, AnnotatedResult): + return cls( + results=results.result, extra={"results_metadata": results.metadata} + ) + return cls(results=results) |