diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 35d1715b..4dbe0959 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -4,6 +4,7 @@ import grpc from google.protobuf import timestamp_pb2 +from loguru import logger from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry from hatchet_sdk.clients.run_event_listener import new_listener @@ -339,13 +340,9 @@ def run_workflow( if self.namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{self.namespace}{workflow_name}" - request = self._prepare_workflow_request(workflow_name, input, options) - resp: TriggerWorkflowResponse = self.client.TriggerWorkflow( - request, - metadata=get_metadata(self.token), - ) + id = self.trigger_workflow(workflow_name, input, options) return WorkflowRunRef( - workflow_run_id=resp.workflow_run_id, + workflow_run_id=id, workflow_listener=self.pooled_workflow_listener, workflow_run_event_listener=self.listener_client, ) @@ -355,6 +352,21 @@ def run_workflow( raise ValueError(f"gRPC error: {e}") + def trigger_workflow( + self, + workflow_name: str, + input, + options: TriggerWorkflowOptions = None, + ) -> str: + request = self._prepare_workflow_request(workflow_name, input, options) + + logger.trace("trigger proto: {}", request) + resp: TriggerWorkflowResponse = self.client.TriggerWorkflow( + request, + metadata=get_metadata(self.token), + ) + return resp.workflow_run_id + def run( self, function: Union[str, Callable[[Any], T]], diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 0738c2f2..2678e6c9 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -1,7 +1,23 @@ import asyncio -from typing import Callable, Dict, Generic, List, Optional, TypedDict, TypeVar, Union - -from hatchet_sdk.context import Context +import inspect +import json +from collections.abc import Awaitable, Callable +from concurrent.futures import Future +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Generic, List, Optional, ParamSpec, Tuple, TypeVar + +from google.protobuf.json_format import MessageToDict +from loguru import logger +from pydantic import BaseModel, ConfigDict, Field, computed_field + +import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.context as context +import hatchet_sdk.v2.runtime.utils as utils +from hatchet_sdk.contracts.dispatcher_pb2 import ( + AssignedAction, + SubscribeToWorkflowRunsRequest, + WorkflowRunEvent, +) from hatchet_sdk.contracts.workflows_pb2 import ( CreateStepRateLimit, CreateWorkflowJobOpts, @@ -9,194 +25,380 @@ CreateWorkflowVersionOpts, DesiredWorkerLabels, StickyStrategy, - WorkflowConcurrencyOpts, + TriggerWorkflowRequest, + TriggerWorkflowResponse, WorkflowKind, ) from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.logger import logger from hatchet_sdk.rate_limit import RateLimit from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.workflow_run import RunRef T = TypeVar("T") +P = ParamSpec("P") + + +def _sourceloc(fn) -> str: + try: + return "{}:{}".format( + inspect.getsourcefile(fn), + inspect.getsourcelines(fn)[1], + ) + except: + return "" + + +# Note: this should be language independent, and useable by Go/Typescript, etc. +@dataclass +class _CallableInput: + """The input to a Hatchet callable.""" + + args: Tuple = field(default_factory=tuple) + kwargs: Dict[str, Any] = field(default_factory=dict) + + def dumps(self) -> str: + return json.dumps(asdict(self)) + + @staticmethod + def loads(s: str) -> "_CallableInput": + # NOTE: AssignedAction.actionPayload looks like the following + # '{"input": , "parents": {}, "overrides": {}, "user_data": {}, "triggered_by": "manual"}' + return _CallableInput(**(json.loads(s)["input"])) + + +# Note: this should be language independent, and usable by Go/Typescript, etc. +@dataclass +class _CallableOutput(Generic[T]): + """The output of a Hatchet callable.""" + + output: T + def dumps(self) -> str: + return json.dumps(asdict(self)) + + @staticmethod + def loads(s: str) -> "_CallableOutput[T]": + ret = _CallableOutput(**json.loads(s)) + return ret + + +class HatchetCallableBase(Generic[P, T]): + """Hatchet callable base.""" -class HatchetCallable(Generic[T]): def __init__( self, - func: Callable[[Context], T], - durable: bool = False, - name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - concurrency: ConcurrencyFunction | None = None, - on_failure: Optional["HatchetCallable"] = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - default_priority: int | None = None, + *, + func: Callable[P, T], + name: str, + namespace: str, + client: "hatchet.Hatchet", + options: "Options", ): - self.func = func + # TODO: maybe use __qualname__ + name = name.lower() or func.__name__.lower() + + # hide everything under self._hatchet since the user has access to everything in HatchetCallableBase. + self._hatchet = CallableMetadata( + name=name, + namespace=namespace, + sourceloc=_sourceloc(func), + options=options, + client=client, + func=func, + action=f"{namespace}:{name}", + ) + client.registry.add(key=self._hatchet.action, callable=self) + + def _to_workflow_proto(self) -> CreateWorkflowVersionOpts: + options = self._hatchet.options + # TODO: handle concurrency function and on failure function + workflow = CreateWorkflowVersionOpts( + name=self._hatchet.name, + kind=WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION, + version=options.version, + event_triggers=options.on_events, + cron_triggers=options.on_crons, + schedule_timeout=options.schedule_timeout, + sticky=options.sticky, + on_failure_job=( + options.on_failure._to_job_proto() if options.on_failure else None + ), + concurrency=None, # TODO + jobs=[ + self._to_job_proto() + ], # Note that the failure job is also a HatchetCallable, and it should manage its own name. + default_priority=options.priority, + ) + return workflow - on_events = on_events or [] - on_crons = on_crons or [] + def _to_job_proto(self) -> CreateWorkflowJobOpts: + job = CreateWorkflowJobOpts( + name=self._hatchet.name, steps=[self._to_step_proto()] + ) + return job + + def _to_step_proto(self) -> CreateWorkflowStepOpts: + options = self._hatchet.options + step = CreateWorkflowStepOpts( + readable_id=self._hatchet.name, + action=self._hatchet.action, + timeout=options.execution_timeout, + inputs="{}", # TODO: not sure that this is, we're defining a step, not running a step + parents=[], # this is a single step workflow, always empty + retries=options.retries, + # rate_limits=options.ratelimits, # TODO + # worker_labels=self.function_desired_worker_labels, # TODO + ) + return step + + def _encode_context( + self, ctx: "context.BackgroundContext" + ) -> TriggerWorkflowRequest: + """Encode the given context into the trigger protobuf.""" + trigger = TriggerWorkflowRequest( + additional_metadata=json.dumps( + {"_hatchet_background_context": ctx.asdict()} + ), + ) - limits = None - if rate_limits: - limits = [ - CreateStepRateLimit(key=rate_limit.key, units=rate_limit.units) - for rate_limit in rate_limits or [] - ] + # We are not in any valid Hatchet context. This means we're the root. + if ctx.current is None: + return trigger + + # Otherwise, the current context is the parent. + assert ctx.current is not None + trigger.parent_id = ctx.current.workflow_run_id or "" + trigger.parent_step_run_id = ctx.current.step_run_id or "" + trigger.child_index = 0 # TODO: this is no longer needed since the user has full control of how they wanna trigger the children + return trigger + + def _decode_context( + self, action: AssignedAction + ) -> Optional["context.BackgroundContext"]: + """Reconstruct the background context using the assigned action protobuf.""" + if not action.additional_metadata: + return None + + d: Optional[Dict] = None + try: + d = json.loads(action.additional_metadata) + except json.JSONDecodeError: + logger.warning("failed to decode additional metadata from assigned action") + return None + + assert isinstance(d, Dict) + if "_hatchet_background_context" not in d: + return None + + ctx = context.BackgroundContext.fromdict( + client=self._hatchet.client, data=d["_hatchet_background_context"] + ) + ctx.client = self._hatchet.client + return ctx + + def _to_trigger_proto( + self, ctx: "context.BackgroundContext", inputs: _CallableInput + ) -> TriggerWorkflowRequest: + # NOTE: serialization error will be raised as TypeError + req = TriggerWorkflowRequest(name=self._hatchet.name, input=inputs.dumps()) + req.MergeFrom(self._encode_context(ctx)) + return req + + # TODO: the return type of decode output needs to be casted. + # For Callable[P, T] the return type is T. + # For Callable[P, Awaitable[T]], the return type is T. + def _decode_output(self, result: WorkflowRunEvent): + """Decode the output from a WorkflowRunEvent. + + Note that the WorkflowRunEvent could be, in the future, encoded from a + different language, like Typescript or Go. + """ + steps = list(result.results) + assert len(steps) == 1 # assumping single step workflows + step = steps[0] + if step.error: + # TODO: find a way to be more precise about the type of exception. + # right now everything is a RuntimeError. + raise RuntimeError(step.error) + else: + ret = _CallableOutput.loads(step.output).output + return ret + + def _trigger(self, *args: P.args, **kwargs: P.kwargs) -> TriggerWorkflowResponse: + ctx = context.ensure_background_context() + trigger = self._to_trigger_proto( + ctx, inputs=_CallableInput(args=args, kwargs=kwargs) + ) + logger.trace("triggering: {}", MessageToDict(trigger)) + client = self._hatchet.client + ref: TriggerWorkflowResponse = client.admin.client.TriggerWorkflow( + trigger, metadata=self._hatchet.client._grpc_metadata() + ) + logger.trace("runid: {}", ref) + return ref - self.function_desired_worker_labels = {} + def _make_ctx(self, action: AssignedAction) -> "context.BackgroundContext": + ctx = context.ensure_background_context(client=self._hatchet.client) + assert ctx.current is None - for key, d in desired_worker_labels.items(): - value = d["value"] if "value" in d else None - self.function_desired_worker_labels[key] = DesiredWorkerLabels( - strValue=str(value) if not isinstance(value, int) else None, - intValue=value if isinstance(value, int) else None, - required=d["required"] if "required" in d else None, - weight=d["weight"] if "weight" in d else None, - comparator=d["comparator"] if "comparator" in d else None, - ) - self.sticky = sticky - self.default_priority = default_priority - self.durable = durable - self.function_name = name.lower() or str(func.__name__).lower() - self.function_version = version - self.function_on_events = on_events - self.function_on_crons = on_crons - self.function_timeout = timeout - self.function_schedule_timeout = schedule_timeout - self.function_retries = retries - self.function_rate_limits = limits - self.function_concurrency = concurrency - self.function_on_failure = on_failure - self.function_namespace = "default" - self.function_auto_register = auto_register - - self.is_coroutine = False - - if asyncio.iscoroutinefunction(func): - self.is_coroutine = True - - def __call__(self, context: Context) -> T: - return self.func(context) - - def with_namespace(self, namespace: str): - if namespace is not None and namespace != "": - self.function_namespace = namespace - self.function_name = namespace + self.function_name - - def to_workflow_opts(self) -> CreateWorkflowVersionOpts: - kind: WorkflowKind = WorkflowKind.FUNCTION - - if self.durable: - kind = WorkflowKind.DURABLE - - on_failure_job: CreateWorkflowJobOpts | None = None - - if self.function_on_failure is not None: - on_failure_job = CreateWorkflowJobOpts( - name=self.function_name + "-on-failure", - steps=[ - self.function_on_failure.to_step(), - ], + parent = self._decode_context(action) or context.BackgroundContext( + client=self._hatchet.client + ) + with context.WithParentContext(parent) as ctx: + assert ctx.current is None + ctx.current = context.RunInfo( + workflow_run_id=action.workflowRunId, + step_run_id=action.stepRunId, + name=self._hatchet.name, + namespace=self._hatchet.namespace, ) + if ctx.root is None: + ctx.root = ctx.current.copy() + return ctx - concurrency: WorkflowConcurrencyOpts | None = None - if self.function_concurrency is not None: - self.function_concurrency.set_namespace(self.function_namespace) - concurrency = WorkflowConcurrencyOpts( - action=self.function_concurrency.get_action_name(), - max_runs=self.function_concurrency.max_runs, - limit_strategy=self.function_concurrency.limit_strategy, - ) +class HatchetCallable(HatchetCallableBase[P, T]): + """A Hatchet callable wrapping a non-asyncio free function.""" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Future[T]: + """Trigger a workflow run and returns the future. + + Note that it is important that we return a Future. We want the user + to trigger multiple calls and decide when to synchronize. Like, + + concurrent.futures.as_completed(wf1(), wf2(), wf3()) + """ + ref = self._trigger(*args, **kwargs) - validated_priority = ( - max(1, min(3, self.default_priority)) if self.default_priority else None + # now setup to wait for the result + sub = SubscribeToWorkflowRunsRequest(workflowRunId=ref.workflow_run_id) + + # TODO: expose a better interface on the Hatchet client for waiting on results. + wfre_future = self._hatchet.client.worker()._wfr_futures.submit(sub) + + fut: Future[T] = utils.MapFuture( + self._decode_output, wfre_future, self._hatchet.client.executor ) - if validated_priority != self.default_priority: - logger.warning( - "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." + return fut + + def _run(self, action: AssignedAction) -> str: + """Executes the actual code and returns a serialized output.""" + + logger.trace("invoking: {}", MessageToDict(action)) + assert action.actionId == self._hatchet.action + + ctx = self._make_ctx(action) + with context.WithContext(ctx): + inputs = _CallableInput.loads(action.actionPayload) + output = _CallableOutput( + output=self._hatchet.func(*inputs.args, **inputs.kwargs) ) + logger.trace("output: {}", output) + return output.dumps() - return CreateWorkflowVersionOpts( - name=self.function_name, - kind=kind, - version=self.function_version, - event_triggers=self.function_on_events, - cron_triggers=self.function_on_crons, - schedule_timeout=self.function_schedule_timeout, - sticky=self.sticky, - on_failure_job=on_failure_job, - concurrency=concurrency, - jobs=[ - CreateWorkflowJobOpts( - name=self.function_name, - steps=[ - self.to_step(), - ], - ) - ], - default_priority=validated_priority, - ) - def to_step(self) -> CreateWorkflowStepOpts: - return CreateWorkflowStepOpts( - readable_id=self.function_name, - action=self.get_action_name(), - timeout=self.function_timeout, - inputs="{}", - parents=[], - retries=self.function_retries, - rate_limits=self.function_rate_limits, - worker_labels=self.function_desired_worker_labels, - ) +class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): + """A Hatchet callable wrapping an asyncio free function.""" - def get_action_name(self) -> str: - return self.function_namespace + ":" + self.function_name + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + ref = self._trigger(*args, **kwargs) + # now setup to wait for the result + sub = SubscribeToWorkflowRunsRequest(workflowRunId=ref.workflow_run_id) -T = TypeVar("T") + # TODO: expose a better interface on the Hatchet client for waiting on results. + wfre_future = await self._hatchet.client.worker()._wfr_futures.asubmit(sub) + return self._decode_output(await wfre_future) -class TriggerOptions(TypedDict): - additional_metadata: Dict[str, str] | None = None - sticky: bool | None = None + async def _run(self, action: AssignedAction) -> str: + logger.trace("invoking: {}", MessageToDict(action)) + assert action.actionId == self._hatchet.action + ctx = self._make_ctx(action) + with context.WithContext(ctx): + inputs = _CallableInput.loads(action.actionPayload) + output = _CallableOutput( + output=await self._hatchet.func(*inputs.args, **inputs.kwargs) + ) + logger.trace("output: {}", output) + return output.dumps() + + +class Options(BaseModel): + """The options for a Hatchet function (aka workflow).""" + + # pydantic configuration + model_config = ConfigDict(arbitrary_types_allowed=True) + + durable: bool = Field(default=False) + auto_register: bool = Field(default=True) + on_failure: Optional[HatchetCallableBase] = Field(default=None, exclude=True) + + # triggering options + on_events: List[str] = Field(default=[]) + on_crons: List[str] = Field(default=[]) + + # metadata + version: str = Field(default="") + + # timeout + execution_timeout: str = Field(default="60m", alias="timeout") + schedule_timeout: str = Field(default="5m") + + # execution + sticky: Optional[StickyStrategy] = Field(default=None) + retries: int = Field(default=0, ge=0) + ratelimits: List[RateLimit] = Field(default=[]) + priority: Optional[int] = Field(default=None, alias="default_priority", ge=1, le=3) + desired_worker_labels: Dict[str, DesiredWorkerLabel] = Field(default=dict()) + concurrency: Optional[ConcurrencyFunction] = Field(default=None) + + @computed_field + @property + def ratelimits_proto(self) -> List[CreateStepRateLimit]: + return [ + CreateStepRateLimit(key=limit.key, units=limit.units) + for limit in self.ratelimits + ] + + @computed_field + @property + def desired_worker_labels_proto(self) -> Dict[str, DesiredWorkerLabels]: + # TODO: double check the default values + labels = dict() + for key, d in self.desired_worker_labels.items(): + value = d.get("value", None) + labels[key] = DesiredWorkerLabels( + strValue=str(value) if not isinstance(value, int) else None, + intValue=value if isinstance(value, int) else None, + required=d.get("required") or False, + weight=d.get("weight") or 0, + comparator=str(d.get("comparator")) or None, + ) + return labels -class DurableContext(Context): - def run( - self, - function: Union[str, HatchetCallable[T]], - input: dict = {}, - key: str = None, - options: TriggerOptions = None, - ) -> "RunRef[T]": - worker_id = self.worker.id() - - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name - - # if ( - # options is not None - # and "sticky" in options - # and options["sticky"] == True - # and not self.worker.has_workflow(workflow_name) - # ): - # raise Exception( - # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" - # ) - - trigger_options = self._prepare_workflow_options(key, options, worker_id) - - return self.admin_client.run(function, input, trigger_options) + +@dataclass +class CallableMetadata(Generic[P, T]): + """Metadata field for a decorated Hatchet workflow.""" + + func: Callable[P, T] # the original function + + name: str + namespace: str + action: str + sourceloc: str # source location of the callable + + options: "Options" + client: "hatchet.Hatchet" + + def _debug(self): + return { + "func": repr(self.func), + "name": self.name, + "namespace": self.namespace, + "action": self.action, + "sourceloc": self.sourceloc, + "client": repr(self.client), + "options": self.options.model_dump(), + } diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 9c866ba8..4998e490 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,222 +1,108 @@ -from typing import Callable, List, Optional, TypeVar - -from hatchet_sdk.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy, StickyStrategy -from hatchet_sdk.hatchet import Hatchet as HatchetV1 -from hatchet_sdk.hatchet import workflow -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.callable import HatchetCallable -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.worker.worker import register_on_worker - -from ..worker import Worker +import asyncio +import functools +import inspect +from concurrent.futures import ThreadPoolExecutor, Future +from contextlib import suppress +from typing import Callable, List, Optional, ParamSpec, Tuple, TypeVar + +import hatchet_sdk.hatchet as v1 +import hatchet_sdk.v2.callable as callable +import hatchet_sdk.v2.runtime.config as config +import hatchet_sdk.v2.runtime.context as context +import hatchet_sdk.v2.runtime.logging as logging +import hatchet_sdk.v2.runtime.registry as registry +import hatchet_sdk.v2.runtime.runner as runner +import hatchet_sdk.v2.runtime.runtime as runtime +import hatchet_sdk.v2.runtime.worker as worker T = TypeVar("T") +P = ParamSpec("P") -def function( - name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Optional["HatchetCallable"] = None, - default_priority: int | None = None, -): - def inner(func: Callable[[Context], T]) -> HatchetCallable[T]: - return HatchetCallable( - func=func, - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - return inner - - -def durable( - name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: HatchetCallable | None = None, - default_priority: int | None = None, -): - def inner(func: HatchetCallable) -> HatchetCallable: - func.durable = True - - f = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - resp = f(func) - - resp.durable = True - - return resp - - return inner - - -def concurrency( - name: str = "concurrency", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -): - def inner(func: Callable[[Context], str]) -> ConcurrencyFunction: - return ConcurrencyFunction(func, name, max_runs, limit_strategy) - - return inner - - -class Hatchet(HatchetV1): - dag = staticmethod(workflow) - concurrency = staticmethod(concurrency) - - functions: List[HatchetCallable] = [] - - def function( +class Hatchet: + def __init__( self, - name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Optional["HatchetCallable"] = None, - default_priority: int | None = None, + config: config.ClientConfig = config.ClientConfig(), + debug=False, + executor: ThreadPoolExecutor = ThreadPoolExecutor(), ): - resp = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, + # ensure a event loop is created before gRPC + with suppress(RuntimeError): + asyncio.get_event_loop() + + self.registry = registry.ActionRegistry() + self.v1: v1.Hatchet = v1.Hatchet.from_environment( + defaults=config, + debug=debug, ) + self.executor = executor - def wrapper(func: Callable[[Context], T]) -> HatchetCallable[T]: - wrapped_resp = resp(func) + self._runtime: Optional["runtime.Runtime"] = None - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) + context.ensure_background_context(client=self) - wrapped_resp.with_namespace(self._client.config.namespace) + @property + def admin(self): + return self.v1.admin - return wrapped_resp + @property + def dispatcher(self): + return self.v1.dispatcher - return wrapper + @property + def config(self): + return self.v1.config - def durable( + # FIXME: consider separating this into @func and @afunc for better type hints. + # Right now, the type hint for the return type is (P -> T) | (P -> Future[T]) and this is because we + # don't statically know whether "func" is a def or an async def. + def function( self, + *, name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Optional["HatchetCallable"] = None, - default_priority: int | None = None, - ) -> Callable[[HatchetCallable], HatchetCallable]: - resp = durable( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - def wrapper(func: Callable[[Context], T]) -> HatchetCallable[T]: - wrapped_resp = resp(func) - - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) - - wrapped_resp.with_namespace(self._client.config.namespace) - - return wrapped_resp - - return wrapper - - def worker( - self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} + namespace: str = "default", + options: "callable.Options" = callable.Options(), ): - worker = Worker( - name=name, - max_runs=max_runs, - labels=labels, - config=self._client.config, - debug=self._client.debug, - ) - - for func in self.functions: - register_on_worker(func, worker) - - return worker + # TODO: needs to detect and reject an already decorated free function. + # TODO: needs to detect and reject a classmethod/staticmethod. + def inner(func: Callable[P, T]): + if inspect.iscoroutinefunction(func): + wrapped = callable.HatchetAwaitable[P, T]( + func=func, + name=name, + namespace=namespace, + client=self, + options=options, + ) + # TODO: investigate the type error here. + aret: Callable[P, T] = functools.update_wrapper(wrapped, func) + return aret + elif inspect.isfunction(func): + wrapped = callable.HatchetCallable( + func=func, + name=name, + namespace=namespace, + client=self, + options=options, + ) + ret: Callable[P, Future[T]] = functools.update_wrapper(wrapped, func) + return ret + else: + raise TypeError( + "the @function decorator can only be applied to functions (def) and async functions (async def)" + ) + + return inner + + # TODO: make it 1 worker : 1 client, which means moving the options to the initializer, and cache the result. + # TODO: rename it to runtime + def worker( + self, *, options: Optional["worker.WorkerOptions"] = None + ) -> "runtime.Runtime": + if self._runtime is None: + assert options is not None + self._runtime = runtime.Runtime(client=self, options=options) + return self._runtime + + def _grpc_metadata(self) -> List[Tuple]: + return [("authorization", f"bearer {self.config.token}")] diff --git a/hatchet_sdk/v2/runtime/config.py b/hatchet_sdk/v2/runtime/config.py new file mode 100644 index 00000000..37684432 --- /dev/null +++ b/hatchet_sdk/v2/runtime/config.py @@ -0,0 +1 @@ +from hatchet_sdk.loader import * diff --git a/hatchet_sdk/v2/runtime/connection.py b/hatchet_sdk/v2/runtime/connection.py new file mode 100644 index 00000000..2d6007b9 --- /dev/null +++ b/hatchet_sdk/v2/runtime/connection.py @@ -0,0 +1,51 @@ +import contextvars as cv +from typing import Optional + +import grpc +import grpc.aio + +import hatchet_sdk.connection as v1 +import hatchet_sdk.v2.runtime.context as context + +_aio_channel_cv: cv.ContextVar[Optional[grpc.aio.Channel]] = cv.ContextVar( + "hatchet_background_aio_channel", default=None +) +_channel_cv: cv.ContextVar[Optional[grpc.Channel]] = cv.ContextVar( + "hatchet_background_channel", default=None +) + + +def ensure_background_channel() -> grpc.Channel: + ctx = context.ensure_background_context(client=None) + channel: Optional[grpc.Channel] = _channel_cv.get() + if channel is None: + # TODO: fix the typing of new_conn + channel = v1.new_conn(ctx.client.config, aio=False) # type: ignore + _channel_cv.set(channel) + assert channel is not None + return channel + + +def ensure_background_achannel() -> grpc.aio.Channel: + ctx = context.ensure_background_context(client=None) + achannel: Optional[grpc.aio.Channel] = _aio_channel_cv.get() + if achannel is None: + # TODO: fix the typing of new_conn + achannel = v1.new_conn(ctx.client.config, aio=True) # type: ignore + _aio_channel_cv.set(achannel) + assert achannel is not None + return achannel + + +def reset_background_channel(): + c = _channel_cv.get() + if c is not None: + c.close() + _channel_cv.set(None) + + +async def reset_background_achannel(): + c: Optional[grpc.aio.Channel] = _aio_channel_cv.get() + if c is not None: + await c.close() + _aio_channel_cv.set(None) diff --git a/hatchet_sdk/v2/runtime/context.py b/hatchet_sdk/v2/runtime/context.py new file mode 100644 index 00000000..0475b820 --- /dev/null +++ b/hatchet_sdk/v2/runtime/context.py @@ -0,0 +1,142 @@ +import asyncio +import copy +import os +import threading +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +from loguru import logger + +import hatchet_sdk.v2.hatchet as hatchet + + +def _loopid() -> Optional[int]: + try: + return id(asyncio.get_running_loop()) + except: + return None + + +_ctxvar: ContextVar[Optional["BackgroundContext"]] = ContextVar( + "hatchet_background_context", default=None +) + + +@dataclass +class RunInfo: + workflow_run_id: Optional[str] = None + step_run_id: Optional[str] = None + + namespace: str = "" + name: str = "" + + # TODO, pid/tid/loopid is not propagated to the engine, we are not able to restore them + pid: int = os.getpid() + tid: int = threading.get_ident() + loopid: Optional[int] = _loopid() + + def copy(self): + return copy.deepcopy(self) + + +@dataclass +class BackgroundContext: + """Background context at function execution time.""" + + # The Hatchet client is a required property. + client: "hatchet.Hatchet" + + current: Optional[RunInfo] = None + root: Optional[RunInfo] = None + parent: Optional[RunInfo] = None + + def asdict(self): + """Return BackgroundContext as a serializable dict.""" + ret = dict() + if self.current: + ret["current"] = asdict(self.current) + if self.root: + ret["root"] = asdict(self.root) + if self.parent: + ret["parent"] = asdict(self.parent) + return ret + + @staticmethod + def fromdict(client: "hatchet.Hatchet", data: Dict) -> "BackgroundContext": + ctx = BackgroundContext(client=client) + if "current" in data: + ctx.current = RunInfo(**(data["current"])) + if "root" in data: + ctx.root = RunInfo(**(data["root"])) + if "parent" in data: + ctx.parent = RunInfo(**(data["parent"])) + return ctx + + def copy(self): + ret = BackgroundContext( + client=self.client, + current=self.current.copy() if self.current else None, + parent=self.parent.copy() if self.parent else None, + root=self.root.copy() if self.root else None, + ) + return ret + + @staticmethod + def set(ctx: Optional["BackgroundContext"]): + global _ctxvar + _ctxvar.set(ctx) + + @staticmethod + def get() -> Optional["BackgroundContext"]: + global _ctxvar + return _ctxvar.get() + + +def ensure_background_context( + client: Optional["hatchet.Hatchet"] = None, +) -> BackgroundContext: + ctx = BackgroundContext.get() + if ctx is None: + assert client is not None + ctx = BackgroundContext(client=client) + BackgroundContext.set(ctx) + return ctx + + +@contextmanager +def WithContext(ctx: BackgroundContext): + prev = BackgroundContext.get() + BackgroundContext.set(ctx) + try: + logger.trace("using context:\n{}", ctx) + yield ctx + finally: + BackgroundContext.set(prev) + + +@contextmanager +def WithParentContext(ctx: BackgroundContext): + """Use the given context as the parent. + + Note that this is to be used in the following pattern: + + with WithParentContext(parent) as ctx: + ctx.current = ... + with WithContext(ctx): + # code in the correct context here + + """ + prev = BackgroundContext.get() + + # NOTE: ctx.current could be None, which means there's no parent. + + child = ctx.copy() + child.parent = ctx.current.copy() if ctx.current else None + child.current = None + BackgroundContext.set(child) + try: + yield child + finally: + BackgroundContext.set(prev) diff --git a/hatchet_sdk/v2/runtime/future.py b/hatchet_sdk/v2/runtime/future.py new file mode 100644 index 00000000..f83727ef --- /dev/null +++ b/hatchet_sdk/v2/runtime/future.py @@ -0,0 +1,220 @@ +import asyncio +import multiprocessing.queues as mpq +import queue +import threading +import time +from collections.abc import Callable, MutableSet +from concurrent.futures import CancelledError, Future, ThreadPoolExecutor +from contextlib import suppress +from typing import Dict, Generic, Optional, TypeAlias, TypeVar + +from google.protobuf.json_format import MessageToDict +from loguru import logger + +import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.utils as utils +from hatchet_sdk.contracts.dispatcher_pb2 import ( + SubscribeToWorkflowRunsRequest, + WorkflowRunEvent, +) + +# TODO: use better generics for Python >= 3.12 +T = TypeVar("T") +RespT = TypeVar("RespT") +ReqT = TypeVar("ReqT") + + +_ThreadSafeQueue: TypeAlias = queue.Queue[T] | mpq.Queue[T] + + +class RequestResponseBroker(Generic[ReqT, RespT]): + def __init__( + self, + *, + inbound: _ThreadSafeQueue[RespT], + outbound: _ThreadSafeQueue[ReqT], + req_key: Callable[[ReqT], str], + resp_key: Callable[[RespT], str], + executor: ThreadPoolExecutor, + ): + """A broker that can send/forward a request and returns a future for the caller to wait upon. + + This is to be used in the main process. The broker loop runs forever and quits upon asyncio.CancelledError. + The broker is essentially an adaptor from server-streams to either concurrent.futures.Future or asyncio.Future. + For the blocking case (i.e. concurrent.futures.Future), the broker uses polling. + + The class needs to be thread-safe for the concurrent.futures.Future case. + + Args: + outbound: a thread-safe blocking queue to which the request should be forwarded to + inbound: a thread-safe blocking queue from which the responses will come + req_key: a function that computes the key of the request, which is used to match the responses + resp_key: a function that computes the key of the response, which is used to match the requests + executor: a thread pool for running any blocking code + """ + logger.trace("init broker") + self._inbound = inbound + self._outbound = outbound + self._req_key = req_key + self._resp_key = resp_key + + # NOTE: this is used for running the polling tasks for results. + # The tasks we submit to the executor (or any executor) should NOT wait indefinitely. + # We must provide it with a way to self-cancelling. + self._executor = executor + + # Used to signal to the tasks on the executor to quit + self._shutdown = False + + self._lock = threading.Lock() # lock for self._keys and self._futures + self._keys: MutableSet[str] = set() + self._futures: Dict[str, Optional[RespT]] = dict() + + self._akeys: MutableSet[str] = set() + self._afutures: Dict[str, asyncio.Future[RespT]] = dict() + + self._task: Optional[asyncio.Task] = None + + def start(self): + logger.trace("starting broker") + self._task = asyncio.create_task(self._loop()) + return + + async def shutdown(self): + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task + self._task = None + + async def _loop(self): + """The main broker loop. + + The loop listens for any responses and resolves the corresponding futures. + """ + logger.trace("broker started") + try: + async for resp in utils.QueueAgen(self._inbound): + logger.trace("broker got: {}", resp) + key = self._resp_key(resp) + + # if the response is for a concurrent.futures.Future, + # finds/resolves it and return True. + def update(): + with self._lock: + if key in self._futures: + self._futures[key] = resp + return True + # NOTE: the clean up happens at submission time + # See self.submit() + return False + + if await asyncio.to_thread(update): + continue + + # if the previous step didn't find a corresponding future, + # looks for the asyncio.Future instead. + if key in self._afutures: + self._afutures[key].set_result(resp) + + # clean up + self._akeys.remove(key) + del self._afutures[key] + continue + + raise KeyError(f"key not found: {key}") + finally: + logger.trace("broker shutting down") + self._shutdown = True + + async def asubmit(self, req: ReqT) -> asyncio.Future[RespT]: + """Submits a request for an asyncio.Future.""" + key = self._req_key(req) + assert key not in self._keys + + f = self._afutures.get(key, None) + if f is None: + self._afutures[key] = asyncio.Future() + f = self._afutures[key] + self._akeys.add(key) + # TODO: pyright can't figure out that both alternatives in the union type is individualy type-checked + await asyncio.to_thread(self._outbound.put, req) # type: ignore + + return f + + def submit(self, req: ReqT) -> Future[RespT]: + """Submits a request for a concurrent.futures.Future. + + The future may raise CancelledError if the broker is shutting down. + """ + key = self._req_key(req) + assert key not in self._akeys + + def poll(): + with self._lock: + if key not in self._keys: + self._futures[key] = None + self._keys.add(key) + self._outbound.put(req) + + resp = None + while resp is None and not self._shutdown: + while self._futures.get(key, None) is None: + time.sleep(1) + with self._lock: + resp = self._futures.get(key, None) + if resp is not None: + self._keys.remove(key) + del self._futures[key] + + if self._shutdown: + logger.trace("broker polling task shutting down") + raise CancelledError("shutting down") + + assert resp is not None + return resp + + return self._executor.submit(poll) + + +class WorkflowRunFutures: + """A workflow run listener to be used in the main process. + + It is a high-level interface that wraps a RequestResponseBroker. + """ + + def __init__( + self, + *, + executor: ThreadPoolExecutor, + broker: RequestResponseBroker["messages.Message", "messages.Message"], + ): + self._broker = broker + self._executor = executor + + def start(self): + logger.trace("starting main-process workflow run listener") + self._broker.start() + + async def shutdown(self): + logger.trace("shutting down main-process workflow run listener") + await self._broker.shutdown() + logger.trace("bye: main-process workflow run listener") + + def submit(self, req: SubscribeToWorkflowRunsRequest) -> Future[WorkflowRunEvent]: + logger.trace("requesting workflow run result: {}", MessageToDict(req)) + f = self._broker.submit( + messages.Message(_subscribe_to_workflow_run=MessageToDict(req)) + ) + return self._executor.submit(lambda: f.result().workflow_run_event) + + async def asubmit( + self, req: SubscribeToWorkflowRunsRequest + ) -> asyncio.Future[WorkflowRunEvent]: + logger.trace("requesting workflow run result: {}", MessageToDict(req)) + f = await self._broker.asubmit( + messages.Message(_subscribe_to_workflow_run=MessageToDict(req)) + ) + event: asyncio.Future[WorkflowRunEvent] = asyncio.Future() + f.add_done_callback(lambda f: event.set_result(f.result().workflow_run_event)) + return event diff --git a/hatchet_sdk/v2/runtime/listeners.py b/hatchet_sdk/v2/runtime/listeners.py new file mode 100644 index 00000000..d7bcadd5 --- /dev/null +++ b/hatchet_sdk/v2/runtime/listeners.py @@ -0,0 +1,277 @@ +import asyncio +from collections.abc import AsyncGenerator, Callable +from contextlib import suppress +from dataclasses import dataclass +from typing import Any, Dict, Generic, Set, TypeVar + +import grpc +from google.protobuf.json_format import MessageToDict +from loguru import logger + +import hatchet_sdk.v2.runtime.connection as connection +import hatchet_sdk.v2.runtime.context as context +import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.utils as utils +import hatchet_sdk.v2.runtime.worker as worker +from hatchet_sdk.contracts.dispatcher_pb2 import ( + AssignedAction, + StepActionEvent, + SubscribeToWorkflowRunsRequest, + WorkerListenRequest, + WorkflowRunEvent, + WorkflowRunEventType, +) +from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub + + +class WorkflowRunEventListener: + """A multiplexing workflow run event listener. It should only be used in the sidecar process.""" + + @dataclass + class Sub: + """A subscription for a workflow run. This is only to be used in the sidecar process.""" + + id: str # TODO: the id is not used right now since one can only subscribe a run_id once. + run_id: str + future: asyncio.Future[WorkflowRunEvent] + + def __hash__(self): + return hash(self.id) + + def __init__(self): + logger.trace("init workflow run event listener") + + # the set of active subscriptions + self._subs: Set[WorkflowRunEventListener.Sub] = set() + + # counter used for generating subscription ids + # not thread safe + self._counter = 0 + + # index from run id to subscriptions + self._by_run_id: Dict[str, WorkflowRunEventListener.Sub] = dict() + + # queue used for iterating requests + # must be created inside the loop + self._q_request: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() + + self._task = None + + def start(self): + logger.trace("starting workflow run event listener") + self._task = asyncio.create_task( + self._loop(), name="workflow run event listener loop" + ) + + async def shutdown(self): + logger.trace("shutting down workflow run event listener") + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task + self._task = None + + async def _loop(self): + """The main listener loop. + + The loop forwards subscription requests over the grpc stream to the server while giving + out a future to the caller. Then it listens for workflow run events and resolves the futures. + """ + logger.trace("started workflow run event listener") + try: + agen = utils.ForeverAgen(self._events, exceptions=(grpc.aio.AioRpcError,)) + async for event in agen: + if isinstance(event, grpc.aio.AioRpcError): + logger.trace("encountered error, retrying: {}", event) + await self._resubscribe() + + else: + self._by_run_id[event.workflowRunId].future.set_result(event) + self._unsubscribe(event.workflowRunId) + finally: + logger.trace("bye: workflow run event listner shuts down") + + async def _events(self) -> AsyncGenerator[WorkflowRunEvent]: + """The async generator backed by server-streamed WorkflowRunEvents.""" + # keep trying until asyncio.CancelledError is raised into this coroutine + # TODO: handle retry, backoff, etc. + stub = DispatcherStub(channel=connection.ensure_background_achannel()) + requests = utils.QueueAgen(self._q_request) + + stream: grpc.aio.StreamStreamCall[ + SubscribeToWorkflowRunsRequest, WorkflowRunEvent + ] = stub.SubscribeToWorkflowRuns( + requests, + metadata=context.ensure_background_context().client._grpc_metadata(), + ) + logger.trace("stream established") + async for event in stream: + logger.trace("received workflow run event: {}", MessageToDict(event)) + assert ( + event.eventType == WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_FINISHED + ) + yield event + + async def _resubscribe(self): + logger.trace("re-subscribing all") + async with asyncio.TaskGroup() as tg: + for id in self._by_run_id.keys(): + tg.create_task( + self._q_request.put( + SubscribeToWorkflowRunsRequest(workflowRunId=id) + ) + ) + + async def subscribe(self, run_id: str) -> "WorkflowRunEventListener.Sub": + if run_id in self._by_run_id: + return self._by_run_id[run_id] + logger.trace("subscribing: {}", run_id) + await self._q_request.put(SubscribeToWorkflowRunsRequest(workflowRunId=run_id)) + sub = self.Sub(id=str(self._counter), run_id=run_id, future=asyncio.Future()) + self._subs.add(sub) + self._by_run_id[run_id] = sub + self._counter += 1 + return sub + + def _unsubscribe(self, run_id: str): + logger.trace("unsubscribing: {}", run_id) + sub = self._by_run_id.get(run_id, None) + if sub is None: + return + self._subs.remove(sub) + del self._by_run_id[run_id] + + +# TODO: use better generics with Python >= 3.12 +T = TypeVar("T") + + +class AssignedActionListner(Generic[T]): + """An assigned action listener that runs a callback on every server-streamed assigned actions.""" + + def __init__(self, *, worker: "worker.Worker", interrupt: asyncio.Queue[T]): + logger.trace("init assigned action listener") + + # used to get the worker id, which is not immediately available. + self._worker = worker + + # used to interrupt the action listener + self._interrupt = interrupt + + self._task = None + + def start( + self, async_on: Callable[[AssignedAction | grpc.aio.AioRpcError | T], Any] + ): + """Starts the assigned action listener loop. + + Args: + async_on: the callback to be invoked when an assigned action is received. + """ + logger.trace("starting assigned action listener") + self._task = asyncio.create_task(self._loop(async_on)) + + async def shutdown(self): + logger.trace("shutting down assigned action listener") + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task + self._task = None + + async def _action_stream(self) -> AsyncGenerator[AssignedAction]: + """The async generator backed by the server-streamed assigend actions.""" + stub = DispatcherStub(connection.ensure_background_achannel()) + proto = WorkerListenRequest(workerId=self._worker.id) + resp = stub.ListenV2( + proto, + metadata=context.ensure_background_context(None).client._grpc_metadata(), + ) + logger.trace("connection established") + async for action in resp: + logger.trace("assigned action: {}", MessageToDict(action)) + yield action + + async def _listen( + self, + ) -> AsyncGenerator[AssignedAction | grpc.aio.AioRpcError | T]: + """The wrapped assigned action async generator that handles retries, etc.""" + + def agen_factory(): + return utils.InterruptableAgen( + self._action_stream(), interrupt=self._interrupt, timeout=5 + ) + + agen = utils.ForeverAgen(agen_factory, exceptions=(grpc.aio.AioRpcError,)) + async for action in agen: + if isinstance(action, grpc.aio.AioRpcError): + logger.trace("encountered error, retrying: {}", action) + yield action + else: + yield action + + async def _loop( + self, async_on: Callable[[AssignedAction | grpc.aio.AioRpcError | T], Any] + ): + """The main assigned action listener loop.""" + try: + logger.trace("started assigned action listener") + async for event in self._listen(): + await async_on(event) + finally: + logger.trace("bye: assigned action listener") + + +class StepEventListener: + """A step event listener that forwards the step event from the main process to the server.""" + + def __init__(self, *, inbound: asyncio.Queue["messages.Message"]): + logger.trace("init step event listener") + self._inbound = inbound + self._stub = DispatcherStub(connection.ensure_background_channel()) + self._task = None + + def start(self): + logger.trace("starting step event listener") + self._task = asyncio.create_task(self._listen()) + + async def shutdown(self): + logger.trace("shutting down step event listener") + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task + self._task = None + + async def _message_stream(self) -> AsyncGenerator["messages.Message"]: + while True: + msg: "messages.Message" = await self._inbound.get() + assert msg.kind in [messages.MessageKind.STEP_EVENT] + logger.trace("event: {}", msg) + yield msg + + async def _listen(self): + """The main listener loop.""" + logger.trace("step event listener started") + try: + async for msg in self._message_stream(): + match msg.kind: + case messages.MessageKind.STEP_EVENT: + await self._on_step_event(msg.step_event) + case _: + raise NotImplementedError(msg.kind) + except Exception as e: + logger.exception(e) + raise + finally: + logger.debug("bye: step event listener") + + async def _on_step_event(self, e: StepActionEvent): + # TODO: need retry + logger.trace("emit step action: {}", MessageToDict(e)) + resp = await asyncio.to_thread( + self._stub.SendStepActionEvent, + e, + metadata=context.ensure_background_context().client._grpc_metadata(), + ) + logger.trace("resp: {}", MessageToDict(resp)) diff --git a/hatchet_sdk/v2/runtime/logging.py b/hatchet_sdk/v2/runtime/logging.py new file mode 100644 index 00000000..d76da215 --- /dev/null +++ b/hatchet_sdk/v2/runtime/logging.py @@ -0,0 +1,30 @@ +import asyncio +import os +import threading + +import hatchet_sdk.logger as v1 + + +def _loopid(): + try: + return id(asyncio.get_running_loop()) + except: + return -1 + + +class HatchetLogger: + def log(self, *args, **kwargs): + v1.logger.log(*args, **kwargs) + + def debug(self, *args, **kwargs): + v1.logger.debug(*args, **kwargs) + + def info(self, *args, **kwargs): + pid = str(os.getpid()) + tid = str(threading.get_ident()) + loopid = str(_loopid()) + v1.logger.info(f"{pid}, {tid}, {loopid}") + v1.logger.info(*args, **kwargs) + + +logger = HatchetLogger() diff --git a/hatchet_sdk/v2/runtime/messages.py b/hatchet_sdk/v2/runtime/messages.py new file mode 100644 index 00000000..5aa4b8e8 --- /dev/null +++ b/hatchet_sdk/v2/runtime/messages.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional + +from google.protobuf.json_format import ParseDict + +from hatchet_sdk.contracts.dispatcher_pb2 import ( + GROUP_KEY_EVENT_TYPE_COMPLETED, + GROUP_KEY_EVENT_TYPE_FAILED, + GROUP_KEY_EVENT_TYPE_STARTED, + STEP_EVENT_TYPE_COMPLETED, + STEP_EVENT_TYPE_FAILED, + STEP_EVENT_TYPE_STARTED, + ActionType, + AssignedAction, + GroupKeyActionEventType, + StepActionEvent, + StepActionEventType, + SubscribeToWorkflowRunsRequest, + WorkflowRunEvent, + WorkflowRunEventType, +) +from hatchet_sdk.worker.action_listener_process import Action + + +class MessageKind(Enum): + UNKNOWN = 0 + ACTION = 1 + STEP_EVENT = 2 + WORKFLOW_RUN_EVENT = 3 + SUBSCRIBE_TO_WORKFLOW_RUN = 4 + WORKER_ID = 5 + + +@dataclass +class Message: + """The runtime IPC message format. Note that it has to be trivially pickle-able.""" + + _action: Optional[Dict] = None + _step_event: Optional[Dict] = None + _workflow_run_event: Optional[Dict] = None + _subscribe_to_workflow_run: Optional[Dict] = None + + worker_id: Optional[str] = None + + @property + def kind(self) -> MessageKind: + if self._action is not None: + return MessageKind.ACTION + if self._step_event is not None: + return MessageKind.STEP_EVENT + if self._workflow_run_event is not None: + return MessageKind.WORKFLOW_RUN_EVENT + if self._subscribe_to_workflow_run is not None: + return MessageKind.SUBSCRIBE_TO_WORKFLOW_RUN + if self.worker_id: + return MessageKind.WORKER_ID + return MessageKind.UNKNOWN + + @property + def action(self) -> AssignedAction: + assert self._action is not None + ret = AssignedAction() + return ParseDict(self._action, ret) + + @property + def step_event(self) -> StepActionEvent: + assert self._step_event is not None + ret = StepActionEvent() + return ParseDict(self._step_event, ret) + + @property + def workflow_run_event(self) -> WorkflowRunEvent: + assert self._workflow_run_event is not None + ret = WorkflowRunEvent() + return ParseDict(self._workflow_run_event, ret) + + @property + def subscribe_to_workflow_run(self) -> SubscribeToWorkflowRunsRequest: + assert self._subscribe_to_workflow_run is not None + ret = SubscribeToWorkflowRunsRequest() + return ParseDict(self._subscribe_to_workflow_run, ret) diff --git a/hatchet_sdk/v2/runtime/registry.py b/hatchet_sdk/v2/runtime/registry.py new file mode 100644 index 00000000..724fe68e --- /dev/null +++ b/hatchet_sdk/v2/runtime/registry.py @@ -0,0 +1,32 @@ +import sys +from typing import Dict + +from loguru import logger + +import hatchet_sdk.v2.callable as callable +import hatchet_sdk.v2.hatchet as hatchet + + +class ActionRegistry: + """A registry from action names (e.g. 'namespace:func') to Hatchet's callables. + + This is intended to be used per Hatchet client instance. + """ + + def __init__(self): + self.registry: Dict[str, "callable.HatchetCallableBase"] = dict() + + def add(self, key: str, callable: "callable.HatchetCallableBase"): + if key in self.registry: + raise KeyError(f"duplicated Hatchet callable: {key}") + self.registry[key] = callable + + def register_all(self, client: "hatchet.Hatchet"): + for callable in self.registry.values(): + proto = callable._to_workflow_proto() + try: + client.admin.put_workflow(proto.name, proto) + except Exception as e: + logger.error("failed to register workflow: {}", proto.name) + logger.exception(e) + sys.exit(1) diff --git a/hatchet_sdk/v2/runtime/runner.py b/hatchet_sdk/v2/runtime/runner.py new file mode 100644 index 00000000..4beaa940 --- /dev/null +++ b/hatchet_sdk/v2/runtime/runner.py @@ -0,0 +1,191 @@ +import asyncio +import multiprocessing.queues as mpq +import queue +import time +import traceback +from contextlib import suppress +from typing import Dict, Optional, Tuple, TypeAlias, TypeVar + +from google.protobuf.json_format import MessageToDict +from google.protobuf.timestamp_pb2 import Timestamp +from loguru import logger + +import hatchet_sdk.v2.callable as callable +import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.registry as registry +import hatchet_sdk.v2.runtime.utils as utils +from hatchet_sdk.contracts.dispatcher_pb2 import ( + ActionType, + AssignedAction, + StepActionEvent, + StepActionEventType, +) + + +def _timestamp(): + ns = time.time_ns() + return Timestamp(seconds=int(ns // 1e9), nanos=int(ns % 1e9)) + + +def _format_exc(e: Exception): + trace = "".join(traceback.format_exception(e)) + return "\n".join([str(e), trace]) + + +async def _invoke( + action: AssignedAction, registry: Dict[str, "callable.HatchetCallableBase"] +) -> Tuple[str, None] | Tuple[None, Exception]: + key = action.actionId + # TODO: handle cases when it's not registered more gracefully + fn: "callable.HatchetCallableBase" = registry[key] + logger.trace("invoking: {}", repr(fn)) + try: + if isinstance(fn, callable.HatchetCallable): + logger.trace("invoking {} on a separate thread", fn._hatchet.name) + return await asyncio.to_thread(fn._run, action), None + elif isinstance(fn, callable.HatchetAwaitable): + return await fn._run(action), None + else: + raise NotImplementedError(f"unsupported callable case: {type(fn)}") + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(e) + return None, e + + +# TODO: Use better generics for Python >= 3.12 +T = TypeVar("T") +_ThreadSafeQueue: TypeAlias = queue.Queue[T] | mpq.Queue[T] + + +class RunnerLoop: + def __init__( + self, + *, + reg: "registry.ActionRegistry", + inbound: _ThreadSafeQueue["messages.Message"], # inbound queue, not owned + outbound: _ThreadSafeQueue["messages.Message"], # outbound queue, not owned + ): + logger.trace("init runner loop") + self.worker_id: Optional[str] = None + + self._registry: Dict[str, "callable.HatchetCallableBase"] = reg.registry + self._inbound = inbound + self._outbound = outbound + self._loop_task: Optional[asyncio.Task] = None + + # a dict from StepRunId to its tasks + self._tasks: Dict[str, asyncio.Task] = dict() + + def start(self): + logger.trace("starting runner loop") + self._loop_task = asyncio.create_task(self._loop(), name="runner loop") + + async def shutdown(self): + logger.trace("shutting down runner loop") + # finishing all the tasks + t = asyncio.gather(*self._tasks.values()) + await t + + if self._loop_task is not None: + self._loop_task.cancel() + with suppress(asyncio.CancelledError): + await self._loop_task + self._loop_task = None + logger.trace("bye: runner loop") + + async def _loop(self): + """The main runner loop. + + It listens for actions from the sidecar process and executes them. + """ + async for msg in utils.QueueAgen(self._inbound): + logger.trace("received: {}", msg) + assert msg.kind == messages.MessageKind.ACTION + match msg.action.actionType: + case ActionType.START_STEP_RUN: + self._on_run(msg) + case ActionType.CANCEL_STEP_RUN: + self._on_cancel(msg) + case _: + raise NotImplementedError(msg) + + def _on_run(self, msg: "messages.Message"): + async def task(): + logger.trace("running {}", msg.action.stepRunId) + try: + await self._emit_started(msg) + result, e = await _invoke(msg.action, self._registry) + if e is None: + assert result is not None + await self._emit_finished(msg, result) + else: + assert result is None + await self._emit_failed(msg, _format_exc(e)) + finally: + del self._tasks[msg.action.stepRunId] + + self._tasks[msg.action.stepRunId] = asyncio.create_task( + task(), name=msg.action.stepRunId + ) + + def _step_event(self, msg: "messages.Message", **kwargs) -> StepActionEvent: + """Makes a StepActionEvent proto.""" + base = StepActionEvent( + jobId=msg.action.jobId, + jobRunId=msg.action.jobRunId, + stepId=msg.action.stepId, + stepRunId=msg.action.stepRunId, + actionId=msg.action.actionId, + eventTimestamp=_timestamp(), + ) + base.MergeFrom(StepActionEvent(**kwargs)) + return base + + def _on_cancel(self, msg: "messages.Message"): + # TODO + pass + + async def _emit_started(self, msg: "messages.Message"): + await self._send( + messages.Message( + _step_event=MessageToDict( + self._step_event( + msg, eventType=StepActionEventType.STEP_EVENT_TYPE_STARTED + ) + ) + ) + ) + + async def _emit_finished(self, msg: "messages.Message", payload: str): + await self._send( + messages.Message( + _step_event=MessageToDict( + self._step_event( + msg, + eventType=StepActionEventType.STEP_EVENT_TYPE_COMPLETED, + eventPayload=payload, + ) + ) + ) + ) + + async def _emit_failed(self, msg: "messages.Message", payload: str): + await self._send( + messages.Message( + _step_event=MessageToDict( + self._step_event( + msg, + eventType=StepActionEventType.STEP_EVENT_TYPE_FAILED, + eventPayload=payload, + ) + ) + ) + ) + + async def _send(self, msg: "messages.Message"): + """Sends a message to the sidecar process.""" + logger.trace("send: {}", msg) + # TODO: pyright could not figure this out + await asyncio.to_thread(self._outbound.put, msg) # type: ignore diff --git a/hatchet_sdk/v2/runtime/runtime.py b/hatchet_sdk/v2/runtime/runtime.py new file mode 100644 index 00000000..72470b71 --- /dev/null +++ b/hatchet_sdk/v2/runtime/runtime.py @@ -0,0 +1,130 @@ +import asyncio +import multiprocessing as mp +import queue +import threading +from concurrent.futures import CancelledError +from contextlib import suppress + +from loguru import logger + +import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.future as future +import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.runner as runner +import hatchet_sdk.v2.runtime.utils as utils +import hatchet_sdk.v2.runtime.worker as worker + + +class Runtime: + """The Hatchet runtime. + + The runtime is managine the runner on the main process, the run event listener on the main process, + and the worker on the sidecar process, together with the queues among them. A Hatchet client should + only contain one Runtime object. The behavior will be undefined if there are multiple Runtime per + Hatchet client. + """ + + # TODO: rename WorkerOptions to RuntimeOptions. + def __init__(self, *, client: "hatchet.Hatchet", options: "worker.WorkerOptions"): + logger.trace("init runtime") + + self._client = client + self._executor = client.executor + + # the main queues between the sidecar process and the main process + self._to_worker: mp.Queue["messages.Message"] = mp.Queue() + self._from_worker: mp.Queue["messages.Message"] = mp.Queue() + + # the queue to the runner on the main process + self._to_runner = queue.Queue() + + # the queue to the workflow run event listener on the main process + self._to_wfr_futures = queue.Queue() + + # the worker on the sidecar process + self._worker = worker.WorkerProcess( + config=client.config, + inbound=self._to_worker, + outbound=self._from_worker, + options=options, + ) + self.worker_id = None + + # the runner on the main process + self._runner = runner.RunnerLoop( + reg=client.registry, + inbound=self._to_runner, + outbound=self._to_worker, + ) + + # the workflow run event listener on the main process + self._wfr_futures = future.WorkflowRunFutures( + executor=self._executor, + broker=future.RequestResponseBroker( + inbound=self._to_wfr_futures, + outbound=self._to_worker, + req_key=lambda msg: msg.subscribe_to_workflow_run.workflowRunId, + resp_key=lambda msg: msg.workflow_run_event.workflowRunId, + executor=self._executor, + ), + ) + + # the shutdown signal + self._shutdown = threading.Event() + self._loop_task = None + + async def _loop(self): + async for msg in utils.QueueAgen(self._from_worker): + match msg.kind: + case messages.MessageKind.ACTION: + await asyncio.to_thread(self._to_runner.put, msg) + case messages.MessageKind.WORKFLOW_RUN_EVENT: + await asyncio.to_thread(self._to_wfr_futures.put, msg) + case messages.MessageKind.WORKER_ID: + self._runner.worker_id = msg.worker_id + self.worker_id = msg.worker_id + case _: + raise NotImplementedError + if self._shutdown.is_set(): + break + + logger.trace("bye: runtime") + + async def start(self): + logger.debug("starting runtime") + + # NOTE: the order matters, we should start things in topological order + self._runner.start() + self._wfr_futures.start() + + # schedule the runtime on a separate thread + self._loop_task = self._executor.submit(asyncio.run, self._loop()) + + self._worker.start() + while self.worker_id is None: + await asyncio.sleep(1) + + logger.debug("runtime started") + return self.worker_id + + async def shutdown(self): + logger.trace("shutting down runtime") + + # NOTE: the order matters, we should shut things down in topological order + self._worker.shutdown() + self._from_worker.close() + self._from_worker.join_thread() + + await self._runner.shutdown() + self._to_worker.close() + self._to_worker.join_thread() + + await self._wfr_futures.shutdown() + + self._shutdown.set() + + if self._loop_task is not None: + with suppress(CancelledError): + self._loop_task.result(timeout=10) + + logger.debug("bye: runtime") diff --git a/hatchet_sdk/v2/runtime/utils.py b/hatchet_sdk/v2/runtime/utils.py new file mode 100644 index 00000000..9112a0e6 --- /dev/null +++ b/hatchet_sdk/v2/runtime/utils.py @@ -0,0 +1,102 @@ +import asyncio +import multiprocessing.queues as mpq +import queue +from collections.abc import AsyncGenerator, Callable +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import suppress +from typing import Tuple, Type, TypeVar + +T = TypeVar("T") +I = TypeVar("I") +R = TypeVar("R") + + +async def InterruptableAgen( + agen: AsyncGenerator[T], + interrupt: asyncio.Queue[I], + timeout: float, +) -> AsyncGenerator[T | I]: + queue: asyncio.Queue[T | StopAsyncIteration] = asyncio.Queue() + + async def producer(): + async for item in agen: + await queue.put(item) + await queue.put(StopAsyncIteration()) + + producer_task = None + try: + producer_task = asyncio.create_task(producer()) + while True: + with suppress(asyncio.TimeoutError): + item = await asyncio.wait_for(queue.get(), timeout=timeout) + # it is not timeout if we reach this line + if isinstance(item, StopAsyncIteration): + break + else: + yield item + + with suppress(asyncio.QueueEmpty): + v = interrupt.get_nowait() + # we are interrupted if we reach this line + yield v + break + + finally: + if producer_task: + producer_task.cancel() + await producer_task + + +E = TypeVar("E") + + +async def ForeverAgen( + agen_factory: Callable[[], AsyncGenerator[T]], exceptions: Tuple[Type[E]] +) -> AsyncGenerator[T | E]: + """Run a async generator forever until its cancelled. + + Args: + agen_factory: a callable that returns the async generator of type T + exceptions: a tuple of exceptions that should be suppressed and yielded. + Exceptions not listed here will be re-raised. + + Returns: + An async generator that yields T or yields the suppressed exceptions. + """ + while True: + agen = agen_factory() + try: + async for item in agen: + yield item + except Exception as e: + if isinstance(e, exceptions): + yield e + else: + raise + + +async def QueueAgen( + inbound: queue.Queue[T] | asyncio.Queue[T] | mpq.Queue[T], +) -> AsyncGenerator[T]: + if isinstance(inbound, asyncio.Queue): + while True: + yield await inbound.get() + inbound.task_done() + elif isinstance(inbound, queue.Queue): + while True: + yield await asyncio.to_thread(inbound.get) + inbound.task_done() + elif isinstance(inbound, mpq.Queue): + while True: + yield await asyncio.to_thread(inbound.get) + else: + raise TypeError(f"unsupported queue type: {type(inbound)}") + + +def MapFuture( + fn: Callable[[T], R], fut: Future[T], pool: ThreadPoolExecutor +) -> Future[R]: + def task(fn: Callable[[T], R], fut: Future[T]) -> R: + return fn(fut.result()) + + return pool.submit(task, fn, fut) diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py new file mode 100644 index 00000000..b00365ea --- /dev/null +++ b/hatchet_sdk/v2/runtime/worker.py @@ -0,0 +1,344 @@ +import asyncio +import multiprocessing as mp +import multiprocessing.queues as mpq +import multiprocessing.synchronize as mps +import sys +import time +from contextlib import suppress +from dataclasses import dataclass, field +from typing import Dict, List, Optional, TypeVar + +import grpc +from google.protobuf import timestamp_pb2 +from google.protobuf.json_format import MessageToDict +from loguru import logger + +import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.config as config +import hatchet_sdk.v2.runtime.connection as connection +import hatchet_sdk.v2.runtime.context as context +import hatchet_sdk.v2.runtime.listeners as listeners +import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.utils as utils +from hatchet_sdk.contracts.dispatcher_pb2 import ( + AssignedAction, + HeartbeatRequest, + WorkerLabels, + WorkerRegisterRequest, + WorkerRegisterResponse, + WorkflowRunEvent, +) +from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub + + +# TODO: change it to RuntimeOptions +@dataclass +class WorkerOptions: + """Options for the runtime behavior of a Runtime.""" + + name: str + actions: List[str] + slots: int = 5 + debug: bool = False + labels: Dict[str, str | int] = field(default_factory=dict) + heartbeat: int = 4 # heartbeat period in seconds + + @property + def labels_proto(self) -> Dict[str, WorkerLabels]: + ret = dict() + for k, v in self.labels.items(): + if isinstance(v, int): + ret[k] = WorkerLabels(intValue=v) + else: + ret[k] = WorkerLabels(strValue=str(v)) + return ret + + +class HeartBeater: + def __init__(self, worker: "Worker"): + logger.trace("init heartbeater") + self._worker = worker # used to access worker id + self._stub = DispatcherStub(connection.ensure_background_channel()) + + self.last_heartbeat: int = -1 # unix epoch in seconds + self.missed = 0 + self.error = 0 + + self._task = None + + async def start(self): + logger.trace("starting heart beater") + self._task = asyncio.create_task(self._heartbeat()) + while self.last_heartbeat < 0: + await asyncio.sleep(1) + + async def shutdown(self): + logger.trace("shutting down heart beater") + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task + self._task = None + + async def _heartbeat(self): + """The main heart beater loop.""" + try: + while True: + now = int(time.time()) + proto = HeartbeatRequest( + workerId=self._worker.id, + heartbeatAt=timestamp_pb2.Timestamp(seconds=now), # TODO + ) + try: + _ = self._stub.Heartbeat( + proto, + timeout=5, + metadata=context.ensure_background_context().client._grpc_metadata(), + ) + logger.debug("heartbeat") + except grpc.RpcError as e: + # TODO + logger.exception(e) + self.error += 1 + + if self.last_heartbeat < 0: + self.last_heartbeat = now + else: + diff = proto.heartbeatAt.seconds - self.last_heartbeat + if diff > self._worker.options.heartbeat: + self.missed += 1 + await asyncio.sleep(self._worker.options.heartbeat) + except Exception as e: + logger.exception(e) + raise + + finally: + logger.debug("bye") + + +T = TypeVar("T") + + +class Worker: + """The main worker logic for the sidecar process.""" + + def __init__( + self, + *, + options: WorkerOptions, + client: "hatchet.Hatchet", + inbound: mpq.Queue["messages.Message"], + outbound: mpq.Queue["messages.Message"], + ): + logger.trace("init worker") + context.ensure_background_context(client=client) + + self.id: Optional[str] = None + self.options = options + + # the main queues to/from the main process + self._inbound = inbound + self._outbound = outbound + + self._heartbeater = HeartBeater(self) + + # used to interrupt the action listener + # TODO: need to hook this up to the heart beater so that the exceptions from heart beater can interrupt the action listener + self._action_listener_interrupt: asyncio.Queue[StopAsyncIteration] = ( + asyncio.Queue() + ) + self._action_listener = listeners.AssignedActionListner( + worker=self, + interrupt=self._action_listener_interrupt, + ) + + # the step event forwarder + self._to_event_listner: asyncio.Queue["messages.Message"] = asyncio.Queue() + self._event_listner = listeners.StepEventListener( + inbound=self._to_event_listner + ) + + # the workflow run listener + self._workflow_run_event_listener = listeners.WorkflowRunEventListener() + + self._main_loop_task = None + + def _register(self) -> str: + req = self._to_register_proto() + logger.trace("registering worker: {}", MessageToDict(req)) + resp: WorkerRegisterResponse = ( + context.ensure_background_context().client.dispatcher.client.Register( + req, + timeout=30, + metadata=context.ensure_background_context().client._grpc_metadata(), + ) + ) + logger.debug("worker registered: {}", MessageToDict(resp)) + return resp.workerId + + async def start(self) -> str: + logger.trace("starting worker") + self.id = self._register() + + # NOTE: order matters, we start them in topological order + self._event_listner.start() + self._workflow_run_event_listener.start() + self._action_listener.start(async_on=self._on_assigned_action) + + self._main_loop_task = asyncio.create_task(self._loop()) + + await self._heartbeater.start() + + # notify the worker id to the main process + await asyncio.to_thread(self._outbound.put, messages.Message(worker_id=self.id)) + + logger.debug("worker started: {}", self.id) + return self.id + + async def shutdown(self): + logger.trace("shutting down worker {}", self.id) + + if self._main_loop_task: + self._main_loop_task.cancel() + with suppress(asyncio.CancelledError): + await self._main_loop_task + self._main_loop_task = None + + tg: asyncio.Future = asyncio.gather( + self._heartbeater.shutdown(), + self._event_listner.shutdown(), + self._action_listener.shutdown(), + self._workflow_run_event_listener.shutdown(), + ) + await tg + logger.debug("bye: worker {}", self.id) + + async def _loop(self): + try: + async for msg in utils.QueueAgen(self._inbound): + logger.trace("worker received msg: {}", msg) + match msg.kind: + case messages.MessageKind.STEP_EVENT: + await self._to_event_listner.put(msg) + case messages.MessageKind.SUBSCRIBE_TO_WORKFLOW_RUN: + await self._on_workflow_run_subscription(msg) + case _: + raise NotImplementedError + except Exception as e: + logger.exception(e) + raise + finally: + logger.trace("bye: worker") + + async def _on_assigned_action( + self, action: StopAsyncIteration | grpc.aio.AioRpcError | AssignedAction + ): + if isinstance(action, StopAsyncIteration): + # interrupted, ignore + pass + elif isinstance(action, grpc.aio.AioRpcError): + # errored out, ignored + pass + else: + assert isinstance(action, AssignedAction) + msg = messages.Message(_action=MessageToDict(action)) + await asyncio.to_thread(self._outbound.put, msg) + + async def _on_workflow_run_subscription(self, msg: "messages.Message"): + def callback(f: asyncio.Future[WorkflowRunEvent]): + logger.trace("workflow run event future resolved") + self._outbound.put( + messages.Message(_workflow_run_event=MessageToDict(f.result())) + ) + + sub = await self._workflow_run_event_listener.subscribe( + msg.subscribe_to_workflow_run.workflowRunId + ) + sub.future.add_done_callback(callback) + + def _to_register_proto(self) -> WorkerRegisterRequest: + options = self.options + proto = WorkerRegisterRequest( + workerName=options.name, + services=["default"], + actions=list(options.actions), + maxRuns=options.slots, + labels=options.labels_proto, + ) + return proto + + +def _worker_process( + config: "config.ClientConfig", + options: WorkerOptions, + inbound: mpq.Queue["messages.Message"], + outbound: mpq.Queue["messages.Message"], + shutdown: mps.Event, +): + """The worker process logic. + + It has to be a top-level function since it needs to be pickled. + """ + # TODO: propagate options, debug, etc. + client = hatchet.Hatchet(config=config, debug=True) + + # TODO: re-configure the loggers based on the options, etc. + logger.remove() + logger.add(sys.stdout, level="TRACE") + + # FIXME: the loop is not exiting correctly. It hangs, instead. Investigate why. + async def loop(): + worker = Worker( + client=client, + inbound=inbound, + outbound=outbound, + options=options, + ) + try: + _ = await worker.start() + while not await asyncio.to_thread(shutdown.wait, 1): + pass + # asyncio.current_task().cancel() + except Exception as e: + logger.exception(e) + raise + finally: + with suppress(asyncio.CancelledError): + await worker.shutdown() + logger.trace("worker process shuts down") + + asyncio.run(loop(), debug=True) + logger.trace("bye: worker process") + + +class WorkerProcess: + """A wrapper to control the sidecar worker process.""" + + def __init__( + self, + *, + config: "config.ClientConfig", + options: WorkerOptions, + inbound: mpq.Queue["messages.Message"], + outbound: mpq.Queue["messages.Message"], + ): + self._to_worker = inbound + self._shutdown_ev = mp.Event() + self.proc = mp.Process( + target=_worker_process, + kwargs={ + "config": config, + "options": options, + "inbound": inbound, + "outbound": outbound, + "shutdown": self._shutdown_ev, + }, + ) + + def start(self): + logger.debug("starting worker process") + self.proc.start() + + def shutdown(self): + self._shutdown_ev.set() + logger.debug("worker process shuts down") diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index d37da955..54294837 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -33,7 +33,8 @@ ) from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger -from hatchet_sdk.v2.callable import DurableContext + +# from hatchet_sdk.v2.callable import DurableContext from hatchet_sdk.worker.action_listener_process import ActionEvent wr: contextvars.ContextVar[str | None] = contextvars.ContextVar( @@ -278,6 +279,9 @@ async def async_wrapped_action_func( wr.set(context.workflow_run_id()) sr.set(context.step_run_id) + if hasattr(action_func, "_run"): + action_func = functools.partial(action_func._run, action_func) + try: if ( hasattr(action_func, "is_coroutine") and action_func.is_coroutine @@ -326,32 +330,32 @@ async def handle_start_step_run(self, action: Action): # Find the corresponding action function from the registry action_func = self.action_registry.get(action_name) - context: Context | DurableContext - - if hasattr(action_func, "durable") and action_func.durable: - context = DurableContext( - action, - self.dispatcher_client, - self.admin_client, - self.client.event, - self.client.rest, - self.client.workflow_listener, - self.workflow_run_event_listener, - self.worker_context, - self.client.config.namespace, - ) - else: - context = Context( - action, - self.dispatcher_client, - self.admin_client, - self.client.event, - self.client.rest, - self.client.workflow_listener, - self.workflow_run_event_listener, - self.worker_context, - self.client.config.namespace, - ) + context: Context #| DurableContext + + # if hasattr(action_func, "durable") and action_func.durable: + # context = DurableContext( + # action, + # self.dispatcher_client, + # self.admin_client, + # self.client.event, + # self.client.rest, + # self.client.workflow_listener, + # self.workflow_run_event_listener, + # self.worker_context, + # self.client.config.namespace, + # ) + # else: + context = Context( + action, + self.dispatcher_client, + self.admin_client, + self.client.event, + self.client.rest, + self.client.workflow_listener, + self.workflow_run_event_listener, + self.worker_context, + self.client.config.namespace, + ) self.contexts[action.step_run_id] = context diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 315f2f4a..36c6b07d 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -13,7 +13,8 @@ from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger -from hatchet_sdk.v2.callable import HatchetCallable + +# from hatchet_sdk.v2.callable import HatchetCallable from hatchet_sdk.worker.action_listener_process import worker_action_listener_process from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager from hatchet_sdk.workflow import WorkflowMeta @@ -62,9 +63,10 @@ def __post_init__(self): self.name = self.client.config.namespace + self.name self._setup_signal_handlers() - def register_function(self, action: str, func: HatchetCallable): + def register_function(self, action: str, func): self.action_registry[action] = func + # TODO: why do it on the worker, it seems unrelated. we should do that on the registry def register_workflow_from_opts(self, name: str, opts: CreateWorkflowVersionOpts): try: self.client.admin.put_workflow(opts.name, opts) @@ -118,7 +120,7 @@ def setup_loop(self, loop: asyncio.AbstractEventLoop = None): def start(self, options: WorkerStartOptions = WorkerStartOptions()): created_loop = self.setup_loop(options.loop) - f = asyncio.run_coroutine_threadsafe( + self.result_f = asyncio.run_coroutine_threadsafe( self.async_start(options, _from_start=True), self.loop ) # start the loop and wait until its closed @@ -127,7 +129,7 @@ def start(self, options: WorkerStartOptions = WorkerStartOptions()): if self.handle_kill: sys.exit(0) - return f + return self.result_f ## Start methods async def async_start( @@ -263,7 +265,7 @@ async def exit_gracefully(self): self.action_listener_process.kill() await self.close() - + # self.result_f.set_result("") if self.loop: self.loop.stop() @@ -285,20 +287,19 @@ def exit_forcefully(self): ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup -def register_on_worker(callable: HatchetCallable, worker: Worker): - worker.register_function(callable.get_action_name(), callable) +def register_on_worker(callable, worker: Worker): + worker.register_function(callable.action_name, callable) - if callable.function_on_failure is not None: - worker.register_function( - callable.function_on_failure.get_action_name(), callable.function_on_failure - ) - - if callable.function_concurrency is not None: - worker.register_function( - callable.function_concurrency.get_action_name(), - callable.function_concurrency, - ) + # if callable.function_on_failure is not None: + # worker.register_function( + # callable.function_on_failure.action_name, callable.function_on_failure + # ) - opts = callable.to_workflow_opts() + # if callable.function_concurrency is not None: + # worker.register_function( + # callable.function_concurrency.action_name, + # callable.function_concurrency, + # ) + opts = callable._to_workflow_proto() worker.register_workflow_from_opts(opts.name, opts) diff --git a/pyproject.toml b/pyproject.toml index 4e94f09c..eb96e1bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ known_third_party = [ "pyyaml", "urllib3", ] +skip = [ + "hatchet_sdk/contracts", +] [tool.poetry.scripts] api = "examples.api.api:main" diff --git a/tests/v2/__init__.py b/tests/v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/v2/test_broker.py b/tests/v2/test_broker.py new file mode 100644 index 00000000..5e61d593 --- /dev/null +++ b/tests/v2/test_broker.py @@ -0,0 +1,61 @@ +import asyncio +import logging +import sys +import queue +import threading + +from concurrent.futures import ThreadPoolExecutor + +# import dotenv +import pytest +from loguru import logger + +# from hatchet_sdk.v2.hatchet import Hatchet +from hatchet_sdk.v2.runtime.broker import QueueToFutureBroker + +logger.remove() +logger.add(sys.stdout, level="TRACE") + +# dotenv.load_dotenv() + +# hatchet = Hatchet(debug=True) + +logging.getLogger("asyncio").setLevel(logging.DEBUG) + + +to_broker = queue.Queue() +to_server = queue.Queue() +exec = ThreadPoolExecutor() +broker = QueueToFutureBroker( + inbound=to_broker, + outbound=to_server, + req_key=lambda x: x, + resp_key=lambda x: x, + executor=exec, +) + + +def echo(p: queue.Queue, q: queue.Queue): + while True: + item = p.get() + logger.trace("echo {}", item) + q.put(item) + + +echo_f = exec.submit(echo, to_server, to_broker) + + +# def test_broker(): +# fut = exec.submit(asyncio.run, broker.loop()) +# f = broker.submit(1) +# print(f.result()) +# fut.cancel() + + +@pytest.mark.asyncio +async def test_broker_async(): + task = asyncio.create_task(broker.loop()) + f = await broker.asubmit(2) + print(await f) + task.cancel() + diff --git a/tests/v2/test_listeners.py b/tests/v2/test_listeners.py new file mode 100644 index 00000000..713b2098 --- /dev/null +++ b/tests/v2/test_listeners.py @@ -0,0 +1,36 @@ +import asyncio +import logging +import sys + +import dotenv +import pytest +from loguru import logger + +from hatchet_sdk.v2.hatchet import Hatchet +from hatchet_sdk.v2.runtime.listeners import WorkflowRunEventListener + +logger.remove() +logger.add(sys.stdout, level="TRACE") + +dotenv.load_dotenv() + +hatchet = Hatchet(debug=True) + +logging.getLogger("asyncio").setLevel(logging.DEBUG) + + +async def interrupt(listener): + await asyncio.sleep(2) + logger.trace("interupt") + await listener._interrupt() + logger.trace("interrupted") + + +@pytest.mark.asyncio +async def test_listener_shutdown(): + listener = WorkflowRunEventListener() + task = asyncio.create_task(listener.loop()) + task2 = asyncio.create_task(interrupt(listener)) + sub = await listener.subscribe("bar-vj13ex/bar") + await sub.future + await task diff --git a/tests/v2/test_simple.py b/tests/v2/test_simple.py new file mode 100644 index 00000000..0db79a25 --- /dev/null +++ b/tests/v2/test_simple.py @@ -0,0 +1,41 @@ +import asyncio + +import pytest + + +def get_client(): + import dotenv + + from hatchet_sdk.v2.hatchet import Hatchet + + dotenv.load_dotenv() + return Hatchet(debug=True) + + +hatchet = get_client() + + +@hatchet.function() +async def foo(a: int): + print(f"in foo: a={a}") + return bar(b=3) + + +@hatchet.function() +def bar(b: int): + print(f"in bar: b={b}") + return b + + +# def test_trace(): +# import json + +# print(json.dumps(foo._debug(), indent=2)) + + +@pytest.mark.asyncio(scope="session") +async def test_run(): + worker = hatchet.worker("worker", max_runs=5) + c = foo(a=1) + worker.start() + print(await c) diff --git a/tests/v2/test_utils.py b/tests/v2/test_utils.py new file mode 100644 index 00000000..7629f87a --- /dev/null +++ b/tests/v2/test_utils.py @@ -0,0 +1,37 @@ +import asyncio +import logging + + +import pytest +from loguru import logger + +from hatchet_sdk.v2.runtime.utils import InterruptableAgen, ForeverAgen + + +logging.getLogger("asyncio").setLevel(logging.DEBUG) + + +async def producer(): + for i in range(10): + await asyncio.sleep(0.5) + logger.info("yielding {}", i) + yield i + + +async def consumer(agen): + async for item in agen: + logger.info("consuming: {}", item) + + +@pytest.mark.asyncio +async def test_interruptable_agen(): + + q = asyncio.Queue() + + agen_factory = lambda: InterruptableAgen(producer(), q, 1) + agen = ForeverAgen(agen_factory) + + async with asyncio.TaskGroup() as tg: + tg.create_task(consumer(agen)) + await asyncio.sleep(2) + await q.put({}) diff --git a/tests/v2/test_worker.py b/tests/v2/test_worker.py new file mode 100644 index 00000000..38f1f7eb --- /dev/null +++ b/tests/v2/test_worker.py @@ -0,0 +1,72 @@ +import time +import asyncio +import logging +import sys +import multiprocessing as mp +import dotenv +import pytest +from loguru import logger + +from hatchet_sdk.v2.hatchet import Hatchet +from hatchet_sdk.v2.runtime.worker import WorkerOptions, WorkerProcess +from concurrent.futures import ThreadPoolExecutor + +# logger.remove() +# logger.add( +# sys.stdout, level="TRACE" +# ) # , format="{level}\t|{module}:{function}:{line}[{process}:{thread}] - {message}") + +dotenv.load_dotenv() + +hatchet = Hatchet(debug=True) + +logging.getLogger("asyncio").setLevel(logging.DEBUG) + + +@hatchet.function() +def foo(): + print("entering Foo") + print("result from bar: ", bar("from foo").result()) + return "foo" + + +@hatchet.function() +def bar(x): + print("entering Bar") + print("arguments for bar: ", x) + return "bar" + + +@pytest.mark.asyncio +async def test_worker(): + + worker = hatchet.worker( + options=WorkerOptions(name="worker", actions=["default:foo", "default:bar"]) + ) + await worker.start() + print("result from foo: ", await asyncio.to_thread(foo().result)) + await asyncio.sleep(10) + await worker.shutdown() + return None + + +# def test_worker_process(): +# to_worker = mp.Queue() +# from_worker = mp.Queue() +# p = WorkerProcess( +# config=hatchet.config, +# options=WorkerOptions(name="worker", actions=[]), +# inbound=to_worker, +# outbound=from_worker, +# ) + +# pool = ThreadPoolExecutor() +# id = pool.submit(from_worker.get) +# print(p.start()) +# print(id.result()) +# time.sleep(10) +# print("shutting down") +# p.shutdown() + +# to_worker.close() +# from_worker.close()