summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormontezdesousa <79287829+montezdesousa@users.noreply.github.com>2024-04-12 13:08:55 +0100
committerGitHub <noreply@github.com>2024-04-12 12:08:55 +0000
commitb8d1846d46374d5229b74afa14333ea8468c9d74 (patch)
treeb85f66cb30412fe7268db3a2c8b82078d00f8419
parent3172b9e31268299586420360925aec28e07184ee (diff)
[BugFix] - Provider is added to every response item (#6305)
* Exclude provider and don't model_dump * still missing docstrings * undo package changes * minor fix * mypy * minor fix * cleaner * private var * docstring * docstrings * add package builder tests * ruff * rename * rename * update tests * minor fix * fix test * handle docstring edge cases * test
-rw-r--r--openbb_platform/core/openbb_core/api/router/commands.py6
-rw-r--r--openbb_platform/core/openbb_core/app/model/obbject.py35
-rw-r--r--openbb_platform/core/openbb_core/app/provider_interface.py93
-rw-r--r--openbb_platform/core/openbb_core/app/router.py70
-rw-r--r--openbb_platform/core/openbb_core/app/static/package_builder.py87
-rw-r--r--openbb_platform/core/openbb_core/provider/registry_map.py102
-rw-r--r--openbb_platform/core/tests/app/static/test_package_builder.py36
-rw-r--r--openbb_platform/core/tests/provider/test_registry_map.py12
-rw-r--r--openbb_platform/openbb/assets/reference.json38
-rw-r--r--openbb_platform/openbb/package/economy.py30
-rw-r--r--openbb_platform/openbb/package/economy_gdp.py6
-rw-r--r--openbb_platform/openbb/package/equity.py6
-rw-r--r--openbb_platform/openbb/package/equity_compare.py1
-rw-r--r--openbb_platform/openbb/package/equity_discovery.py5
-rw-r--r--openbb_platform/openbb/package/equity_ownership.py4
-rw-r--r--openbb_platform/openbb/package/index.py4
-rw-r--r--openbb_platform/openbb/package/regulators_sec.py4
17 files changed, 316 insertions, 223 deletions
diff --git a/openbb_platform/core/openbb_core/api/router/commands.py b/openbb_platform/core/openbb_core/api/router/commands.py
index 2d2ea3e6a11..ce1cab484a8 100644
--- a/openbb_platform/core/openbb_core/api/router/commands.py
+++ b/openbb_platform/core/openbb_core/api/router/commands.py
@@ -117,7 +117,7 @@ def build_new_signature(path: str, func: Callable) -> Signature:
)
-def validate_output(c_out: OBBject) -> Dict:
+def validate_output(c_out: OBBject) -> OBBject:
"""
Validate OBBject object.
@@ -170,7 +170,7 @@ def validate_output(c_out: OBBject) -> Dict:
for k, v in c_out.model_copy():
exclude_fields_from_api(k, v)
- return c_out.model_dump()
+ return c_out
def build_api_wrapper(
@@ -188,7 +188,7 @@ def build_api_wrapper(
func.__annotations__ = new_annotations_map
@wraps(wrapped=func)
- async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Dict:
+ async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject:
user_settings: UserSettings = UserSettings.model_validate(
kwargs.pop(
"__authenticated_user_settings",
diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py
index 8d4ba26839c..92038ea4970 100644
--- a/openbb_platform/core/openbb_core/app/model/obbject.py
+++ b/openbb_platform/core/openbb_core/app/model/obbject.py
@@ -1,6 +1,5 @@
"""The OBBject."""
-from re import sub
from typing import (
TYPE_CHECKING,
Any,
@@ -82,40 +81,6 @@ class OBBject(Tagged, Generic[T]):
]
return f"{self.__class__.__name__}\n\n" + "\n".join(items)
- @classmethod
- def results_type_repr(cls, params: Optional[Any] = None) -> str:
- """Return the results type representation."""
- results_field = cls.model_fields.get("results")
- type_repr = "Any"
- if results_field:
- type_ = params[0] if params else results_field.annotation
- type_repr = getattr(type_, "__name__", str(type_))
-
- if json_schema_extra := getattr(results_field, "json_schema_extra", {}):
- model = json_schema_extra.get("model", "Any")
-
- if json_schema_extra.get("is_union"):
- return f"Union[List[{model}], {model}]"
- if json_schema_extra.get("has_list"):
- return f"List[{model}]"
-
- return model
-
- if "typing." in str(type_):
- unpack_optional = sub(r"Optional\[(.*)\]", r"\1", str(type_))
- type_repr = sub(
- r"(\w+\.)*(\w+)?(\, NoneType)?",
- r"\2",
- unpack_optional,
- )
-
- return type_repr
-
- @classmethod
- def model_parametrized_name(cls, params: Any) -> str:
- """Return the model name with the parameters."""
- return f"OBBject[{cls.results_type_repr(params)}]"
-
def to_df(
self, index: Optional[Union[str, None]] = "date", sort_by: Optional[str] = None
) -> pd.DataFrame:
diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py
index 5bef4ac6e6c..0932704cb96 100644
--- a/openbb_platform/core/openbb_core/app/provider_interface.py
+++ b/openbb_platform/core/openbb_core/app/provider_interface.py
@@ -2,18 +2,33 @@
from dataclasses import dataclass, make_dataclass
from difflib import SequenceMatcher
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
+from typing import (
+ Annotated,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
from fastapi import Query
from pydantic import (
BaseModel,
ConfigDict,
+ Discriminator,
Field,
+ SerializeAsAny,
+ Tag,
create_model,
)
from pydantic.fields import FieldInfo
from openbb_core.app.model.abstract.singleton import SingletonMeta
+from openbb_core.app.model.obbject import OBBject
from openbb_core.provider.query_executor import QueryExecutor
from openbb_core.provider.registry_map import MapType, RegistryMap
from openbb_core.provider.utils.helpers import to_snake_case
@@ -92,12 +107,15 @@ class ProviderInterface(metaclass=SingletonMeta):
self._registry_map = registry_map or RegistryMap()
self._query_executor = query_executor or QueryExecutor
- self._map = self._registry_map.map
+ self._map = self._registry_map.standard_extra
# TODO: Try these 4 methods in a single iteration
self._model_providers_map = self._generate_model_providers_dc(self._map)
self._params = self._generate_params_dc(self._map)
self._data = self._generate_data_dc(self._map)
self._return_schema = self._generate_return_schema(self._data)
+ self._return_annotations = self._generate_return_annotations(
+ self._registry_map.original_models
+ )
self._available_providers = self._registry_map.available_providers
self._provider_choices = self._get_provider_choices(self._available_providers)
@@ -148,9 +166,9 @@ class ProviderInterface(metaclass=SingletonMeta):
return self._registry_map.models
@property
- def return_map(self) -> Dict[str, Dict[str, Any]]:
+ def return_annotations(self) -> Dict[str, Type[OBBject]]:
"""Return map."""
- return self._registry_map.return_map
+ return self._return_annotations
def create_executor(self) -> QueryExecutor:
"""Get query executor."""
@@ -242,7 +260,9 @@ class ProviderInterface(metaclass=SingletonMeta):
additional_description += " Multiple comma separated items allowed."
else:
additional_description += (
- " Multiple comma separated items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
+ " Multiple comma separated items allowed for provider(s): "
+ + ", ".join(multiple) # type: ignore[arg-type]
+ + "."
)
provider_field = (
@@ -396,7 +416,7 @@ class ProviderInterface(metaclass=SingletonMeta):
This creates a dictionary of dataclasses that can be injected as a FastAPI
dependency.
- Example:
+ Example
-------
@dataclass
class CompanyNews(StandardParams):
@@ -437,7 +457,7 @@ class ProviderInterface(metaclass=SingletonMeta):
This creates a dictionary that maps model names to dataclasses that can be
injected as a FastAPI dependency.
- Example:
+ Example
-------
@dataclass
class CompanyNews(ProviderChoices):
@@ -471,7 +491,7 @@ class ProviderInterface(metaclass=SingletonMeta):
This creates a dictionary of dataclasses.
- Example:
+ Example
-------
class EquityHistoricalData(StandardData):
date: date
@@ -546,3 +566,60 @@ class ProviderInterface(metaclass=SingletonMeta):
fields=[("provider", Literal[tuple(available_providers)])], # type: ignore
bases=(ProviderChoices,),
)
+
+ def _generate_return_annotations(
+ self, original_models: Dict[str, Dict[str, Any]]
+ ) -> Dict[str, Type[OBBject]]:
+ """Generate return annotations for FastAPI.
+
+ Example
+ -------
+ class Data(BaseModel):
+ ...
+
+ class EquityData(Data):
+ price: float
+
+ class YFEquityData(EquityData):
+ yf_field: str
+
+ class AVEquityData(EquityData):
+ av_field: str
+
+ class OBBject(BaseModel):
+ results: List[
+ SerializeAsAny[
+ Annotated[
+ Union[
+ Annotated[YFEquityData, Tag("yf")],
+ Annotated[AVEquityData, Tag("av")],
+ ],
+ Discriminator(get_provider),
+ ]
+ ]
+ ]
+ """
+
+ def get_provider(v: Type[BaseModel]):
+ """Callable to discriminate which BaseModel to use."""
+ return getattr(v, "_provider", None)
+
+ annotations = {}
+ for name, models in original_models.items():
+ outer = set()
+ args = set()
+ for provider, model in models.items():
+ data = model["data"]
+ outer.add(model["results_type"])
+ args.add(Annotated[data, Tag(provider)])
+ # We set the provider to use it in discriminator function
+ setattr(data, "_provider", provider)
+ meta = Discriminator(get_provider) if len(args) > 1 else None
+ inner = SerializeAsAny[Annotated[Union[tuple(args)], meta]] # type: ignore[misc,valid-type]
+ full = Union[tuple((o[inner] if o else inner) for o in outer)] # type: ignore[valid-type]
+ annotations[name] = create_model(
+ f"OBBject_{name}",
+ __base__=OBBject[full], # type: ignore[valid-type]
+ __doc__=f"OBBject with results of type {name}",
+ )
+ return annotations
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py
index 51bfc836d07..c9d196216aa 100644
--- a/openbb_platform/core/openbb_core/app/router.py
+++ b/openbb_platform/core/openbb_core/app/router.py
@@ -12,7 +12,6 @@ from typing import (
Mapping,
Optional,
Type,
- Union,
get_args,
get_origin,
get_type_hints,
@@ -20,7 +19,7 @@ from typing import (
)
from fastapi import APIRouter, Depends
-from pydantic import BaseModel, Field, SerializeAsAny, Tag, create_model
+from pydantic import BaseModel
from pydantic.v1.validators import find_validators
from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias
@@ -397,10 +396,9 @@ class SignatureInspector:
callable_=provider_interface.params[model]["extra"],
)
- func = cls.inject_return_type(
+ func = cls.inject_return_annotation(
func=func,
- return_map=provider_interface.return_map.get(model, {}),
- model=model,
+ annotation=provider_interface.return_annotations[model],
)
else:
@@ -418,60 +416,6 @@ class SignatureInspector:
return func
@staticmethod
- def inject_return_type(
- func: Callable[P, OBBject],
- return_map: Dict[str, dict],
- model: str,
- ) -> Callable[P, OBBject]:
- """Inject full return model into the function. Also updates __name__ and __doc__ for API schemas."""
- results: Dict[str, Any] = {"list_type": [], "dict_type": []}
-
- for provider, return_data in return_map.items():
- if return_data["is_list"]:
- results["list_type"].append(
- Annotated[return_data["model"], Tag(provider)]
- )
- continue
-
- results["dict_type"].append(Annotated[return_data["model"], Tag(provider)])
-
- list_models, union_models = results.values()
-
- return_types = []
- for t, v in results.items():
- if not v:
- continue
-
- inner_type: Any = SerializeAsAny[ # type: ignore[misc,valid-type]
- Annotated[
- Union[tuple(v)], # type: ignore
- Field(discriminator="provider"),
- ]
- ]
- return_types.append(List[inner_type] if t == "list_type" else inner_type)
-
- return_type = create_model(
- f"OBBject_{model}",
- __base__=OBBject,
- __doc__=f"OBBject with results of type {model}",
- results=(
- Optional[Union[tuple(return_types)]], # type: ignore
- Field(
- None,
- description="Serializable results.",
- json_schema_extra={
- "model": model,
- "has_list": bool(len(list_models) > 0),
- "is_union": bool(list_models and union_models),
- },
- ),
- ),
- )
-
- func.__annotations__["return"] = return_type
- return func
-
- @staticmethod
def polish_return_schema(func: Callable[P, OBBject]) -> Callable[P, OBBject]:
"""Polish API schemas by filling `__doc__` and `__name__`."""
return_type = func.__annotations__["return"]
@@ -518,6 +462,14 @@ class SignatureInspector:
return func
@staticmethod
+ def inject_return_annotation(
+ func: Callable[P, OBBject], annotation: Type[OBBject]
+ ) -> Callable[P, OBBject]:
+ """Annotate function with return annotation."""
+ func.__annotations__["return"] = annotation
+ return func
+
+ @staticmethod
def get_description(func: Callable) -> str:
"""Get description from docstring."""
doc = func.__doc__
diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py
index 4559d0762b8..df445cb3e50 100644
--- a/openbb_platform/core/openbb_core/app/static/package_builder.py
+++ b/openbb_platform/core/openbb_core/app/static/package_builder.py
@@ -7,6 +7,7 @@ import re
import shutil
import sys
from dataclasses import Field
+from functools import partial
from inspect import Parameter, _empty, isclass, signature
from json import dumps, load
from pathlib import Path
@@ -24,6 +25,7 @@ from typing import (
TypeVar,
Union,
get_args,
+ get_origin,
get_type_hints,
)
@@ -41,6 +43,7 @@ from openbb_core.app.model.custom_parameter import (
OpenBBCustomParameter,
)
from openbb_core.app.model.example import Example
+from openbb_core.app.model.obbject import OBBject
from openbb_core.app.provider_interface import ProviderInterface
from openbb_core.app.router import RouterLoader
from openbb_core.app.static.utils.console import Console
@@ -1146,18 +1149,25 @@ class DocstringGenerator:
# Parameters passed as **kwargs
kwarg_params = params["extra"].__dataclass_fields__
param_types.update({k: v.type for k, v in kwarg_params.items()})
-
- returns = return_schema.model_fields
- results_type = func.__annotations__.get("return", model_name)
- if hasattr(results_type, "results_type_repr"):
- results_type = results_type.results_type_repr()
-
+ # Format the annotation to hide the metadata, tags, etc.
+ annotation = func.__annotations__.get("return")
+ results_type = (
+ cls._get_repr(
+ cls._get_generic_types(
+ annotation.model_fields["results"].annotation, # type: ignore[union-attr]
+ [],
+ ),
+ model_name,
+ )
+ if isclass(annotation) and issubclass(annotation, OBBject) # type: ignore[arg-type]
+ else model_name
+ )
doc = cls.generate_model_docstring(
model_name=model_name,
summary=func.__doc__ or "",
explicit_params=explicit_params,
kwarg_params=kwarg_params,
- returns=returns,
+ returns=return_schema.model_fields,
results_type=results_type,
)
else:
@@ -1172,6 +1182,65 @@ class DocstringGenerator:
return doc
+ @classmethod
+ def _get_generic_types(cls, type_: type, items: list) -> List[str]:
+ """Unpack generic types recursively.
+
+ Parameters
+ ----------
+ type_ : type
+ Type to unpack.
+ items : list
+ List to store the unpacked types.
+
+ Returns
+ -------
+ List[str]
+ List of unpacked type names.
+
+ Examples
+ --------
+ Union[List[str], Dict[str, str], Tuple[str]] -> ["List", "Dict", "Tuple"]
+ """
+ if hasattr(type_, "__args__"):
+ origin = get_origin(type_)
+ # pylint: disable=unidiomatic-typecheck
+ if (
+ type(origin) is type
+ and origin is not Annotated
+ and (name := getattr(type_, "_name", getattr(type_, "__name__", None)))
+ ):
+ items.append(name.title())
+ func = partial(cls._get_generic_types, items=items)
+ set().union(*map(func, type_.__args__), items)
+ return items
+
+ @staticmethod
+ def _get_repr(items: List[str], model: str) -> str:
+ """Get the string representation of the types list with the model name.
+
+ Parameters
+ ----------
+ items : List[str]
+ List of type names.
+ model : str
+ Model name to access the model providers.
+
+ Returns
+ -------
+ str
+ String representation of the unpacked types list.
+
+ Examples
+ --------
+ [List, Dict, Tuple], M -> "Union[List[M], Dict[str, M], Tuple[M]]"
+ """
+ if s := [
+ f"Dict[str, {model}]" if i == "Dict" else f"{i}[{model}]" for i in items
+ ]:
+ return f"Union[{', '.join(s)}]" if len(s) > 1 else s[0]
+ return model
+
class PathHandler:
"""Handle the paths for the Platform."""
@@ -1362,7 +1431,7 @@ class ReferenceGenerator:
"""
provider_field_params = []
expanded_types = MethodDefinition.TYPE_EXPANSION
- model_map = cls.pi._map[model] # pylint: disable=protected-access
+ model_map = cls.pi.map[model]
for field, field_info in model_map[provider][params_type]["fields"].items():
# Determine the field type, expanding it if necessary and if params_type is "Parameters"
@@ -1605,7 +1674,7 @@ class ReferenceGenerator:
route, "description", "No description available."
)
# Access model map from the ProviderInterface
- model_map = cls.pi._map[
+ model_map = cls.pi.map[
standard_model
] # pylint: disable=protected-access
diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py
index b3463b5da79..05dff9ef4f6 100644
--- a/openbb_platform/core/openbb_core/provider/registry_map.py
+++ b/openbb_platform/core/openbb_core/provider/registry_map.py
@@ -1,12 +1,11 @@
"""Provider registry map."""
-import sys
from copy import deepcopy
from inspect import getfile, isclass
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, get_origin
-from pydantic import BaseModel, Field, create_model
+from pydantic import BaseModel
from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.fetcher import Fetcher
@@ -27,8 +26,8 @@ class RegistryMap:
self._registry = registry or RegistryLoader.from_extensions()
self._credentials = self._get_credentials(self._registry)
self._available_providers = self._get_available_providers(self._registry)
- self._map, self._return_map = self._get_map(self._registry)
- self._models = self._get_models(self._map)
+ self._standard_extra, self._original_models = self._get_maps(self._registry)
+ self._models = self._get_models(self._standard_extra)
@property
def registry(self) -> Registry:
@@ -46,14 +45,14 @@ class RegistryMap:
return self._credentials
@property
- def map(self) -> MapType:
- """Get provider registry map."""
- return self._map
+ def standard_extra(self) -> MapType:
+ """Get standard extra map."""
+ return self._standard_extra
@property
- def return_map(self) -> MapType:
- """Get provider registry return map."""
- return self._return_map
+ def original_models(self) -> MapType:
+ """Get original models."""
+ return self._original_models
@property
def models(self) -> List[str]:
@@ -72,37 +71,42 @@ class RegistryMap:
"""Get list of available providers."""
return sorted(list(registry.providers.keys()))
- def _get_map(self, registry: Registry) -> Tuple[MapType, Dict[str, Dict]]:
+ def _get_maps(self, registry: Registry) -> Tuple[MapType, Dict[str, Dict]]:
"""Generate map for the provider package."""
- map_: MapType = {}
- return_schemas: Dict[str, Dict] = {}
+ standard_extra: MapType = {}
+ original_models: Dict[str, Dict] = {}
for p in registry.providers:
for model_name, fetcher in registry.providers[p].fetcher_dict.items():
- standard_query, extra_query = self.extract_info(fetcher, "query_params")
- standard_data, extra_data = self.extract_info(fetcher, "data")
- if model_name not in map_:
- map_[model_name] = {}
+ standard_query, extra_query = self._extract_info(
+ fetcher, "query_params"
+ )
+ standard_data, extra_data = self._extract_info(fetcher, "data")
+ if model_name not in standard_extra:
+ standard_extra[model_name] = {}
# The deepcopy avoids modifications from one model to affect another
- map_[model_name]["openbb"] = {
+ standard_extra[model_name]["openbb"] = {
"QueryParams": deepcopy(standard_query),
"Data": deepcopy(standard_data),
}
- map_[model_name][p] = {
+ standard_extra[model_name][p] = {
"QueryParams": extra_query,
"Data": extra_data,
}
- if provider_model := self.extract_data_model(fetcher, p):
- is_list = get_origin(self.extract_return_type(fetcher)) == list
-
- return_schemas.setdefault(model_name, {}).update(
- {p: {"model": provider_model, "is_list": is_list}}
- )
+ original_models.setdefault(model_name, {}).update(
+ {
+ p: {
+ "query": self._get_model(fetcher, "query_params"),
+ "data": self._get_model(fetcher, "data"),
+ "results_type": self._get_results_type(fetcher),
+ }
+ }
+ )
- self._merge_json_schema_extra(p, fetcher, map_[model_name])
+ self._merge_json_schema_extra(p, fetcher, standard_extra[model_name])
- return map_, return_schemas
+ return standard_extra, original_models
def _merge_json_schema_extra(
self,
@@ -138,48 +142,14 @@ class RegistryMap:
return list(map_.keys())
@staticmethod
- def extract_return_type(fetcher: Fetcher):
+ def _get_results_type(fetcher: Fetcher) -> Any:
"""Extract return info from fetcher."""
- return getattr(fetcher, "return_type", None)
-
- @staticmethod
- def extract_data_model(fetcher: Fetcher, provider_str: str) -> BaseModel:
- """Extract info (fields and docstring) from fetcher query params or data."""
- model: BaseModel = RegistryMap._get_model(fetcher, "data")
- model_name = getattr(model, "__name__", "")
- fields = {}
- for field_name, field in model.model_fields.items():
- field.serialization_alias = field_name
- fields[field_name] = (field.annotation, field)
-