diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/provider/abstract/fetcher.py')
-rw-r--r-- | openbb_platform/core/openbb_core/provider/abstract/fetcher.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/openbb_platform/core/openbb_core/provider/abstract/fetcher.py b/openbb_platform/core/openbb_core/provider/abstract/fetcher.py index 4ff3482118a..f2cc759821a 100644 --- a/openbb_platform/core/openbb_core/provider/abstract/fetcher.py +++ b/openbb_platform/core/openbb_core/provider/abstract/fetcher.py @@ -9,12 +9,14 @@ from typing import ( Generic, Optional, TypeVar, + Union, get_args, get_origin, ) from pandas import DataFrame +from openbb_core.provider.abstract.annotated_result import AnnotatedResult from openbb_core.provider.abstract.data import Data from openbb_core.provider.abstract.query_params import QueryParams from openbb_core.provider.utils.helpers import maybe_coroutine, run_async @@ -56,7 +58,7 @@ class Fetcher(Generic[Q, R]): """Extract the data from the provider.""" @staticmethod - def transform_data(query: Q, data: Any, **kwargs) -> R: + def transform_data(query: Q, data: Any, **kwargs) -> Union[R, AnnotatedResult[R]]: """Transform the provider-specific data.""" raise NotImplementedError @@ -65,7 +67,7 @@ class Fetcher(Generic[Q, R]): super().__init_subclass__(*args, **kwargs) if cls.aextract_data != Fetcher.aextract_data: - cls.extract_data = cls.aextract_data + cls.extract_data = cls.aextract_data # type: ignore[method-assign] elif cls.extract_data == Fetcher.extract_data: raise NotImplementedError( "Fetcher subclass must implement either extract_data or aextract_data" @@ -79,7 +81,7 @@ class Fetcher(Generic[Q, R]): params: Dict[str, Any], credentials: Optional[Dict[str, str]] = None, **kwargs, - ) -> R: + ) -> Union[R, AnnotatedResult[R]]: """Fetch data from a provider.""" query = cls.transform_query(params=params) data = await maybe_coroutine( @@ -139,7 +141,7 @@ class Fetcher(Generic[Q, R]): data = run_async( cls.extract_data, query=query, credentials=credentials, **kwargs ) - transformed_data = cls.transform_data(query=query, data=data, **kwargs) + result = cls.transform_data(query=query, data=data, **kwargs) # Class Assertions assert isinstance( @@ -184,10 +186,13 @@ class Fetcher(Generic[Q, R]): assert len(data) > 0, "Data must not be empty." # Transformed Data Assertions + transformed_data = ( + result.result if isinstance(result, AnnotatedResult) else result + ) + assert transformed_data, "Transformed data must not be None." - is_list = isinstance(transformed_data, list) - if is_list: + if isinstance(transformed_data, list): return_type_args = cls.return_type.__args__[0] return_type_is_dict = ( hasattr(return_type_args, "__origin__") |