diff options
author | montezdesousa <79287829+montezdesousa@users.noreply.github.com> | 2024-03-06 19:30:36 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-06 19:30:36 +0000 |
commit | a8122e9c037a1ded402adc2fb80cf160fa45688a (patch) | |
tree | 74d0675a20ea438adb9d7a7f9c4cae33e9df286c | |
parent | 76556dfd28137c6ed0fd25ac883d8359edef7c37 (diff) |
[Bug fix] - Handle multiple items with arbitrary type (#6171)
* handle multiple items with arbitrary type
* minor fix
* ruff
* inequality
* integration tests
* test
* pylint
* fix tests
* fix category
* ruff
12 files changed, 118 insertions, 46 deletions
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index abfad97eda6..924baa69347 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -571,7 +571,9 @@ class MethodDefinition: type_ = MethodDefinition.get_type(field) default = MethodDefinition.get_default(field) extra = MethodDefinition.get_extra(field) - new_type = MethodDefinition.get_expanded_type(field_name, extra) + new_type = MethodDefinition.get_expanded_type( + field_name, extra, type_ + ) updated_type = type_ if new_type is ... else Union[type_, new_type] formatted[field_name] = Parameter( @@ -782,10 +784,19 @@ class MethodDefinition: return code @classmethod - def get_expanded_type(cls, field_name: str, extra: Optional[dict] = None) -> object: + def get_expanded_type( + cls, + field_name: str, + extra: Optional[dict] = None, + original_type: Optional[type] = None, + ) -> object: """Expand the original field type.""" if extra and "multiple_items_allowed" in extra: - return List[str] + if original_type is None: + raise ValueError( + "multiple_items_allowed requires the original type to be specified." + ) + return List[original_type] return cls.TYPE_EXPANSION.get(field_name, ...) @classmethod diff --git a/openbb_platform/core/openbb_core/app/static/utils/decorators.py b/openbb_platform/core/openbb_core/app/static/utils/decorators.py index febb23cd3b5..18eadeffa69 100644 --- a/openbb_platform/core/openbb_core/app/static/utils/decorators.py +++ b/openbb_platform/core/openbb_core/app/static/utils/decorators.py @@ -80,8 +80,9 @@ def exception_handler(func: Callable[P, R]) -> Callable[P, R]: ).with_traceback(tb) from None # If the error is not a ValidationError, then it is a generic exception + error_type = getattr(e, "original", e).__class__.__name__ raise OpenBBError( - f"\nType -> {e.original.__class__.__name__}\n\nDetail -> {str(e)}" + f"\nType -> {error_type}\n\nDetail -> {str(e)}" ).with_traceback(tb) from None return wrapper diff --git a/openbb_platform/core/openbb_core/app/static/utils/filters.py b/openbb_platform/core/openbb_core/app/static/utils/filters.py index aaefb9a040f..36e9d3164db 100644 --- a/openbb_platform/core/openbb_core/app/static/utils/filters.py +++ b/openbb_platform/core/openbb_core/app/static/utils/filters.py @@ -29,7 +29,9 @@ def filter_inputs( if field in kwargs.get(p, {}): current = kwargs[p][field] new = ( - ",".join(current) if isinstance(current, list) else current + ",".join(map(str, current)) + if isinstance(current, list) + else current ) if provider and provider not in props[PROPERTY]: diff --git a/openbb_platform/core/openbb_core/provider/standard_models/spot.py b/openbb_platform/core/openbb_core/provider/standard_models/spot.py index f976fe9cdcc..4717a677bd7 100644 --- a/openbb_platform/core/openbb_core/provider/standard_models/spot.py +++ b/openbb_platform/core/openbb_core/provider/standard_models/spot.py @@ -3,9 +3,9 @@ from datetime import ( date as dateType, ) -from typing import List, Literal, Optional +from typing import Optional, Union -from pydantic import Field, field_validator +from pydantic import Field from openbb_core.provider.abstract.data import Data from openbb_core.provider.abstract.query_params import QueryParams @@ -26,25 +26,15 @@ class SpotRateQueryParams(QueryParams): default=None, description=QUERY_DESCRIPTIONS.get("end_date", ""), ) - maturity: List[float] = Field( - default=[10.0], description="The maturities in years." + maturity: Union[float, str] = Field( + default=10.0, description="Maturities in years." ) - category: List[Literal["par_yield", "spot_rate"]] = Field( - default=["spot_rate"], - description="The category.", + category: str = Field( + default="spot_rate", + description="Rate category. Options: spot_rate, par_yield.", + choices=["par_yield", "spot_rate"], ) - @field_validator("maturity") - @classmethod - def maturity_validate(cls, v): - """Validate maturity.""" - for i in v: - if not isinstance(i, float): - raise ValueError("`maturity` must be a float") - if not 1 <= i <= 100: - raise ValueError("`maturity` must be between 1 and 100") - return v - class SpotRateData(Data): """Spot Rate Data.""" diff --git a/openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py b/openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py index 87539ad0eef..95b759beaac 100644 --- a/openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py +++ b/openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py @@ -341,10 +341,26 @@ def test_fixedincome_corporate_commercial_paper(params, headers): "start_date": "2023-01-01", "end_date": "2023-06-06", "maturity": [10.0], - "category": ["spot_rate"], + "category": "spot_rate", "provider": "fred", } - ) + ), + ( + { + "start_date": None, + "end_date": None, + "maturity": 5.5, + "category": ["spot_rate"], + } + ), + ( + { + "start_date": None, + "end_date": None, + "maturity": "1,5.5,10", + "category": "spot_rate,par_yield", + } + ), ], ) @pytest.mark.integration diff --git a/openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py b/openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py index df53341fb85..40c73bf077d 100644 --- a/openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py +++ b/openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py @@ -312,10 +312,26 @@ def test_fixedincome_corporate_commercial_paper(params, obb): "start_date": "2023-01-01", "end_date": "2023-06-06", "maturity": [10.0], - "category": ["spot_rate"], + "category": "spot_rate", "provider": "fred", } - ) + ), + ( + { + "start_date": None, + "end_date": None, + "maturity": 5.5, + "category": ["spot_rate"], + } + ), + ( + { + "start_date": None, + "end_date": None, + "maturity": "1,5.5,10", + "category": "spot_rate,par_yield", + } + ), ], ) @pytest.mark.integration diff --git a/openbb_platform/extensions/tests/utils/helpers.py b/openbb_platform/extensions/tests/utils/helpers.py index 264bdb4958d..ad15ae75fb5 100644 --- a/openbb_platform/extensions/tests/utils/helpers.py +++ b/openbb_platform/extensions/tests/utils/helpers.py @@ -89,13 +89,13 @@ def list_openbb_extensions() -> Tuple[Set[str], Set[str], Set[str]]: obbject_extensions = set() entry_points_dict = entry_points() - for entry_point in entry_points_dict["openbb_core_extension"]: + for entry_point in entry_points_dict.get("openbb_core_extension", []): core_extensions.add(f"{entry_point.name}") - for entry_point in entry_points_dict["openbb_provider_extension"]: + for entry_point in entry_points_dict.get("openbb_provider_extension", []): provider_extensions.add(f"{entry_point.name}") - for entry_point in entry_points_dict["openbb_obbject_extension"]: + for entry_point in entry_points_dict.get("openbb_obbject_extension", []): obbject_extensions.add(f"{entry_point.name}") return core_extensions, provider_extensions, obbject_extensions diff --git a/openbb_platform/openbb/package/equity_estimates.py b/openbb_platform/openbb/package/equity_estimates.py index 1880f70d440..a1d21d83b4d 100644 --- a/openbb_platform/openbb/package/equity_estimates.py +++ b/openbb_platform/openbb/package/equity_estimates.py @@ -401,7 +401,7 @@ class ROUTER_equity_estimates(Container): def price_target( self, symbol: Annotated[ - Union[str, None, List[str]], + Union[str, None, List[Optional[str]]], OpenBBCustomParameter( description="Symbol to get data for. Multiple items allowed for provider(s): benzinga." ), @@ -417,7 +417,7 @@ class ROUTER_equity_estimates(Container): Parameters ---------- - symbol : Union[str, None, List[str]] + symbol : Union[str, None, List[Optional[str]]] Symbol to get data for. Multiple items allowed for provider(s): benzinga. limit : int The number of data entries to return. diff --git a/openbb_platform/openbb/package/fixedincome_corporate.py b/openbb_platform/openbb/package/fixedincome_corporate.py index 6d8d8778a8d..b651e91c163 100644 --- a/openbb_platform/openbb/package/fixedincome_corporate.py +++ b/openbb_platform/openbb/package/fixedincome_corporate.py @@ -419,12 +419,17 @@ class ROUTER_fixedincome_corporate(Container): ), ] = None, maturity: Annotated[ - List[float], OpenBBCustomParameter(description="The maturities in years.") - ] = [10.0], + Union[float, str, List[Union[float, str]]], + OpenBBCustomParameter( + description="Maturities in years. Multiple items allowed for provider(s): fred." + ), + ] = 10.0, category: Annotated[ - List[Literal["par_yield", "spot_rate"]], - OpenBBCustomParameter(description="The category."), - ] = ["spot_rate"], + Union[str, List[str]], + OpenBBCustomParameter( + description="Rate category. Options: spot_rate, par_yield. Multiple items allowed for provider(s): fred." + ), + ] = "spot_rate", provider: Optional[Literal["fred"]] = None, **kwargs ) -> OBBject: @@ -442,10 +447,10 @@ class ROUTER_fixedincome_corporate(Container): Start date of the data, in YYYY-MM-DD format. end_date : Union[datetime.date, None, str] End date of the data, in YYYY-MM-DD format. - maturity : List[float] - The maturities in years. - category : List[Literal['par_yield', 'spot_rate']] - The category. + maturity : Union[float, str, List[Union[float, str]]] + Maturities in years. Multiple items allowed for provider(s): fred. + category : Union[str, List[str]] + Rate category. Options: spot_rate, par_yield. Multiple items allowed for provider(s): fred. provider : Optional[Literal['fred']] The provider to use for the query, by default None. If None, the provider specified in defaults is selected or 'fred' if there is @@ -495,5 +500,12 @@ class ROUTER_fixedincome_corporate(Container): "category": category, }, extra_params=kwargs, + extra_info={ + "maturity": {"multiple_items_allowed": ["fred"]}, + "category": { + "choices": ["par_yield", "spot_rate"], + "multiple_items_allowed": ["fred"], + }, + }, ) ) diff --git a/openbb_platform/openbb/package/news.py b/openbb_platform/openbb/package/news.py index fcb394cb2ca..0f1f140b580 100644 --- a/openbb_platform/openbb/package/news.py +++ b/openbb_platform/openbb/package/news.py @@ -26,7 +26,7 @@ class ROUTER_news(Container): def company( self, symbol: Annotated[ - Union[str, None, List[str]], + Union[str, None, List[Optional[str]]], OpenBBCustomParameter( description="Symbol to get data for. This endpoint will accept multiple symbols separated by commas. Multiple items allowed for provider(s): benzinga, fmp, intrinio, polygon, tiingo, yfinance." ), @@ -56,7 +56,7 @@ class ROUTER_news(Container): Parameters ---------- - symbol : Union[str, None, List[str]] + symbol : Union[str, None, List[Optional[str]]] Symbol to get data for. This endpoint will accept multiple symbols separated by commas. Multiple items allowed for provider(s): benzinga, fmp, intrinio, polygon, tiingo, yfinance. start_date : Union[datetime.date, None, str] Start date of the data, in YYYY-MM-DD format. diff --git a/openbb_platform/providers/fred/openbb_fred/models/spot.py b/openbb_platform/providers/fred/openbb_fred/models/spot.py index 2dae3a230ec..4db9cebe4e0 100644 --- a/openbb_platform/providers/fred/openbb_fred/models/spot.py +++ b/openbb_platform/providers/fred/openbb_fred/models/spot.py @@ -8,13 +8,18 @@ from openbb_core.provider.standard_models.spot import ( SpotRateQueryParams, ) from openbb_fred.utils.fred_base import Fred -from openbb_fred.utils.fred_helpers import get_spot_series_id +from openbb_fred.utils.fred_helpers import comma_to_float_list, get_spot_series_id from pydantic import field_validator class FREDSpotRateQueryParams(SpotRateQueryParams): """FRED Spot Rate Query.""" + __json_schema_extra__ = { + "maturity": ["multiple_items_allowed"], + "category": ["multiple_items_allowed"], + } + class FREDSpotRateData(SpotRateData): """FRED Spot Rate Data.""" @@ -56,9 +61,17 @@ class FREDSpotRateFetcher( key = credentials.get("fred_api_key") if credentials else "" fred = Fred(key) + maturity = ( + comma_to_float_list(query.maturity) + if isinstance(query.maturity, str) + else [query.maturity] + ) + if any(1 > m > 100 for m in maturity): + raise ValueError("Maturity must be between 1 and 100") + series = get_spot_series_id( - maturity=query.maturity, - category=query.category, + maturity=maturity, + category=query.category.split(","), ) data = [] diff --git a/openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py b/openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py index 568017249ef..b61cc26625e 100644 --- a/openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py +++ b/openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py @@ -51,6 +51,16 @@ YIELD_CURVE_SERIES_CORPORATE_PAR = { } +def comma_to_float_list(v: str) -> List[float]: + """Convert comma-separated string to list of floats.""" + try: + return [float(m) for m in v.split(",")] + except ValueError as e: + raise ValueError( + "'maturity' must be a float or a comma-separated string of floats" + ) from e + + def all_cpi_options(harmonized: bool = False) -> List[dict]: """Get all CPI options.""" data = [] @@ -136,6 +146,7 @@ def get_ice_bofa_series_id( units = "index" if type_ == "total_return" else "percent" for s in series: + # pylint: disable=too-many-boolean-expressions if ( s["Type"] == type_ and s["Units"] == units |