Skip to content

Commit

Permalink
Add Pydantic validation test (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r authored Sep 13, 2024
1 parent e3c9b1f commit 28c3a4f
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/test_function/cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
on_failure,
parallel_step_disappears,
parallel_steps,
pydantic_event,
pydantic_output,
retry_after_error,
sleep_until,
Expand Down Expand Up @@ -73,6 +74,7 @@
on_failure,
parallel_step_disappears,
parallel_steps,
pydantic_event,
pydantic_output,
retry_after_error,
sleep_until,
Expand Down
108 changes: 108 additions & 0 deletions tests/test_function/cases/pydantic_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
We don't support returning Pydantic models in steps or functions. This may
change in the future.
"""

import json
import typing

import pydantic
from typing_extensions import assert_type

import inngest
import tests.helper
from inngest._internal import server_lib

from . import base

TEvent = typing.TypeVar("TEvent", bound="BaseEvent")


class BaseEvent(pydantic.BaseModel):
data: pydantic.BaseModel
id: str = ""
name: typing.ClassVar[str]
ts: int = 0

@classmethod
def from_event(cls: type[TEvent], event: inngest.Event) -> TEvent:
return cls.model_validate(event.model_dump(mode="json"))

def to_event(self) -> inngest.Event:
return inngest.Event(
name=self.name,
data=self.data.model_dump(mode="json"),
id=self.id,
ts=self.ts,
)


class MyEventData(pydantic.BaseModel):
count: int


def create(
client: inngest.Inngest,
framework: server_lib.Framework,
is_sync: bool,
) -> base.Case:
test_name = base.create_test_name(__file__)

class MyEvent(BaseEvent):
data: MyEventData
name = base.create_event_name(framework, test_name)

fn_id = base.create_fn_id(test_name)
state = base.BaseState()

@client.create_function(
fn_id=fn_id,
retries=0,
trigger=inngest.TriggerEvent(event=MyEvent.name),
)
def fn_sync(
ctx: inngest.Context,
step: inngest.StepSync,
) -> typing.Any:
state.run_id = ctx.run_id
event = MyEvent.from_event(ctx.event)
assert_type(event, MyEvent)
return event.model_dump(mode="json")

@client.create_function(
fn_id=fn_id,
retries=0,
trigger=inngest.TriggerEvent(event=MyEvent.name),
)
async def fn_async(
ctx: inngest.Context,
step: inngest.Step,
) -> typing.Any:
state.run_id = ctx.run_id
event = MyEvent.from_event(ctx.event)
assert_type(event, MyEvent)
return event.model_dump(mode="json")

async def run_test(self: base.TestClass) -> None:
self.client.send_sync(MyEvent(data=MyEventData(count=1)).to_event())
run = tests.helper.client.wait_for_run_status(
state.wait_for_run_id(),
tests.helper.RunStatus.COMPLETED,
)

assert run.output is not None
event = MyEvent.model_validate(json.loads(run.output))
assert event.data.count == 1
assert event.id != ""
assert event.ts > 0

if is_sync:
fn = fn_sync
else:
fn = fn_async

return base.Case(
fn=fn,
run_test=run_test,
name=test_name,
)

0 comments on commit 28c3a4f

Please sign in to comment.