summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/app/provider_interface.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/app/provider_interface.py')
-rw-r--r--openbb_platform/core/openbb_core/app/provider_interface.py143
1 files changed, 84 insertions, 59 deletions
diff --git a/openbb_platform/core/openbb_core/app/provider_interface.py b/openbb_platform/core/openbb_core/app/provider_interface.py
index eab7ea1e807..5bef4ac6e6c 100644
--- a/openbb_platform/core/openbb_core/app/provider_interface.py
+++ b/openbb_platform/core/openbb_core/app/provider_interface.py
@@ -18,7 +18,7 @@ from openbb_core.provider.query_executor import QueryExecutor
from openbb_core.provider.registry_map import MapType, RegistryMap
from openbb_core.provider.utils.helpers import to_snake_case
-TupleFieldType = Tuple[str, type, Any]
+TupleFieldType = Tuple[str, Optional[Type], Optional[Any]]
@dataclass
@@ -26,8 +26,8 @@ class DataclassField:
"""Dataclass field."""
name: str
- type_: type
- default: Any
+ annotation: Optional[Type]
+ default: Optional[Any]
@dataclass
@@ -154,20 +154,21 @@ class ProviderInterface(metaclass=SingletonMeta):
def create_executor(self) -> QueryExecutor:
"""Get query executor."""
- return self._query_executor(self._registry_map.registry) # type: ignore
+ return self._query_executor(self._registry_map.registry) # type: ignore[operator]
@staticmethod
def _merge_fields(
current: DataclassField, incoming: DataclassField, query: bool = False
) -> DataclassField:
- current_name = current.name
- current_type = current.type_
- current_desc = getattr(current.default, "description", "")
+ """Merge 2 dataclass fields."""
+ curr_name = current.name
+ curr_type: Optional[Type] = current.annotation
+ curr_desc = getattr(current.default, "description", "")
+ curr_json_schema_extra = getattr(current.default, "json_schema_extra", {})
- incoming_type = incoming.type_
- incoming_desc = getattr(incoming.default, "description", "")
-
- F: Union[Callable, object] = Query if query else FieldInfo
+ inc_type: Optional[Type] = incoming.annotation
+ inc_desc = getattr(incoming.default, "description", "")
+ inc_json_schema_extra = getattr(incoming.default, "json_schema_extra", {})
def split_desc(desc: str) -> str:
"""Split field description."""
@@ -175,34 +176,52 @@ class ProviderInterface(metaclass=SingletonMeta):
detail = item[0] if item else ""
return detail
- curr_detail = split_desc(current_desc)
- inc_detail = split_desc(incoming_desc)
+ def merge_json_schema_extra(curr: dict, inc: dict) -> dict:
+ """Merge json schema extra."""
+ for key in curr.keys() & inc.keys():
+ # Merge keys that are in both dictionaries if both are lists
+ curr_value = curr[key]
+ inc_value = inc[key]
+ if isinstance(curr_value, list) and isinstance(inc_value, list):
+ curr[key] = list(set(curr.get(key, []) + inc.get(key, [])))
+ inc.pop(key)
+
+ # Add any remaining keys from inc to curr
+ curr.update(inc)
+ return curr
+
+ json_schema_extra: dict = merge_json_schema_extra(
+ curr=curr_json_schema_extra or {}, inc=inc_json_schema_extra or {}
+ )
+
+ curr_detail = split_desc(curr_desc)
+ inc_detail = split_desc(inc_desc)
- providers = f"{current.default.title},{incoming.default.title}"
+ curr_title = getattr(current.default, "title", "")
+ inc_title = getattr(incoming.default, "title", "")
+ providers = ",".join([curr_title, inc_title])
formatted_prov = providers.replace(",", ", ")
if SequenceMatcher(None, curr_detail, inc_detail).ratio() > 0.8:
new_desc = f"{curr_detail} (provider: {formatted_prov})"
else:
- new_desc = f"{current_desc};\n {incoming_desc}"
+ new_desc = f"{curr_desc};\n {inc_desc}"
- merged_default = F( # type: ignore
- default=current.default.default,
+ QF: Callable = Query if query else FieldInfo # type: ignore[assignment]
+ merged_default = QF(
+ default=getattr(current.default, "default", None),
title=providers,
description=new_desc,
+ json_schema_extra=json_schema_extra,
)
- merged_type = (
- Union[current_type, incoming_type]
- if current_type != incoming_type
- else current_type
+ merged_type: Optional[Type] = (
+ Union[curr_type, inc_type] # type: ignore[assignment]
+ if curr_type != inc_type
+ else curr_type
)
- return DataclassField(
- name=current_name,
- type_=merged_type, # type: ignore
- default=merged_default,
- )
+ return DataclassField(curr_name, merged_type, merged_default)
@staticmethod
def _create_field(
@@ -213,19 +232,17 @@ class ProviderInterface(metaclass=SingletonMeta):
force_optional: bool = False,
) -> DataclassField:
new_name = name.replace(".", "_")
- # field.type_ don't work for nested types
- # field.outer_type_ don't work for Optional nested types
- type_ = field.annotation
+ annotation = field.annotation
additional_description = ""
if (extra := field.json_schema_extra) and (
multiple := extra.get("multiple_items_allowed") # type: ignore
):
if provider_name:
- additional_description += " Multiple items allowed."
+ additional_description += " Multiple comma separated items allowed."
else:
additional_description += (
- " Multiple items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
+ " Multiple comma separated items allowed for provider(s): " + ", ".join(multiple) + "." # type: ignore
)
provider_field = (
@@ -239,7 +256,7 @@ class ProviderInterface(metaclass=SingletonMeta):
if field.is_required():
if force_optional:
- type_ = Optional[type_] # type: ignore
+ annotation = Optional[annotation] # type: ignore
default = None
else:
default = ...
@@ -247,24 +264,32 @@ class ProviderInterface(metaclass=SingletonMeta):
default = field.default
if query:
- # We need to use query if we want the field description to show up in the
- # swagger, it's a fastapi limitation
- default = Query(
- default=default,
- title=provider_name,
- description=description,
- alias=field.alias or None,
- json_schema_extra=field.json_schema_extra,
+ # We need to use query if we want the field description to show
+ # up in the swagger, it's a fastapi limitation
+ return DataclassField(
+ new_name,
+ annotation,
+ Query(
+ default=default,
+ title=provider_name,
+ description=description,
+ alias=field.alias or None,
+ json_schema_extra=getattr(field, "json_schema_extra", None),
+ ),
)
- elif provider_name:
- default: FieldInfo = Field(
- default=default or None,
- title=provider_name,
- description=description,
- json_schema_extra=field.json_schema_extra,
+ if provider_name:
+ return DataclassField(
+ new_name,
+ annotation,
+ Field(
+ default=default or None,
+ title=provider_name,
+ description=description,
+ json_schema_extra=field.json_schema_extra,
+ ),
)
- return DataclassField(new_name, type_, default)
+ return DataclassField(new_name, annotation, default)
@classmethod
def _extract_params(
@@ -282,7 +307,7 @@ class ProviderInterface(metaclass=SingletonMeta):
standard[incoming.name] = (
incoming.name,
- incoming.type_,
+ incoming.annotation,
incoming.default,
)
else:
@@ -305,7 +330,7 @@ class ProviderInterface(metaclass=SingletonMeta):
extra[updated.name] = (
updated.name,
- updated.type_,
+ updated.annotation,
updated.default,
)
@@ -331,7 +356,7 @@ class ProviderInterface(metaclass=SingletonMeta):
standard[incoming.name] = (
incoming.name,
- incoming.type_,
+ incoming.annotation,
incoming.default,
)
else:
@@ -357,7 +382,7 @@ class ProviderInterface(metaclass=SingletonMeta):
extra[updated.name] = (
updated.name,
- updated.type_,
+ updated.annotation,
updated.default,
)
@@ -393,14 +418,14 @@ class ProviderInterface(metaclass=SingletonMeta):
standard, extra = self._extract_params(providers)
result[model_name] = {
- "standard": make_dataclass( # type: ignore
+ "standard": make_dataclass(
cls_name=model_name,
- fields=list(standard.values()),
+ fields=list(standard.values()), # type: ignore[arg-type]
bases=(StandardParams,),
),
- "extra": make_dataclass( # type: ignore
+ "extra": make_dataclass(
cls_name=model_name,
- fields=list(extra.values()),
+ fields=list(extra.values()), # type: ignore[arg-type]
bases=(ExtraParams,),
),
}
@@ -464,14 +489,14 @@ class ProviderInterface(metaclass=SingletonMeta):
extra: dict
standard, extra = self._extract_data(providers)
result[model_name] = {
- "standard": make_dataclass( # type: ignore
+ "standard": make_dataclass(
cls_name=model_name,
- fields=list(standard.values()),
+ fields=list(standard.values()), # type: ignore[arg-type]
bases=(StandardData,),
),
- "extra": make_dataclass( # type: ignore
+ "extra": make_dataclass(
cls_name=model_name,
- fields=list(extra.values()),
+ fields=list(extra.values()), # type: ignore[arg-type]
bases=(ExtraData,),
),
}