diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/provider/registry_map.py')
-rw-r--r-- | openbb_platform/core/openbb_core/provider/registry_map.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index 38c49aadef2..b3463b5da79 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -110,7 +110,7 @@ class RegistryMap: fetcher: Fetcher, model_map: dict, ): - """Merge json schema extra for different providers""" + """Merge json schema extra for different providers.""" model: BaseModel = RegistryMap._get_model(fetcher, "query_params") std_fields = model_map["openbb"]["QueryParams"]["fields"] extra_fields = model_map[provider]["QueryParams"]["fields"] @@ -146,7 +146,7 @@ class RegistryMap: def extract_data_model(fetcher: Fetcher, provider_str: str) -> BaseModel: """Extract info (fields and docstring) from fetcher query params or data.""" model: BaseModel = RegistryMap._get_model(fetcher, "data") - + model_name = getattr(model, "__name__", "") fields = {} for field_name, field in model.model_fields.items(): field.serialization_alias = field_name @@ -157,12 +157,11 @@ class RegistryMap: Field( default=provider_str, description="The data provider for the data.", - exclude=True, ), ) - provider_model = create_model( - model.__name__.replace("Data", ""), + provider_model = create_model( # type: ignore[call-overload] + model_name.replace("Data", ""), __base__=model, __doc__=model.__doc__, __module__=model.__module__, @@ -171,7 +170,11 @@ class RegistryMap: # Replace the provider models in the modules with the new models we created # To make sure provider field is defined to be the provider string - setattr(sys.modules[model.__module__], model.__name__, provider_model) + # This is hacky, but we need to have `provider: Literal['provider_name']` + # in the model to serve as union discriminator for the API validation + # the alternative would be to specify it manually in all the models + if model_name: + setattr(sys.modules[model.__module__], model_name, provider_model) return provider_model |