summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorteh_coderer <me@tehcoderer.com>2024-02-08 04:58:14 -0500
committerGitHub <noreply@github.com>2024-02-08 09:58:14 +0000
commit7c68832379bf9869517114732d2888f5e10c436a (patch)
tree30aec21d2f23bc5d6d696e9427acfebd0cc0a37d
parent7d668b10c87a8d8909459bc7ee0e9f28b981da15 (diff)
improve discriminator logic, fix package return type docs (#6052)
* improve discriminator logic, fix package return type docs * Update registry_map.py * build package * Update registry_map.py * defaults * Update registry_map.py --------- Co-authored-by: Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com>
-rw-r--r--openbb_platform/core/openbb_core/app/model/obbject.py32
-rw-r--r--openbb_platform/core/openbb_core/app/router.py69
-rw-r--r--openbb_platform/core/openbb_core/provider/abstract/data.py14
-rw-r--r--openbb_platform/core/openbb_core/provider/registry_map.py52
-rw-r--r--openbb_platform/openbb/package/crypto.py2
-rw-r--r--openbb_platform/openbb/package/crypto_price.py2
-rw-r--r--openbb_platform/openbb/package/currency.py2
-rw-r--r--openbb_platform/openbb/package/currency_price.py2
-rw-r--r--openbb_platform/openbb/package/derivatives_options.py4
-rw-r--r--openbb_platform/openbb/package/economy.py20
-rw-r--r--openbb_platform/openbb/package/economy_gdp.py6
-rw-r--r--openbb_platform/openbb/package/equity.py8
-rw-r--r--openbb_platform/openbb/package/equity_calendar.py8
-rw-r--r--openbb_platform/openbb/package/equity_compare.py2
-rw-r--r--openbb_platform/openbb/package/equity_discovery.py16
-rw-r--r--openbb_platform/openbb/package/equity_estimates.py6
-rw-r--r--openbb_platform/openbb/package/equity_fundamental.py50
-rw-r--r--openbb_platform/openbb/package/equity_ownership.py8
-rw-r--r--openbb_platform/openbb/package/equity_price.py8
-rw-r--r--openbb_platform/openbb/package/equity_shorts.py2
-rw-r--r--openbb_platform/openbb/package/etf.py18
-rw-r--r--openbb_platform/openbb/package/fixedincome.py2
-rw-r--r--openbb_platform/openbb/package/fixedincome_corporate.py10
-rw-r--r--openbb_platform/openbb/package/fixedincome_government.py4
-rw-r--r--openbb_platform/openbb/package/fixedincome_rate.py16
-rw-r--r--openbb_platform/openbb/package/fixedincome_spreads.py6
-rw-r--r--openbb_platform/openbb/package/index.py6
-rw-r--r--openbb_platform/openbb/package/news.py4
-rw-r--r--openbb_platform/openbb/package/regulators_sec.py12
-rw-r--r--openbb_platform/providers/fmp/openbb_fmp/models/etf_info.py28
30 files changed, 196 insertions, 223 deletions
diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py
index 3f87c6d5e20..30a21b61c9a 100644
--- a/openbb_platform/core/openbb_core/app/model/obbject.py
+++ b/openbb_platform/core/openbb_core/app/model/obbject.py
@@ -73,20 +73,19 @@ class OBBject(Tagged, Generic[T]):
@classmethod
def results_type_repr(cls, params: Optional[Any] = None) -> str:
"""Return the results type name."""
- type_ = params[0] if params else cls.model_fields["results"].annotation
+ results_field = cls.model_fields.get("results")
+ type_ = params[0] if params else results_field.annotation
name = type_.__name__ if hasattr(type_, "__name__") else str(type_)
- if "Annotated" in str(type_):
- annotated_inners = []
- for inner in cls.model_fields["results"].annotation.__args__:
- if hasattr(inner, "__args__") and inner.__args__[0] == type(None):
- continue
+ if (json_schema_extra := results_field.json_schema_extra) is not None:
+ model = json_schema_extra.get("model")
- annotated_inners.append(
- inner.__args__[0] if hasattr(inner, "__args__") else inner
- )
+ if json_schema_extra.get("is_union"):
+ return f"Union[List[{model}], {model}]"
+ if json_schema_extra.get("has_list"):
+ return f"List[{model}]"
- type_ = Union[tuple(annotated_inners)] if len(annotated_inners) > 1 else annotated_inners[0] # type: ignore
+ return model
if "typing." in str(type_):
unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_))
@@ -323,15 +322,4 @@ class OBBject(Tagged, Generic[T]):
OBBject[ResultsType]
OBBject with results.
"""
- data: List[BaseModel] = await query.execute()
-
- if isinstance(data, list) and data:
- if isinstance(data[0], BaseModel):
- data[0] = data[0].model_copy(update={"provider": query.provider})
- if isinstance(data[0], dict):
- data[0] = data[0].update({"provider": query.provider})
-
- if isinstance(data, BaseModel):
- data = data.model_copy(update={"provider": query.provider})
-
- return cls(results=data)
+ return cls(results=await query.execute())
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py
index 6e2e23de9b7..038697fdd83 100644
--- a/openbb_platform/core/openbb_core/app/router.py
+++ b/openbb_platform/core/openbb_core/app/router.py
@@ -20,7 +20,7 @@ from typing import (
)
from fastapi import APIRouter, Depends
-from pydantic import BaseModel, Discriminator, Field, Tag, create_model
+from pydantic import BaseModel, Field, SerializeAsAny, Tag, create_model
from pydantic.v1.validators import find_validators
from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias
@@ -41,17 +41,6 @@ from openbb_core.env import Env
P = ParamSpec("P")
-def model_t_discriminator(v: Any) -> str:
- """Discriminator function for the results field."""
- if isinstance(v, dict):
- return v.get("provider", "openbb")
-
- if isinstance(v, list) and v:
- return model_t_discriminator(v[0])
-
- return getattr(v, "provider", "openbb")
-
-
class OpenBBErrorResponse(BaseModel):
"""OpenBB Error Response."""
@@ -389,7 +378,7 @@ class SignatureInspector:
@staticmethod
def inject_return_type(
func: Callable[P, OBBject],
- return_map: Dict[str, Any],
+ return_map: Dict[str, dict],
model: str,
) -> Callable[P, OBBject]:
"""
@@ -397,34 +386,50 @@ class SignatureInspector:
Also updates __name__ and __doc__ for API schemas.
"""
- union_models = [Annotated[Union[list, dict], Tag("openbb")]]
+ results: Dict[str, Any] = {"list_type": [], "dict_type": []}
+
+ for provider, return_data in return_map.items():
+ if return_data["is_list"]:
+ results["list_type"].append(
+ Annotated[return_data["model"], Tag(provider)]
+ )
+ continue
+
+ results["dict_type"].append(Annotated[return_data["model"], Tag(provider)])
+
+ list_models, union_models = results.values()
+
+ return_types = []
+ for t, v in results.items():
+ if not v:
+ continue
- for provider, return_type in return_map.items():
- union_models.append(Annotated[return_type, Tag(provider)])
+ inner_type = SerializeAsAny[
+ Annotated[
+ Union[tuple(v)], # type: ignore
+ Field(discriminator="provider"),
+ ]
+ ]
+ return_types.append(List[inner_type] if t == "list_type" else inner_type)
return_type = create_model(
f"OBBject_{model}",
__base__=OBBject,
+ __doc__=f"OBBject with results of type {model}",
results=(
- Optional[
- Annotated[
- Union[tuple(union_models)], # type: ignore
- Field(
- None,
- description="Serializable results.",
- discriminator=Discriminator(model_t_discriminator),
- ),
- ]
- ],
- Field(None, description="Serializable results."),
+ Optional[Union[tuple(return_types)]], # type: ignore
+ Field(
+ None,
+ description="Serializable results.",
+ json_schema_extra={
+ "model": model,
+ "has_list": bool(len(list_models) > 0),
+ "is_union": bool(list_models and union_models),
+ },
+ ),
),
)
- return_type.__name__ = f"OBBject_{model.replace('Data', '')}"
- return_type.__doc__ = (
- f"OBBject with results of type {model.replace('Data', '')}"
- )
-
func.__annotations__["return"] = return_type
return func
diff --git a/openbb_platform/core/openbb_core/provider/abstract/data.py b/openbb_platform/core/openbb_core/provider/abstract/data.py
index 6fd7b3a183f..80731a16537 100644
--- a/openbb_platform/core/openbb_core/provider/abstract/data.py
+++ b/openbb_platform/core/openbb_core/provider/abstract/data.py
@@ -1,12 +1,12 @@
"""The OpenBB Standardized Data Model."""
-from typing import Dict, Literal
+from typing import Dict
from pydantic import (
+ AliasGenerator,
BaseModel,
BeforeValidator,
ConfigDict,
- Field,
alias_generators,
model_validator,
)
@@ -70,11 +70,6 @@ class Data(BaseModel):
"""
__alias_dict__: Dict[str, str] = {}
- provider: Literal["openbb"] = Field(
- "openbb",
- description="The data provider for the data.",
- exclude=True,
- )
def __repr__(self):
"""Return a string representation of the object."""
@@ -83,8 +78,11 @@ class Data(BaseModel):
model_config = ConfigDict(
extra="allow",
populate_by_name=True,
- alias_generator=alias_generators.to_camel,
strict=False,
+ alias_generator=AliasGenerator(
+ validation_alias=alias_generators.to_camel,
+ serialization_alias=alias_generators.to_snake,
+ ),
)
@model_validator(mode="before")
diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py
index 5a2cd8e7c51..be433e41a29 100644
--- a/openbb_platform/core/openbb_core/provider/registry_map.py
+++ b/openbb_platform/core/openbb_core/provider/registry_map.py
@@ -1,10 +1,11 @@
"""Provider registry map."""
+import sys
from inspect import getfile, isclass
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, get_origin
-from pydantic import BaseModel, ConfigDict, Field, alias_generators, create_model
+from pydantic import BaseModel, Field, create_model
from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.fetcher import Fetcher
@@ -94,7 +95,7 @@ class RegistryMap:
is_list = get_origin(self.extract_return_type(fetcher)) == list
return_schemas.setdefault(model_name, {}).update(
- {p: List[provider_model] if is_list else provider_model}
+ {p: {"model": provider_model, "is_list": is_list}}
)
return map_, return_schemas
@@ -115,49 +116,30 @@ class RegistryMap:
fields = {}
for field_name, field in model.model_fields.items():
- field.alias_priority = None
+ field.serialization_alias = field_name
fields[field_name] = (field.annotation, field)
- fields.pop("provider", None)
-
- return create_model(
- model.__name__.replace("Data", ""),
- __doc__=model.__doc__,
- __config__=ConfigDict(
- extra="allow",
- alias_generator=alias_generators.to_snake,
- populate_by_name=True,
- ),
- provider=(
- Literal[provider_str, "openbb"], # type: ignore
- Field(
- default=provider_str,
- description="The data provider for the data.",
- exclude=True,
- ),
+ fields["provider"] = (
+ Literal[provider_str], # type: ignore
+ Field(
+ default=provider_str,
+ description="The data provider for the data.",
+ exclude=True,
),
- **fields,
)
- @staticmethod
- def extract_query_model(fetcher: Fetcher, provider: str) -> BaseModel:
- """Extract info (fields and docstring) from fetcher query params or data."""
- model: BaseModel = RegistryMap._get_model(fetcher, "query_params")
-
provider_model = create_model(
- model.__name__,
+ model.__name__.replace("Data", ""),
__base__=model,
+ __doc__=model.__doc__,
__module__=model.__module__,
- provider=(
- Literal[provider], # type: ignore
- Field(
- default=provider,
- description="The data provider for the data.",
- exclude=True,
- ),
- ),
+ **fields,
)
+ # Replace the provider models in the modules with the new models we created
+ # To make sure provider field is defined to be the provider string
+ setattr(sys.modules[model.__module__], model.__name__, provider_model)
+
return provider_model
@staticmethod
diff --git a/openbb_platform/openbb/package/crypto.py b/openbb_platform/openbb/package/crypto.py
index 7bb5228c299..88f66999759 100644
--- a/openbb_platform/openbb/package/crypto.py
+++ b/openbb_platform/openbb/package/crypto.py
@@ -49,7 +49,7 @@ class ROUTER_crypto(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPCryptoSearch], Tag(tag='fmp')]]
+ results : List[CryptoSearch]
Serializable results.
provider : Optional[Literal['fmp']]
Provider name.
diff --git a/openbb_platform/openbb/package/crypto_price.py b/openbb_platform/openbb/package/crypto_price.py
index 42ae9ca3dba..549d3b2f60c 100644
--- a/openbb_platform/openbb/package/crypto_price.py
+++ b/openbb_platform/openbb/package/crypto_price.py
@@ -73,7 +73,7 @@ class ROUTER_crypto_price(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPCryptoHistorical], Tag(tag='fmp')], Annotated[List[PolygonCryptoHistorical], Tag(tag='polygon')], Annotated[List[TiingoCryptoHistorical], Tag(tag='tiingo')], Annotated[List[YFinanceCryptoHistorical], Tag(tag='yfinance')]]
+ results : List[CryptoHistorical]
Serializable results.
provider : Optional[Literal['fmp', 'polygon', 'tiingo', 'yfinance']]
Provider name.
diff --git a/openbb_platform/openbb/package/currency.py b/openbb_platform/openbb/package/currency.py
index b9f97368152..ea4a2c1f5f5 100644
--- a/openbb_platform/openbb/package/currency.py
+++ b/openbb_platform/openbb/package/currency.py
@@ -64,7 +64,7 @@ class ROUTER_currency(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPCurrencyPairs], Tag(tag='fmp')], Annotated[List[IntrinioCurrencyPairs], Tag(tag='intrinio')], Annotated[List[PolygonCurrencyPairs], Tag(tag='polygon')]]
+ results : List[CurrencyPairs]
Serializable results.
provider : Optional[Literal['fmp', 'intrinio', 'polygon']]
Provider name.
diff --git a/openbb_platform/openbb/package/currency_price.py b/openbb_platform/openbb/package/currency_price.py
index 0301e483613..c3a0562390a 100644
--- a/openbb_platform/openbb/package/currency_price.py
+++ b/openbb_platform/openbb/package/currency_price.py
@@ -77,7 +77,7 @@ class ROUTER_currency_price(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPCurrencyHistorical], Tag(tag='fmp')], Annotated[List[PolygonCurrencyHistorical], Tag(tag='polygon')], Annotated[List[TiingoCurrencyHistorical], Tag(tag='tiingo')], Annotated[List[YFinanceCurrencyHistorical], Tag(tag='yfinance')]]
+ results : List[CurrencyHistorical]
Serializable results.
provider : Optional[Literal['fmp', 'polygon', 'tiingo', 'yfinance']]
Provider name.
diff --git a/openbb_platform/openbb/package/derivatives_options.py b/openbb_platform/openbb/package/derivatives_options.py
index 5e5bba0ae17..43077b374bb 100644
--- a/openbb_platform/openbb/package/derivatives_options.py
+++ b/openbb_platform/openbb/package/derivatives_options.py
@@ -45,7 +45,7 @@ class ROUTER_derivatives_options(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[IntrinioOptionsChains], Tag(tag='intrinio')]]
+ results : List[OptionsChains]
Serializable results.
provider : Optional[Literal['intrinio']]
Provider name.
@@ -194,7 +194,7 @@ class ROUTER_derivatives_options(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[IntrinioOptionsUnusual], Tag(tag='intrinio')]]
+ results : List[OptionsUnusual]
Serializable results.
provider : Optional[Literal['intrinio']]
Provider name.
diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py
index e7784ce4b95..95cb79a8422 100644
--- a/openbb_platform/openbb/package/economy.py
+++ b/openbb_platform/openbb/package/economy.py
@@ -69,7 +69,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPEconomicCalendar], Tag(tag='fmp')], Annotated[List[TEEconomicCalendar], Tag(tag='tradingeconomics')]]
+ results : List[EconomicCalendar]
Serializable results.
provider : Optional[Literal['fmp', 'tradingeconomics']]
Provider name.
@@ -178,7 +178,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDCLI], Tag(tag='oecd')]]
+ results : List[CLI]
Serializable results.
provider : Optional[Literal['oecd']]
Provider name.
@@ -341,7 +341,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FREDConsumerPriceIndex], Tag(tag='fred')]]
+ results : List[ConsumerPriceIndex]
Serializable results.
provider : Optional[Literal['fred']]
Provider name.
@@ -424,7 +424,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FredSearch], Tag(tag='fred')]]
+ results : List[FredSearch]
Serializable results.
provider : Optional[Literal['fred']]
Provider name.
@@ -583,7 +583,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FredSeries], Tag(tag='fred')], Annotated[List[IntrinioFredSeries], Tag(tag='intrinio')]]
+ results : List[FredSeries]
Serializable results.
provider : Optional[Literal['fred', 'intrinio']]
Provider name.
@@ -677,7 +677,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDSTIR], Tag(tag='oecd')]]
+ results : List[STIR]
Serializable results.
provider : Optional[Literal['oecd']]
Provider name.
@@ -759,7 +759,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FederalReserveMoneyMeasures], Tag(tag='federal_reserve')]]
+ results : List[MoneyMeasures]
Serializable results.
provider : Optional[Literal['federal_reserve']]
Provider name.
@@ -826,7 +826,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPRiskPremium], Tag(tag='fmp')]]
+ results : List[RiskPremium]
Serializable results.
provider : Optional[Literal['fmp']]
Provider name.
@@ -909,7 +909,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDSTIR], Tag(tag='oecd')]]
+ results : List[STIR]
Serializable results.
provider : Optional[Literal['oecd']]
Provider name.
@@ -993,7 +993,7 @@ class ROUTER_economy(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDUnemployment], Tag(tag='oecd')]]
+ results : List[Unemployment]
Serializable results.
provider : Optional[Literal['oecd']]
Provider name.
diff --git a/openbb_platform/openbb/package/economy_gdp.py b/openbb_platform/openbb/package/economy_gdp.py
index 7febbd6ceca..5cb9b376d80 100644
--- a/openbb_platform/openbb/package/economy_gdp.py
+++ b/openbb_platform/openbb/package/economy_gdp.py
@@ -73,7 +73,7 @@ class ROUTER_economy_gdp(Container):
Returns
-------
OBBject
- results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDGdpForecast], Tag(tag='oecd')]]
+ results : List[GdpForecast]
Serializable results.
provider : Optional[Literal['oecd']]