From 3fae9c20b2fa3f229ef8d296a724b39fc607c618 Mon Sep 17 00:00:00 2001 From: montezdesousa <79287829+montezdesousa@users.noreply.github.com> Date: Fri, 9 Feb 2024 15:04:35 +0000 Subject: [Bugfix] - Fix economic calendar country (#6059) * fix economic calendar country * lint --------- Co-authored-by: Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com> --- .../openbb_core/provider/standard_models/cpi.py | 2 +- openbb_platform/openbb/package/economy.py | 6 ++-- .../models/economic_calendar.py | 38 ++++++++-------------- .../openbb_tradingeconomics/utils/countries.py | 8 +++-- .../openbb_tradingeconomics/utils/url_generator.py | 2 +- 5 files changed, 25 insertions(+), 31 deletions(-) diff --git a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py index 24a94721864..cb4d7afbdde 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/cpi.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/cpi.py @@ -106,7 +106,7 @@ class ConsumerPriceIndexQueryParams(QueryParams): def validate_country(cls, c: str): # pylint: disable=E0213 """Validate country.""" result = [] - values = c.split(",") + values = c.replace(" ", "_").split(",") for v in values: check_item(v.lower(), CPI_COUNTRIES) result.append(v.lower()) diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py index 2a708f8a9d7..456f02fd9a8 100644 --- a/openbb_platform/openbb/package/economy.py +++ b/openbb_platform/openbb/package/economy.py @@ -59,8 +59,8 @@ class ROUTER_economy(Container): The provider to use for the query, by default None. If None, the provider specified in defaults is selected or 'fmp' if there is no default. - country : Optional[Union[str, List[str]]] - Country of the event (provider: tradingeconomics) + country : Optional[str] + Country of the event. (provider: tradingeconomics) importance : Optional[Literal['Low', 'Medium', 'High']] Importance of the event. (provider: tradingeconomics) group : Optional[Literal['interest rate', 'inflation', 'bonds', 'consumer', 'gdp', 'government', 'housing', 'labour', 'markets', 'money', 'prices', 'trade', 'business']] @@ -69,7 +69,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : List[EconomicCalendar] + results : EconomicCalendar Serializable results. provider : Optional[Literal['fmp', 'tradingeconomics']] Provider name. diff --git a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py index 3991e7e76bc..df4a823da0b 100644 --- a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py +++ b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/models/economic_calendar.py @@ -1,16 +1,16 @@ """Trading Economics Economic Calendar Model.""" from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Set, Union +from typing import Any, Dict, List, Literal, Optional, Union from openbb_core.provider.abstract.fetcher import Fetcher from openbb_core.provider.standard_models.economic_calendar import ( EconomicCalendarData, EconomicCalendarQueryParams, ) -from openbb_core.provider.utils.helpers import ClientResponse, amake_request +from openbb_core.provider.utils.helpers import ClientResponse, amake_request, check_item from openbb_tradingeconomics.utils import url_generator -from openbb_tradingeconomics.utils.countries import country_list +from openbb_tradingeconomics.utils.countries import COUNTRIES from pandas import to_datetime from pydantic import Field, field_validator @@ -40,9 +40,7 @@ class TEEconomicCalendarQueryParams(EconomicCalendarQueryParams): """ # TODO: Probably want to figure out the list we can use. - country: Optional[Union[str, List[str]]] = Field( - default=None, description="Country of the event" - ) + country: Optional[str] = Field(default=None, description="Country of the event.") importance: Optional[IMPORTANCE] = Field( default=None, description="Importance of the event." ) @@ -50,11 +48,14 @@ class TEEconomicCalendarQueryParams(EconomicCalendarQueryParams): @field_validator("country", mode="before", check_fields=False) @classmethod - def validate_country(cls, v: Union[str, List[str], Set[str]]): - """Validate the country input.""" - if isinstance(v, str): - return v.lower().replace(" ", "_") - return ",".join([country.lower().replace(" ", "_") for country in list(v)]) + def validate_country(cls, c: str): # pylint: disable=E0213 + """Validate country.""" + result = [] + values = c.replace(" ", "_").split(",") + for v in values: + check_item(v.lower(), COUNTRIES) + result.append(v.lower()) + return ",".join(result) @field_validator("importance") @classmethod @@ -111,21 +112,10 @@ class TEEconomicCalendarFetcher( query: TEEconomicCalendarQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any, - ) -> List[Dict]: + ) -> Union[dict, List[dict]]: """Return the raw data from the TE endpoint.""" api_key = credentials.get("tradingeconomics_api_key") if credentials else "" - if query.country is not None: - country = ( - query.country.split(",") if "," in query.country else query.country - ) - country = [country] if isinstance(country, str) else country - - for c in country: - if c.replace("_", " ").lower() not in country_list: - raise ValueError(f"{c} is not a valid country") - query.country = country - url = url_generator.generate_url(query) if not url: raise RuntimeError( @@ -133,7 +123,7 @@ class TEEconomicCalendarFetcher( ) url = f"{url}{api_key}" - async def callback(response: ClientResponse, _: Any) -> List[Dict]: + async def callback(response: ClientResponse, _: Any) -> Union[dict, List[dict]]: """Return the response.""" if response.status != 200: raise RuntimeError(f"Error in TE request -> {await response.text()}") diff --git a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py index 56efb1e2626..120774e68c4 100644 --- a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py +++ b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/countries.py @@ -223,6 +223,10 @@ country_dict = { ], } -country_list = list( - set([item.lower() for sublist in country_dict.values() for item in sublist]) +COUNTRIES = list( + { + item.lower().replace(" ", "_") + for sublist in country_dict.values() + for item in sublist + } ) diff --git a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py index 354bcf346fb..0bcb2a46005 100644 --- a/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py +++ b/openbb_platform/providers/tradingeconomics/openbb_tradingeconomics/utils/url_generator.py @@ -35,7 +35,7 @@ def generate_url(in_query): # Handle the formatting for the api if "country" in query: - country = quote(",".join(query["country"]).replace("_", " ")) + country = quote(query["country"].replace("_", " ")) if "group" in query: group = quote(query["group"]) -- cgit v1.2.3