diff options
author | teh_coderer <me@tehcoderer.com> | 2024-02-06 04:46:21 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-06 09:46:21 +0000 |
commit | d6742378c64ab583b26dd42a762f0d414b73d7ed (patch) | |
tree | 6488fd8053700d12ea5edfe936fb649f078b69a6 | |
parent | b5f41fe411ea2895da1e16ff307f324514154c24 (diff) |
fix openapi schema fields `to_snake` (#6036)
-rw-r--r-- | openbb_platform/core/openbb_core/app/router.py | 22 | ||||
-rw-r--r-- | openbb_platform/core/openbb_core/provider/registry_map.py | 29 |
2 files changed, 33 insertions, 18 deletions
diff --git a/openbb_platform/core/openbb_core/app/router.py b/openbb_platform/core/openbb_core/app/router.py index bb626caba4d..6e2e23de9b7 100644 --- a/openbb_platform/core/openbb_core/app/router.py +++ b/openbb_platform/core/openbb_core/app/router.py @@ -46,7 +46,7 @@ def model_t_discriminator(v: Any) -> str: if isinstance(v, dict): return v.get("provider", "openbb") - if isinstance(v, list): + if isinstance(v, list) and v: return model_t_discriminator(v[0]) return getattr(v, "provider", "openbb") @@ -397,7 +397,7 @@ class SignatureInspector: Also updates __name__ and __doc__ for API schemas. """ - union_models = [Annotated[None, Tag("openbb")]] + union_models = [Annotated[Union[list, dict], Tag("openbb")]] for provider, return_type in return_map.items(): union_models.append(Annotated[return_type, Tag(provider)]) @@ -406,15 +406,17 @@ class SignatureInspector: f"OBBject_{model}", __base__=OBBject, results=( - Annotated[ - Union[tuple(union_models)], # type: ignore - Field( - ..., - description="Serializable results.", - discriminator=Discriminator(model_t_discriminator), - ), + Optional[ + Annotated[ + Union[tuple(union_models)], # type: ignore + Field( + None, + description="Serializable results.", + discriminator=Discriminator(model_t_discriminator), + ), + ] ], - Field(..., description="Serializable results."), + Field(None, description="Serializable results."), ), ) diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index ff9b710a3b8..5a2cd8e7c51 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -113,17 +113,30 @@ class RegistryMap: """Extract info (fields and docstring) from fetcher query params or data.""" model: BaseModel = RegistryMap._get_model(fetcher, "data") - class DataModel(model): - model_config = ConfigDict(alias_generator=alias_generators.to_snake) + fields = {} + for field_name, field in model.model_fields.items(): + field.alias_priority = None + fields[field_name] = (field.annotation, field) - provider: Literal[provider_str, "openbb"] = Field( # type: ignore - default=provider_str, - description="The data provider for the data.", - exclude=True, - ) + fields.pop("provider", None) return create_model( - model.__name__, __base__=DataModel, __module__=model.__module__ + model.__name__.replace("Data", ""), + __doc__=model.__doc__, + __config__=ConfigDict( + extra="allow", + alias_generator=alias_generators.to_snake, + populate_by_name=True, + ), + provider=( + Literal[provider_str, "openbb"], # type: ignore + Field( + default=provider_str, + description="The data provider for the data.", + exclude=True, + ), + ), + **fields, ) @staticmethod |