diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/app/provider_interface.py')
-rw-r--r-- | openbb_platform/core/openbb_core/app/provider_interface.py | 143 |
1 files changed, 84 insertions, 59 deletions
diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py index eab7ea1e807..5bef4ac6e6c 100644 --- a/openbb_platform/core/openbb_core/app/provider_interface.py +++ b/openbb_platform/core/openbb_core/app/provider_interface.py @@ -18,7 +18,7 @@ from openbb_core.provider.query_executor import QueryExecutor from openbb_core.provider.registry_map import MapType, RegistryMap from openbb_core.provider.utils.helpers import to_snake_case -TupleFieldType = Tuple[str, type, Any] +TupleFieldType = Tuple[str, Optional[Type], Optional[Any]] @dataclass @@ -26,8 +26,8 @@ class DataclassField: """Dataclass field.""" name: str - type_: type - default: Any + annotation: Optional[Type] + default: Optional[Any] @dataclass @@ -154,20 +154,21 @@ class ProviderInterface(metaclass=SingletonMeta): def create_executor(self) -> QueryExecutor: """Get query executor.""" - return self._query_executor(self._registry_map.registry) # type: ignore + return self._query_executor(self._registry_map.registry) # type: ignore[operator] @staticmethod def _merge_fields( current: DataclassField, incoming: DataclassField, query: bool = False ) -> DataclassField: - current_name = current.name - current_type = current.type_ - current_desc = getattr(current.default, "description", "") + """Merge 2 dataclass fields.""" + curr_name = current.name + curr_type: Optional[Type] = current.annotation + curr_desc = getattr(current.default, "description", "") + curr_json_schema_extra = getattr(current.default, "json_schema_extra", {}) - incoming_type = incoming.type_ - incoming_desc = getattr(incoming.default, "description", "") - - F: Union[Callable, object] = Query if query else FieldInfo + inc_type: Optional[Type] = incoming.annotation + inc_desc = getattr(incoming.default, "description", "") + inc_json_schema_extra = getattr(incoming.default, "json_schema_extra", {}) def split_desc(desc: str) -> str: """Split field description.""" @@ -175,34 +176,52 @@ class ProviderInterface(metaclass=SingletonMeta): detail = item[0] if item else "" return detail - curr_detail = split_desc(current_desc) - inc_detail = split_desc(incoming_desc) + def merge_json_schema_extra(curr: dict, inc: dict) -> dict: + """Merge json schema extra.""" + for key in curr.keys() & inc.keys(): + # Merge keys that are in both dictionaries if both are lists + curr_value = curr[key] + inc_value = inc[key] + if isinstance(curr_value, list) and isinstance(inc_value, list): + curr[key] = list(set(curr.get(key, []) + inc.get(key, []))) + inc.pop(key) + + # Add any remaining keys from inc to curr + curr.update(inc) + return curr + + json_schema_extra: dict = merge_json_schema_extra( + curr=curr_json_schema_extra or {}, inc=inc_json_schema_extra or {} + ) + + curr_detail = split_desc(curr_desc) + inc_detail = split_desc(inc_desc) - providers = f"{current.default.title},{incoming.default.title}" + curr_title = getattr(current.default, "title", "") + inc_title = getattr(incoming.default, "title", "") + providers = ",".join([curr_title, inc_title]) formatted_prov = providers.replace(",", ", ") if SequenceMatcher(None, curr_detail, inc_detail).ratio() > 0.8: new_desc = f"{curr_detail} (provider: {formatted_prov})" else: - new_desc = f"{current_desc};\n {incoming_desc}" + new_desc = f"{curr_desc};\n {inc_desc}" - merged_default = F( # type: ignore - default=current.default.default, + QF: Callable = Query if query else FieldInfo # type: ignore[assignment] + merged_default = QF( + default=getattr(current.default, "default", None), title=providers, description=new_desc, + json_schema_extra=json_schema_extra, ) - merged_type = ( - Union[current_type, incoming_type] - if current_type != incoming_type - else current_type + merged_type: Optional[Type] = ( + Union[curr_type, inc_type] # type: ignore[assignment] + if curr_type != inc_type + else curr_type ) - return DataclassField( - name=current_name, - type_=merged_type, # type: ignore - default=merged_default, - ) + return DataclassField(curr_name, merged_type, merged_default) @staticmethod def _create_field( @@ -213,19 +232,17 @@ class ProviderInterface(metaclass=SingletonMeta): force_optional: bool = False, ) -> DataclassField: new_name = name.replace(".", "_") - # field.type_ don't work for nested types - # field.outer_type_ don't work for Optional nested types - type_ = field.annotation + annotation = field.annotation additional_description = "" if (extra := field.json_schema_extra) and ( multiple := extra.get("multiple_items_allowed") # type: ignore ): if provider_name: - additional_description += " Multiple items allowed." + additional_description += " Multiple comma separated items allowed." else: additional_description += ( - " Multiple items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore + " Multiple comma separated items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore ) provider_field = ( @@ -239,7 +256,7 @@ class ProviderInterface(metaclass=SingletonMeta): if field.is_required(): if force_optional: - type_ = Optional[type_] # type: ignore + annotation = Optional[annotation] # type: ignore default = None else: default = ... @@ -247,24 +264,32 @@ class ProviderInterface(metaclass=SingletonMeta): default = field.default if query: - # We need to use query if we want the field description to show up in the - # swagger, it's a fastapi limitation - default = Query( - default=default, - title=provider_name, - description=description, - alias=field.alias or None, - json_schema_extra=field.json_schema_extra, + # We need to use query if we want the field description to show + # up in the swagger, it's a fastapi limitation + return DataclassField( + new_name, + annotation, + Query( + default=default, + title=provider_name, + description=description, + alias=field.alias or None, + json_schema_extra=getattr(field, "json_schema_extra", None), + ), ) - elif provider_name: - default: FieldInfo = Field( - default=default or None, - title=provider_name, - description=description, - json_schema_extra=field.json_schema_extra, + if provider_name: + return DataclassField( + new_name, + annotation, + Field( + default=default or None, + title=provider_name, + description=description, + json_schema_extra=field.json_schema_extra, + ), ) - return DataclassField(new_name, type_, default) + return DataclassField(new_name, annotation, default) @classmethod def _extract_params( @@ -282,7 +307,7 @@ class ProviderInterface(metaclass=SingletonMeta): standard[incoming.name] = ( incoming.name, - incoming.type_, + incoming.annotation, incoming.default, ) else: @@ -305,7 +330,7 @@ class ProviderInterface(metaclass=SingletonMeta): extra[updated.name] = ( updated.name, - updated.type_, + updated.annotation, updated.default, ) @@ -331,7 +356,7 @@ class ProviderInterface(metaclass=SingletonMeta): standard[incoming.name] = ( incoming.name, - incoming.type_, + incoming.annotation, incoming.default, ) else: @@ -357,7 +382,7 @@ class ProviderInterface(metaclass=SingletonMeta): extra[updated.name] = ( updated.name, - updated.type_, + updated.annotation, updated.default, ) @@ -393,14 +418,14 @@ class ProviderInterface(metaclass=SingletonMeta): standard, extra = self._extract_params(providers) result[model_name] = { - "standard": make_dataclass( # type: ignore + "standard": make_dataclass( cls_name=model_name, - fields=list(standard.values()), + fields=list(standard.values()), # type: ignore[arg-type] bases=(StandardParams,), ), - "extra": make_dataclass( # type: ignore + "extra": make_dataclass( cls_name=model_name, - fields=list(extra.values()), + fields=list(extra.values()), # type: ignore[arg-type] bases=(ExtraParams,), ), } @@ -464,14 +489,14 @@ class ProviderInterface(metaclass=SingletonMeta): extra: dict standard, extra = self._extract_data(providers) result[model_name] = { - "standard": make_dataclass( # type: ignore + "standard": make_dataclass( cls_name=model_name, - fields=list(standard.values()), + fields=list(standard.values()), # type: ignore[arg-type] bases=(StandardData,), ), - "extra": make_dataclass( # type: ignore + "extra": make_dataclass( cls_name=model_name, - fields=list(extra.values()), + fields=list(extra.values()), # type: ignore[arg-type] bases=(ExtraData,), ), } |