diff options
Diffstat (limited to 'openbb_platform/core/openbb_core/api/router/commands.py')
-rw-r--r-- | openbb_platform/core/openbb_core/api/router/commands.py | 55 |
1 files changed, 36 insertions, 19 deletions
diff --git a/openbb_platform/core/openbb_core/api/router/commands.py b/openbb_platform/core/openbb_core/api/router/commands.py index 96334f549fa..2d2ea3e6a11 100644 --- a/openbb_platform/core/openbb_core/api/router/commands.py +++ b/openbb_platform/core/openbb_core/api/router/commands.py @@ -51,11 +51,16 @@ def build_new_signature(path: str, func: Callable) -> Signature: parameter_list = sig.parameters.values() return_annotation = sig.return_annotation new_parameter_list = [] - - for parameter in parameter_list: + var_kw_pos = len(parameter_list) + for pos, parameter in enumerate(parameter_list): if parameter.name == "cc" and parameter.annotation == CommandContext: continue + if parameter.kind == Parameter.VAR_KEYWORD: + # We track VAR_KEYWORD parameter to insert the any additional + # parameters we need to add before it and avoid a SyntaxError + var_kw_pos = pos + new_parameter_list.append( Parameter( parameter.name, @@ -66,18 +71,21 @@ def build_new_signature(path: str, func: Callable) -> Signature: ) if CHARTING_INSTALLED and path.replace("/", "_")[1:] in Charting.functions(): - new_parameter_list.append( + new_parameter_list.insert( + var_kw_pos, Parameter( "chart", kind=Parameter.POSITIONAL_OR_KEYWORD, default=False, annotation=bool, - ) + ), ) + var_kw_pos += 1 if custom_headers := SystemService().system_settings.api_settings.custom_headers: for name, default in custom_headers.items(): - new_parameter_list.append( + new_parameter_list.insert( + var_kw_pos, Parameter( name.replace("-", "_"), kind=Parameter.POSITIONAL_OR_KEYWORD, @@ -85,11 +93,13 @@ def build_new_signature(path: str, func: Callable) -> Signature: annotation=Annotated[ Optional[str], Header(include_in_schema=False) ], - ) + ), ) + var_kw_pos += 1 if Env().API_AUTH: - new_parameter_list.append( + new_parameter_list.insert( + var_kw_pos, Parameter( "__authenticated_user_settings", kind=Parameter.POSITIONAL_OR_KEYWORD, @@ -97,8 +107,9 @@ def build_new_signature(path: str, func: Callable) -> Signature: annotation=Annotated[ UserSettings, Depends(AuthService().user_settings_hook) ], - ) + ), ) + var_kw_pos += 1 return Signature( parameters=new_parameter_list, @@ -106,7 +117,7 @@ def build_new_signature(path: str, func: Callable) -> Signature: ) -def validate_output(c_out: OBBject) -> OBBject: +def validate_output(c_out: OBBject) -> Dict: """ Validate OBBject object. @@ -121,8 +132,8 @@ def validate_output(c_out: OBBject) -> OBBject: Returns ------- - OBBject - Validated OBBject object. + Dict + Serialized OBBject. """ def is_model(type_): @@ -134,25 +145,32 @@ def validate_output(c_out: OBBject) -> OBBject: json_schema_extra = field.json_schema_extra if field else None # case where 1st layer field needs to be excluded - if json_schema_extra and json_schema_extra.get("exclude_from_api", None): + if ( + json_schema_extra + and isinstance(json_schema_extra, dict) + and json_schema_extra.get("exclude_from_api", None) + ): delattr(c_out, key) # if it's a model with nested fields elif is_model(type_): for field_name, field in type_.__fields__.items(): - if field.json_schema_extra and field.json_schema_extra.get( - "exclude_from_api", None + extra = getattr(field, "json_schema_extra", None) + if ( + extra + and isinstance(extra, dict) + and extra.get("exclude_from_api", None) ): delattr(value, field_name) # if it's a yet a nested model we need to go deeper in the recursion - elif is_model(field.annotation): + elif is_model(getattr(field, "annotation", None)): exclude_fields_from_api(field_name, getattr(value, field_name)) for k, v in c_out.model_copy(): exclude_fields_from_api(k, v) - return c_out + return c_out.model_dump() def build_api_wrapper( @@ -170,7 +188,7 @@ def build_api_wrapper( func.__annotations__ = new_annotations_map @wraps(wrapped=func) - async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]): + async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Dict: user_settings: UserSettings = UserSettings.model_validate( kwargs.pop( "__authenticated_user_settings", @@ -180,8 +198,7 @@ def build_api_wrapper( execute = partial(command_runner.run, path, user_settings) output: OBBject = await execute(*args, **kwargs) - output = validate_output(output) - return output + return validate_output(output) return wrapper |