diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm.py index 292f81da..60f0431c 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm.py @@ -39,12 +39,77 @@ def __init__( def is_success(self) -> bool: return self.status_code < 400 + @classmethod + def from_call_result( + cls, + logger: logging.Logger, + framework: const.Framework, + call_res: execution.CallResult, + ) -> CommResponse: + headers = { + **net.create_headers(framework=framework), + const.HeaderKey.SERVER_TIMING.value: "handler", + } + + if execution.is_step_call_responses(call_res): + out: list[dict[str, object]] = [] + for item in call_res: + match item.to_dict(): + case result.Ok(d): + out.append(d) + case result.Err(err): + return cls.from_error( + logger, + framework, + err, + ) + + return cls( + body=transforms.prep_body(out), + headers=headers, + status_code=http.HTTPStatus.PARTIAL_CONTENT.value, + ) + + if isinstance(call_res, execution.CallError): + logger.error(call_res.stack) + + match call_res.to_dict(): + case result.Ok(d): + body = transforms.prep_body(d) + case result.Err(err): + return cls.from_error( + logger, + framework, + err, + ) + + if call_res.is_retriable is False: + headers[const.HeaderKey.NO_RETRY.value] = "true" + + return cls( + body=body, + headers=headers, + status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + if isinstance(call_res, execution.FunctionCallResponse): + return cls( + body=call_res.data, + headers=headers, + ) + + return cls.from_error( + logger, + framework, + errors.UnknownError("unknown call result"), + ) + @classmethod def from_error( cls, logger: logging.Logger, - err: Exception, framework: const.Framework, + err: Exception, ) -> CommResponse: code: str | None = None status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR.value @@ -187,7 +252,9 @@ async def call_function( extra["code"] = err.code self._logger.error(err, extra=extra) comm_res = CommResponse.from_error( - self._logger, err, self._framework + self._logger, + self._framework, + err, ) else: match self._get_function(fn_id): @@ -199,7 +266,11 @@ async def call_function( # level. self._client.middleware.after_run_execution_sync() - comm_res = self._create_response(call_res) + comm_res = CommResponse.from_call_result( + self._logger, + self._framework, + call_res, + ) case result.Err(err): extra = {} if isinstance(err, errors.InternalError): @@ -207,8 +278,8 @@ async def call_function( self._logger.error(err, extra=extra) comm_res = CommResponse.from_error( self._logger, - err, self._framework, + err, ) self._client.middleware.before_response_sync() @@ -239,7 +310,9 @@ def call_function_sync( extra["code"] = err.code self._logger.error(err, extra=extra) comm_res = CommResponse.from_error( - self._logger, err, self._framework + self._logger, + self._framework, + err, ) else: match self._get_function(fn_id): @@ -251,7 +324,11 @@ def call_function_sync( # level. self._client.middleware.after_run_execution_sync() - comm_res = self._create_response(call_res) + comm_res = CommResponse.from_call_result( + self._logger, + self._framework, + call_res, + ) case result.Err(err): extra = {} if isinstance(err, errors.InternalError): @@ -259,75 +336,13 @@ def call_function_sync( self._logger.error(err, extra=extra) comm_res = CommResponse.from_error( self._logger, - err, self._framework, + err, ) self._client.middleware.before_response_sync() return comm_res - def _create_response( - self, - call_res: execution.CallResult, - ) -> CommResponse: - headers = { - **net.create_headers(framework=self._framework), - const.HeaderKey.SERVER_TIMING.value: "handler", - } - - if execution.is_step_call_responses(call_res): - out: list[dict[str, object]] = [] - for item in call_res: - match item.to_dict(): - case result.Ok(d): - out.append(d) - case result.Err(err): - return CommResponse.from_error( - self._logger, - err, - self._framework, - ) - - return CommResponse( - body=transforms.prep_body(out), - headers=headers, - status_code=http.HTTPStatus.PARTIAL_CONTENT.value, - ) - - if isinstance(call_res, execution.CallError): - self._logger.error(call_res.stack) - - match call_res.to_dict(): - case result.Ok(d): - body = transforms.prep_body(d) - case result.Err(err): - return CommResponse.from_error( - self._logger, - err, - self._framework, - ) - - if call_res.is_retriable is False: - headers[const.HeaderKey.NO_RETRY.value] = "true" - - return CommResponse( - body=body, - headers=headers, - status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - if isinstance(call_res, execution.FunctionCallResponse): - return CommResponse( - body=call_res.data, - headers=headers, - ) - - return CommResponse.from_error( - self._logger, - errors.UnknownError("unknown call result"), - self._framework, - ) - def _get_function( self, fn_id: str ) -> result.Result[function.Function, Exception]: @@ -386,15 +401,15 @@ def _parse_registration_response( except Exception: return CommResponse.from_error( self._logger, - errors.RegistrationError("response is not valid JSON"), self._framework, + errors.RegistrationError("response is not valid JSON"), ) if not isinstance(server_res_body, dict): return CommResponse.from_error( self._logger, - errors.RegistrationError("response is not an object"), self._framework, + errors.RegistrationError("response is not an object"), ) if server_res.status_code < 400: @@ -409,8 +424,8 @@ def _parse_registration_response( msg = "registration failed" comm_res = CommResponse.from_error( self._logger, - errors.RegistrationError(msg.strip()), self._framework, + errors.RegistrationError(msg.strip()), ) comm_res.status_code = server_res.status_code return comm_res @@ -432,8 +447,8 @@ async def register( self._logger.error(err) return CommResponse.from_error( self._logger, - err, self._framework, + err, ) async with httpx.AsyncClient() as client: @@ -446,8 +461,8 @@ async def register( self._logger.error(err) return CommResponse.from_error( self._logger, - err, self._framework, + err, ) return res @@ -469,8 +484,8 @@ def register_sync( self._logger.error(err) return CommResponse.from_error( self._logger, - err, self._framework, + err, ) with httpx.Client() as client: @@ -481,8 +496,8 @@ def register_sync( self._logger.error(err) return CommResponse.from_error( self._logger, - err, self._framework, + err, ) return res