diff options
author | Pratyush Shukla <ps4534@nyu.edu> | 2024-03-17 00:45:17 +0530 |
---|---|---|
committer | Pratyush Shukla <ps4534@nyu.edu> | 2024-03-17 00:45:17 +0530 |
commit | 3599691dc03e0c86683a555658243bb256a897ad (patch) | |
tree | a976ba612cd08bed07c3cdf2c08ff224f3ff9a4c | |
parent | 9ae13079f3a96d831b466ceff088a0e57f8038ac (diff) |
include overriden standard params in map
-rw-r--r-- | openbb_platform/core/openbb_core/app/provider_interface.py | 15 | ||||
-rw-r--r-- | openbb_platform/core/openbb_core/provider/registry_map.py | 10 |
2 files changed, 15 insertions, 10 deletions
diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py index eab7ea1e807..a920c141d2b 100644 --- a/openbb_platform/core/openbb_core/app/provider_interface.py +++ b/openbb_platform/core/openbb_core/app/provider_interface.py @@ -249,22 +249,22 @@ class ProviderInterface(metaclass=SingletonMeta): if query: # We need to use query if we want the field description to show up in the # swagger, it's a fastapi limitation - default = Query( + default = Query( # type: ignore default=default, title=provider_name, description=description, alias=field.alias or None, - json_schema_extra=field.json_schema_extra, + json_schema_extra=getattr(field, "json_schema_extra", {}), ) elif provider_name: - default: FieldInfo = Field( + default: FieldInfo = Field( # type: ignore default=default or None, title=provider_name, description=description, - json_schema_extra=field.json_schema_extra, + json_schema_extra=getattr(field, "json_schema_extra", {}), ) - return DataclassField(new_name, type_, default) + return DataclassField(new_name, type_, default) # type: ignore @classmethod def _extract_params( @@ -287,7 +287,10 @@ class ProviderInterface(metaclass=SingletonMeta): ) else: for name, field in model_details["QueryParams"]["fields"].items(): - if name not in providers["openbb"]["QueryParams"]["fields"]: + if (name not in providers["openbb"]["QueryParams"]["fields"]) or ( + field.annotation + != providers["openbb"]["QueryParams"]["fields"][name].annotation + ): s_name = to_snake_case(name) incoming = cls._create_field( s_name, diff --git a/openbb_platform/core/openbb_core/provider/registry_map.py b/openbb_platform/core/openbb_core/provider/registry_map.py index 38c49aadef2..b01d208c0e7 100644 --- a/openbb_platform/core/openbb_core/provider/registry_map.py +++ b/openbb_platform/core/openbb_core/provider/registry_map.py @@ -110,7 +110,7 @@ class RegistryMap: fetcher: Fetcher, model_map: dict, ): - """Merge json schema extra for different providers""" + """Merge json schema extra for different providers.""" model: BaseModel = RegistryMap._get_model(fetcher, "query_params") std_fields = model_map["openbb"]["QueryParams"]["fields"] extra_fields = model_map[provider]["QueryParams"]["fields"] @@ -162,7 +162,7 @@ class RegistryMap: ) provider_model = create_model( - model.__name__.replace("Data", ""), + model.__name__.replace("Data", ""), # type: ignore __base__=model, __doc__=model.__doc__, __module__=model.__module__, @@ -171,7 +171,7 @@ class RegistryMap: # Replace the provider models in the modules with the new models we created # To make sure provider field is defined to be the provider string - setattr(sys.modules[model.__module__], model.__name__, provider_model) + setattr(sys.modules[model.__module__], model.__name__, provider_model) # type: ignore return provider_model @@ -204,7 +204,9 @@ class RegistryMap: # We ignore fields that are already in the standard model for name, field in all_fields.items(): - if name not in standard_info["fields"]: + if (name not in standard_info["fields"]) or ( + standard_info["fields"][name].annotation != field.annotation + ): extra_info["fields"][name] = field return standard_info, extra_info |