Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PoC]: Declarative Workflows in V2 SDK #292

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,4 @@ cython_debug/
#.idea/

openapitools.json
.python-version
3 changes: 3 additions & 0 deletions examples/v2/declarative_workflows/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from hatchet_sdk.v2.hatchet import Hatchet

hatchet = Hatchet(debug=True)
63 changes: 63 additions & 0 deletions examples/v2/declarative_workflows/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from collections import Counter
from typing import Literal

from examples.v2.declarative_workflows.client import hatchet
from examples.v2.declarative_workflows.workflows import (
Greeting,
Language,
greet_workflow,
language_counter_workflow,
)
from hatchet_sdk import Context


def complete_greeting(greeting: Greeting) -> str:
match greeting:
case "Hello":
return "world!"
case "Ciao":
return "mondo!"
case "Hej":
return "världen!"


@greet_workflow.declare()
async def greet(ctx: Context) -> dict[Literal["message"], str]:
Comment on lines +24 to +25
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

single decorator to register a function in Hatchet

workflow_input = greet_workflow.workflow_input(ctx)
greeting = workflow_input.greeting
Comment on lines +26 to +27
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these type check correctly, so no need to cast or parse from a dict to an ExampleInput


await language_counter_workflow.spawn(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refer to the workflow as an object instead of passing a string with the name

context=ctx,
input=language_counter_workflow.construct_spawn_workflow_input(
input=workflow_input
),
Comment on lines +31 to +33
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

types are enforced here, so you know the shape of input in your IDE + at type checking time

)

return {"message": greeting + " " + complete_greeting(greeting)}


## Imagine this is a metric in a monitoring system
counter: Counter[Language] = Counter()


@language_counter_workflow.declare()
async def language_counter(
ctx: Context,
) -> dict[Language, int]:
greeting = language_counter_workflow.workflow_input(ctx).greeting

match greeting:
Comment on lines +47 to +49
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type checking here for greeting

case "Hello":
counter["English"] += 1
case "Ciao":
counter["Italian"] += 1
case "Hej":
counter["Swedish"] += 1

return dict(counter)


if __name__ == "__main__":
worker = hatchet.worker("my-worker")

worker.start()
18 changes: 18 additions & 0 deletions examples/v2/declarative_workflows/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Literal

from pydantic import BaseModel

from examples.v2.declarative_workflows.client import hatchet

Greeting = Literal["Hello", "Ciao", "Hej"]
Language = Literal["English", "Swedish", "Italian"]


class ExampleInput(BaseModel):
greeting: Greeting


greet_workflow = hatchet.declare_workflow(input_validator=ExampleInput)
language_counter_workflow = hatchet.declare_workflow(
input_validator=ExampleInput,
)
6 changes: 6 additions & 0 deletions hatchet_sdk/clients/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import grpc
from google.protobuf import timestamp_pb2
from pydantic import BaseModel

from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
Expand Down Expand Up @@ -54,6 +55,11 @@ class ChildTriggerWorkflowOptions(TypedDict, total=False):
sticky: bool | None = None


class ChildTriggerWorkflowOptionsV2(BaseModel):
additional_metadata: dict[str, str] | None = None
sticky: bool | None = None


class ChildWorkflowRunDict(TypedDict, total=False):
workflow_name: str
input: Any
Expand Down
21 changes: 9 additions & 12 deletions hatchet_sdk/v2/callable.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
import asyncio
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
TypedDict,
TypeVar,
Union,
)
from typing import Any, Callable, Generic, List, Type, TypeVar, Union

from pydantic import BaseModel, ConfigDict

from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions
from hatchet_sdk.context.context import Context
from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
CreateStepRateLimit,
CreateWorkflowJobOpts,
CreateWorkflowStepOpts,
CreateWorkflowVersionOpts,
Expand All @@ -32,10 +23,15 @@
T = TypeVar("T")


class EmptyModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")


class HatchetCallable(Generic[T]):
def __init__(
self,
func: Callable[[Context], T],
input_validator: Type[BaseModel] = EmptyModel,
durable: bool = False,
name: str = "",
auto_register: bool = True,
Expand Down Expand Up @@ -87,6 +83,7 @@ def __init__(
self.function_on_failure = on_failure
self.function_namespace = "default"
self.function_auto_register = auto_register
self.input_validator = input_validator

self.is_coroutine = False

Expand Down
97 changes: 94 additions & 3 deletions hatchet_sdk/v2/hatchet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Callable, TypeVar, Union
from typing import Any, Callable, Generic, Type, TypeVar, Union, cast

from pydantic import BaseModel, ConfigDict

from hatchet_sdk import Worker
from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptionsV2
from hatchet_sdk.context.context import Context
from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
ConcurrencyLimitStrategy,
Expand All @@ -10,11 +13,27 @@
from hatchet_sdk.hatchet import workflow
from hatchet_sdk.labels import DesiredWorkerLabel
from hatchet_sdk.rate_limit import RateLimit
from hatchet_sdk.v2.callable import DurableContext, HatchetCallable
from hatchet_sdk.v2.callable import DurableContext, EmptyModel, HatchetCallable
from hatchet_sdk.v2.concurrency import ConcurrencyFunction
from hatchet_sdk.worker.worker import register_on_worker
from hatchet_sdk.workflow_run import WorkflowRunRef

T = TypeVar("T")
TWorkflowInput = TypeVar("TWorkflowInput", bound=BaseModel)


class DeclarativeWorkflowConfig(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

input_validator: Type[BaseModel] = EmptyModel
name: str = ""
on_events: list[str] | None = None
on_crons: list[str] | None = None
version: str = ""
timeout: str = "60m"
schedule_timeout: str = "5m"
concurrency: ConcurrencyFunction | None = None
default_priority: int | None = None


def function(
Expand All @@ -32,6 +51,7 @@ def function(
concurrency: ConcurrencyFunction | None = None,
on_failure: Union["HatchetCallable[T]", None] = None,
default_priority: int | None = None,
input_validator: Type[BaseModel] = EmptyModel,
) -> Callable[[Callable[[Context], str]], HatchetCallable[T]]:
def inner(func: Callable[[Context], T]) -> HatchetCallable[T]:
return HatchetCallable(
Expand All @@ -50,6 +70,7 @@ def inner(func: Callable[[Context], T]) -> HatchetCallable[T]:
concurrency=concurrency,
on_failure=on_failure,
default_priority=default_priority,
input_validator=input_validator,
)

return inner
Expand All @@ -70,6 +91,7 @@ def durable(
concurrency: ConcurrencyFunction | None = None,
on_failure: HatchetCallable[T] | None = None,
default_priority: int | None = None,
input_validator: Type[BaseModel] = EmptyModel,
) -> Callable[[HatchetCallable[T]], HatchetCallable[T]]:
def inner(func: HatchetCallable[T]) -> HatchetCallable[T]:
func.durable = True
Expand All @@ -89,6 +111,7 @@ def inner(func: HatchetCallable[T]) -> HatchetCallable[T]:
concurrency=concurrency,
on_failure=on_failure,
default_priority=default_priority,
input_validator=input_validator,
)

resp = f(func)
Expand All @@ -111,6 +134,45 @@ def inner(func: Callable[[Context], str]) -> ConcurrencyFunction:
return inner


class SpawnWorkflowInput(BaseModel):
workflow_name: str
input: BaseModel
key: str | None = None
options: ChildTriggerWorkflowOptionsV2 | None = None


class DeclarativeWorkflow(Generic[TWorkflowInput]):
def __init__(self, config: DeclarativeWorkflowConfig, hatchet: "Hatchet"):
self.config = config
self.hatchet = hatchet

def run(self, input: TWorkflowInput) -> WorkflowRunRef:
return self.hatchet.admin.run_workflow(
workflow_name=self.config.name, input=input.model_dump()
)

async def spawn(
self, context: Context, input: SpawnWorkflowInput
) -> WorkflowRunRef:
return await context.aio.spawn_workflow(
workflow_name=input.workflow_name,
input=input.input.model_dump(),
key=input.key,
options=input.options,
)

def construct_spawn_workflow_input(
self, input: TWorkflowInput
) -> SpawnWorkflowInput:
return SpawnWorkflowInput(workflow_name=self.config.name, input=input)

def declare(self) -> Callable[[Callable[[Context], Any]], Callable[[Context], Any]]:
return self.hatchet.function(**self.config.model_dump())

def workflow_input(self, ctx: Context) -> TWorkflowInput:
return cast(TWorkflowInput, ctx.workflow_input())


class Hatchet(HatchetV1):
dag = staticmethod(workflow)
concurrency = staticmethod(concurrency)
Expand All @@ -119,6 +181,7 @@ class Hatchet(HatchetV1):

def function(
self,
input_validator: Type[BaseModel] = EmptyModel,
name: str = "",
auto_register: bool = True,
on_events: list[str] | None = None,
Expand Down Expand Up @@ -147,9 +210,10 @@ def function(
concurrency=concurrency,
on_failure=on_failure,
default_priority=default_priority,
input_validator=input_validator,
)

def wrapper(func: Callable[[Context], str]) -> HatchetCallable[T]:
def wrapper(func: Callable[[Context], Any]) -> HatchetCallable[T]:
wrapped_resp = resp(func)

if wrapped_resp.function_auto_register:
Expand Down Expand Up @@ -222,3 +286,30 @@ def worker(
register_on_worker(func, worker)

return worker

def declare_workflow(
self,
input_validator: Type[TWorkflowInput],
name: str = "",
on_events: list[str] | None = None,
on_crons: list[str] | None = None,
version: str = "",
timeout: str = "60m",
schedule_timeout: str = "5m",
concurrency: ConcurrencyFunction | None = None,
default_priority: int | None = None,
) -> DeclarativeWorkflow[TWorkflowInput]:
return DeclarativeWorkflow[input_validator](
hatchet=self,
config=DeclarativeWorkflowConfig(
input_validator=input_validator,
name=name,
on_events=on_events,
on_crons=on_crons,
version=version,
timeout=timeout,
schedule_timeout=schedule_timeout,
concurrency=concurrency,
default_priority=default_priority,
),
)
4 changes: 4 additions & 0 deletions hatchet_sdk/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def __init__(

def register_function(self, action: str, func: Callable[[Context], Any]) -> None:
self.action_registry[action] = func
self.validator_registry[action] = WorkflowValidator(
workflow_input=getattr(func, "input_validator"),
step_output=None,
)

def register_workflow_from_opts(
self, name: str, opts: CreateWorkflowVersionOpts
Expand Down
Loading