summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/provider/abstract/fetcher.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/provider/abstract/fetcher.py')
-rw-r--r--openbb_platform/core/openbb_core/provider/abstract/fetcher.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/openbb_platform/core/openbb_core/provider/abstract/fetcher.py b/openbb_platform/core/openbb_core/provider/abstract/fetcher.py
index 4ff3482118a..f2cc759821a 100644
--- a/openbb_platform/core/openbb_core/provider/abstract/fetcher.py
+++ b/openbb_platform/core/openbb_core/provider/abstract/fetcher.py
@@ -9,12 +9,14 @@ from typing import (
Generic,
Optional,
TypeVar,
+ Union,
get_args,
get_origin,
)
from pandas import DataFrame
+from openbb_core.provider.abstract.annotated_result import AnnotatedResult
from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.query_params import QueryParams
from openbb_core.provider.utils.helpers import maybe_coroutine, run_async
@@ -56,7 +58,7 @@ class Fetcher(Generic[Q, R]):
"""Extract the data from the provider."""
@staticmethod
- def transform_data(query: Q, data: Any, **kwargs) -> R:
+ def transform_data(query: Q, data: Any, **kwargs) -> Union[R, AnnotatedResult[R]]:
"""Transform the provider-specific data."""
raise NotImplementedError
@@ -65,7 +67,7 @@ class Fetcher(Generic[Q, R]):
super().__init_subclass__(*args, **kwargs)
if cls.aextract_data != Fetcher.aextract_data:
- cls.extract_data = cls.aextract_data
+ cls.extract_data = cls.aextract_data # type: ignore[method-assign]
elif cls.extract_data == Fetcher.extract_data:
raise NotImplementedError(
"Fetcher subclass must implement either extract_data or aextract_data"
@@ -79,7 +81,7 @@ class Fetcher(Generic[Q, R]):
params: Dict[str, Any],
credentials: Optional[Dict[str, str]] = None,
**kwargs,
- ) -> R:
+ ) -> Union[R, AnnotatedResult[R]]:
"""Fetch data from a provider."""
query = cls.transform_query(params=params)
data = await maybe_coroutine(
@@ -139,7 +141,7 @@ class Fetcher(Generic[Q, R]):
data = run_async(
cls.extract_data, query=query, credentials=credentials, **kwargs
)
- transformed_data = cls.transform_data(query=query, data=data, **kwargs)
+ result = cls.transform_data(query=query, data=data, **kwargs)
# Class Assertions
assert isinstance(
@@ -184,10 +186,13 @@ class Fetcher(Generic[Q, R]):
assert len(data) > 0, "Data must not be empty."
# Transformed Data Assertions
+ transformed_data = (
+ result.result if isinstance(result, AnnotatedResult) else result
+ )
+
assert transformed_data, "Transformed data must not be None."
- is_list = isinstance(transformed_data, list)
- if is_list:
+ if isinstance(transformed_data, list):
return_type_args = cls.return_type.__args__[0]
return_type_is_dict = (
hasattr(return_type_args, "__origin__")