summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormontezdesousa <79287829+montezdesousa@users.noreply.github.com>2024-03-06 19:30:36 +0000
committerGitHub <noreply@github.com>2024-03-06 19:30:36 +0000
commita8122e9c037a1ded402adc2fb80cf160fa45688a (patch)
tree74d0675a20ea438adb9d7a7f9c4cae33e9df286c
parent76556dfd28137c6ed0fd25ac883d8359edef7c37 (diff)
[Bug fix] - Handle multiple items with arbitrary type (#6171)
* handle multiple items with arbitrary type * minor fix * ruff * inequality * integration tests * test * pylint * fix tests * fix category * ruff
-rw-r--r--openbb_platform/core/openbb_core/app/static/package_builder.py17
-rw-r--r--openbb_platform/core/openbb_core/app/static/utils/decorators.py3
-rw-r--r--openbb_platform/core/openbb_core/app/static/utils/filters.py4
-rw-r--r--openbb_platform/core/openbb_core/provider/standard_models/spot.py26
-rw-r--r--openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py20
-rw-r--r--openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py20
-rw-r--r--openbb_platform/extensions/tests/utils/helpers.py6
-rw-r--r--openbb_platform/openbb/package/equity_estimates.py4
-rw-r--r--openbb_platform/openbb/package/fixedincome_corporate.py30
-rw-r--r--openbb_platform/openbb/package/news.py4
-rw-r--r--openbb_platform/providers/fred/openbb_fred/models/spot.py19
-rw-r--r--openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py11
12 files changed, 118 insertions, 46 deletions
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py
index abfad97eda6..924baa69347 100644
--- a/openbb_platform/core/openbb_core/app/static/package_builder.py
+++ b/openbb_platform/core/openbb_core/app/static/package_builder.py
@@ -571,7 +571,9 @@ class MethodDefinition:
type_ = MethodDefinition.get_type(field)
default = MethodDefinition.get_default(field)
extra = MethodDefinition.get_extra(field)
- new_type = MethodDefinition.get_expanded_type(field_name, extra)
+ new_type = MethodDefinition.get_expanded_type(
+ field_name, extra, type_
+ )
updated_type = type_ if new_type is ... else Union[type_, new_type]
formatted[field_name] = Parameter(
@@ -782,10 +784,19 @@ class MethodDefinition:
return code
@classmethod
- def get_expanded_type(cls, field_name: str, extra: Optional[dict] = None) -> object:
+ def get_expanded_type(
+ cls,
+ field_name: str,
+ extra: Optional[dict] = None,
+ original_type: Optional[type] = None,
+ ) -> object:
"""Expand the original field type."""
if extra and "multiple_items_allowed" in extra:
- return List[str]
+ if original_type is None:
+ raise ValueError(
+ "multiple_items_allowed requires the original type to be specified."
+ )
+ return List[original_type]
return cls.TYPE_EXPANSION.get(field_name, ...)
@classmethod
diff --git a/openbb_platform/core/openbb_core/app/static/utils/decorators.py b/openbb_platform/core/openbb_core/app/static/utils/decorators.py
index febb23cd3b5..18eadeffa69 100644
--- a/openbb_platform/core/openbb_core/app/static/utils/decorators.py
+++ b/openbb_platform/core/openbb_core/app/static/utils/decorators.py
@@ -80,8 +80,9 @@ def exception_handler(func: Callable[P, R]) -> Callable[P, R]:
).with_traceback(tb) from None
# If the error is not a ValidationError, then it is a generic exception
+ error_type = getattr(e, "original", e).__class__.__name__
raise OpenBBError(
- f"\nType -> {e.original.__class__.__name__}\n\nDetail -> {str(e)}"
+ f"\nType -> {error_type}\n\nDetail -> {str(e)}"
).with_traceback(tb) from None
return wrapper
diff --git a/openbb_platform/core/openbb_core/app/static/utils/filters.py b/openbb_platform/core/openbb_core/app/static/utils/filters.py
index aaefb9a040f..36e9d3164db 100644
--- a/openbb_platform/core/openbb_core/app/static/utils/filters.py
+++ b/openbb_platform/core/openbb_core/app/static/utils/filters.py
@@ -29,7 +29,9 @@ def filter_inputs(
if field in kwargs.get(p, {}):
current = kwargs[p][field]
new = (
- ",".join(current) if isinstance(current, list) else current
+ ",".join(map(str, current))
+ if isinstance(current, list)
+ else current
)
if provider and provider not in props[PROPERTY]:
diff --git a/openbb_platform/core/openbb_core/provider/standard_models/spot.py b/openbb_platform/core/openbb_core/provider/standard_models/spot.py
index f976fe9cdcc..4717a677bd7 100644
--- a/openbb_platform/core/openbb_core/provider/standard_models/spot.py
+++ b/openbb_platform/core/openbb_core/provider/standard_models/spot.py
@@ -3,9 +3,9 @@
from datetime import (
date as dateType,
)
-from typing import List, Literal, Optional
+from typing import Optional, Union
-from pydantic import Field, field_validator
+from pydantic import Field
from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.query_params import QueryParams
@@ -26,25 +26,15 @@ class SpotRateQueryParams(QueryParams):
default=None,
description=QUERY_DESCRIPTIONS.get("end_date", ""),
)
- maturity: List[float] = Field(
- default=[10.0], description="The maturities in years."
+ maturity: Union[float, str] = Field(
+ default=10.0, description="Maturities in years."
)
- category: List[Literal["par_yield", "spot_rate"]] = Field(
- default=["spot_rate"],
- description="The category.",
+ category: str = Field(
+ default="spot_rate",
+ description="Rate category. Options: spot_rate, par_yield.",
+ choices=["par_yield", "spot_rate"],
)
- @field_validator("maturity")
- @classmethod
- def maturity_validate(cls, v):
- """Validate maturity."""
- for i in v:
- if not isinstance(i, float):
- raise ValueError("`maturity` must be a float")
- if not 1 <= i <= 100:
- raise ValueError("`maturity` must be between 1 and 100")
- return v
-
class SpotRateData(Data):
"""Spot Rate Data."""
diff --git a/openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py b/openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py
index 87539ad0eef..95b759beaac 100644
--- a/openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py
+++ b/openbb_platform/extensions/fixedincome/integration/test_fixedincome_api.py
@@ -341,10 +341,26 @@ def test_fixedincome_corporate_commercial_paper(params, headers):
"start_date": "2023-01-01",
"end_date": "2023-06-06",
"maturity": [10.0],
- "category": ["spot_rate"],
+ "category": "spot_rate",
"provider": "fred",
}
- )
+ ),
+ (
+ {
+ "start_date": None,
+ "end_date": None,
+ "maturity": 5.5,
+ "category": ["spot_rate"],
+ }
+ ),
+ (
+ {
+ "start_date": None,
+ "end_date": None,
+ "maturity": "1,5.5,10",
+ "category": "spot_rate,par_yield",
+ }
+ ),
],
)
@pytest.mark.integration
diff --git a/openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py b/openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py
index df53341fb85..40c73bf077d 100644
--- a/openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py
+++ b/openbb_platform/extensions/fixedincome/integration/test_fixedincome_python.py
@@ -312,10 +312,26 @@ def test_fixedincome_corporate_commercial_paper(params, obb):
"start_date": "2023-01-01",
"end_date": "2023-06-06",
"maturity": [10.0],
- "category": ["spot_rate"],
+ "category": "spot_rate",
"provider": "fred",
}
- )
+ ),
+ (
+ {
+ "start_date": None,
+ "end_date": None,
+ "maturity": 5.5,
+ "category": ["spot_rate"],
+ }
+ ),
+ (
+ {
+ "start_date": None,
+ "end_date": None,
+ "maturity": "1,5.5,10",
+ "category": "spot_rate,par_yield",
+ }
+ ),
],
)
@pytest.mark.integration
diff --git a/openbb_platform/extensions/tests/utils/helpers.py b/openbb_platform/extensions/tests/utils/helpers.py
index 264bdb4958d..ad15ae75fb5 100644
--- a/openbb_platform/extensions/tests/utils/helpers.py
+++ b/openbb_platform/extensions/tests/utils/helpers.py
@@ -89,13 +89,13 @@ def list_openbb_extensions() -> Tuple[Set[str], Set[str], Set[str]]:
obbject_extensions = set()
entry_points_dict = entry_points()
- for entry_point in entry_points_dict["openbb_core_extension"]:
+ for entry_point in entry_points_dict.get("openbb_core_extension", []):
core_extensions.add(f"{entry_point.name}")
- for entry_point in entry_points_dict["openbb_provider_extension"]:
+ for entry_point in entry_points_dict.get("openbb_provider_extension", []):
provider_extensions.add(f"{entry_point.name}")
- for entry_point in entry_points_dict["openbb_obbject_extension"]:
+ for entry_point in entry_points_dict.get("openbb_obbject_extension", []):
obbject_extensions.add(f"{entry_point.name}")
return core_extensions, provider_extensions, obbject_extensions
diff --git a/openbb_platform/openbb/package/equity_estimates.py b/openbb_platform/openbb/package/equity_estimates.py
index 1880f70d440..a1d21d83b4d 100644
--- a/openbb_platform/openbb/package/equity_estimates.py
+++ b/openbb_platform/openbb/package/equity_estimates.py
@@ -401,7 +401,7 @@ class ROUTER_equity_estimates(Container):
def price_target(
self,
symbol: Annotated[
- Union[str, None, List[str]],
+ Union[str, None, List[Optional[str]]],
OpenBBCustomParameter(
description="Symbol to get data for. Multiple items allowed for provider(s): benzinga."
),
@@ -417,7 +417,7 @@ class ROUTER_equity_estimates(Container):
Parameters
----------
- symbol : Union[str, None, List[str]]
+ symbol : Union[str, None, List[Optional[str]]]
Symbol to get data for. Multiple items allowed for provider(s): benzinga.
limit : int
The number of data entries to return.
diff --git a/openbb_platform/openbb/package/fixedincome_corporate.py b/openbb_platform/openbb/package/fixedincome_corporate.py
index 6d8d8778a8d..b651e91c163 100644
--- a/openbb_platform/openbb/package/fixedincome_corporate.py
+++ b/openbb_platform/openbb/package/fixedincome_corporate.py
@@ -419,12 +419,17 @@ class ROUTER_fixedincome_corporate(Container):
),
] = None,
maturity: Annotated[
- List[float], OpenBBCustomParameter(description="The maturities in years.")
- ] = [10.0],
+ Union[float, str, List[Union[float, str]]],
+ OpenBBCustomParameter(
+ description="Maturities in years. Multiple items allowed for provider(s): fred."
+ ),
+ ] = 10.0,
category: Annotated[
- List[Literal["par_yield", "spot_rate"]],
- OpenBBCustomParameter(description="The category."),
- ] = ["spot_rate"],
+ Union[str, List[str]],
+ OpenBBCustomParameter(
+ description="Rate category. Options: spot_rate, par_yield. Multiple items allowed for provider(s): fred."
+ ),
+ ] = "spot_rate",
provider: Optional[Literal["fred"]] = None,
**kwargs
) -> OBBject:
@@ -442,10 +447,10 @@ class ROUTER_fixedincome_corporate(Container):
Start date of the data, in YYYY-MM-DD format.
end_date : Union[datetime.date, None, str]
End date of the data, in YYYY-MM-DD format.
- maturity : List[float]
- The maturities in years.
- category : List[Literal['par_yield', 'spot_rate']]
- The category.
+ maturity : Union[float, str, List[Union[float, str]]]
+ Maturities in years. Multiple items allowed for provider(s): fred.
+ category : Union[str, List[str]]
+ Rate category. Options: spot_rate, par_yield. Multiple items allowed for provider(s): fred.
provider : Optional[Literal['fred']]
The provider to use for the query, by default None.
If None, the provider specified in defaults is selected or 'fred' if there is
@@ -495,5 +500,12 @@ class ROUTER_fixedincome_corporate(Container):
"category": category,
},
extra_params=kwargs,
+ extra_info={
+ "maturity": {"multiple_items_allowed": ["fred"]},
+ "category": {
+ "choices": ["par_yield", "spot_rate"],
+ "multiple_items_allowed": ["fred"],
+ },
+ },
)
)
diff --git a/openbb_platform/openbb/package/news.py b/openbb_platform/openbb/package/news.py
index fcb394cb2ca..0f1f140b580 100644
--- a/openbb_platform/openbb/package/news.py
+++ b/openbb_platform/openbb/package/news.py
@@ -26,7 +26,7 @@ class ROUTER_news(Container):
def company(
self,
symbol: Annotated[
- Union[str, None, List[str]],
+ Union[str, None, List[Optional[str]]],
OpenBBCustomParameter(
description="Symbol to get data for. This endpoint will accept multiple symbols separated by commas. Multiple items allowed for provider(s): benzinga, fmp, intrinio, polygon, tiingo, yfinance."
),
@@ -56,7 +56,7 @@ class ROUTER_news(Container):
Parameters
----------
- symbol : Union[str, None, List[str]]
+ symbol : Union[str, None, List[Optional[str]]]
Symbol to get data for. This endpoint will accept multiple symbols separated by commas. Multiple items allowed for provider(s): benzinga, fmp, intrinio, polygon, tiingo, yfinance.
start_date : Union[datetime.date, None, str]
Start date of the data, in YYYY-MM-DD format.
diff --git a/openbb_platform/providers/fred/openbb_fred/models/spot.py b/openbb_platform/providers/fred/openbb_fred/models/spot.py
index 2dae3a230ec..4db9cebe4e0 100644
--- a/openbb_platform/providers/fred/openbb_fred/models/spot.py
+++ b/openbb_platform/providers/fred/openbb_fred/models/spot.py
@@ -8,13 +8,18 @@ from openbb_core.provider.standard_models.spot import (
SpotRateQueryParams,
)
from openbb_fred.utils.fred_base import Fred
-from openbb_fred.utils.fred_helpers import get_spot_series_id
+from openbb_fred.utils.fred_helpers import comma_to_float_list, get_spot_series_id
from pydantic import field_validator
class FREDSpotRateQueryParams(SpotRateQueryParams):
"""FRED Spot Rate Query."""
+ __json_schema_extra__ = {
+ "maturity": ["multiple_items_allowed"],
+ "category": ["multiple_items_allowed"],
+ }
+
class FREDSpotRateData(SpotRateData):
"""FRED Spot Rate Data."""
@@ -56,9 +61,17 @@ class FREDSpotRateFetcher(
key = credentials.get("fred_api_key") if credentials else ""
fred = Fred(key)
+ maturity = (
+ comma_to_float_list(query.maturity)
+ if isinstance(query.maturity, str)
+ else [query.maturity]
+ )
+ if any(1 > m > 100 for m in maturity):
+ raise ValueError("Maturity must be between 1 and 100")
+
series = get_spot_series_id(
- maturity=query.maturity,
- category=query.category,
+ maturity=maturity,
+ category=query.category.split(","),
)
data = []
diff --git a/openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py b/openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py
index 568017249ef..b61cc26625e 100644
--- a/openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py
+++ b/openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py
@@ -51,6 +51,16 @@ YIELD_CURVE_SERIES_CORPORATE_PAR = {
}
+def comma_to_float_list(v: str) -> List[float]:
+ """Convert comma-separated string to list of floats."""
+ try:
+ return [float(m) for m in v.split(",")]
+ except ValueError as e:
+ raise ValueError(
+ "'maturity' must be a float or a comma-separated string of floats"
+ ) from e
+
+
def all_cpi_options(harmonized: bool = False) -> List[dict]:
"""Get all CPI options."""
data = []
@@ -136,6 +146,7 @@ def get_ice_bofa_series_id(
units = "index" if type_ == "total_return" else "percent"
for s in series:
+ # pylint: disable=too-many-boolean-expressions
if (
s["Type"] == type_
and s["Units"] == units