summaryrefslogtreecommitdiffstats
path: root/openbb_platform/providers/sec/openbb_sec/models/sic_search.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/providers/sec/openbb_sec/models/sic_search.py')
-rw-r--r--openbb_platform/providers/sec/openbb_sec/models/sic_search.py53
1 files changed, 31 insertions, 22 deletions
diff --git a/openbb_platform/providers/sec/openbb_sec/models/sic_search.py b/openbb_platform/providers/sec/openbb_sec/models/sic_search.py
index 3ea086aa648..a605f39831e 100644
--- a/openbb_platform/providers/sec/openbb_sec/models/sic_search.py
+++ b/openbb_platform/providers/sec/openbb_sec/models/sic_search.py
@@ -1,13 +1,18 @@
"""SEC Standard Industrial Classification Code (SIC) Model."""
+# pylint: disable=unused-argument
+
from typing import Any, Dict, List, Optional
import pandas as pd
-import requests
+from aiohttp_client_cache import SQLiteBackend
+from aiohttp_client_cache.session import CachedSession
+from openbb_core.app.utils import get_user_cache_directory
from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.cot_search import CotSearchQueryParams
-from openbb_sec.utils.helpers import SEC_HEADERS, sec_session_companies
+from openbb_core.provider.utils.helpers import amake_request
+from openbb_sec.utils.helpers import SEC_HEADERS, sec_callback
from pydantic import Field
@@ -37,7 +42,6 @@ class SecSicSearchFetcher(
):
"""Transform the query, extract and transform the data from the SEC endpoints."""
- # pylint: disable=unused-argument
@staticmethod
def transform_query(
params: Dict[str, Any], **kwargs: Any
@@ -45,9 +49,8 @@ class SecSicSearchFetcher(
"""Transform the query."""
return SecSicSearchQueryParams(**params)
- # pylint: disable=unused-argument
@staticmethod
- def extract_data(
+ async def aextract_data(
query: SecSicSearchQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
@@ -59,28 +62,34 @@ class SecSicSearchFetcher(
"https://www.sec.gov/corpfin/"
"division-of-corporation-finance-standard-industrial-classification-sic-code-list"
)
- r = (
- sec_session_companies.get(url, timeout=5, headers=SEC_HEADERS)
- if query.use_cache is True
- else requests.get(url, timeout=5, headers=SEC_HEADERS)
- )
+ if query.use_cache is True:
+ cache_dir = f"{get_user_cache_directory()}/http/sec_sic"
+ async with CachedSession(
+ cache=SQLiteBackend(cache_dir, expire_after=3600 * 24 * 30)
+ ) as session:
+ try:
+ response = await amake_request(
+ url, headers=SEC_HEADERS, session=session, response_callback=sec_callback # type: ignore
+ )
+ finally:
+ await session.close()
+ else:
+ response = await amake_request(url, headers=SEC_HEADERS, response_callback=sec_callback) # type: ignore
- if r.status_code == 200:
- data = pd.read_html(r.content.decode())[0].astype(str)
- if len(data) == 0:
- return results
- if query:
- data = data[
- data["SIC Code"].str.contains(query.query, case=False)
- | data["Office"].str.contains(query.query, case=False)
- | data["Industry Title"].str.contains(query.query, case=False)
- ]
- data["SIC Code"] = data["SIC Code"].astype(int)
+ data = pd.read_html(response)[0].astype(str)
+ if len(data) == 0:
+ return results
+ if query:
+ data = data[
+ data["SIC Code"].str.contains(query.query, case=False)
+ | data["Office"].str.contains(query.query, case=False)
+ | data["Industry Title"].str.contains(query.query, case=False)
+ ]
+ data["SIC Code"] = data["SIC Code"].astype(int)
results = data.to_dict("records")
return results
- # pylint: disable=unused-argument
@staticmethod
def transform_data(
query: SecSicSearchQueryParams, data: List[Dict], **kwargs: Any