diff options
author | teh_coderer <me@tehcoderer.com> | 2024-02-08 04:58:14 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-08 09:58:14 +0000 |
commit | 7c68832379bf9869517114732d2888f5e10c436a (patch) | |
tree | 30aec21d2f23bc5d6d696e9427acfebd0cc0a37d | |
parent | 7d668b10c87a8d8909459bc7ee0e9f28b981da15 (diff) |
improve discriminator logic, fix package return type docs (#6052)
* improve discriminator logic, fix package return type docs
* Update registry_map.py
* build package
* Update registry_map.py
* defaults
* Update registry_map.py
---------
Co-authored-by: Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com>
30 files changed, 196 insertions, 223 deletions
diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py index 3f87c6d5e20..30a21b61c9a 100644 --- a/openbb_platform/core/openbb_core/app/model/obbject.py +++ b/openbb_platform/core/openbb_core/app/model/obbject.py @@ -73,20 +73,19 @@ class OBBject(Tagged, Generic[T]): @classmethod def results_type_repr(cls, params: Optional[Any] = None) -> str: """Return the results type name.""" - type_ = params[0] if params else cls.model_fields["results"].annotation + results_field = cls.model_fields.get("results") + type_ = params[0] if params else results_field.annotation name = type_.__name__ if hasattr(type_, "__name__") else str(type_) - if "Annotated" in str(type_): - annotated_inners = [] - for inner in cls.model_fields["results"].annotation.__args__: - if hasattr(inner, "__args__") and inner.__args__[0] == type(None): - continue + if (json_schema_extra := results_field.json_schema_extra) is not None: + model = json_schema_extra.get("model") - annotated_inners.append( - inner.__args__[0] if hasattr(inner, "__args__") else inner - ) + if json_schema_extra.get("is_union"): + return f"Union[List[{model}], {model}]" + if json_schema_extra.get("has_list"): + return f"List[{model}]" - type_ = Union[tuple(annotated_inners)] if len(annotated_inners) > 1 else annotated_inners[0] # type: ignore + return model if "typing." in str(type_): unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_)) @@ -323,15 +322,4 @@ class OBBject(Tagged, Generic[T]): OBBject[ResultsType] OBBject with results. """ - data: List[BaseModel] = await query.execute() - - if isinstance(data, list) and data: - if isinstance(data[0], BaseModel): - data[0] = data[0].model_copy(update={"provider": query.provider}) - if isinstance(data[0], dict): - data[0] = data[0].update({"provider": query.provider}) - - if isinstance(data, BaseModel): - data = data.model_copy(update={"provider": query.provider}) - - return cls(results=data) + return cls(results=await query.execute()) diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index 6e2e23de9b7..038697fdd83 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -20,7 +20,7 @@ from typing import ( ) from fastapi import APIRouter, Depends -from pydantic import BaseModel, Discriminator, Field, Tag, create_model +from pydantic import BaseModel, Field, SerializeAsAny, Tag, create_model from pydantic.v1.validators import find_validators from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias @@ -41,17 +41,6 @@ from openbb_core.env import Env P = ParamSpec("P") -def model_t_discriminator(v: Any) -> str: - """Discriminator function for the results field.""" - if isinstance(v, dict): - return v.get("provider", "openbb") - - if isinstance(v, list) and v: - return model_t_discriminator(v[0]) - - return getattr(v, "provider", "openbb") - - class OpenBBErrorResponse(BaseModel): """OpenBB Error Response.""" @@ -389,7 +378,7 @@ class SignatureInspector: @staticmethod def inject_return_type( func: Callable[P, OBBject], - return_map: Dict[str, Any], + return_map: Dict[str, dict], model: str, ) -> Callable[P, OBBject]: """ @@ -397,34 +386,50 @@ class SignatureInspector: Also updates __name__ and __doc__ for API schemas. """ - union_models = [Annotated[Union[list, dict], Tag("openbb")]] + results: Dict[str, Any] = {"list_type": [], "dict_type": []} + + for provider, return_data in return_map.items(): + if return_data["is_list"]: + results["list_type"].append( + Annotated[return_data["model"], Tag(provider)] + ) + continue + + results["dict_type"].append(Annotated[return_data["model"], Tag(provider)]) + + list_models, union_models = results.values() + + return_types = [] + for t, v in results.items(): + if not v: + continue - for provider, return_type in return_map.items(): - union_models.append(Annotated[return_type, Tag(provider)]) + inner_type = SerializeAsAny[ + Annotated[ + Union[tuple(v)], # type: ignore + Field(discriminator="provider"), + ] + ] + return_types.append(List[inner_type] if t == "list_type" else inner_type) return_type = create_model( f"OBBject_{model}", __base__=OBBject, + __doc__=f"OBBject with results of type {model}", results=( - Optional[ - Annotated[ - Union[tuple(union_models)], # type: ignore - Field( - None, - description="Serializable results.", - discriminator=Discriminator(model_t_discriminator), - ), - ] - ], - Field(None, description="Serializable results."), + Optional[Union[tuple(return_types)]], # type: ignore + Field( + None, + description="Serializable results.", + json_schema_extra={ + "model": model, + "has_list": bool(len(list_models) > 0), + "is_union": bool(list_models and union_models), + }, + ), ), ) - return_type.__name__ = f"OBBject_{model.replace('Data', '')}" - return_type.__doc__ = ( - f"OBBject with results of type {model.replace('Data', '')}" - ) - func.__annotations__["return"] = return_type return func diff --git a/openbb_platform/core/openbb_core/provider/abstract/data.py b/openbb_platform/core/openbb_core/provider/abstract/data.py index 6fd7b3a183f..80731a16537 100644 --- a/openbb_platform/core/openbb_core/provider/abstract/data.py +++ b/openbb_platform/core/openbb_core/provider/abstract/data.py @@ -1,12 +1,12 @@ """The OpenBB Standardized Data Model.""" -from typing import Dict, Literal +from typing import Dict from pydantic import ( + AliasGenerator, BaseModel, BeforeValidator, ConfigDict, - Field, alias_generators, model_validator, ) @@ -70,11 +70,6 @@ class Data(BaseModel): """ __alias_dict__: Dict[str, str] = {} - provider: Literal["openbb"] = Field( - "openbb", - description="The data provider for the data.", - exclude=True, - ) def __repr__(self): """Return a string representation of the object.""" @@ -83,8 +78,11 @@ class Data(BaseModel): model_config = ConfigDict( extra="allow", populate_by_name=True, - alias_generator=alias_generators.to_camel, strict=False, + alias_generator=AliasGenerator( + validation_alias=alias_generators.to_camel, + serialization_alias=alias_generators.to_snake, + ), ) @model_validator(mode="before") diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index 5a2cd8e7c51..be433e41a29 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -1,10 +1,11 @@ """Provider registry map.""" +import sys from inspect import getfile, isclass from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, get_origin -from pydantic import BaseModel, ConfigDict, Field, alias_generators, create_model +from pydantic import BaseModel, Field, create_model from openbb_core.provider.abstract.data import Data from openbb_core.provider.abstract.fetcher import Fetcher @@ -94,7 +95,7 @@ class RegistryMap: is_list = get_origin(self.extract_return_type(fetcher)) == list return_schemas.setdefault(model_name, {}).update( - {p: List[provider_model] if is_list else provider_model} + {p: {"model": provider_model, "is_list": is_list}} ) return map_, return_schemas @@ -115,49 +116,30 @@ class RegistryMap: fields = {} for field_name, field in model.model_fields.items(): - field.alias_priority = None + field.serialization_alias = field_name fields[field_name] = (field.annotation, field) - fields.pop("provider", None) - - return create_model( - model.__name__.replace("Data", ""), - __doc__=model.__doc__, - __config__=ConfigDict( - extra="allow", - alias_generator=alias_generators.to_snake, - populate_by_name=True, - ), - provider=( - Literal[provider_str, "openbb"], # type: ignore - Field( - default=provider_str, - description="The data provider for the data.", - exclude=True, - ), + fields["provider"] = ( + Literal[provider_str], # type: ignore + Field( + default=provider_str, + description="The data provider for the data.", + exclude=True, ), - **fields, ) - @staticmethod - def extract_query_model(fetcher: Fetcher, provider: str) -> BaseModel: - """Extract info (fields and docstring) from fetcher query params or data.""" - model: BaseModel = RegistryMap._get_model(fetcher, "query_params") - provider_model = create_model( - model.__name__, + model.__name__.replace("Data", ""), __base__=model, + __doc__=model.__doc__, __module__=model.__module__, - provider=( - Literal[provider], # type: ignore - Field( - default=provider, - description="The data provider for the data.", - exclude=True, - ), - ), + **fields, ) + # Replace the provider models in the modules with the new models we created + # To make sure provider field is defined to be the provider string + setattr(sys.modules[model.__module__], model.__name__, provider_model) + return provider_model @staticmethod diff --git a/openbb_platform/openbb/package/crypto.py b/openbb_platform/openbb/package/crypto.py index 7bb5228c299..88f66999759 100644 --- a/openbb_platform/openbb/package/crypto.py +++ b/openbb_platform/openbb/package/crypto.py @@ -49,7 +49,7 @@ class ROUTER_crypto(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPCryptoSearch], Tag(tag='fmp')]] + results : List[CryptoSearch] Serializable results. provider : Optional[Literal['fmp']] Provider name. diff --git a/openbb_platform/openbb/package/crypto_price.py b/openbb_platform/openbb/package/crypto_price.py index 42ae9ca3dba..549d3b2f60c 100644 --- a/openbb_platform/openbb/package/crypto_price.py +++ b/openbb_platform/openbb/package/crypto_price.py @@ -73,7 +73,7 @@ class ROUTER_crypto_price(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPCryptoHistorical], Tag(tag='fmp')], Annotated[List[PolygonCryptoHistorical], Tag(tag='polygon')], Annotated[List[TiingoCryptoHistorical], Tag(tag='tiingo')], Annotated[List[YFinanceCryptoHistorical], Tag(tag='yfinance')]] + results : List[CryptoHistorical] Serializable results. provider : Optional[Literal['fmp', 'polygon', 'tiingo', 'yfinance']] Provider name. diff --git a/openbb_platform/openbb/package/currency.py b/openbb_platform/openbb/package/currency.py index b9f97368152..ea4a2c1f5f5 100644 --- a/openbb_platform/openbb/package/currency.py +++ b/openbb_platform/openbb/package/currency.py @@ -64,7 +64,7 @@ class ROUTER_currency(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPCurrencyPairs], Tag(tag='fmp')], Annotated[List[IntrinioCurrencyPairs], Tag(tag='intrinio')], Annotated[List[PolygonCurrencyPairs], Tag(tag='polygon')]] + results : List[CurrencyPairs] Serializable results. provider : Optional[Literal['fmp', 'intrinio', 'polygon']] Provider name. diff --git a/openbb_platform/openbb/package/currency_price.py b/openbb_platform/openbb/package/currency_price.py index 0301e483613..c3a0562390a 100644 --- a/openbb_platform/openbb/package/currency_price.py +++ b/openbb_platform/openbb/package/currency_price.py @@ -77,7 +77,7 @@ class ROUTER_currency_price(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPCurrencyHistorical], Tag(tag='fmp')], Annotated[List[PolygonCurrencyHistorical], Tag(tag='polygon')], Annotated[List[TiingoCurrencyHistorical], Tag(tag='tiingo')], Annotated[List[YFinanceCurrencyHistorical], Tag(tag='yfinance')]] + results : List[CurrencyHistorical] Serializable results. provider : Optional[Literal['fmp', 'polygon', 'tiingo', 'yfinance']] Provider name. diff --git a/openbb_platform/openbb/package/derivatives_options.py b/openbb_platform/openbb/package/derivatives_options.py index 5e5bba0ae17..43077b374bb 100644 --- a/openbb_platform/openbb/package/derivatives_options.py +++ b/openbb_platform/openbb/package/derivatives_options.py @@ -45,7 +45,7 @@ class ROUTER_derivatives_options(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[IntrinioOptionsChains], Tag(tag='intrinio')]] + results : List[OptionsChains] Serializable results. provider : Optional[Literal['intrinio']] Provider name. @@ -194,7 +194,7 @@ class ROUTER_derivatives_options(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[IntrinioOptionsUnusual], Tag(tag='intrinio')]] + results : List[OptionsUnusual] Serializable results. provider : Optional[Literal['intrinio']] Provider name. diff --git a/openbb_platform/openbb/package/economy.py b/openbb_platform/openbb/package/economy.py index e7784ce4b95..95cb79a8422 100644 --- a/openbb_platform/openbb/package/economy.py +++ b/openbb_platform/openbb/package/economy.py @@ -69,7 +69,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPEconomicCalendar], Tag(tag='fmp')], Annotated[List[TEEconomicCalendar], Tag(tag='tradingeconomics')]] + results : List[EconomicCalendar] Serializable results. provider : Optional[Literal['fmp', 'tradingeconomics']] Provider name. @@ -178,7 +178,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDCLI], Tag(tag='oecd')]] + results : List[CLI] Serializable results. provider : Optional[Literal['oecd']] Provider name. @@ -341,7 +341,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FREDConsumerPriceIndex], Tag(tag='fred')]] + results : List[ConsumerPriceIndex] Serializable results. provider : Optional[Literal['fred']] Provider name. @@ -424,7 +424,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FredSearch], Tag(tag='fred')]] + results : List[FredSearch] Serializable results. provider : Optional[Literal['fred']] Provider name. @@ -583,7 +583,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FredSeries], Tag(tag='fred')], Annotated[List[IntrinioFredSeries], Tag(tag='intrinio')]] + results : List[FredSeries] Serializable results. provider : Optional[Literal['fred', 'intrinio']] Provider name. @@ -677,7 +677,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDSTIR], Tag(tag='oecd')]] + results : List[STIR] Serializable results. provider : Optional[Literal['oecd']] Provider name. @@ -759,7 +759,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FederalReserveMoneyMeasures], Tag(tag='federal_reserve')]] + results : List[MoneyMeasures] Serializable results. provider : Optional[Literal['federal_reserve']] Provider name. @@ -826,7 +826,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[FMPRiskPremium], Tag(tag='fmp')]] + results : List[RiskPremium] Serializable results. provider : Optional[Literal['fmp']] Provider name. @@ -909,7 +909,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDSTIR], Tag(tag='oecd')]] + results : List[STIR] Serializable results. provider : Optional[Literal['oecd']] Provider name. @@ -993,7 +993,7 @@ class ROUTER_economy(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDUnemployment], Tag(tag='oecd')]] + results : List[Unemployment] Serializable results. provider : Optional[Literal['oecd']] Provider name. diff --git a/openbb_platform/openbb/package/economy_gdp.py b/openbb_platform/openbb/package/economy_gdp.py index 7febbd6ceca..5cb9b376d80 100644 --- a/openbb_platform/openbb/package/economy_gdp.py +++ b/openbb_platform/openbb/package/economy_gdp.py @@ -73,7 +73,7 @@ class ROUTER_economy_gdp(Container): Returns ------- OBBject - results : Union[Annotated[Union[list, dict], Tag(tag='openbb')], Annotated[List[OECDGdpForecast], Tag(tag='oecd')]] + results : List[GdpForecast] Serializable results. provider : Optional[Literal['oecd']] |