diff options
Diffstat (limited to 'openbb_platform/providers/sec/openbb_sec/models/sic_search.py')
-rw-r--r-- | openbb_platform/providers/sec/openbb_sec/models/sic_search.py | 53 |
1 files changed, 31 insertions, 22 deletions
diff --git a/openbb_platform/providers/sec/openbb_sec/models/sic_search.py b/openbb_platform/providers/sec/openbb_sec/models/sic_search.py index 3ea086aa648..a605f39831e 100644 --- a/openbb_platform/providers/sec/openbb_sec/models/sic_search.py +++ b/openbb_platform/providers/sec/openbb_sec/models/sic_search.py @@ -1,13 +1,18 @@ """SEC Standard Industrial Classification Code (SIC) Model.""" +# pylint: disable=unused-argument + from typing import Any, Dict, List, Optional import pandas as pd -import requests +from aiohttp_client_cache import SQLiteBackend +from aiohttp_client_cache.session import CachedSession +from openbb_core.app.utils import get_user_cache_directory from openbb_core.provider.abstract.data import Data from openbb_core.provider.abstract.fetcher import Fetcher from openbb_core.provider.standard_models.cot_search import CotSearchQueryParams -from openbb_sec.utils.helpers import SEC_HEADERS, sec_session_companies +from openbb_core.provider.utils.helpers import amake_request +from openbb_sec.utils.helpers import SEC_HEADERS, sec_callback from pydantic import Field @@ -37,7 +42,6 @@ class SecSicSearchFetcher( ): """Transform the query, extract and transform the data from the SEC endpoints.""" - # pylint: disable=unused-argument @staticmethod def transform_query( params: Dict[str, Any], **kwargs: Any @@ -45,9 +49,8 @@ class SecSicSearchFetcher( """Transform the query.""" return SecSicSearchQueryParams(**params) - # pylint: disable=unused-argument @staticmethod - def extract_data( + async def aextract_data( query: SecSicSearchQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any, @@ -59,28 +62,34 @@ class SecSicSearchFetcher( "https://www.sec.gov/corpfin/" "division-of-corporation-finance-standard-industrial-classification-sic-code-list" ) - r = ( - sec_session_companies.get(url, timeout=5, headers=SEC_HEADERS) - if query.use_cache is True - else requests.get(url, timeout=5, headers=SEC_HEADERS) - ) + if query.use_cache is True: + cache_dir = f"{get_user_cache_directory()}/http/sec_sic" + async with CachedSession( + cache=SQLiteBackend(cache_dir, expire_after=3600 * 24 * 30) + ) as session: + try: + response = await amake_request( + url, headers=SEC_HEADERS, session=session, response_callback=sec_callback # type: ignore + ) + finally: + await session.close() + else: + response = await amake_request(url, headers=SEC_HEADERS, response_callback=sec_callback) # type: ignore - if r.status_code == 200: - data = pd.read_html(r.content.decode())[0].astype(str) - if len(data) == 0: - return results - if query: - data = data[ - data["SIC Code"].str.contains(query.query, case=False) - | data["Office"].str.contains(query.query, case=False) - | data["Industry Title"].str.contains(query.query, case=False) - ] - data["SIC Code"] = data["SIC Code"].astype(int) + data = pd.read_html(response)[0].astype(str) + if len(data) == 0: + return results + if query: + data = data[ + data["SIC Code"].str.contains(query.query, case=False) + | data["Office"].str.contains(query.query, case=False) + | data["Industry Title"].str.contains(query.query, case=False) + ] + data["SIC Code"] = data["SIC Code"].astype(int) results = data.to_dict("records") return results - # pylint: disable=unused-argument @staticmethod def transform_data( query: SecSicSearchQueryParams, data: List[Dict], **kwargs: Any |