Skip to content

Commit

Permalink
add multislice support in ray
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed Oct 19, 2024
1 parent d4df28b commit b383698
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 24 deletions.
3 changes: 2 additions & 1 deletion infra/launch_on_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"])
cli.add_arg(parser, config, ["--tpu_type"], required=True)
# TODO: bring node_count to Ray
# cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true")
cli.add_arg(parser, config, ["--retries"], default=10, type=int)
cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str)
Expand Down Expand Up @@ -122,6 +122,7 @@ def main():
env=env,
name="levanter",
retries=retries,
node_count=args.node_count,
)

address = args.address or os.getenv("RAY_ADDRESS")
Expand Down
31 changes: 31 additions & 0 deletions src/levanter/infra/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ def get_git_commit():
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()


class DockerRunCommand:
def __init__(self, image_id, command, *, foreground, env, name="levanter"):
self.base_part = [
"docker",
"run",
"-t" if foreground else "-d",
f"--name={name}",
"--privileged",
"--shm-size=32gb",
"--net=host",
"--init",
"--mount",
"type=volume,source=levanter,target=/home/levanter",
"-v",
"/tmp:/tmp",
]

self.env_part = []
self.add_env(env)

self.cmd_part = [image_id, *command]

def add_env(self, env):
for k, v in env.items():
self.env_part.extend(["-e", k + f"={str(v)}"])

@property
def full_cmd(self):
return self.base_part + self.env_part + self.cmd_part


def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"):
docker_command = [
"docker",
Expand Down
234 changes: 211 additions & 23 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import multiprocessing
import os
import socket
import subprocess
import tempfile
import time
Expand All @@ -16,7 +17,7 @@
from ray.exceptions import NodeDiedError, RayError, RaySystemError, RayTaskError, WorkerCrashedError
from ray.remote_function import RemoteFunction

from levanter.infra.cli_helpers import make_docker_run_command
from levanter.infra.cli_helpers import DockerRunCommand
from levanter.utils.ray_utils import ser_exc_info


Expand Down Expand Up @@ -62,21 +63,28 @@ class TpuRunError(_TpuRunResult):
error: Exception


def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.ObjectRef:
def run_on_pod(docker_cmd: DockerRunCommand, name: str, tpu_type: str) -> ray.ObjectRef:
"""
Run a remote function on a TPU pod.
Args:
remote_fn: A remote function that takes no arguments
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
tpu_type: The type of TPU to run on, e.g. "v4-32"
Returns:
A Ray ObjectRef that represents the result of the function
"""

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
def do_run(remote_fn) -> _TpuRunResult:
def do_run(docker_cmd: DockerRunCommand, name: str) -> _TpuRunResult:
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4

def _run_docker():
run_docker(docker_cmd=docker_cmd.full_cmd, name=name)

remote_fn = ray.remote(_run_docker)

remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts)

info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
Expand All @@ -93,10 +101,87 @@ def do_run(remote_fn) -> _TpuRunResult:
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)

return do_run.remote(remote_fn)
return do_run.remote(docker_cmd, name)


def run_on_pod_multislice(docker_cmd: DockerRunCommand, name: str, tpu_type: str, num_slices: int) -> ray.ObjectRef:
"""
Run a remote function on multiple TPU slices.
Args:
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run
def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):
Returns:
A Ray ObjectRef that represents the result of the function
"""

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
class MultisliceActor:
def __init__(self):
self.pod_name = ray.util.accelerators.tpu.get_current_pod_name()
self.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
self.ip = socket.gethostbyname(socket.gethostname())

def get_slice_info(self):
return self.pod_name, self.num_hosts, self.ip

def do_run(self, docker_cmd, name, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:
port = 8081
mxla_env = {
"MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}",
"MEGASCALE_NUM_SLICES": str(num_slices),
"MEGASCALE_PORT": f"{port}",
"MEGASCALE_SLICE_ID": str(slice_id),
}

docker_cmd.add_env(mxla_env)

def _run_docker():
run_docker(docker_cmd=docker_cmd.full_cmd, name=name)

remote_fn = ray.remote(_run_docker)

remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env)

info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
futures = [remote_fn.remote() for _ in range(self.num_hosts)]
try:
out = ray.get(futures)
logger.info("TPU job finished")
return TpuSuccess(info, out)
except RayError as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)

actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore
info = _TpuInfo("get_slice_info", "ACTIVE", "TPU")
futures = [actor.get_slice_info.remote() for actor in actors]
try:
logger.info("Getting slice infos...")
# also act as a sync step
slice_infos = ray.get(futures)
logger.info(f"TPU slice infos {slice_infos}")
except RayError as e:
for actor in actors:
try:
ray.cancel(actor)
except Exception:
logger.exception("Failed to kill actor after primary failure")
return [_handle_ray_error(info, e)]

coordinator_ip = slice_infos[0][2]

return [actor.do_run.remote(docker_cmd, name, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)]


def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):
"""
Redecorate a remote function to run on a TPU pod.
Expand All @@ -112,17 +197,21 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):

tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu
num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host})
remote_fn = remote_fn.options(
runtime_env=runtime_env,
resources={tpu_name: 1, "TPU": num_tpus_per_host},
)
logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host")
return remote_fn, tpu_name


def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_retries_failure=10):
def run_on_pod_resumable(docker_cmd, name, tpu_type, max_retries_preemption=1e6, max_retries_failure=10):
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.
Args:
remote_fn: A remote function that takes no arguments
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
tpu_type: The type of TPU to run on, e.g. "v4-32"
max_retries_preemption: The maximum number of times to retry if the job is preempted
max_retries_failure: The maximum number of times to retry if the job fails
Expand All @@ -141,7 +230,7 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
attempt += 1
problem = None
try:
out = ray.get(run_on_pod(remote_fn, tpu_type))
out = ray.get(run_on_pod(docker_cmd, name, tpu_type))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e):
Expand Down Expand Up @@ -185,26 +274,123 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
raise RuntimeError("Failed too many times") from problem


def run_on_pod_multislice_resumable(
docker_cmd, name, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10
):
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.
Args:
docker_cmd: A DockerRunCommand object that holds a docker command to run
name: docker image name
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run
max_retries_preemption: The maximum number of times to retry if the job is preempted
max_retries_failure: The maximum number of times to retry if the job fails
Returns:
The result of the function (not an ObjectRef)
"""
num_failures = 0
num_preemptions = 0
attempt = 0
problem: Exception | None = None

while num_failures < max_retries_failure and num_preemptions < max_retries_preemption:
logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}")
attempt += 1
problem = None
try:
outs = ray.get(run_on_pod_multislice(docker_cmd, name, tpu_type, num_slices))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e):
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times, {e}")
else:
num_failures += 1
logger.warning(f"Failed {num_failures} times")
continue
except Exception as e:
problem = e
num_failures += 1
if num_failures >= max_retries_failure:
logger.exception("Failed too many times", exc_info=e)
raise e
else:
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue

if all(isinstance(out, TpuSuccess) for out in outs):
results = [out.result for out in outs]
logger.info("Success")
return results
elif any(isinstance(out, TpuPreempted) for out in outs):
out = None
for o in outs:
if isinstance(o, TpuPreempted):
out = o
assert out is not None
problem = out.error
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem)
elif any(isinstance(out, TpuFailed) for out in outs):
num_preemptions += 1
logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times")
elif any(isinstance(out, TpuRunError) for out in outs):
out = None
for o in outs:
if isinstance(o, TpuRunError):
out = o
assert out is not None
problem = out.error
num_preemptions += 1
problem = out.error
num_failures += 1
logger.warning(f"Failed {num_failures} times", exc_info=problem)
else:
raise RuntimeError(f"Unexpected result: {out}")

if num_preemptions >= max_retries_preemption:
raise RuntimeError("Preempted too many times") from problem
elif num_failures >= max_retries_failure:
raise RuntimeError("Failed too many times") from problem


def _run_command(*args, **kwargs):
return subprocess.check_call(args, **kwargs)


def run_docker_on_pod(image_id: str, command: Sequence[str], *, tpu_type: str, env: dict, name="levanter", retries=10):
env = _massage_env(env)
def run_docker(docker_cmd, name="levanter"):
_kill_old_container(name)
try:
return _run_command(*docker_cmd)
except subprocess.CalledProcessError as e:
logger.exception("Failed to run docker command")
raise e

docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name)

def run_docker():
_kill_old_container(name)
try:
return _run_command(*docker_cmd)
except subprocess.CalledProcessError as e:
logger.exception("Failed to run docker command")
raise e
def run_docker_on_pod(
image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10
):
env = _massage_env(env)

run_on_pod_resumable(
ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
docker_cmd = DockerRunCommand(image_id, command, env=env, foreground=True, name=name)

if num_slices == 1:
run_on_pod_resumable(
docker_cmd, name=name, tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
else:
run_on_pod_multislice_resumable(
docker_cmd,
name=name,
tpu_type=tpu_type,
num_slices=num_slices,
max_retries_failure=retries,
max_retries_preemption=10000,
)


def _kill_old_container(name):
Expand Down Expand Up @@ -343,6 +529,7 @@ class RunDockerOnPodConfig:
env: dict = dataclasses.field(default_factory=dict)
name: str = "levanter"
retries: int = 10
node_count: int = 1


def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None):
Expand Down Expand Up @@ -411,6 +598,7 @@ def main(args: RunDockerOnPodConfig):
tpu_type=args.tpu_type,
env=args.env,
name=args.name,
num_slices=args.node_count,
)


Expand Down

0 comments on commit b383698

Please sign in to comment.