Skip to content

Commit

Permalink
Add from_call_result
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Nov 2, 2023
1 parent 307e4af commit 7583e2a
Showing 1 changed file with 91 additions and 76 deletions.
167 changes: 91 additions & 76 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,77 @@ def __init__(
def is_success(self) -> bool:
return self.status_code < 400

@classmethod
def from_call_result(
cls,
logger: logging.Logger,
framework: const.Framework,
call_res: execution.CallResult,
) -> CommResponse:
headers = {
**net.create_headers(framework=framework),
const.HeaderKey.SERVER_TIMING.value: "handler",
}

if execution.is_step_call_responses(call_res):
out: list[dict[str, object]] = []
for item in call_res:
match item.to_dict():
case result.Ok(d):
out.append(d)
case result.Err(err):
return cls.from_error(
logger,
framework,
err,
)

return cls(
body=transforms.prep_body(out),
headers=headers,
status_code=http.HTTPStatus.PARTIAL_CONTENT.value,
)

if isinstance(call_res, execution.CallError):
logger.error(call_res.stack)

match call_res.to_dict():
case result.Ok(d):
body = transforms.prep_body(d)
case result.Err(err):
return cls.from_error(
logger,
framework,
err,
)

if call_res.is_retriable is False:
headers[const.HeaderKey.NO_RETRY.value] = "true"

return cls(
body=body,
headers=headers,
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value,
)

if isinstance(call_res, execution.FunctionCallResponse):
return cls(
body=call_res.data,
headers=headers,
)

return cls.from_error(
logger,
framework,
errors.UnknownError("unknown call result"),
)

@classmethod
def from_error(
cls,
logger: logging.Logger,
err: Exception,
framework: const.Framework,
err: Exception,
) -> CommResponse:
code: str | None = None
status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR.value
Expand Down Expand Up @@ -187,7 +252,9 @@ async def call_function(
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger, err, self._framework
self._logger,
self._framework,
err,
)
else:
match self._get_function(fn_id):
Expand All @@ -199,16 +266,20 @@ async def call_function(
# level.
self._client.middleware.after_run_execution_sync()

comm_res = self._create_response(call_res)
comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger,
err,
self._framework,
err,
)

self._client.middleware.before_response_sync()
Expand Down Expand Up @@ -239,7 +310,9 @@ def call_function_sync(
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger, err, self._framework
self._logger,
self._framework,
err,
)
else:
match self._get_function(fn_id):
Expand All @@ -251,83 +324,25 @@ def call_function_sync(
# level.
self._client.middleware.after_run_execution_sync()

comm_res = self._create_response(call_res)
comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger,
err,
self._framework,
err,
)

self._client.middleware.before_response_sync()
return comm_res

def _create_response(
self,
call_res: execution.CallResult,
) -> CommResponse:
headers = {
**net.create_headers(framework=self._framework),
const.HeaderKey.SERVER_TIMING.value: "handler",
}

if execution.is_step_call_responses(call_res):
out: list[dict[str, object]] = []
for item in call_res:
match item.to_dict():
case result.Ok(d):
out.append(d)
case result.Err(err):
return CommResponse.from_error(
self._logger,
err,
self._framework,
)

return CommResponse(
body=transforms.prep_body(out),
headers=headers,
status_code=http.HTTPStatus.PARTIAL_CONTENT.value,
)

if isinstance(call_res, execution.CallError):
self._logger.error(call_res.stack)

match call_res.to_dict():
case result.Ok(d):
body = transforms.prep_body(d)
case result.Err(err):
return CommResponse.from_error(
self._logger,
err,
self._framework,
)

if call_res.is_retriable is False:
headers[const.HeaderKey.NO_RETRY.value] = "true"

return CommResponse(
body=body,
headers=headers,
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value,
)

if isinstance(call_res, execution.FunctionCallResponse):
return CommResponse(
body=call_res.data,
headers=headers,
)

return CommResponse.from_error(
self._logger,
errors.UnknownError("unknown call result"),
self._framework,
)

def _get_function(
self, fn_id: str
) -> result.Result[function.Function, Exception]:
Expand Down Expand Up @@ -386,15 +401,15 @@ def _parse_registration_response(
except Exception:
return CommResponse.from_error(
self._logger,
errors.RegistrationError("response is not valid JSON"),
self._framework,
errors.RegistrationError("response is not valid JSON"),
)

if not isinstance(server_res_body, dict):
return CommResponse.from_error(
self._logger,
errors.RegistrationError("response is not an object"),
self._framework,
errors.RegistrationError("response is not an object"),
)

if server_res.status_code < 400:
Expand All @@ -409,8 +424,8 @@ def _parse_registration_response(
msg = "registration failed"
comm_res = CommResponse.from_error(
self._logger,
errors.RegistrationError(msg.strip()),
self._framework,
errors.RegistrationError(msg.strip()),
)
comm_res.status_code = server_res.status_code
return comm_res
Expand All @@ -432,8 +447,8 @@ async def register(
self._logger.error(err)
return CommResponse.from_error(
self._logger,
err,
self._framework,
err,
)

async with httpx.AsyncClient() as client:
Expand All @@ -446,8 +461,8 @@ async def register(
self._logger.error(err)
return CommResponse.from_error(
self._logger,
err,
self._framework,
err,
)

return res
Expand All @@ -469,8 +484,8 @@ def register_sync(
self._logger.error(err)
return CommResponse.from_error(
self._logger,
err,
self._framework,
err,
)

with httpx.Client() as client:
Expand All @@ -481,8 +496,8 @@ def register_sync(
self._logger.error(err)
return CommResponse.from_error(
self._logger,
err,
self._framework,
err,
)

return res
Expand Down

0 comments on commit 7583e2a

Please sign in to comment.