From d6742378c64ab583b26dd42a762f0d414b73d7ed Mon Sep 17 00:00:00 2001 From: teh_coderer Date: Tue, 6 Feb 2024 04:46:21 -0500 Subject: fix openapi schema fields `to_snake` (#6036) --- openbb_platform/core/openbb_core/app/router.py | 22 ++++++++-------- .../core/openbb_core/provider/registry_map.py | 29 ++++++++++++++++------ 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index bb626caba4d..6e2e23de9b7 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -46,7 +46,7 @@ def model_t_discriminator(v: Any) -> str: if isinstance(v, dict): return v.get("provider", "openbb") - if isinstance(v, list): + if isinstance(v, list) and v: return model_t_discriminator(v[0]) return getattr(v, "provider", "openbb") @@ -397,7 +397,7 @@ class SignatureInspector: Also updates __name__ and __doc__ for API schemas. """ - union_models = [Annotated[None, Tag("openbb")]] + union_models = [Annotated[Union[list, dict], Tag("openbb")]] for provider, return_type in return_map.items(): union_models.append(Annotated[return_type, Tag(provider)]) @@ -406,15 +406,17 @@ class SignatureInspector: f"OBBject_{model}", __base__=OBBject, results=( - Annotated[ - Union[tuple(union_models)], # type: ignore - Field( - ..., - description="Serializable results.", - discriminator=Discriminator(model_t_discriminator), - ), + Optional[ + Annotated[ + Union[tuple(union_models)], # type: ignore + Field( + None, + description="Serializable results.", + discriminator=Discriminator(model_t_discriminator), + ), + ] ], - Field(..., description="Serializable results."), + Field(None, description="Serializable results."), ), ) diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index ff9b710a3b8..5a2cd8e7c51 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -113,17 +113,30 @@ class RegistryMap: """Extract info (fields and docstring) from fetcher query params or data.""" model: BaseModel = RegistryMap._get_model(fetcher, "data") - class DataModel(model): - model_config = ConfigDict(alias_generator=alias_generators.to_snake) + fields = {} + for field_name, field in model.model_fields.items(): + field.alias_priority = None + fields[field_name] = (field.annotation, field) - provider: Literal[provider_str, "openbb"] = Field( # type: ignore - default=provider_str, - description="The data provider for the data.", - exclude=True, - ) + fields.pop("provider", None) return create_model( - model.__name__, __base__=DataModel, __module__=model.__module__ + model.__name__.replace("Data", ""), + __doc__=model.__doc__, + __config__=ConfigDict( + extra="allow", + alias_generator=alias_generators.to_snake, + populate_by_name=True, + ), + provider=( + Literal[provider_str, "openbb"], # type: ignore + Field( + default=provider_str, + description="The data provider for the data.", + exclude=True, + ), + ), + **fields, ) @staticmethod -- cgit v1.2.3