Skip to content

Commit

Permalink
[Jobs] --sync-down execution log support for managed jobs (#4527)
Browse files Browse the repository at this point in the history
* add support for managed jobs

* support --name to download logs

* format

* fix job_id

* add support that separate controller that not controller

* remove the interface for getting complete job_ids, parse table instead

* code formatting

* documentation for sync down logs

* revive job id to timestamp via new interface

* make the log works

* revert managed job yaml

* code formatting and cleanup

* Fix formatting in managed_job.yaml by removing unnecessary whitespace

* Remove unnecessary whitespace in managed_job.yaml for improved formatting

* add sync down smoke tests

* Enhance smoke tests for managed jobs by adding command to wait for job status. This update improves the reliability of the test by ensuring the job is in a RUNNING state before proceeding with log synchronization.

* refactor sync down logic remove the dependency of getting job status

* formatting

* fix rsync bug due to refactor

* clean up printing

* Update sky/jobs/core.py

Co-authored-by: Christopher Cooper <[email protected]>

* Refactor job ID retrieval and update CLI option flag

- Changed the method name from `get_job_ids_by_name` to `get_all_job_ids_by_name` for clarity in `ManagedJobCodeGen`.
- Updated the CLI option flag for `sync-down` from `-d` to `-s` for better alignment with common conventions.

* format & linting

---------

Co-authored-by: Christopher Cooper <[email protected]>
  • Loading branch information
KeplerC and cg505 authored Jan 9, 2025
1 parent bc777e2 commit 1578108
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 7 deletions.
146 changes: 146 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3891,6 +3891,152 @@ def tail_managed_job_logs(self,
stdin=subprocess.DEVNULL,
)

def sync_down_managed_job_logs(
self,
handle: CloudVmRayResourceHandle,
job_id: Optional[int] = None,
job_name: Optional[str] = None,
controller: bool = False,
local_dir: str = constants.SKY_LOGS_DIRECTORY) -> Dict[str, str]:
"""Sync down logs for a managed job.
Args:
handle: The handle to the cluster.
job_id: The job ID to sync down logs for.
job_name: The job name to sync down logs for.
controller: Whether to sync down logs for the controller.
local_dir: The local directory to sync down logs to.
Returns:
A dictionary mapping job_id to log path.
"""
# if job_name is not None, job_id should be None
assert job_name is None or job_id is None, (job_name, job_id)
if job_id is None and job_name is not None:
# generate code to get the job_id
code = managed_jobs.ManagedJobCodeGen.get_all_job_ids_by_name(
job_name=job_name)
returncode, run_timestamps, stderr = self.run_on_head(
handle,
code,
stream_logs=False,
require_outputs=True,
separate_stderr=True)
subprocess_utils.handle_returncode(returncode, code,
'Failed to sync down logs.',
stderr)
job_ids = common_utils.decode_payload(run_timestamps)
if not job_ids:
logger.info(f'{colorama.Fore.YELLOW}'
'No matching job found'
f'{colorama.Style.RESET_ALL}')
return {}
elif len(job_ids) > 1:
logger.info(
f'{colorama.Fore.YELLOW}'
f'Multiple jobs IDs found under the name {job_name}. '
'Downloading the latest job logs.'
f'{colorama.Style.RESET_ALL}')
job_ids = [job_ids[0]] # descending order
else:
job_ids = [job_id]

# get the run_timestamp
# the function takes in [job_id]
code = job_lib.JobLibCodeGen.get_run_timestamp_with_globbing(job_ids)
returncode, run_timestamps, stderr = self.run_on_head(
handle,
code,
stream_logs=False,
require_outputs=True,
separate_stderr=True)
subprocess_utils.handle_returncode(returncode, code,
'Failed to sync logs.', stderr)
# returns with a dict of {job_id: run_timestamp}
run_timestamps = common_utils.decode_payload(run_timestamps)
if not run_timestamps:
logger.info(f'{colorama.Fore.YELLOW}'
'No matching log directories found'
f'{colorama.Style.RESET_ALL}')
return {}

run_timestamp = list(run_timestamps.values())[0]
job_id = list(run_timestamps.keys())[0]
local_log_dir = ''
if controller: # download controller logs
remote_log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY,
run_timestamp)
local_log_dir = os.path.expanduser(
os.path.join(local_dir, run_timestamp))

logger.info(f'{colorama.Fore.CYAN}'
f'Job {job_ids} local logs: {local_log_dir}'
f'{colorama.Style.RESET_ALL}')

runners = handle.get_command_runners()

def _rsync_down(args) -> None:
"""Rsync down logs from remote nodes.
Args:
args: A tuple of (runner, local_log_dir, remote_log_dir)
"""
(runner, local_log_dir, remote_log_dir) = args
try:
os.makedirs(local_log_dir, exist_ok=True)
runner.rsync(
source=f'{remote_log_dir}/',
target=local_log_dir,
up=False,
stream_logs=False,
)
except exceptions.CommandError as e:
if e.returncode == exceptions.RSYNC_FILE_NOT_FOUND_CODE:
# Raised by rsync_down. Remote log dir may not exist
# since the job can be run on some part of the nodes.
logger.debug(
f'{runner.node_id} does not have the tasks/*.')
else:
raise

parallel_args = [[runner, *item]
for item in zip([local_log_dir], [remote_log_dir])
for runner in runners]
subprocess_utils.run_in_parallel(_rsync_down, parallel_args)
else: # download job logs
local_log_dir = os.path.expanduser(
os.path.join(local_dir, 'managed_jobs', run_timestamp))
os.makedirs(os.path.dirname(local_log_dir), exist_ok=True)
log_file = os.path.join(local_log_dir, 'run.log')

code = managed_jobs.ManagedJobCodeGen.stream_logs(job_name=None,
job_id=job_id,
follow=False,
controller=False)

# With the stdin=subprocess.DEVNULL, the ctrl-c will not
# kill the process, so we need to handle it manually here.
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, backend_utils.interrupt_handler)
signal.signal(signal.SIGTSTP, backend_utils.stop_handler)

# We redirect the output to the log file
# and disable the STDOUT and STDERR
self.run_on_head(
handle,
code,
log_path=log_file,
stream_logs=False,
process_stream=False,
ssh_mode=command_runner.SshMode.INTERACTIVE,
stdin=subprocess.DEVNULL,
)

logger.info(f'{colorama.Fore.CYAN}'
f'Job {job_id} logs: {local_log_dir}'
f'{colorama.Style.RESET_ALL}')
return {str(job_id): local_log_dir}

def tail_serve_logs(self, handle: CloudVmRayResourceHandle,
service_name: str, target: serve_lib.ServiceComponent,
replica_id: Optional[int], follow: bool) -> None:
Expand Down
26 changes: 19 additions & 7 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3933,17 +3933,29 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
required=False,
help='Query the latest job logs, restarting the jobs controller if stopped.'
)
@click.option('--sync-down',
'-s',
default=False,
is_flag=True,
required=False,
help='Download logs for all jobs shown in the queue.')
@click.argument('job_id', required=False, type=int)
@usage_lib.entrypoint
def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller: bool, refresh: bool):
"""Tail the log of a managed job."""
controller: bool, refresh: bool, sync_down: bool):
"""Tail or sync down the log of a managed job."""
try:
managed_jobs.tail_logs(name=name,
job_id=job_id,
follow=follow,
controller=controller,
refresh=refresh)
if sync_down:
managed_jobs.sync_down_logs(name=name,
job_id=job_id,
controller=controller,
refresh=refresh)
else:
managed_jobs.tail_logs(name=name,
job_id=job_id,
follow=follow,
controller=controller,
refresh=refresh)
except exceptions.ClusterNotUpError:
with ux_utils.print_exception_no_traceback():
raise
Expand Down
2 changes: 2 additions & 0 deletions sky/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sky.jobs.core import launch
from sky.jobs.core import queue
from sky.jobs.core import queue_from_kubernetes_pod
from sky.jobs.core import sync_down_logs
from sky.jobs.core import tail_logs
from sky.jobs.recovery_strategy import DEFAULT_RECOVERY_STRATEGY
from sky.jobs.recovery_strategy import RECOVERY_STRATEGIES
Expand Down Expand Up @@ -37,6 +38,7 @@
'queue',
'queue_from_kubernetes_pod',
'tail_logs',
'sync_down_logs',
# utils
'ManagedJobCodeGen',
'format_job_table',
Expand Down
46 changes: 46 additions & 0 deletions sky/jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,52 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller=controller)


@usage_lib.entrypoint
def sync_down_logs(
name: Optional[str],
job_id: Optional[int],
refresh: bool,
controller: bool,
local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) -> None:
"""Sync down logs of managed jobs.
Please refer to sky.cli.job_logs for documentation.
Raises:
ValueError: invalid arguments.
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
"""
# TODO(zhwu): Automatically restart the jobs controller
if name is not None and job_id is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Cannot specify both name and job_id.')

jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
job_name_or_id_str = ''
if job_id is not None:
job_name_or_id_str = str(job_id)
elif name is not None:
job_name_or_id_str = f'-n {name}'
else:
job_name_or_id_str = ''
handle = _maybe_restart_controller(
refresh,
stopped_message=(
f'{jobs_controller_type.value.name.capitalize()} is stopped. To '
f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs '
f'-r --sync-down {job_name_or_id_str}{colorama.Style.RESET_ALL}'),
spinner_message='Retrieving job logs')

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend), backend

backend.sync_down_managed_job_logs(handle,
job_id=job_id,
job_name=name,
controller=controller,
local_dir=local_dir)


spot_launch = common_utils.deprecated_function(
launch,
name='sky.jobs.launch',
Expand Down
27 changes: 27 additions & 0 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,33 @@ def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
return job_ids


def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
"""Get all job ids by name."""
name_filter = ''
field_values = []
if name is not None:
# We match the job name from `job_info` for the jobs submitted after
# #1982, and from `spot` for the jobs submitted before #1982, whose
# job_info is not available.
name_filter = ('WHERE (job_info.name=(?) OR '
'(job_info.name IS NULL AND spot.task_name=(?)))')
field_values = [name, name]

# Left outer join is used here instead of join, because the job_info does
# not contain the managed jobs submitted before #1982.
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
f"""\
SELECT DISTINCT spot.spot_job_id
FROM spot
LEFT OUTER JOIN job_info
ON spot.spot_job_id=job_info.spot_job_id
{name_filter}
ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
job_ids = [row[0] for row in rows if row[0] is not None]
return job_ids


def _get_all_task_ids_statuses(
job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
with db_utils.safe_cursor(_DB_PATH) as cursor:
Expand Down
9 changes: 9 additions & 0 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,15 @@ def cancel_job_by_name(cls, job_name: str) -> str:
""")
return cls._build(code)

@classmethod
def get_all_job_ids_by_name(cls, job_name: str) -> str:
code = textwrap.dedent(f"""\
from sky.utils import common_utils
job_id = managed_job_state.get_all_job_ids_by_name({job_name!r})
print(common_utils.encode_payload(job_id), end="", flush=True)
""")
return cls._build(code)

@classmethod
def stream_logs(cls,
job_name: Optional[str],
Expand Down
23 changes: 23 additions & 0 deletions tests/smoke_tests/test_managed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,3 +871,26 @@ def test_managed_jobs_inline_env(generic_cloud: str):
timeout=20 * 60,
)
smoke_tests_utils.run_one_test(test)


@pytest.mark.managed_jobs
def test_managed_jobs_logs_sync_down():
name = smoke_tests_utils.get_cluster_name()
test = smoke_tests_utils.Test(
'test-managed-jobs-logs-sync-down',
[
f'sky jobs launch -n {name} -y examples/managed_job.yaml -d',
smoke_tests_utils.
get_cmd_wait_until_managed_job_status_contains_matching_job_name(
job_name=f'{name}',
job_status=[sky.ManagedJobStatus.RUNNING],
timeout=300 + smoke_tests_utils.BUMP_UP_SECONDS),
f'sky jobs logs --controller 1 --sync-down',
f'sky jobs logs 1 --sync-down',
f'sky jobs logs --controller --name minimal --sync-down',
f'sky jobs logs --name minimal --sync-down',
],
f'sky jobs cancel -y -n {name}',
timeout=20 * 60,
)
smoke_tests_utils.run_one_test(test)

0 comments on commit 1578108

Please sign in to comment.