diff options
Diffstat (limited to 'openbb_platform/providers/sec/openbb_sec/models/equity_ftd.py')
-rw-r--r-- | openbb_platform/providers/sec/openbb_sec/models/equity_ftd.py | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/openbb_platform/providers/sec/openbb_sec/models/equity_ftd.py b/openbb_platform/providers/sec/openbb_sec/models/equity_ftd.py index 097a73d3a4d..4280c344401 100644 --- a/openbb_platform/providers/sec/openbb_sec/models/equity_ftd.py +++ b/openbb_platform/providers/sec/openbb_sec/models/equity_ftd.py @@ -1,6 +1,8 @@ """SEC Equity FTD Model.""" -import concurrent.futures +# pylint: disable=unused-argument + +import asyncio from typing import Any, Dict, List, Optional from openbb_core.provider.abstract.fetcher import Fetcher @@ -8,6 +10,7 @@ from openbb_core.provider.standard_models.equity_ftd import ( EquityFtdData, EquityFtdQueryParams, ) +from openbb_core.provider.utils.errors import EmptyDataError from openbb_sec.utils.helpers import download_zip_file, get_ftd_urls from pydantic import Field @@ -31,11 +34,18 @@ class SecEquityFtdQueryParams(EquityFtdQueryParams): """, default=0, ) + use_cache: Optional[bool] = Field( + default=True, + description="Whether or not to use cache for the request, default is True." + + " Each reporting period is a separate URL, new reports will be added to the cache.", + ) class SecEquityFtdData(EquityFtdData): """SEC Equity FTD Data.""" + __alias_dict__ = {"settlement_date": "date"} + class SecEquityFtdFetcher( Fetcher[ @@ -51,17 +61,15 @@ class SecEquityFtdFetcher( return SecEquityFtdQueryParams(**params) @staticmethod - def extract_data( - query: SecEquityFtdQueryParams, # pylint: disable=unused-argument + async def aextract_data( + query: SecEquityFtdQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any, ) -> List[Dict]: """Extract the data from the SEC website.""" results = [] limit = query.limit if query.limit is not None and query.limit > 0 else 0 - symbol = query.symbol.upper() - - urls_data = get_ftd_urls() + urls_data = await get_ftd_urls() urls = list(urls_data.values()) if limit > 0: urls = ( @@ -70,14 +78,21 @@ class SecEquityFtdFetcher( else urls[query.skip_reports : limit + query.skip_reports] # noqa: E203 ) - with concurrent.futures.ThreadPoolExecutor() as executor: - executor.map( - lambda url: results.extend(download_zip_file(url, symbol)), urls - ) + async def get_one(url): + """Get data for one URL as a task.""" + data = await download_zip_file(url, query.symbol, query.use_cache) + results.extend(data) + + tasks = [get_one(url) for url in urls] - results = sorted(results, key=lambda d: d["date"], reverse=True) + await asyncio.gather(*tasks) + + if not results: + raise EmptyDataError( + "There was an error collecting data, no results were returned." + ) - return results + return sorted(results, key=lambda d: d["date"], reverse=True) @staticmethod def transform_data( |