Skip to content

Commit

Permalink
add multislice support in ray (#771)
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ authored Oct 30, 2024
1 parent 331c0aa commit 5ebf8ce
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 10 deletions.
45 changes: 42 additions & 3 deletions infra/cluster/job-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ cluster_name: levanter-cluster
# Configure GCP
provider:
type: gcp
region: us-central2
availability_zone: us-central2-b
region: us-west4
availability_zone: us-west4-a
project_id: hai-gcp-models

# Maximum Workers (excluding Head Node)
Expand Down Expand Up @@ -126,6 +126,45 @@ available_node_types:
schedulingConfig:
preemptible: true

tpu_slice_v5e_16:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }

node_config:
acceleratorType: v5litepod-16
runtimeVersion: tpu-ubuntu2204-base

# [IMPORTANT] Configure all TPU Workers to be Preemptible!
schedulingConfig:
preemptible: true

tpu_slice_v5e_64:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }

node_config:
acceleratorType: v5litepod-64
runtimeVersion: tpu-ubuntu2204-base

# [IMPORTANT] Configure all TPU Workers to be Preemptible!
schedulingConfig:
preemptible: true

tpu_slice_v5e_256:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }

node_config:
acceleratorType: v5litepod-256
runtimeVersion: tpu-ubuntu2204-base

# [IMPORTANT] Configure all TPU Workers to be Preemptible!
schedulingConfig:
preemptible: true

docker:
image: "ghcr.io/stanford-crfm/levanter-cluster:latest"
container_name: "ray_docker"
Expand All @@ -140,7 +179,7 @@ docker:
- -v "/var/run/docker.sock:/var/run/docker.sock"

initialization_commands:
- yes | gcloud auth configure-docker us-central2-docker.pkg.dev
- yes | gcloud auth configure-docker us-west4-docker.pkg.dev
- "export TPU_WORKER_ID=$(curl -H 'Metadata-Flavor: Google' http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number) || true"
- which docker || (curl -fsSL https://get.docker.com -o get-docker.sh; sudo sh get-docker.sh; sudo usermod -aG docker $USER; sudo systemctl restart docker -f)
# always run this because ray doesn't run with sudo
Expand Down
3 changes: 2 additions & 1 deletion infra/launch_on_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"])
cli.add_arg(parser, config, ["--tpu_type"], required=True)
# TODO: bring node_count to Ray
# cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true")
cli.add_arg(parser, config, ["--retries"], default=10, type=int)
cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str)
Expand Down Expand Up @@ -122,6 +122,7 @@ def main():
env=env,
name="levanter",
retries=retries,
node_count=args.node_count,
)

address = args.address or os.getenv("RAY_ADDRESS")
Expand Down
5 changes: 5 additions & 0 deletions src/levanter/infra/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante
"/tmp:/tmp",
]

# optionally add multislice env vars (if set by ray runtime env vars)
for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]:
v = shlex.quote(str(v))
docker_command.extend(["-e", v])

for k, v in env.items():
v = shlex.quote(str(v))
k = shlex.quote(str(k))
Expand Down
201 changes: 195 additions & 6 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import multiprocessing
import os
import socket
import subprocess
import tempfile
import time
Expand Down Expand Up @@ -104,7 +105,83 @@ def do_run(remote_fn) -> _TpuRunResult:
return do_run.remote(remote_fn)


def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):
def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef:
"""
Run a remote function on multiple TPU slices.
Args:
remote_fn: A remote function that takes no arguments
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run
Returns:
A Ray ObjectRef that represents the result of the function
"""

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
class MultisliceActor:
def __init__(self):
self.pod_name = ray.util.accelerators.tpu.get_current_pod_name()
self.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
self.ip = socket.gethostbyname(socket.gethostname())

def get_slice_info(self):
return self.pod_name, self.num_hosts, self.ip

def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:
port = 8081
mxla_env = {
"MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}",
"MEGASCALE_NUM_SLICES": str(num_slices),
"MEGASCALE_PORT": f"{port}",
"MEGASCALE_SLICE_ID": str(slice_id),
}

remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env)

info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
futures = [remote_fn.remote() for _ in range(self.num_hosts)]
try:
out = ray.get(futures)
logger.info("TPU job finished")
return TpuSuccess(info, out)
except RayError as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)
except Exception as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
return TpuFailed(info, e)

actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore
futures = [actor.get_slice_info.remote() for actor in actors]
try:
logger.info("Getting slice infos...")
# also act as a sync step
slice_infos = ray.get(futures)
logger.info(f"TPU slice infos {slice_infos}")
except RayError as e:
logger.exception(e)
for actor in actors:
try:
ray.cancel(actor)
except Exception:
logger.exception("Failed to kill actor after primary failure")
return futures

coordinator_ip = slice_infos[0][2]

return [actor.do_run.remote(remote_fn, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)]


def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):
"""
Redecorate a remote function to run on a TPU pod.
Expand All @@ -120,7 +197,11 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):

tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu
num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host})

remote_fn = remote_fn.options(
runtime_env=runtime_env,
resources={tpu_name: 1, "TPU": num_tpus_per_host},
)
logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host")
return remote_fn, tpu_name

Expand Down Expand Up @@ -193,11 +274,107 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
raise RuntimeError("Failed too many times") from problem


def run_on_pod_multislice_resumable(
remote_fn, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10
):
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.
Args:
remote_fn: A remote function that takes no arguments
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run
max_retries_preemption: The maximum number of times to retry if the job is preempted
max_retries_failure: The maximum number of times to retry if the job fails
Returns:
The result of the function (not an ObjectRef)
"""
num_failures = 0
num_preemptions = 0
attempt = 0
problem: Exception | None = None

while num_failures < max_retries_failure and num_preemptions < max_retries_preemption:
logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}")
attempt += 1
problem = None
futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices)
try:
outs = ray.get(futures)
except ray.exceptions.RayTaskError as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
problem = e
if "preempted" in str(e).lower():
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times, {e}")
else:
num_failures += 1
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue
except Exception as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
problem = e
num_failures += 1
if num_failures >= max_retries_failure:
logger.exception("Failed too many times", exc_info=e)
raise e
else:
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue

if all(isinstance(out, TpuSuccess) for out in outs):
results = [out.result for out in outs]
logger.info("Success")
return results
elif any(isinstance(out, TpuPreempted) for out in outs):
out = None
for o in outs:
if isinstance(o, TpuPreempted):
out = o
assert out is not None
problem = out.error
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem)
elif any(isinstance(out, TpuFailed) for out in outs):
num_preemptions += 1
logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times")
elif any(isinstance(out, TpuRunError) for out in outs):
out = None
for o in outs:
if isinstance(o, TpuRunError):
out = o
assert out is not None
problem = out.error
num_preemptions += 1
problem = out.error
num_failures += 1
logger.warning(f"Failed {num_failures} times", exc_info=problem)
else:
raise RuntimeError(f"Unexpected result: {out}")

if num_preemptions >= max_retries_preemption:
raise RuntimeError("Preempted too many times") from problem
elif num_failures >= max_retries_failure:
raise RuntimeError("Failed too many times") from problem


def _run_command(*args, **kwargs):
return subprocess.check_call(args, **kwargs)


def run_docker_on_pod(image_id: str, command: Sequence[str], *, tpu_type: str, env: dict, name="levanter", retries=10):
def run_docker_on_pod(
image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10
):
env = _massage_env(env)

docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name)
Expand All @@ -210,9 +387,18 @@ def run_docker():
logger.exception("Failed to run docker command")
raise e

run_on_pod_resumable(
ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
if num_slices == 1:
run_on_pod_resumable(
ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
else:
run_on_pod_multislice_resumable(
ray.remote(run_docker),
tpu_type=tpu_type,
num_slices=num_slices,
max_retries_failure=retries,
max_retries_preemption=10000,
)


def _kill_old_container(name):
Expand Down Expand Up @@ -351,6 +537,7 @@ class RunDockerOnPodConfig:
env: dict = dataclasses.field(default_factory=dict)
name: str = "levanter"
retries: int = 10
node_count: int = 1


def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None):
Expand Down Expand Up @@ -419,6 +606,8 @@ def main(args: RunDockerOnPodConfig):
tpu_type=args.tpu_type,
env=args.env,
name=args.name,
retries=args.retries,
num_slices=args.node_count,
)


Expand Down

0 comments on commit 5ebf8ce

Please sign in to comment.