summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormontezdesousa <79287829+montezdesousa@users.noreply.github.com>2024-02-09 15:04:35 +0000
committerGitHub <noreply@github.com>2024-02-09 15:04:35 +0000
commit3fae9c20b2fa3f229ef8d296a724b39fc607c618 (patch)
tree476bf98e4c0c1b9fe0aa80be36758e4cab8e16cb
parentf08b0a6082ec5f63c696c865300191e0585fd8a5 (diff)
[Bugfix] - Fix economic calendar country (#6059)
* fix economic calendar country * lint --------- Co-authored-by: Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com>
-rw-r--r--openbb_platform/core/openbb_core/provider/standard_models/cpi.py2
-rw-r--r--openbb_platform/openbb/package/economy.py6
-rw-r--r--openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py38
-rw-r--r--openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py8
-rw-r--r--openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py2
5 files changed, 25 insertions, 31 deletions
diff --git a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py
index 24a94721864..cb4d7afbdde 100644
--- a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py
+++ b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py
@@ -106,7 +106,7 @@ class ConsumerPriceIndexQueryParams(QueryParams):
def validate_country(cls, c: str): # pylint: disable=E0213
"""Validate country."""
result = []
- values = c.split(",")
+ values = c.replace(" ", "_").split(",")
for v in values:
check_item(v.lower(), CPI_COUNTRIES)
result.append(v.lower())
diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py
index 2a708f8a9d7..456f02fd9a8 100644
--- a/openbb_platform/openbb/package/economy.py
+++ b/openbb_platform/openbb/package/economy.py
@@ -59,8 +59,8 @@ class ROUTER_economy(Container):
The provider to use for the query, by default None.
If None, the provider specified in defaults is selected or 'fmp' if there is
no default.
- country : Optional[Union[str, List[str]]]
- Country of the event (provider: tradingeconomics)
+ country : Optional[str]
+ Country of the event. (provider: tradingeconomics)
importance : Optional[Literal['Low', 'Medium', 'High']]
Importance of the event. (provider: tradingeconomics)
group : Optional[Literal['interest rate', 'inflation', 'bonds', 'consumer', 'gdp', 'government', 'housing', 'labour', 'markets', 'money', 'prices', 'trade', 'business']]
@@ -69,7 +69,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : List[EconomicCalendar]
+ results : EconomicCalendar
Serializable results.
provider : Optional[Literal['fmp', 'tradingeconomics']]
Provider name.
diff --git a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py
index 3991e7e76bc..df4a823da0b 100644
--- a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py
+++ b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py
@@ -1,16 +1,16 @@
"""Trading Economics Economic Calendar Model."""
from datetime import datetime
-from typing import Any, Dict, List, Literal, Optional, Set, Union
+from typing import Any, Dict, List, Literal, Optional, Union
from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.economic_calendar import (
EconomicCalendarData,
EconomicCalendarQueryParams,
)
-from openbb_core.provider.utils.helpers import ClientResponse, amake_request
+from openbb_core.provider.utils.helpers import ClientResponse, amake_request, check_item
from openbb_tradingeconomics.utils import url_generator
-from openbb_tradingeconomics.utils.countries import country_list
+from openbb_tradingeconomics.utils.countries import COUNTRIES
from pandas import to_datetime
from pydantic import Field, field_validator
@@ -40,9 +40,7 @@ class TEEconomicCalendarQueryParams(EconomicCalendarQueryParams):
"""
# TODO: Probably want to figure out the list we can use.
- country: Optional[Union[str, List[str]]] = Field(
- default=None, description="Country of the event"
- )
+ country: Optional[str] = Field(default=None, description="Country of the event.")
importance: Optional[IMPORTANCE] = Field(
default=None, description="Importance of the event."
)
@@ -50,11 +48,14 @@ class TEEconomicCalendarQueryParams(EconomicCalendarQueryParams):
@field_validator("country", mode="before", check_fields=False)
@classmethod
- def validate_country(cls, v: Union[str, List[str], Set[str]]):
- """Validate the country input."""
- if isinstance(v, str):
- return v.lower().replace(" ", "_")
- return ",".join([country.lower().replace(" ", "_") for country in list(v)])
+ def validate_country(cls, c: str): # pylint: disable=E0213
+ """Validate country."""
+ result = []
+ values = c.replace(" ", "_").split(",")
+ for v in values:
+ check_item(v.lower(), COUNTRIES)
+ result.append(v.lower())
+ return ",".join(result)
@field_validator("importance")
@classmethod
@@ -111,21 +112,10 @@ class TEEconomicCalendarFetcher(
query: TEEconomicCalendarQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
- ) -> List[Dict]:
+ ) -> Union[dict, List[dict]]:
"""Return the raw data from the TE endpoint."""
api_key = credentials.get("tradingeconomics_api_key") if credentials else ""
- if query.country is not None:
- country = (
- query.country.split(",") if "," in query.country else query.country
- )
- country = [country] if isinstance(country, str) else country
-
- for c in country:
- if c.replace("_", " ").lower() not in country_list:
- raise ValueError(f"{c} is not a valid country")
- query.country = country
-
url = url_generator.generate_url(query)
if not url:
raise RuntimeError(
@@ -133,7 +123,7 @@ class TEEconomicCalendarFetcher(
)
url = f"{url}{api_key}"
- async def callback(response: ClientResponse, _: Any) -> List[Dict]:
+ async def callback(response: ClientResponse, _: Any) -> Union[dict, List[dict]]:
"""Return the response."""
if response.status != 200:
raise RuntimeError(f"Error in TE request -> {await response.text()}")
diff --git a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py
index 56efb1e2626..120774e68c4 100644
--- a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py
+++ b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py
@@ -223,6 +223,10 @@ country_dict = {
],
}
-country_list = list(
- set([item.lower() for sublist in country_dict.values() for item in sublist])
+COUNTRIES = list(
+ {
+ item.lower().replace(" ", "_")
+ for sublist in country_dict.values()
+ for item in sublist
+ }
)
diff --git a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py
index 354bcf346fb..0bcb2a46005 100644
--- a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py
+++ b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py
@@ -35,7 +35,7 @@ def generate_url(in_query):
# Handle the formatting for the api
if "country" in query:
- country = quote(",".join(query["country"]).replace("_", " "))
+ country = quote(query["country"].replace("_", " "))
if "group" in query:
group = quote(query["group"])