Skip to content

Commit

Permalink
Fix handling of multi-node job statuses and logs (#198)
Browse files Browse the repository at this point in the history
The status needs to come from the overall job, but the log
has to come from the main node.
  • Loading branch information
pconrad-insitro authored Sep 16, 2022
1 parent 0f6339c commit 0cd06c8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 11 deletions.
39 changes: 28 additions & 11 deletions redun/executors/aws_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,6 @@ def apply_resources(container_properties: Dict) -> None:
**batch_job_args,
)

# For multi-node jobs, the rank 0 job id is sufficient for monitoring.
if num_nodes is not None:
batch_run["jobId"] = f"{batch_run['jobId']}#0"

return batch_run


Expand Down Expand Up @@ -618,6 +614,7 @@ def aws_describe_jobs(
for i in range(0, len(job_ids), chunk_size):
chunk_job_ids = job_ids[i : i + chunk_size]
response = batch_client.describe_jobs(jobs=chunk_job_ids)

for job in response["jobs"]:
yield job

Expand Down Expand Up @@ -656,6 +653,28 @@ def iter_batch_job_status(
break


def get_job_log_stream(job: Optional[dict], aws_region: str) -> Optional[str]:
"""Extract the log stream from a `JobDetail` status dictionary. For non-multi-node jobs,
(i.e., single node and array jobs), this is simply a field in the detail dictionary. But for
multi-node jobs, this requires another query to get the log stream for the main node."""
if job and "nodeProperties" in job:
# There isn't a log stream on the main job detail object. However, we can find the right
# node and query it:
main_node = job["nodeProperties"]["mainNode"]
job_id = job["jobId"]
# The docs indicate we can rely on this format for getting the per-worker jobs ids:
# `jobId#worker_id`.
jobs: Iterator[Optional[Dict[Any, Any]]] = aws_describe_jobs(
[f"{job_id}#{main_node}"], aws_region=aws_region
)
job = next(jobs, None)
if not job:
# Job is no longer present in AWS API. Return no logs.
return None

return job.get("container", {}).get("logStreamName")


def format_log_stream_event(event: dict) -> str:
"""
Format a logStream event as a line.
Expand All @@ -677,10 +696,8 @@ def iter_batch_job_logs(
"""
# Get job's log stream.
job = next(aws_describe_jobs([job_id], aws_region=aws_region), None)
if not job:
# Job is no longer present in AWS API. Return no logs.
return
log_stream = job.get("container", {}).get("logStreamName")

log_stream = get_job_log_stream(job, aws_region)
if not log_stream:
# No log stream is present. Return no logs.
return
Expand Down Expand Up @@ -993,7 +1010,7 @@ def _process_job_status(self, job: dict) -> None:
redun_job = self.pending_batch_jobs.pop(job["jobId"])
job_tags = []
job_tags.append(("aws_batch_job", job["jobId"]))
log_stream = job.get("container", {}).get("logStreamName")
log_stream = get_job_log_stream(job, aws_region=self.aws_region)
if log_stream:
job_tags.append(("aws_log_stream", log_stream))

Expand Down Expand Up @@ -1216,7 +1233,7 @@ def _submit_array_job(
" array_job_name = {job_name}\n"
" array_size = {array_size}\n"
" s3_scratch_path = {job_dir}\n"
" retry_attempts = {retries}\n".format(
" submit_retry_attempts = {retries}\n".format(
array_job_id=array_job_id,
array_size=array_size,
job_type=job_type,
Expand Down Expand Up @@ -1273,7 +1290,7 @@ def _submit_single_job(self, job: Job, args: Tuple, kwargs: dict) -> None:
" job_id = {batch_job}\n"
" job_name = {job_name}\n"
" s3_scratch_path = {job_dir}\n"
" retry_attempts = {retries}\n".format(
" submit_retry_attempts = {retries}\n".format(
redun_job=job.id,
job_type=job_type,
batch_job=batch_resp["jobId"],
Expand Down
75 changes: 75 additions & 0 deletions redun/tests/test_aws_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pickle
import time
import unittest.mock
import uuid
from typing import cast
from unittest.mock import Mock, patch
Expand All @@ -26,6 +27,7 @@
get_batch_job_name,
get_hash_from_job_name,
get_job_definition,
get_job_log_stream,
get_or_create_job_definition,
iter_batch_job_log_lines,
iter_batch_job_logs,
Expand Down Expand Up @@ -2238,3 +2240,76 @@ def other_task(x, y):
)
error, _ = pickle.loads(cast(bytes, error_file.read("rb")))
assert isinstance(error, ZeroDivisionError)


@patch("redun.executors.aws_batch.aws_describe_jobs")
def test_get_log_stream(aws_describe_jobs_mock) -> None:
"""Test the log stream fetching utility, which handles the difference in behavior for
multi-node."""

aws_region = "fake_region"

# Heavily pruned `JobDetail` messages
single_node_status = {
"jobArn": "arn:aws:batch:us-west-2:298579124006:job/89ee416c",
"jobName": "redun-testing-c813121061b56b5934d5db78730ada2e2ae11e44",
"jobId": "single_node_id",
"status": "FAILED",
"container": {
"logStreamName": "log_stream_arn",
},
"nodeDetails": {"nodeIndex": 0, "isMainNode": True},
}

# This is the overall message for a multi-node job. There is no log stream, but there is
# a `nodeProperties` field.
multi_node_status = {
"jobArn": "arn:aws:batch:us-west-2:298579124006:job/89ee416c-1ab1-457c-b728-1f10318f61bb",
"jobName": "redun-testing-c813121061b56b5934d5db78730ada2e2ae11e44",
"jobId": "multi_node_id",
"attempts": [
{
"container": {
"exitCode": 1,
"logStreamName": "cannot_be_returned",
"networkInterfaces": [],
},
"startedAt": 1663274825002,
"stoppedAt": 1663274840060,
"statusReason": "Essential container in task exited",
}
],
"createdAt": 1663274377312,
"retryStrategy": {"attempts": 2, "evaluateOnExit": []},
"startedAt": 1663274957408,
"parameters": {},
"nodeProperties": {
"numNodes": 3,
"mainNode": 0,
"nodeRangeProperties": [
{
"targetNodes": "0",
"container": {
"image": "image_arn",
},
},
{
"targetNodes": "1:",
"container": {
"image": "image_arn",
},
},
],
},
}

# No query is required for single node
aws_describe_jobs_mock.return_value = None
assert get_job_log_stream(single_node_status, aws_region) == "log_stream_arn"

# Multi node will query again and fetch from there.
aws_describe_jobs_mock.return_value = iter([single_node_status])
assert get_job_log_stream(multi_node_status, aws_region) == "log_stream_arn"
assert aws_describe_jobs_mock.call_args == unittest.mock.call(
["multi_node_id#0"], aws_region=aws_region
)

0 comments on commit 0cd06c8

Please sign in to comment.