diff --git a/redun/executors/aws_batch.py b/redun/executors/aws_batch.py index ad7cfff1..4031af50 100644 --- a/redun/executors/aws_batch.py +++ b/redun/executors/aws_batch.py @@ -391,10 +391,6 @@ def apply_resources(container_properties: Dict) -> None: **batch_job_args, ) - # For multi-node jobs, the rank 0 job id is sufficient for monitoring. - if num_nodes is not None: - batch_run["jobId"] = f"{batch_run['jobId']}#0" - return batch_run @@ -618,6 +614,7 @@ def aws_describe_jobs( for i in range(0, len(job_ids), chunk_size): chunk_job_ids = job_ids[i : i + chunk_size] response = batch_client.describe_jobs(jobs=chunk_job_ids) + for job in response["jobs"]: yield job @@ -656,6 +653,28 @@ def iter_batch_job_status( break +def get_job_log_stream(job: Optional[dict], aws_region: str) -> Optional[str]: + """Extract the log stream from a `JobDetail` status dictionary. For non-multi-node jobs, + (i.e., single node and array jobs), this is simply a field in the detail dictionary. But for + multi-node jobs, this requires another query to get the log stream for the main node.""" + if job and "nodeProperties" in job: + # There isn't a log stream on the main job detail object. However, we can find the right + # node and query it: + main_node = job["nodeProperties"]["mainNode"] + job_id = job["jobId"] + # The docs indicate we can rely on this format for getting the per-worker jobs ids: + # `jobId#worker_id`. + jobs: Iterator[Optional[Dict[Any, Any]]] = aws_describe_jobs( + [f"{job_id}#{main_node}"], aws_region=aws_region + ) + job = next(jobs, None) + if not job: + # Job is no longer present in AWS API. Return no logs. + return None + + return job.get("container", {}).get("logStreamName") + + def format_log_stream_event(event: dict) -> str: """ Format a logStream event as a line. @@ -677,10 +696,8 @@ def iter_batch_job_logs( """ # Get job's log stream. job = next(aws_describe_jobs([job_id], aws_region=aws_region), None) - if not job: - # Job is no longer present in AWS API. Return no logs. - return - log_stream = job.get("container", {}).get("logStreamName") + + log_stream = get_job_log_stream(job, aws_region) if not log_stream: # No log stream is present. Return no logs. return @@ -993,7 +1010,7 @@ def _process_job_status(self, job: dict) -> None: redun_job = self.pending_batch_jobs.pop(job["jobId"]) job_tags = [] job_tags.append(("aws_batch_job", job["jobId"])) - log_stream = job.get("container", {}).get("logStreamName") + log_stream = get_job_log_stream(job, aws_region=self.aws_region) if log_stream: job_tags.append(("aws_log_stream", log_stream)) @@ -1216,7 +1233,7 @@ def _submit_array_job( " array_job_name = {job_name}\n" " array_size = {array_size}\n" " s3_scratch_path = {job_dir}\n" - " retry_attempts = {retries}\n".format( + " submit_retry_attempts = {retries}\n".format( array_job_id=array_job_id, array_size=array_size, job_type=job_type, @@ -1273,7 +1290,7 @@ def _submit_single_job(self, job: Job, args: Tuple, kwargs: dict) -> None: " job_id = {batch_job}\n" " job_name = {job_name}\n" " s3_scratch_path = {job_dir}\n" - " retry_attempts = {retries}\n".format( + " submit_retry_attempts = {retries}\n".format( redun_job=job.id, job_type=job_type, batch_job=batch_resp["jobId"], diff --git a/redun/tests/test_aws_batch.py b/redun/tests/test_aws_batch.py index 246de3a5..090da129 100644 --- a/redun/tests/test_aws_batch.py +++ b/redun/tests/test_aws_batch.py @@ -2,6 +2,7 @@ import os import pickle import time +import unittest.mock import uuid from typing import cast from unittest.mock import Mock, patch @@ -26,6 +27,7 @@ get_batch_job_name, get_hash_from_job_name, get_job_definition, + get_job_log_stream, get_or_create_job_definition, iter_batch_job_log_lines, iter_batch_job_logs, @@ -2238,3 +2240,76 @@ def other_task(x, y): ) error, _ = pickle.loads(cast(bytes, error_file.read("rb"))) assert isinstance(error, ZeroDivisionError) + + +@patch("redun.executors.aws_batch.aws_describe_jobs") +def test_get_log_stream(aws_describe_jobs_mock) -> None: + """Test the log stream fetching utility, which handles the difference in behavior for + multi-node.""" + + aws_region = "fake_region" + + # Heavily pruned `JobDetail` messages + single_node_status = { + "jobArn": "arn:aws:batch:us-west-2:298579124006:job/89ee416c", + "jobName": "redun-testing-c813121061b56b5934d5db78730ada2e2ae11e44", + "jobId": "single_node_id", + "status": "FAILED", + "container": { + "logStreamName": "log_stream_arn", + }, + "nodeDetails": {"nodeIndex": 0, "isMainNode": True}, + } + + # This is the overall message for a multi-node job. There is no log stream, but there is + # a `nodeProperties` field. + multi_node_status = { + "jobArn": "arn:aws:batch:us-west-2:298579124006:job/89ee416c-1ab1-457c-b728-1f10318f61bb", + "jobName": "redun-testing-c813121061b56b5934d5db78730ada2e2ae11e44", + "jobId": "multi_node_id", + "attempts": [ + { + "container": { + "exitCode": 1, + "logStreamName": "cannot_be_returned", + "networkInterfaces": [], + }, + "startedAt": 1663274825002, + "stoppedAt": 1663274840060, + "statusReason": "Essential container in task exited", + } + ], + "createdAt": 1663274377312, + "retryStrategy": {"attempts": 2, "evaluateOnExit": []}, + "startedAt": 1663274957408, + "parameters": {}, + "nodeProperties": { + "numNodes": 3, + "mainNode": 0, + "nodeRangeProperties": [ + { + "targetNodes": "0", + "container": { + "image": "image_arn", + }, + }, + { + "targetNodes": "1:", + "container": { + "image": "image_arn", + }, + }, + ], + }, + } + + # No query is required for single node + aws_describe_jobs_mock.return_value = None + assert get_job_log_stream(single_node_status, aws_region) == "log_stream_arn" + + # Multi node will query again and fetch from there. + aws_describe_jobs_mock.return_value = iter([single_node_status]) + assert get_job_log_stream(multi_node_status, aws_region) == "log_stream_arn" + assert aws_describe_jobs_mock.call_args == unittest.mock.call( + ["multi_node_id#0"], aws_region=aws_region + )