diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py index eef8fa969..b92b6efb5 100644 --- a/src/levanter/infra/cli_helpers.py +++ b/src/levanter/infra/cli_helpers.py @@ -1,6 +1,7 @@ import argparse import base64 import os +import shlex import subprocess from typing import Optional @@ -64,7 +65,7 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante "docker", "run", "-t" if foreground else "-d", - f"--name={name}", + f"--name={shlex.quote(name)}", "--privileged", "--shm-size=32gb", "--net=host", @@ -76,7 +77,9 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante ] for k, v in env.items(): - docker_command.extend(["-e", k + f"={str(v)}"]) + v = shlex.quote(str(v)) + k = shlex.quote(str(k)) + docker_command.extend(["-e", f"{k}={v}"]) docker_command.extend([image_id, *command]) return docker_command diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 3ae5d0105..2dc554808 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -76,6 +76,7 @@ def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.Objec @ray.remote(resources={f"TPU-{tpu_type}-head": 1}) def do_run(remote_fn) -> _TpuRunResult: + logging.basicConfig(level=logging.INFO) num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4 remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts) @@ -92,6 +93,13 @@ def do_run(remote_fn) -> _TpuRunResult: except Exception: logger.exception("Failed to kill job after primary failure") return _handle_ray_error(info, e) + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return TpuFailed(info, e) return do_run.remote(remote_fn) @@ -144,12 +152,12 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re out = ray.get(run_on_pod(remote_fn, tpu_type)) except ray.exceptions.RayTaskError as e: problem = e - if "preempted" in str(e): + if "preempted" in str(e).lower(): num_preemptions += 1 logger.warning(f"Preempted {num_preemptions} times, {e}") else: num_failures += 1 - logger.warning(f"Failed {num_failures} times") + logger.warning(f"Failed {num_failures} times", exc_info=e) continue except Exception as e: problem = e