summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/provider/registry_map.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/provider/registry_map.py')
-rw-r--r--openbb_platform/core/openbb_core/provider/registry_map.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py
index 38c49aadef2..b3463b5da79 100644
--- a/openbb_platform/core/openbb_core/provider/registry_map.py
+++ b/openbb_platform/core/openbb_core/provider/registry_map.py
@@ -110,7 +110,7 @@ class RegistryMap:
fetcher: Fetcher,
model_map: dict,
):
- """Merge json schema extra for different providers"""
+ """Merge json schema extra for different providers."""
model: BaseModel = RegistryMap._get_model(fetcher, "query_params")
std_fields = model_map["openbb"]["QueryParams"]["fields"]
extra_fields = model_map[provider]["QueryParams"]["fields"]
@@ -146,7 +146,7 @@ class RegistryMap:
def extract_data_model(fetcher: Fetcher, provider_str: str) -> BaseModel:
"""Extract info (fields and docstring) from fetcher query params or data."""
model: BaseModel = RegistryMap._get_model(fetcher, "data")
-
+ model_name = getattr(model, "__name__", "")
fields = {}
for field_name, field in model.model_fields.items():
field.serialization_alias = field_name
@@ -157,12 +157,11 @@ class RegistryMap:
Field(
default=provider_str,
description="The data provider for the data.",
- exclude=True,
),
)
- provider_model = create_model(
- model.__name__.replace("Data", ""),
+ provider_model = create_model( # type: ignore[call-overload]
+ model_name.replace("Data", ""),
__base__=model,
__doc__=model.__doc__,
__module__=model.__module__,
@@ -171,7 +170,11 @@ class RegistryMap:
# Replace the provider models in the modules with the new models we created
# To make sure provider field is defined to be the provider string
- setattr(sys.modules[model.__module__], model.__name__, provider_model)
+ # This is hacky, but we need to have `provider: Literal['provider_name']`
+ # in the model to serve as union discriminator for the API validation
+ # the alternative would be to specify it manually in all the models
+ if model_name:
+ setattr(sys.modules[model.__module__], model_name, provider_model)
return provider_model