summaryrefslogtreecommitdiffstats
path: root/openbb_platform/providers/seeking_alpha/openbb_seeking_alpha/models/calendar_earnings.py
blob: 67b6ee133398efdf2ce6639e6f904b7b8ad33f8a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Seeking Alpha Calendar Earnings Model."""

# pylint: disable=unused-argument

import asyncio
import json
from datetime import datetime, timedelta
from typing import Any, Dict, List, Literal, Optional
from warnings import warn

from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.calendar_earnings import (
    CalendarEarningsData,
    CalendarEarningsQueryParams,
)
from openbb_core.provider.utils.helpers import amake_request
from openbb_seeking_alpha.utils.helpers import HEADERS, date_range
from pydantic import Field, field_validator


class SACalendarEarningsQueryParams(CalendarEarningsQueryParams):
    """Seeking Alpha Calendar Earnings Query.

    Source: https://seekingalpha.com/earnings/earnings-calendar
    """

    country: Literal["us", "ca"] = Field(
        default="us",
        description="The country to get calendar data for.",
        json_schema_extra={"choices": ["us", "ca"]},
    )


class SACalendarEarningsData(CalendarEarningsData):
    """Seeking Alpha Calendar Earnings Data."""

    market_cap: Optional[float] = Field(
        default=None,
        description="Market cap of the entity.",
    )
    reporting_time: Optional[str] = Field(
        default=None,
        description="The reporting time - e.g. after market close.",
    )
    exchange: Optional[str] = Field(
        default=None,
        description="The primary trading exchange.",
    )
    sector_id: Optional[int] = Field(
        default=None,
        description="The Seeking Alpha Sector ID.",
    )

    @field_validator("report_date", mode="before", check_fields=False)
    @classmethod
    def validate_release_date(cls, v):
        """Validate the release date."""
        v = v.split("T")[0]
        return datetime.strptime(v, "%Y-%m-%d").date()


class SACalendarEarningsFetcher(
    Fetcher[
        SACalendarEarningsQueryParams,
        List[SACalendarEarningsData],
    ]
):
    """Seeking Alpha Calendar Earnings Fetcher."""

    @staticmethod
    def transform_query(params: Dict[str, Any]) -> SACalendarEarningsQueryParams:
        """Transform the query."""
        now = datetime.today().date()
        transformed_params = params
        if not params.get("start_date"):
            transformed_params["start_date"] = now
        if not params.get("end_date"):
            transformed_params["end_date"] = now + timedelta(days=3)
        return SACalendarEarningsQueryParams(**transformed_params)

    @staticmethod
    async def aextract_data(
        query: SACalendarEarningsQueryParams,
        credentials: Optional[Dict[str, str]],
        **kwargs: Any,
    ) -> List[Dict]:
        """Return the raw data from the Seeking Alpha endpoint."""
        results: List[Dict] = []
        dates = [
            date.strftime("%Y-%m-%d")
            for date in date_range(query.start_date, query.end_date)
        ]
        currency = "USD" if query.country == "us" else "CAD"
        messages: List = []

        async def get_date(date, currency):
            """Get date for one date."""
            url = (
                f"https://seekingalpha.com/api/v3/earnings_calendar/tickers?"
                f"filter%5Bselected_date%5D={date}"
                f"&filter%5Bwith_rating%5D=false&filter%5Bcurrency%5D={currency}"
            )
            response = await amake_request(url=url, headers=HEADERS)
            # Try again if the response is blocked.
            if "blockScript" in response:
                response = await amake_request(url=url, headers=HEADERS)
                if "blockScript" in response:
                    message = json.dumps(response)
                    messages.append(message)
                    warn(message)
            if "data" in response:
                results.extend(response.get("data"))

        await asyncio.gather(*[get_date(date, currency) for date in dates])

        if not results:
            raise RuntimeError(f"Error with the Seeking Alpha request -> {messages}")

        return results

    @staticmethod
    def transform_data(
        query: SACalendarEarningsQueryParams,
        data: List[Dict],
        **kwargs: Any,
    ) -> List[SACalendarEarningsData]:
        """Transform the data to the standard format."""
        transformed_data: List[SACalendarEarningsData] = []
        for row in sorted(data, key=lambda x: x["attributes"]["release_date"]):
            attributes = row.get("attributes", {})
            transformed_data.append(
                SACalendarEarningsData.model_validate(
                    {
                        "report_date": attributes.get("release_date"),
                        "reporting_time": attributes.get("release_time"),
                        "symbol": attributes.get("slug"),
                        "name": attributes.get("name"),
                        "market_cap": attributes.get("marketcap"),
                        "exchange": attributes.get("exchange"),
                        "sector_id": attributes.get("sector_id"),
                    }
                )
            )
        return transformed_data