Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More async/sync unification; fix function output #4

Merged
merged 2 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 12 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# Inngest Python SDK

## 🚧 Currently in alpha! Not guaranteed to be production ready! 🚧
> 🚧 Currently in beta! It hasn't been battle-tested in production environments yet.

Supported frameworks:

Expand All @@ -33,9 +33,11 @@ Supported frameworks:

## Usage

> 💡 Most of these examples don't show `async` functions but you can mix `async` and non-`async` functions in the same app!

- [Basic](#basic-no-steps)
- [Step run](#step-run)
- [Avoiding async/await](#avoiding-async-functions)
- [Async function](#async-function)

### Basic (no steps)

Expand All @@ -46,7 +48,6 @@ import flask
import inngest.flask
import requests


@inngest.create_function(
fn_id="find_person",
trigger=inngest.TriggerEvent(event="app/person.find"),
Expand All @@ -61,7 +62,6 @@ async def fetch_person(
res = requests.get(f"https://swapi.dev/api/people/{person_id}")
return res.json()


app = flask.Flask(__name__)
inngest_client = inngest.Inngest(app_id="flask_example")

Expand Down Expand Up @@ -89,10 +89,10 @@ The following example registers a function that will:
fn_id="find_ships",
trigger=inngest.TriggerEvent(event="app/ships.find"),
)
async def fetch_ships(
def fetch_ships(
*,
event: inngest.Event,
step: inngest.Step,
step: inngest.StepSync,
**_kwargs: object,
) -> dict:
"""
Expand Down Expand Up @@ -125,37 +125,23 @@ async def fetch_ships(
}
```

### Avoiding async functions

Completely avoiding `async`` functions only requires 2 differences:

1. Use `step: inngest.StepSync` instead of `step: inngest.Step`
2. Use `serve_sync` instead of `serve`
### Async functions

```py
@inngest.create_function(
fn_id="find_person",
trigger=inngest.TriggerEvent(event="app/person.find"),
)
def fetch_person(
async def fetch_person(
*,
event: inngest.Event,
step: inngest.StepSync,
step: inngest.Step,
**_kwargs: object,
) -> dict:
person_id = event.data["person_id"]
res = requests.get(f"https://swapi.dev/api/people/{person_id}")
return res.json()


app = flask.Flask(__name__)
inngest_client = inngest.Inngest(app_id="flask_example")

inngest.flask.serve_sync(
app,
inngest_client,
[fetch_person],
)
async with httpx.AsyncClient() as client:
res = await client.get(f"https://swapi.dev/api/people/{person_id}")
return res.json()
```

> 💡 You can mix `async` and non-`async` functions in the same app!
7 changes: 5 additions & 2 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
registration,
result,
transforms,
types,
)


Expand Down Expand Up @@ -226,7 +227,9 @@ def call_function_sync(

def _create_response(
self,
call_res: list[execution.CallResponse] | str | execution.CallError,
call_res: list[execution.CallResponse]
| types.Serializable
| execution.CallError,
) -> CommResponse:
comm_res = CommResponse(
headers={
Expand All @@ -235,7 +238,7 @@ def _create_response(
}
)

if isinstance(call_res, list):
if execution.is_call_responses(call_res):
out: list[dict[str, object]] = []
for item in call_res:
match item.to_dict():
Expand Down
9 changes: 9 additions & 0 deletions inngest/_internal/execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import enum
import typing

from . import errors, event_lib, transforms, types

Expand Down Expand Up @@ -49,6 +50,14 @@ class CallResponse(types.BaseModel):
opts: dict[str, object] | None = None


def is_call_responses(
value: object,
) -> typing.TypeGuard[list[CallResponse]]:
if not isinstance(value, list):
return False
return all(isinstance(item, CallResponse) for item in value)


class Opcode(enum.Enum):
SLEEP = "Sleep"
STEP = "Step"
Expand Down
43 changes: 33 additions & 10 deletions inngest/_internal/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dataclasses
import hashlib
import inspect
import json
import typing

import pydantic
Expand All @@ -16,6 +15,7 @@
execution,
function_config,
step_lib,
transforms,
types,
)

Expand All @@ -39,7 +39,7 @@ def __call__(
events: list[event_lib.Event],
run_id: str,
step: step_lib.Step,
) -> typing.Awaitable[types.JSONSerializableOutput]:
) -> typing.Awaitable[types.Serializable]:
...


Expand All @@ -53,7 +53,7 @@ def __call__(
events: list[event_lib.Event],
run_id: str,
step: step_lib.StepSync,
) -> types.JSONSerializableOutput:
) -> types.Serializable:
...


Expand Down Expand Up @@ -134,8 +134,22 @@ def id(self) -> str:

@property
def is_handler_async(self) -> bool:
"""
Whether the main handler is async.
"""
return _is_function_handler_async(self._handler)

@property
def is_on_failure_handler_async(self) -> bool | None:
"""
Whether the on_failure handler is async. Returns None if there isn't an
on_failure handler.
"""

if self._opts.on_failure is None:
return None
return _is_function_handler_async(self._opts.on_failure)

@property
def on_failure_fn_id(self) -> str | None:
return self._on_failure_fn_id
Expand All @@ -162,7 +176,9 @@ async def call(
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
) -> list[execution.CallResponse] | str | execution.CallError:
) -> list[
execution.CallResponse
] | types.Serializable | execution.CallError:
try:
handler: FunctionHandlerAsync | FunctionHandlerSync
if self.id == fn_id:
Expand Down Expand Up @@ -213,7 +229,7 @@ async def call(
)
)

return json.dumps(res)
return res
except step_lib.Interrupt as out:
return [
execution.CallResponse(
Expand All @@ -226,14 +242,20 @@ async def call(
)
]
except Exception as err:
# An error occurred with the user's code. Print the traceback to
# help them debug.
print(transforms.get_traceback(err))

return execution.CallError.from_error(err)

def call_sync(
self,
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
) -> list[execution.CallResponse] | str | execution.CallError:
) -> (
list[execution.CallResponse] | types.Serializable | execution.CallError
):
try:
handler: FunctionHandlerAsync | FunctionHandlerSync
if self.id == fn_id:
Expand All @@ -250,7 +272,7 @@ def call_sync(
)

if _is_function_handler_sync(handler):
res = handler(
return handler(
attempt=call.ctx.attempt,
event=call.event,
events=call.events,
Expand All @@ -261,9 +283,6 @@ def call_sync(
step_lib.StepIDCounter(),
),
)

return json.dumps(res)

return execution.CallError.from_error(
errors.MismatchedSync(
"encountered async function in non-async context"
Expand All @@ -281,6 +300,10 @@ def call_sync(
)
]
except Exception as err:
# An error occurred with the user's code. Print the traceback to
# help them debug.
print(transforms.get_traceback(err))

return execution.CallError.from_error(err)

def get_config(self, app_url: str) -> _Config:
Expand Down
44 changes: 35 additions & 9 deletions inngest/_internal/step_lib/step_async.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import datetime
import json
import inspect
import typing

from inngest._internal import (
client_lib,
errors,
event_lib,
execution,
result,
Expand All @@ -26,13 +25,35 @@ def __init__(
self._memos = memos
self._step_id_counter = step_id_counter

@typing.overload
async def run(
self,
step_id: str,
handler: typing.Callable[[], typing.Awaitable[types.T]],
) -> types.T:
handler: typing.Callable[[], typing.Awaitable[types.SerializableT]],
) -> types.SerializableT:
...

@typing.overload
async def run(
self,
step_id: str,
handler: typing.Callable[[], types.SerializableT],
) -> types.SerializableT:
...

async def run(
self,
step_id: str,
handler: typing.Callable[[], typing.Awaitable[types.SerializableT]]
| typing.Callable[[], types.SerializableT],
) -> types.SerializableT:
"""
Run logic that should be retried on error and memoized after success.

Args:
step_id: Unique step ID within the function. If the same step ID is
encountered multiple times then it'll get an index suffix.
handler: The logic to run. Can be async or sync.
"""

hashed_id = self._get_hashed_id(step_id)
Expand All @@ -41,12 +62,17 @@ async def run(
if memo is not types.EmptySentinel:
return memo # type: ignore

output = await handler()
if inspect.iscoroutinefunction(handler):
output = await handler()
else:
output = handler()

try:
json.dumps(output)
except TypeError as err:
raise errors.UnserializableOutput(str(err)) from None
# Check whether output is serializable
match transforms.dump_json(output):
case result.Ok(_):
pass
case result.Err(err):
raise err

raise base.Interrupt(
hashed_id=hashed_id,
Expand Down
17 changes: 11 additions & 6 deletions inngest/_internal/step_lib/step_sync.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import datetime
import json
import typing

from inngest._internal import (
client_lib,
errors,
event_lib,
execution,
result,
Expand Down Expand Up @@ -33,6 +31,11 @@ def run(
) -> types.T:
"""
Run logic that should be retried on error and memoized after success.

Args:
step_id: Unique step ID within the function. If the same step ID is
encountered multiple times then it'll get an index suffix.
handler: The logic to run.
"""

hashed_id = self._get_hashed_id(step_id)
Expand All @@ -43,10 +46,12 @@ def run(

output = handler()

try:
json.dumps(output)
except TypeError as err:
raise errors.UnserializableOutput(str(err)) from None
# Check whether output is serializable
match transforms.dump_json(output):
case result.Ok(_):
pass
case result.Err(err):
raise err

raise base.Interrupt(
hashed_id=hashed_id,
Expand Down
Loading
Loading