summaryrefslogtreecommitdiffstats
path: root/openbb_platform/core/openbb_core/api/router/commands.py
diff options
context:
space:
mode:
Diffstat (limited to 'openbb_platform/core/openbb_core/api/router/commands.py')
-rw-r--r--openbb_platform/core/openbb_core/api/router/commands.py55
1 files changed, 36 insertions, 19 deletions
diff --git a/openbb_platform/core/openbb_core/api/router/commands.py b/openbb_platform/core/openbb_core/api/router/commands.py
index 96334f549fa..2d2ea3e6a11 100644
--- a/openbb_platform/core/openbb_core/api/router/commands.py
+++ b/openbb_platform/core/openbb_core/api/router/commands.py
@@ -51,11 +51,16 @@ def build_new_signature(path: str, func: Callable) -> Signature:
parameter_list = sig.parameters.values()
return_annotation = sig.return_annotation
new_parameter_list = []
-
- for parameter in parameter_list:
+ var_kw_pos = len(parameter_list)
+ for pos, parameter in enumerate(parameter_list):
if parameter.name == "cc" and parameter.annotation == CommandContext:
continue
+ if parameter.kind == Parameter.VAR_KEYWORD:
+ # We track VAR_KEYWORD parameter to insert the any additional
+ # parameters we need to add before it and avoid a SyntaxError
+ var_kw_pos = pos
+
new_parameter_list.append(
Parameter(
parameter.name,
@@ -66,18 +71,21 @@ def build_new_signature(path: str, func: Callable) -> Signature:
)
if CHARTING_INSTALLED and path.replace("/", "_")[1:] in Charting.functions():
- new_parameter_list.append(
+ new_parameter_list.insert(
+ var_kw_pos,
Parameter(
"chart",
kind=Parameter.POSITIONAL_OR_KEYWORD,
default=False,
annotation=bool,
- )
+ ),
)
+ var_kw_pos += 1
if custom_headers := SystemService().system_settings.api_settings.custom_headers:
for name, default in custom_headers.items():
- new_parameter_list.append(
+ new_parameter_list.insert(
+ var_kw_pos,
Parameter(
name.replace("-", "_"),
kind=Parameter.POSITIONAL_OR_KEYWORD,
@@ -85,11 +93,13 @@ def build_new_signature(path: str, func: Callable) -> Signature:
annotation=Annotated[
Optional[str], Header(include_in_schema=False)
],
- )
+ ),
)
+ var_kw_pos += 1
if Env().API_AUTH:
- new_parameter_list.append(
+ new_parameter_list.insert(
+ var_kw_pos,
Parameter(
"__authenticated_user_settings",
kind=Parameter.POSITIONAL_OR_KEYWORD,
@@ -97,8 +107,9 @@ def build_new_signature(path: str, func: Callable) -> Signature:
annotation=Annotated[
UserSettings, Depends(AuthService().user_settings_hook)
],
- )
+ ),
)
+ var_kw_pos += 1
return Signature(
parameters=new_parameter_list,
@@ -106,7 +117,7 @@ def build_new_signature(path: str, func: Callable) -> Signature:
)
-def validate_output(c_out: OBBject) -> OBBject:
+def validate_output(c_out: OBBject) -> Dict:
"""
Validate OBBject object.
@@ -121,8 +132,8 @@ def validate_output(c_out: OBBject) -> OBBject:
Returns
-------
- OBBject
- Validated OBBject object.
+ Dict
+ Serialized OBBject.
"""
def is_model(type_):
@@ -134,25 +145,32 @@ def validate_output(c_out: OBBject) -> OBBject:
json_schema_extra = field.json_schema_extra if field else None
# case where 1st layer field needs to be excluded
- if json_schema_extra and json_schema_extra.get("exclude_from_api", None):
+ if (
+ json_schema_extra
+ and isinstance(json_schema_extra, dict)
+ and json_schema_extra.get("exclude_from_api", None)
+ ):
delattr(c_out, key)
# if it's a model with nested fields
elif is_model(type_):
for field_name, field in type_.__fields__.items():
- if field.json_schema_extra and field.json_schema_extra.get(
- "exclude_from_api", None
+ extra = getattr(field, "json_schema_extra", None)
+ if (
+ extra
+ and isinstance(extra, dict)
+ and extra.get("exclude_from_api", None)
):
delattr(value, field_name)
# if it's a yet a nested model we need to go deeper in the recursion
- elif is_model(field.annotation):
+ elif is_model(getattr(field, "annotation", None)):
exclude_fields_from_api(field_name, getattr(value, field_name))
for k, v in c_out.model_copy():
exclude_fields_from_api(k, v)
- return c_out
+ return c_out.model_dump()
def build_api_wrapper(
@@ -170,7 +188,7 @@ def build_api_wrapper(
func.__annotations__ = new_annotations_map
@wraps(wrapped=func)
- async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]):
+ async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Dict:
user_settings: UserSettings = UserSettings.model_validate(
kwargs.pop(
"__authenticated_user_settings",
@@ -180,8 +198,7 @@ def build_api_wrapper(
execute = partial(command_runner.run, path, user_settings)
output: OBBject = await execute(*args, **kwargs)
- output = validate_output(output)
- return output
+ return validate_output(output)
return wrapper