Skip to content

Commit

Permalink
Merge pull request #4 from inngest/fixes
Browse files Browse the repository at this point in the history
More async/sync unification; fix function output
  • Loading branch information
amh4r authored Nov 1, 2023
2 parents 6fc74d9 + 8b31acc commit 54584ab
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 73 deletions.
38 changes: 12 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# Inngest Python SDK

## 🚧 Currently in alpha! Not guaranteed to be production ready! 🚧
> 🚧 Currently in beta! It hasn't been battle-tested in production environments yet.
Supported frameworks:

Expand All @@ -33,9 +33,11 @@ Supported frameworks:

## Usage

> 💡 Most of these examples don't show `async` functions but you can mix `async` and non-`async` functions in the same app!
- [Basic](#basic-no-steps)
- [Step run](#step-run)
- [Avoiding async/await](#avoiding-async-functions)
- [Async function](#async-function)

### Basic (no steps)

Expand All @@ -46,7 +48,6 @@ import flask
import inngest.flask
import requests


@inngest.create_function(
fn_id="find_person",
trigger=inngest.TriggerEvent(event="app/person.find"),
Expand All @@ -61,7 +62,6 @@ async def fetch_person(
res = requests.get(f"https://swapi.dev/api/people/{person_id}")
return res.json()


app = flask.Flask(__name__)
inngest_client = inngest.Inngest(app_id="flask_example")

Expand Down Expand Up @@ -89,10 +89,10 @@ The following example registers a function that will:
fn_id="find_ships",
trigger=inngest.TriggerEvent(event="app/ships.find"),
)
async def fetch_ships(
def fetch_ships(
*,
event: inngest.Event,
step: inngest.Step,
step: inngest.StepSync,
**_kwargs: object,
) -> dict:
"""
Expand Down Expand Up @@ -125,37 +125,23 @@ async def fetch_ships(
}
```

### Avoiding async functions

Completely avoiding `async`` functions only requires 2 differences:

1. Use `step: inngest.StepSync` instead of `step: inngest.Step`
2. Use `serve_sync` instead of `serve`
### Async functions

```py
@inngest.create_function(
fn_id="find_person",
trigger=inngest.TriggerEvent(event="app/person.find"),
)
def fetch_person(
async def fetch_person(
*,
event: inngest.Event,
step: inngest.StepSync,
step: inngest.Step,
**_kwargs: object,
) -> dict:
person_id = event.data["person_id"]
res = requests.get(f"https://swapi.dev/api/people/{person_id}")
return res.json()


app = flask.Flask(__name__)
inngest_client = inngest.Inngest(app_id="flask_example")

inngest.flask.serve_sync(
app,
inngest_client,
[fetch_person],
)
async with httpx.AsyncClient() as client:
res = await client.get(f"https://swapi.dev/api/people/{person_id}")
return res.json()
```

> 💡 You can mix `async` and non-`async` functions in the same app!
7 changes: 5 additions & 2 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
registration,
result,
transforms,
types,
)


Expand Down Expand Up @@ -226,7 +227,9 @@ def call_function_sync(

def _create_response(
self,
call_res: list[execution.CallResponse] | str | execution.CallError,
call_res: list[execution.CallResponse]
| types.Serializable
| execution.CallError,
) -> CommResponse:
comm_res = CommResponse(
headers={
Expand All @@ -235,7 +238,7 @@ def _create_response(
}
)

if isinstance(call_res, list):
if execution.is_call_responses(call_res):
out: list[dict[str, object]] = []
for item in call_res:
match item.to_dict():
Expand Down
9 changes: 9 additions & 0 deletions inngest/_internal/execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import enum
import typing

from . import errors, event_lib, transforms, types

Expand Down Expand Up @@ -49,6 +50,14 @@ class CallResponse(types.BaseModel):
opts: dict[str, object] | None = None


def is_call_responses(
value: object,
) -> typing.TypeGuard[list[CallResponse]]:
if not isinstance(value, list):
return False
return all(isinstance(item, CallResponse) for item in value)


class Opcode(enum.Enum):
SLEEP = "Sleep"
STEP = "Step"
Expand Down
43 changes: 33 additions & 10 deletions inngest/_internal/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dataclasses
import hashlib
import inspect
import json
import typing

import pydantic
Expand All @@ -16,6 +15,7 @@
execution,
function_config,
step_lib,
transforms,
types,
)

Expand All @@ -39,7 +39,7 @@ def __call__(
events: list[event_lib.Event],
run_id: str,
step: step_lib.Step,
) -> typing.Awaitable[types.JSONSerializableOutput]:
) -> typing.Awaitable[types.Serializable]:
...


Expand All @@ -53,7 +53,7 @@ def __call__(
events: list[event_lib.Event],
run_id: str,
step: step_lib.StepSync,
) -> types.JSONSerializableOutput:
) -> types.Serializable:
...


Expand Down Expand Up @@ -134,8 +134,22 @@ def id(self) -> str:

@property
def is_handler_async(self) -> bool:
"""
Whether the main handler is async.
"""
return _is_function_handler_async(self._handler)

@property
def is_on_failure_handler_async(self) -> bool | None:
"""
Whether the on_failure handler is async. Returns None if there isn't an
on_failure handler.
"""

if self._opts.on_failure is None:
return None
return _is_function_handler_async(self._opts.on_failure)

@property
def on_failure_fn_id(self) -> str | None:
return self._on_failure_fn_id
Expand All @@ -162,7 +176,9 @@ async def call(
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
) -> list[execution.CallResponse] | str | execution.CallError:
) -> list[
execution.CallResponse
] | types.Serializable | execution.CallError:
try:
handler: FunctionHandlerAsync | FunctionHandlerSync
if self.id == fn_id:
Expand Down Expand Up @@ -213,7 +229,7 @@ async def call(
)
)

return json.dumps(res)
return res
except step_lib.Interrupt as out:
return [
execution.CallResponse(
Expand All @@ -226,14 +242,20 @@ async def call(
)
]
except Exception as err:
# An error occurred with the user's code. Print the traceback to
# help them debug.
print(transforms.get_traceback(err))

return execution.CallError.from_error(err)

def call_sync(
self,
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
) -> list[execution.CallResponse] | str | execution.CallError:
) -> (
list[execution.CallResponse] | types.Serializable | execution.CallError
):
try:
handler: FunctionHandlerAsync | FunctionHandlerSync
if self.id == fn_id:
Expand All @@ -250,7 +272,7 @@ def call_sync(
)

if _is_function_handler_sync(handler):
res = handler(
return handler(
attempt=call.ctx.attempt,
event=call.event,
events=call.events,
Expand All @@ -261,9 +283,6 @@ def call_sync(
step_lib.StepIDCounter(),
),
)

return json.dumps(res)

return execution.CallError.from_error(
errors.MismatchedSync(
"encountered async function in non-async context"
Expand All @@ -281,6 +300,10 @@ def call_sync(
)
]
except Exception as err:
# An error occurred with the user's code. Print the traceback to
# help them debug.
print(transforms.get_traceback(err))

return execution.CallError.from_error(err)

def get_config(self, app_url: str) -> _Config:
Expand Down
44 changes: 35 additions & 9 deletions inngest/_internal/step_lib/step_async.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import datetime
import json
import inspect
import typing

from inngest._internal import (
client_lib,
errors,
event_lib,
execution,
result,
Expand All @@ -26,13 +25,35 @@ def __init__(
self._memos = memos
self._step_id_counter = step_id_counter

@typing.overload
async def run(
self,
step_id: str,
handler: typing.Callable[[], typing.Awaitable[types.T]],
) -> types.T:
handler: typing.Callable[[], typing.Awaitable[types.SerializableT]],
) -> types.SerializableT:
...

@typing.overload
async def run(
self,
step_id: str,
handler: typing.Callable[[], types.SerializableT],
) -> types.SerializableT:
...

async def run(
self,
step_id: str,
handler: typing.Callable[[], typing.Awaitable[types.SerializableT]]
| typing.Callable[[], types.SerializableT],
) -> types.SerializableT:
"""
Run logic that should be retried on error and memoized after success.
Args:
step_id: Unique step ID within the function. If the same step ID is
encountered multiple times then it'll get an index suffix.
handler: The logic to run. Can be async or sync.
"""

hashed_id = self._get_hashed_id(step_id)
Expand All @@ -41,12 +62,17 @@ async def run(
if memo is not types.EmptySentinel:
return memo # type: ignore

output = await handler()
if inspect.iscoroutinefunction(handler):
output = await handler()
else:
output = handler()

try:
json.dumps(output)
except TypeError as err:
raise errors.UnserializableOutput(str(err)) from None
# Check whether output is serializable
match transforms.dump_json(output):
case result.Ok(_):
pass
case result.Err(err):
raise err

raise base.Interrupt(
hashed_id=hashed_id,
Expand Down
17 changes: 11 additions & 6 deletions inngest/_internal/step_lib/step_sync.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import datetime
import json
import typing

from inngest._internal import (
client_lib,
errors,
event_lib,
execution,
result,
Expand Down Expand Up @@ -33,6 +31,11 @@ def run(
) -> types.T:
"""
Run logic that should be retried on error and memoized after success.
Args:
step_id: Unique step ID within the function. If the same step ID is
encountered multiple times then it'll get an index suffix.
handler: The logic to run.
"""

hashed_id = self._get_hashed_id(step_id)
Expand All @@ -43,10 +46,12 @@ def run(

output = handler()

try:
json.dumps(output)
except TypeError as err:
raise errors.UnserializableOutput(str(err)) from None
# Check whether output is serializable
match transforms.dump_json(output):
case result.Ok(_):
pass
case result.Err(err):
raise err

raise base.Interrupt(
hashed_id=hashed_id,
Expand Down
Loading

0 comments on commit 54584ab

Please sign in to comment.