Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into sft
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 25, 2024
2 parents 18a5352 + 0f2f326 commit f1ef2c7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
7 changes: 5 additions & 2 deletions src/levanter/infra/cli_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import base64
import os
import shlex
import subprocess
from typing import Optional

Expand Down Expand Up @@ -64,7 +65,7 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante
"docker",
"run",
"-t" if foreground else "-d",
f"--name={name}",
f"--name={shlex.quote(name)}",
"--privileged",
"--shm-size=32gb",
"--net=host",
Expand All @@ -76,7 +77,9 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante
]

for k, v in env.items():
docker_command.extend(["-e", k + f"={str(v)}"])
v = shlex.quote(str(v))
k = shlex.quote(str(k))
docker_command.extend(["-e", f"{k}={v}"])

docker_command.extend([image_id, *command])
return docker_command
Expand Down
12 changes: 10 additions & 2 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.Objec

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
def do_run(remote_fn) -> _TpuRunResult:
logging.basicConfig(level=logging.INFO)
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4
remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts)

Expand All @@ -92,6 +93,13 @@ def do_run(remote_fn) -> _TpuRunResult:
except Exception:
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)
except Exception as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
return TpuFailed(info, e)

return do_run.remote(remote_fn)

Expand Down Expand Up @@ -144,12 +152,12 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
out = ray.get(run_on_pod(remote_fn, tpu_type))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e):
if "preempted" in str(e).lower():
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times, {e}")
else:
num_failures += 1
logger.warning(f"Failed {num_failures} times")
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue
except Exception as e:
problem = e
Expand Down

0 comments on commit f1ef2c7

Please sign in to comment.