Skip to content

Commit

Permalink
fix(engine): Add run_id to ActionRun store identifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed Mar 4, 2024
1 parent 379bad5 commit c6d3ca1
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 37 deletions.
115 changes: 88 additions & 27 deletions tracecat/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,30 @@
class ActionRun(BaseModel):
"""A run of an action to be executed as part of a workflow run."""

run_id: str
run_id: str = Field(frozen=True)
run_kwargs: dict[str, Any] | None = None
action_key: str = Field(pattern=ACTION_KEY_PATTERN, max_length=50)
action_key: str = Field(pattern=ACTION_KEY_PATTERN, max_length=50, frozen=True)

@property
def id(self) -> str:
"""The unique identifier of the action run.
The action key tells us where to find the action in the workflow graph.
The run ID tells us which workflow run the action is part of.
We need both to uniquely identify an action run.
"""
return get_action_run_id(self.run_id, self.action_key)

def __hash__(self) -> int:
return hash(f"{self.run_id}:{self.action_key}")

def __eq__(self, other: Any) -> bool:
match other:
case ActionRun(run_id=self.run_id, action_key=self.action_key):
return True
case _:
return False


class ActionRunStatus(StrEnum):
Expand Down Expand Up @@ -178,6 +199,35 @@ class LLMAction(Action):
"llm": LLMAction,
}

ACTION_RUN_ID_PREFIX = "ar"


def get_action_run_id(
run_id: str, action_key: str, *, prefix: str = ACTION_RUN_ID_PREFIX
) -> str:
return f"{prefix}:{action_key}:{run_id}"


def parse_action_run_id(ar_id: str, component: Literal["action_key", "run_id"]) -> str:
"""Parse an action run ID and return the action key or the run ID.
Example
-------
>>> parse_action_run_id("ar:TEST-WORKFLOW-ID.receive_sentry_event:RUN_ID", "action_key")
"TEST-WORKFLOW-ID.receive_sentry_event"
>>> parse_action_run_id("ar:TEST-WORKFLOW-ID.receive_sentry_event:RUN_ID", "run_id")
"RUN_ID"
"""
if not ar_id.startswith(f"{ACTION_RUN_ID_PREFIX}:"):
raise ValueError(f"Invalid action run ID {ar_id!r}")
match component:
case "action_key":
return ar_id.split(":")[1]
case "run_id":
return ar_id.split(":")[2]
case _:
raise ValueError(f"Invalid component {component!r}")


def action_key_to_workflow_id(action_key: str) -> str:
return action_key.split(".")[0]
Expand Down Expand Up @@ -277,9 +327,13 @@ def _get_dependencies_results(


async def _wait_for_dependencies(
dependencies: Iterable[str], task_status: dict[str, ActionRunStatus]
upstream_deps_ar_ids: Iterable[str],
action_run_status_store: dict[str, ActionRunStatus],
) -> None:
while not all(task_status.get(d) == ActionRunStatus.SUCCESS for d in dependencies):
while not all(
action_run_status_store.get(ar_id) == ActionRunStatus.SUCCESS
for ar_id in upstream_deps_ar_ids
):
await asyncio.sleep(random.uniform(0, 0.5))


Expand All @@ -295,26 +349,29 @@ async def start_action_run(
pending_timeout: float | None = None,
custom_logger: logging.Logger | None = None,
) -> None:
ar_id = action_run.id
action_key = action_run.action_key
upstream_deps = workflow_ref.action_dependencies[action_key]
downstream_deps = workflow_ref.adj_list[action_key]
upstream_deps_ar_ids = [
get_action_run_id(action_run.run_id, k)
for k in workflow_ref.action_dependencies[action_key]
]

custom_logger = custom_logger or logger
custom_logger.debug(
f"Action {action_key} waiting for dependencies {upstream_deps}."
f"Action run {ar_id} waiting for dependencies {upstream_deps_ar_ids}."
)
try:
await asyncio.wait_for(
_wait_for_dependencies(upstream_deps, action_run_status_store),
_wait_for_dependencies(upstream_deps_ar_ids, action_run_status_store),
timeout=pending_timeout,
)

action_trail = _get_dependencies_results(upstream_deps, action_result_store)

custom_logger.debug(
f"Running action {action_key!r}. Trail {action_trail.keys()}."
action_trail = _get_dependencies_results(
upstream_deps_ar_ids, action_result_store
)
action_run_status_store[action_key] = ActionRunStatus.RUNNING

custom_logger.debug(f"Running action {ar_id!r}. Trail {action_trail.keys()}.")
action_run_status_store[ar_id] = ActionRunStatus.RUNNING
action_ref = workflow_ref.action_map[action_key]
result = await run_action(
custom_logger=custom_logger,
Expand All @@ -324,39 +381,43 @@ async def start_action_run(
)

# Mark the action as completed
action_run_status_store[action_key] = ActionRunStatus.SUCCESS
action_run_status_store[action_run.id] = ActionRunStatus.SUCCESS

# Store the result in the action result store.
# Every action has its own result and the trail of actions that led to it.
# The schema is {<action ID> : <action result>, ...}
action_result_store[action_key] = action_trail | {action_key: result}
custom_logger.debug(f"Action {action_key!r} completed with result {result}.")
action_result_store[ar_id] = action_trail | {ar_id: result}
custom_logger.debug(f"Action run {ar_id!r} completed with result {result}.")

downstream_deps_ar_ids = [
get_action_run_id(action_run.run_id, k)
for k in workflow_ref.adj_list[action_key]
]
# Broadcast the results to the next actions and enqueue them
for next_action_key in downstream_deps:
if next_action_key not in action_run_status_store:
action_run_status_store[next_action_key] = ActionRunStatus.QUEUED
for next_ar_id in downstream_deps_ar_ids:
if next_ar_id not in action_run_status_store:
action_run_status_store[next_ar_id] = ActionRunStatus.QUEUED
ready_jobs_queue.put_nowait(
ActionRun(
run_id=action_run.run_id,
action_key=next_action_key,
action_key=parse_action_run_id(next_ar_id, "action_key"),
)
)

except TimeoutError:
custom_logger.error(
f"Action {action_key} timed out waiting for dependencies {upstream_deps}."
f"Action run {ar_id} timed out waiting for dependencies {upstream_deps_ar_ids}."
)
except asyncio.CancelledError:
custom_logger.warning(f"Action {action_key!r} was cancelled.")
custom_logger.warning(f"Action run {ar_id!r} was cancelled.")
except Exception as e:
custom_logger.error(f"Action {action_key!r} failed with error {e}.")
custom_logger.error(f"Action run {ar_id!r} failed with error {e}.")
finally:
if action_run_status_store[action_key] != ActionRunStatus.SUCCESS:
if action_run_status_store[ar_id] != ActionRunStatus.SUCCESS:
# Exception was raised before the action was marked as successful
action_run_status_store[action_key] = ActionRunStatus.FAILURE
running_jobs_store.pop(action_key, None)
custom_logger.debug(f"Remaining tasks: {running_jobs_store.keys()}")
action_run_status_store[ar_id] = ActionRunStatus.FAILURE
running_jobs_store.pop(ar_id, None)
custom_logger.debug(f"Remaining acrion runs: {running_jobs_store.keys()}")


async def run_action(
Expand Down
20 changes: 10 additions & 10 deletions tracecat/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,27 +305,27 @@ async def run_workflow(
not ready_jobs_queue.empty() or running_jobs_store
) and runner_status == RunnerStatus.RUNNING:
try:
curr_action_run = await asyncio.wait_for(
ready_jobs_queue.get(), timeout=3
)
action_run = await asyncio.wait_for(ready_jobs_queue.get(), timeout=3)
except TimeoutError:
continue
action_key = curr_action_run.action_key
# Defensive: Deduplicate tasks
if action_key in running_jobs_store or action_key in action_result_store:
if (
action_run.id in running_jobs_store
or action_run.id in action_result_store
):
run_logger.debug(
f"Action {action_key!r} already running or completed. Skipping."
f"Action {action_run.id!r} already running or completed. Skipping."
)
continue

run_logger.info(
f"{workflow.action_map[action_key].__class__.__name__} {action_key!r} ready. Running."
f"{workflow.action_map[action_run.action_key].__class__.__name__} {action_run.id!r} ready. Running."
)
action_run_status_store[action_key] = ActionRunStatus.PENDING
action_run_status_store[action_run.id] = ActionRunStatus.PENDING
# Schedule a new action run
running_jobs_store[action_key] = asyncio.create_task(
running_jobs_store[action_run.id] = asyncio.create_task(
start_action_run(
action_run=curr_action_run,
action_run=action_run,
workflow_ref=workflow,
ready_jobs_queue=ready_jobs_queue,
running_jobs_store=running_jobs_store,
Expand Down

0 comments on commit c6d3ca1

Please sign in to comment.