summaryrefslogtreecommitdiffstats
path: root/openbb_platform/providers/sec/openbb_sec/models/sic_search.py
blob: a130bddead65db3a4c8887633bc50f9ca19d2a08 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""SEC Standard Industrial Classification Code (SIC) Model."""

# pylint: disable=unused-argument

from typing import Any, Dict, List, Optional, Union

import pandas as pd
from aiohttp_client_cache import SQLiteBackend
from aiohttp_client_cache.session import CachedSession
from openbb_core.app.utils import get_user_cache_directory
from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.cot_search import CotSearchQueryParams
from openbb_core.provider.utils.helpers import amake_request
from openbb_sec.utils.helpers import SEC_HEADERS, sec_callback
from pydantic import Field


class SecSicSearchQueryParams(CotSearchQueryParams):
    """SEC Standard Industrial Classification Code (SIC) Query.

    Source: https://sec.gov/
    """


class SecSicSearchData(Data):
    """SEC Standard Industrial Classification Code (SIC) Data."""

    sic: int = Field(description="Sector Industrial Code (SIC)", alias="SIC Code")
    industry: str = Field(description="Industry title.", alias="Industry Title")
    office: str = Field(
        description="Reporting office within the Corporate Finance Office",
        alias="Office",
    )


class SecSicSearchFetcher(
    Fetcher[
        SecSicSearchQueryParams,
        List[SecSicSearchData],
    ]
):
    """Transform the query, extract and transform the data from the SEC endpoints."""

    @staticmethod
    def transform_query(
        params: Dict[str, Any], **kwargs: Any
    ) -> SecSicSearchQueryParams:
        """Transform the query."""
        return SecSicSearchQueryParams(**params)

    @staticmethod
    async def aextract_data(
        query: SecSicSearchQueryParams,
        credentials: Optional[Dict[str, str]],
        **kwargs: Any,
    ) -> List[Dict]:
        """Extract data from the SEC website table."""
        data = pd.DataFrame()
        results: List[Dict] = []
        url = (
            "https://www.sec.gov/corpfin/"
            "division-of-corporation-finance-standard-industrial-classification-sic-code-list"
        )
        response: Union[dict, List[dict], str] = {}
        if query.use_cache is True:
            cache_dir = f"{get_user_cache_directory()}/http/sec_sic"
            async with CachedSession(
                cache=SQLiteBackend(cache_dir, expire_after=3600 * 24 * 30)
            ) as session:
                try:
                    response = await amake_request(
                        url, headers=SEC_HEADERS, session=session, response_callback=sec_callback  # type: ignore
                    )
                finally:
                    await session.close()
        else:
            response = await amake_request(url, headers=SEC_HEADERS, response_callback=sec_callback)  # type: ignore

        data = pd.read_html(response)[0].astype(str)
        if len(data) == 0:
            return results
        if query:
            data = data[
                data["SIC Code"].str.contains(query.query, case=False)
                | data["Office"].str.contains(query.query, case=False)
                | data["Industry Title"].str.contains(query.query, case=False)
            ]
        data["SIC Code"] = data["SIC Code"].astype(int)
        results = data.to_dict("records")

        return results

    @staticmethod
    def transform_data(
        query: SecSicSearchQueryParams, data: List[Dict], **kwargs: Any
    ) -> List[SecSicSearchData]:
        """Transform the data."""
        return [SecSicSearchData.model_validate(d) for d in data]