summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorteh_coderer <me@tehcoderer.com>2024-02-06 04:46:21 -0500
committerGitHub <noreply@github.com>2024-02-06 09:46:21 +0000
commitd6742378c64ab583b26dd42a762f0d414b73d7ed (patch)
tree6488fd8053700d12ea5edfe936fb649f078b69a6
parentb5f41fe411ea2895da1e16ff307f324514154c24 (diff)
fix openapi schema fields `to_snake` (#6036)
-rw-r--r--openbb_platform/core/openbb_core/app/router.py22
-rw-r--r--openbb_platform/core/openbb_core/provider/registry_map.py29
2 files changed, 33 insertions, 18 deletions
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py
index bb626caba4d..6e2e23de9b7 100644
--- a/openbb_platform/core/openbb_core/app/router.py
+++ b/openbb_platform/core/openbb_core/app/router.py
@@ -46,7 +46,7 @@ def model_t_discriminator(v: Any) -> str:
if isinstance(v, dict):
return v.get("provider", "openbb")
- if isinstance(v, list):
+ if isinstance(v, list) and v:
return model_t_discriminator(v[0])
return getattr(v, "provider", "openbb")
@@ -397,7 +397,7 @@ class SignatureInspector:
Also updates __name__ and __doc__ for API schemas.
"""
- union_models = [Annotated[None, Tag("openbb")]]
+ union_models = [Annotated[Union[list, dict], Tag("openbb")]]
for provider, return_type in return_map.items():
union_models.append(Annotated[return_type, Tag(provider)])
@@ -406,15 +406,17 @@ class SignatureInspector:
f"OBBject_{model}",
__base__=OBBject,
results=(
- Annotated[
- Union[tuple(union_models)], # type: ignore
- Field(
- ...,
- description="Serializable results.",
- discriminator=Discriminator(model_t_discriminator),
- ),
+ Optional[
+ Annotated[
+ Union[tuple(union_models)], # type: ignore
+ Field(
+ None,
+ description="Serializable results.",
+ discriminator=Discriminator(model_t_discriminator),
+ ),
+ ]
],
- Field(..., description="Serializable results."),
+ Field(None, description="Serializable results."),
),
)
diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py
index ff9b710a3b8..5a2cd8e7c51 100644
--- a/openbb_platform/core/openbb_core/provider/registry_map.py
+++ b/openbb_platform/core/openbb_core/provider/registry_map.py
@@ -113,17 +113,30 @@ class RegistryMap:
"""Extract info (fields and docstring) from fetcher query params or data."""
model: BaseModel = RegistryMap._get_model(fetcher, "data")
- class DataModel(model):
- model_config = ConfigDict(alias_generator=alias_generators.to_snake)
+ fields = {}
+ for field_name, field in model.model_fields.items():
+ field.alias_priority = None
+ fields[field_name] = (field.annotation, field)
- provider: Literal[provider_str, "openbb"] = Field( # type: ignore
- default=provider_str,
- description="The data provider for the data.",
- exclude=True,
- )
+ fields.pop("provider", None)
return create_model(
- model.__name__, __base__=DataModel, __module__=model.__module__
+ model.__name__.replace("Data", ""),
+ __doc__=model.__doc__,
+ __config__=ConfigDict(
+ extra="allow",
+ alias_generator=alias_generators.to_snake,
+ populate_by_name=True,
+ ),
+ provider=(
+ Literal[provider_str, "openbb"], # type: ignore
+ Field(
+ default=provider_str,
+ description="The data provider for the data.",
+ exclude=True,
+ ),
+ ),
+ **fields,
)
@staticmethod