diff options
author | montezdesousa <79287829+montezdesousa@users.noreply.github.com> | 2024-04-12 13:08:55 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-12 12:08:55 +0000 |
commit | b8d1846d46374d5229b74afa14333ea8468c9d74 (patch) | |
tree | b85f66cb30412fe7268db3a2c8b82078d00f8419 | |
parent | 3172b9e31268299586420360925aec28e07184ee (diff) |
[BugFix] - Provider is added to every response item (#6305)
* Exclude provider and don't model_dump
* still missing docstrings
* undo package changes
* minor fix
* mypy
* minor fix
* cleaner
* private var
* docstring
* docstrings
* add package builder tests
* ruff
* rename
* rename
* update tests
* minor fix
* fix test
* handle docstring edge cases
* test
17 files changed, 316 insertions, 223 deletions
diff --git a/openbb_platform/core/openbb_core/api/router/commands.py b/openbb_platform/core/openbb_core/api/router/commands.py index 2d2ea3e6a11..ce1cab484a8 100644 --- a/openbb_platform/core/openbb_core/api/router/commands.py +++ b/openbb_platform/core/openbb_core/api/router/commands.py @@ -117,7 +117,7 @@ def build_new_signature(path: str, func: Callable) -> Signature: ) -def validate_output(c_out: OBBject) -> Dict: +def validate_output(c_out: OBBject) -> OBBject: """ Validate OBBject object. @@ -170,7 +170,7 @@ def validate_output(c_out: OBBject) -> Dict: for k, v in c_out.model_copy(): exclude_fields_from_api(k, v) - return c_out.model_dump() + return c_out def build_api_wrapper( @@ -188,7 +188,7 @@ def build_api_wrapper( func.__annotations__ = new_annotations_map @wraps(wrapped=func) - async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Dict: + async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject: user_settings: UserSettings = UserSettings.model_validate( kwargs.pop( "__authenticated_user_settings", diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py index 8d4ba26839c..92038ea4970 100644 --- a/openbb_platform/core/openbb_core/app/model/obbject.py +++ b/openbb_platform/core/openbb_core/app/model/obbject.py @@ -1,6 +1,5 @@ """The OBBject.""" -from re import sub from typing import ( TYPE_CHECKING, Any, @@ -82,40 +81,6 @@ class OBBject(Tagged, Generic[T]): ] return f"{self.__class__.__name__}\n\n" + "\n".join(items) - @classmethod - def results_type_repr(cls, params: Optional[Any] = None) -> str: - """Return the results type representation.""" - results_field = cls.model_fields.get("results") - type_repr = "Any" - if results_field: - type_ = params[0] if params else results_field.annotation - type_repr = getattr(type_, "__name__", str(type_)) - - if json_schema_extra := getattr(results_field, "json_schema_extra", {}): - model = json_schema_extra.get("model", "Any") - - if json_schema_extra.get("is_union"): - return f"Union[List[{model}], {model}]" - if json_schema_extra.get("has_list"): - return f"List[{model}]" - - return model - - if "typing." in str(type_): - unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_)) - type_repr = sub( - r"(\w+\.)*(\w+)?(\, NoneType)?", - r"\2", - unpack_optional, - ) - - return type_repr - - @classmethod - def model_parametrized_name(cls, params: Any) -> str: - """Return the model name with the parameters.""" - return f"OBBject[{cls.results_type_repr(params)}]" - def to_df( self, index: Optional[Union[str, None]] = "date", sort_by: Optional[str] = None ) -> pd.DataFrame: diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py index 5bef4ac6e6c..0932704cb96 100644 --- a/openbb_platform/core/openbb_core/app/provider_interface.py +++ b/openbb_platform/core/openbb_core/app/provider_interface.py @@ -2,18 +2,33 @@ from dataclasses import dataclass, make_dataclass from difflib import SequenceMatcher -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import ( + Annotated, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) from fastapi import Query from pydantic import ( BaseModel, ConfigDict, + Discriminator, Field, + SerializeAsAny, + Tag, create_model, ) from pydantic.fields import FieldInfo from openbb_core.app.model.abstract.singleton import SingletonMeta +from openbb_core.app.model.obbject import OBBject from openbb_core.provider.query_executor import QueryExecutor from openbb_core.provider.registry_map import MapType, RegistryMap from openbb_core.provider.utils.helpers import to_snake_case @@ -92,12 +107,15 @@ class ProviderInterface(metaclass=SingletonMeta): self._registry_map = registry_map or RegistryMap() self._query_executor = query_executor or QueryExecutor - self._map = self._registry_map.map + self._map = self._registry_map.standard_extra # TODO: Try these 4 methods in a single iteration self._model_providers_map = self._generate_model_providers_dc(self._map) self._params = self._generate_params_dc(self._map) self._data = self._generate_data_dc(self._map) self._return_schema = self._generate_return_schema(self._data) + self._return_annotations = self._generate_return_annotations( + self._registry_map.original_models + ) self._available_providers = self._registry_map.available_providers self._provider_choices = self._get_provider_choices(self._available_providers) @@ -148,9 +166,9 @@ class ProviderInterface(metaclass=SingletonMeta): return self._registry_map.models @property - def return_map(self) -> Dict[str, Dict[str, Any]]: + def return_annotations(self) -> Dict[str, Type[OBBject]]: """Return map.""" - return self._registry_map.return_map + return self._return_annotations def create_executor(self) -> QueryExecutor: """Get query executor.""" @@ -242,7 +260,9 @@ class ProviderInterface(metaclass=SingletonMeta): additional_description += " Multiple comma separated items allowed." else: additional_description += ( - " Multiple comma separated items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore + " Multiple comma separated items allowed for provider(s): " + + ", ".join(multiple) # type: ignore[arg-type] + + "." ) provider_field = ( @@ -396,7 +416,7 @@ class ProviderInterface(metaclass=SingletonMeta): This creates a dictionary of dataclasses that can be injected as a FastAPI dependency. - Example: + Example ------- @dataclass class CompanyNews(StandardParams): @@ -437,7 +457,7 @@ class ProviderInterface(metaclass=SingletonMeta): This creates a dictionary that maps model names to dataclasses that can be injected as a FastAPI dependency. - Example: + Example ------- @dataclass class CompanyNews(ProviderChoices): @@ -471,7 +491,7 @@ class ProviderInterface(metaclass=SingletonMeta): This creates a dictionary of dataclasses. - Example: + Example ------- class EquityHistoricalData(StandardData): date: date @@ -546,3 +566,60 @@ class ProviderInterface(metaclass=SingletonMeta): fields=[("provider", Literal[tuple(available_providers)])], # type: ignore bases=(ProviderChoices,), ) + + def _generate_return_annotations( + self, original_models: Dict[str, Dict[str, Any]] + ) -> Dict[str, Type[OBBject]]: + """Generate return annotations for FastAPI. + + Example + ------- + class Data(BaseModel): + ... + + class EquityData(Data): + price: float + + class YFEquityData(EquityData): + yf_field: str + + class AVEquityData(EquityData): + av_field: str + + class OBBject(BaseModel): + results: List[ + SerializeAsAny[ + Annotated[ + Union[ + Annotated[YFEquityData, Tag("yf")], + Annotated[AVEquityData, Tag("av")], + ], + Discriminator(get_provider), + ] + ] + ] + """ + + def get_provider(v: Type[BaseModel]): + """Callable to discriminate which BaseModel to use.""" + return getattr(v, "_provider", None) + + annotations = {} + for name, models in original_models.items(): + outer = set() + args = set() + for provider, model in models.items(): + data = model["data"] + outer.add(model["results_type"]) + args.add(Annotated[data, Tag(provider)]) + # We set the provider to use it in discriminator function + setattr(data, "_provider", provider) + meta = Discriminator(get_provider) if len(args) > 1 else None + inner = SerializeAsAny[Annotated[Union[tuple(args)], meta]] # type: ignore[misc,valid-type] + full = Union[tuple((o[inner] if o else inner) for o in outer)] # type: ignore[valid-type] + annotations[name] = create_model( + f"OBBject_{name}", + __base__=OBBject[full], # type: ignore[valid-type] + __doc__=f"OBBject with results of type {name}", + ) + return annotations diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index 51bfc836d07..c9d196216aa 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -12,7 +12,6 @@ from typing import ( Mapping, Optional, Type, - Union, get_args, get_origin, get_type_hints, @@ -20,7 +19,7 @@ from typing import ( ) from fastapi import APIRouter, Depends -from pydantic import BaseModel, Field, SerializeAsAny, Tag, create_model +from pydantic import BaseModel from pydantic.v1.validators import find_validators from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias @@ -397,10 +396,9 @@ class SignatureInspector: callable_=provider_interface.params[model]["extra"], ) - func = cls.inject_return_type( + func = cls.inject_return_annotation( func=func, - return_map=provider_interface.return_map.get(model, {}), - model=model, + annotation=provider_interface.return_annotations[model], ) else: @@ -418,60 +416,6 @@ class SignatureInspector: return func @staticmethod - def inject_return_type( - func: Callable[P, OBBject], - return_map: Dict[str, dict], - model: str, - ) -> Callable[P, OBBject]: - """Inject full return model into the function. Also updates __name__ and __doc__ for API schemas.""" - results: Dict[str, Any] = {"list_type": [], "dict_type": []} - - for provider, return_data in return_map.items(): - if return_data["is_list"]: - results["list_type"].append( - Annotated[return_data["model"], Tag(provider)] - ) - continue - - results["dict_type"].append(Annotated[return_data["model"], Tag(provider)]) - - list_models, union_models = results.values() - - return_types = [] - for t, v in results.items(): - if not v: - continue - - inner_type: Any = SerializeAsAny[ # type: ignore[misc,valid-type] - Annotated[ - Union[tuple(v)], # type: ignore - Field(discriminator="provider"), - ] - ] - return_types.append(List[inner_type] if t == "list_type" else inner_type) - - return_type = create_model( - f"OBBject_{model}", - __base__=OBBject, - __doc__=f"OBBject with results of type {model}", - results=( - Optional[Union[tuple(return_types)]], # type: ignore - Field( - None, - description="Serializable results.", - json_schema_extra={ - "model": model, - "has_list": bool(len(list_models) > 0), - "is_union": bool(list_models and union_models), - }, - ), - ), - ) - - func.__annotations__["return"] = return_type - return func - - @staticmethod def polish_return_schema(func: Callable[P, OBBject]) -> Callable[P, OBBject]: """Polish API schemas by filling `__doc__` and `__name__`.""" return_type = func.__annotations__["return"] @@ -518,6 +462,14 @@ class SignatureInspector: return func @staticmethod + def inject_return_annotation( + func: Callable[P, OBBject], annotation: Type[OBBject] + ) -> Callable[P, OBBject]: + """Annotate function with return annotation.""" + func.__annotations__["return"] = annotation + return func + + @staticmethod def get_description(func: Callable) -> str: """Get description from docstring.""" doc = func.__doc__ diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index 4559d0762b8..df445cb3e50 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -7,6 +7,7 @@ import re import shutil import sys from dataclasses import Field +from functools import partial from inspect import Parameter, _empty, isclass, signature from json import dumps, load from pathlib import Path @@ -24,6 +25,7 @@ from typing import ( TypeVar, Union, get_args, + get_origin, get_type_hints, ) @@ -41,6 +43,7 @@ from openbb_core.app.model.custom_parameter import ( OpenBBCustomParameter, ) from openbb_core.app.model.example import Example +from openbb_core.app.model.obbject import OBBject from openbb_core.app.provider_interface import ProviderInterface from openbb_core.app.router import RouterLoader from openbb_core.app.static.utils.console import Console @@ -1146,18 +1149,25 @@ class DocstringGenerator: # Parameters passed as **kwargs kwarg_params = params["extra"].__dataclass_fields__ param_types.update({k: v.type for k, v in kwarg_params.items()}) - - returns = return_schema.model_fields - results_type = func.__annotations__.get("return", model_name) - if hasattr(results_type, "results_type_repr"): - results_type = results_type.results_type_repr() - + # Format the annotation to hide the metadata, tags, etc. + annotation = func.__annotations__.get("return") + results_type = ( + cls._get_repr( + cls._get_generic_types( + annotation.model_fields["results"].annotation, # type: ignore[union-attr] + [], + ), + model_name, + ) + if isclass(annotation) and issubclass(annotation, OBBject) # type: ignore[arg-type] + else model_name + ) doc = cls.generate_model_docstring( model_name=model_name, summary=func.__doc__ or "", explicit_params=explicit_params, kwarg_params=kwarg_params, - returns=returns, + returns=return_schema.model_fields, results_type=results_type, ) else: @@ -1172,6 +1182,65 @@ class DocstringGenerator: return doc + @classmethod + def _get_generic_types(cls, type_: type, items: list) -> List[str]: + """Unpack generic types recursively. + + Parameters + ---------- + type_ : type + Type to unpack. + items : list + List to store the unpacked types. + + Returns + ------- + List[str] + List of unpacked type names. + + Examples + -------- + Union[List[str], Dict[str, str], Tuple[str]] -> ["List", "Dict", "Tuple"] + """ + if hasattr(type_, "__args__"): + origin = get_origin(type_) + # pylint: disable=unidiomatic-typecheck + if ( + type(origin) is type + and origin is not Annotated + and (name := getattr(type_, "_name", getattr(type_, "__name__", None))) + ): + items.append(name.title()) + func = partial(cls._get_generic_types, items=items) + set().union(*map(func, type_.__args__), items) + return items + + @staticmethod + def _get_repr(items: List[str], model: str) -> str: + """Get the string representation of the types list with the model name. + + Parameters + ---------- + items : List[str] + List of type names. + model : str + Model name to access the model providers. + + Returns + ------- + str + String representation of the unpacked types list. + + Examples + -------- + [List, Dict, Tuple], M -> "Union[List[M], Dict[str, M], Tuple[M]]" + """ + if s := [ + f"Dict[str, {model}]" if i == "Dict" else f"{i}[{model}]" for i in items + ]: + return f"Union[{', '.join(s)}]" if len(s) > 1 else s[0] + return model + class PathHandler: """Handle the paths for the Platform.""" @@ -1362,7 +1431,7 @@ class ReferenceGenerator: """ provider_field_params = [] expanded_types = MethodDefinition.TYPE_EXPANSION - model_map = cls.pi._map[model] # pylint: disable=protected-access + model_map = cls.pi.map[model] for field, field_info in model_map[provider][params_type]["fields"].items(): # Determine the field type, expanding it if necessary and if params_type is "Parameters" @@ -1605,7 +1674,7 @@ class ReferenceGenerator: route, "description", "No description available." ) # Access model map from the ProviderInterface - model_map = cls.pi._map[ + model_map = cls.pi.map[ standard_model ] # pylint: disable=protected-access diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index b3463b5da79..05dff9ef4f6 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -1,12 +1,11 @@ """Provider registry map.""" -import sys from copy import deepcopy from inspect import getfile, isclass from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, get_origin -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel from openbb_core.provider.abstract.data import Data from openbb_core.provider.abstract.fetcher import Fetcher @@ -27,8 +26,8 @@ class RegistryMap: self._registry = registry or RegistryLoader.from_extensions() self._credentials = self._get_credentials(self._registry) self._available_providers = self._get_available_providers(self._registry) - self._map, self._return_map = self._get_map(self._registry) - self._models = self._get_models(self._map) + self._standard_extra, self._original_models = self._get_maps(self._registry) + self._models = self._get_models(self._standard_extra) @property def registry(self) -> Registry: @@ -46,14 +45,14 @@ class RegistryMap: return self._credentials @property - def map(self) -> MapType: - """Get provider registry map.""" - return self._map + def standard_extra(self) -> MapType: + """Get standard extra map.""" + return self._standard_extra @property - def return_map(self) -> MapType: - """Get provider registry return map.""" - return self._return_map + def original_models(self) -> MapType: + """Get original models.""" + return self._original_models @property def models(self) -> List[str]: @@ -72,37 +71,42 @@ class RegistryMap: """Get list of available providers.""" return sorted(list(registry.providers.keys())) - def _get_map(self, registry: Registry) -> Tuple[MapType, Dict[str, Dict]]: + def _get_maps(self, registry: Registry) -> Tuple[MapType, Dict[str, Dict]]: """Generate map for the provider package.""" - map_: MapType = {} - return_schemas: Dict[str, Dict] = {} + standard_extra: MapType = {} + original_models: Dict[str, Dict] = {} for p in registry.providers: for model_name, fetcher in registry.providers[p].fetcher_dict.items(): - standard_query, extra_query = self.extract_info(fetcher, "query_params") - standard_data, extra_data = self.extract_info(fetcher, "data") - if model_name not in map_: - map_[model_name] = {} + standard_query, extra_query = self._extract_info( + fetcher, "query_params" + ) + standard_data, extra_data = self._extract_info(fetcher, "data") + if model_name not in standard_extra: + standard_extra[model_name] = {} # The deepcopy avoids modifications from one model to affect another - map_[model_name]["openbb"] = { + standard_extra[model_name]["openbb"] = { "QueryParams": deepcopy(standard_query), "Data": deepcopy(standard_data), } - map_[model_name][p] = { + standard_extra[model_name][p] = { "QueryParams": extra_query, "Data": extra_data, } - if provider_model := self.extract_data_model(fetcher, p): - is_list = get_origin(self.extract_return_type(fetcher)) == list - - return_schemas.setdefault(model_name, {}).update( - {p: {"model": provider_model, "is_list": is_list}} - ) + original_models.setdefault(model_name, {}).update( + { + p: { + "query": self._get_model(fetcher, "query_params"), + "data": self._get_model(fetcher, "data"), + "results_type": self._get_results_type(fetcher), + } + } + ) - self._merge_json_schema_extra(p, fetcher, map_[model_name]) + self._merge_json_schema_extra(p, fetcher, standard_extra[model_name]) - return map_, return_schemas + return standard_extra, original_models |