summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPratyush Shukla <ps4534@nyu.edu>2024-03-17 00:45:17 +0530
committerPratyush Shukla <ps4534@nyu.edu>2024-03-17 00:45:17 +0530
commit3599691dc03e0c86683a555658243bb256a897ad (patch)
treea976ba612cd08bed07c3cdf2c08ff224f3ff9a4c
parent9ae13079f3a96d831b466ceff088a0e57f8038ac (diff)
include overriden standard params in map
-rw-r--r--openbb_platform/core/openbb_core/app/provider_interface.py15
-rw-r--r--openbb_platform/core/openbb_core/provider/registry_map.py10
2 files changed, 15 insertions, 10 deletions
diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py
index eab7ea1e807..a920c141d2b 100644
--- a/openbb_platform/core/openbb_core/app/provider_interface.py
+++ b/openbb_platform/core/openbb_core/app/provider_interface.py
@@ -249,22 +249,22 @@ class ProviderInterface(metaclass=SingletonMeta):
if query:
# We need to use query if we want the field description to show up in the
# swagger, it's a fastapi limitation
- default = Query(
+ default = Query( # type: ignore
default=default,
title=provider_name,
description=description,
alias=field.alias or None,
- json_schema_extra=field.json_schema_extra,
+ json_schema_extra=getattr(field, "json_schema_extra", {}),
)
elif provider_name:
- default: FieldInfo = Field(
+ default: FieldInfo = Field( # type: ignore
default=default or None,
title=provider_name,
description=description,
- json_schema_extra=field.json_schema_extra,
+ json_schema_extra=getattr(field, "json_schema_extra", {}),
)
- return DataclassField(new_name, type_, default)
+ return DataclassField(new_name, type_, default) # type: ignore
@classmethod
def _extract_params(
@@ -287,7 +287,10 @@ class ProviderInterface(metaclass=SingletonMeta):
)
else:
for name, field in model_details["QueryParams"]["fields"].items():
- if name not in providers["openbb"]["QueryParams"]["fields"]:
+ if (name not in providers["openbb"]["QueryParams"]["fields"]) or (
+ field.annotation
+ != providers["openbb"]["QueryParams"]["fields"][name].annotation
+ ):
s_name = to_snake_case(name)
incoming = cls._create_field(
s_name,
diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py
index 38c49aadef2..b01d208c0e7 100644
--- a/openbb_platform/core/openbb_core/provider/registry_map.py
+++ b/openbb_platform/core/openbb_core/provider/registry_map.py
@@ -110,7 +110,7 @@ class RegistryMap:
fetcher: Fetcher,
model_map: dict,
):
- """Merge json schema extra for different providers"""
+ """Merge json schema extra for different providers."""
model: BaseModel = RegistryMap._get_model(fetcher, "query_params")
std_fields = model_map["openbb"]["QueryParams"]["fields"]
extra_fields = model_map[provider]["QueryParams"]["fields"]
@@ -162,7 +162,7 @@ class RegistryMap:
)
provider_model = create_model(
- model.__name__.replace("Data", ""),
+ model.__name__.replace("Data", ""), # type: ignore
__base__=model,
__doc__=model.__doc__,
__module__=model.__module__,
@@ -171,7 +171,7 @@ class RegistryMap:
# Replace the provider models in the modules with the new models we created
# To make sure provider field is defined to be the provider string
- setattr(sys.modules[model.__module__], model.__name__, provider_model)
+ setattr(sys.modules[model.__module__], model.__name__, provider_model) # type: ignore
return provider_model
@@ -204,7 +204,9 @@ class RegistryMap:
# We ignore fields that are already in the standard model
for name, field in all_fields.items():
- if name not in standard_info["fields"]:
+ if (name not in standard_info["fields"]) or (
+ standard_info["fields"][name].annotation != field.annotation
+ ):
extra_info["fields"][name] = field
return standard_info, extra_info