Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ray fault tolerance #1032

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def create_orbax_checkpoint_manager(
enable_checkpointing: bool,
use_async: bool,
save_interval_steps: int,
max_to_keep: Optional[int] = None,
dataset_type: Optional[str] = "tfds",
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
use_ocdbt: bool = True,
Expand Down Expand Up @@ -77,6 +78,7 @@ def create_orbax_checkpoint_manager(
create=True,
save_interval_steps=save_interval_steps,
enable_async_checkpointing=use_async,
max_to_keep=max_to_keep
),
logger=orbax_logger,
)
Expand Down
10 changes: 10 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ load_full_state_path: ""
enable_checkpointing: True
async_checkpointing: True
checkpoint_period: 10_000
num_checkpoints_to_keep: 5
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down Expand Up @@ -464,3 +465,12 @@ ragged_block_size: 256
sa_block_q: 512
sa_block_q_dkv: 512
sa_block_q_dq: 512

# Ray
use_ray: False
failure_sim_time: 300
crash_prob: 0.5
hang_prob: 0.5

# Logging
log_hps: False
37 changes: 37 additions & 0 deletions MaxText/launch_ray_maxtext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import time
import os
from absl import app
from typing import Sequence
from ray.job_submission import JobSubmissionClient, JobStatus

def main(argv: Sequence[str]) -> None:
client = JobSubmissionClient("http://127.0.0.1:8265")
print("Connected to head!", flush=True)

maxtext_cmd_args = " ".join(argv[1:])
job_id = client.submit_job(
entrypoint=f"RAY_DEDUP_LOGS=0 python3 MaxText/resilient_train.py {maxtext_cmd_args}",
runtime_env={"working_dir" : "./",
"excludes" : ["MaxText/test_assets", ".git"]}
)

print(f"Launched job: {job_id}", flush=True)
prev_logs = ''
while True:
status = client.get_job_status(job_id)
if status in {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}:
if status in {JobStatus.STOPPED, JobStatus.FAILED}:
logs = client.get_job_logs(job_id)
print(logs, flush=True)
break
time.sleep(5)
if status == JobStatus.RUNNING:
logs = client.get_job_logs(job_id)
print(logs[len(prev_logs):], flush=True)
prev_logs = logs




if __name__ == "__main__":
app.run(main)
12 changes: 7 additions & 5 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def __init__(self, argv: list[str], **kwargs):
validate_no_keys_overwritten_twice(keys_from_env_and_command_line, keys_from_model)

# We initialize the jax distributed system here because it must be done before device backend is initialized.
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
if not raw_keys["use_ray"]:
max_utils.maybe_initialize_jax_distributed_system(raw_keys)

if raw_keys["jax_cache_dir"]:
compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"]))
Expand All @@ -332,10 +333,11 @@ def __init__(self, argv: list[str], **kwargs):
raw_keys["tokenizer_path"] = tokenizer_path

self.keys = raw_keys
keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension
keys.sort()
for k in keys:
max_logging.log(f"Config param {k}: {raw_keys[k]}")
if raw_keys["log_hps"]:
keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension
keys.sort()
for k in keys:
max_logging.log(f"Config param {k}: {raw_keys[k]}")

@staticmethod
def user_init(raw_keys):
Expand Down
189 changes: 189 additions & 0 deletions MaxText/ray_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import asyncio.selector_events
import ray
import traceback
import os
import jax
import random
import redis
import datetime
import asyncio
from contextlib import contextmanager
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy as NASS

import max_logging

class RayClusterCoordinator:
def __init__(self, worker_cls, hang_time_threshold) -> None:
self.worker_cls = worker_cls
self.num_workers = int(os.environ.get('NGPUS'))
self.num_workers_per_node = int(os.environ.get('GPUS_PER_NODE'))
self.workers_initialized = False
self.log = lambda user_str: max_logging.log(f"[RayClusterCoordinator] {user_str}")
self.hang_time_threshold = hang_time_threshold if hang_time_threshold is not None else 300

self.redis_addr = os.environ.get('REDIS_ADDR').split(':')

worker_node_info, self.num_physical_nodes = self._get_schedulable_worker_info()
self.workers = [worker_cls.options(num_gpus=1,
num_cpus=16,
resources={"worker_units": 1},
scheduling_strategy=NASS(node_id=worker_node_info[i][0], soft=False)).remote(i,
worker_node_info[i][1],
worker_node_info[i][2])
for i in range(self.num_workers)]

self.jax_coordinator_ip = worker_node_info[0][2]
self.redis = redis.Redis(host=self.redis_addr[0], port=int(self.redis_addr[1]), decode_responses=True, password=None)
self._init_sync_dict()

def _get_schedulable_worker_info(self):
worker_node_info = []
worker_nodes = sorted([node for node in ray.nodes() if (node['Alive'] and 'worker_units' in node['Resources'])],
key=lambda x: x['NodeID'])

num_nodes_required = self.num_workers // self.num_workers_per_node
num_nodes_available = len(worker_nodes)
assert num_nodes_required <= num_nodes_available

worker_nodes = worker_nodes[:num_nodes_required]
for worker_node_id, worker_node in enumerate(worker_nodes):
for _ in range(self.num_workers_per_node):
worker_node_info.append((worker_node['NodeID'], worker_node_id, worker_node['NodeName']))

return worker_node_info, num_nodes_required

def _init_sync_dict(self):
self.redis.flushdb()
init_time = datetime.datetime.now().isoformat()
for pid in range(self.num_workers):
self.redis.set(pid, init_time)

def initialize_workers(self, **kwargs):
self.worker_init_kwargs = kwargs
coordinator_port = random.randint(1, 100000) % 2**12 + (65535 - 2**12 + 1)
self.jax_coordinator_addr = f"{self.jax_coordinator_ip}:{coordinator_port}"

ray.get([w.initialize.remote(self.jax_coordinator_addr, self.num_workers, **kwargs) for i, w in enumerate(self.workers)])
self.workers_initialized = True

async def _run_workers_async(self, *args, **kwargs):
worker_run_futures = [w.run.remote(*args, **kwargs) for w in self.workers]
while True:
completed_worker_results = []
for _, wf in enumerate(worker_run_futures):
try:
worker_result = ray.get(wf, timeout=0)
completed_worker_results.append(worker_result)
except ray.exceptions.GetTimeoutError:
continue

if len(completed_worker_results) < len(self.workers):
self.log(f"All workers seem to be alive, but only {len(completed_worker_results)} completed")
await asyncio.sleep(30)
else:
self.log(f"All {len(completed_worker_results)} workers completed. Returning results.")
return completed_worker_results

async def _detect_worker_hang_async(self):
# Check if processes are hanging
while True:
await asyncio.sleep(30)
for pid in range(self.num_workers):
current_time = datetime.datetime.now()
last_hearbeat_time = datetime.datetime.fromisoformat(self.redis.get(pid))
elapsed = (current_time - last_hearbeat_time).total_seconds()
if elapsed > self.hang_time_threshold:
self.log(f"Worker {pid} has been hanged for {elapsed / 60} minutes")
raise Exception(f"Worker {pid} appears to have hanged")

self.log("No hangs detected")

async def run(self, *args, **kwargs):
if not self.workers_initialized:
raise ValueError("""Cannot run workers without initializing them first.
Please call the initialize_workers method of your cluster coordinator first.""")

runners = asyncio.create_task(self._run_workers_async(*args, **kwargs))
hang_detector = asyncio.create_task(self._detect_worker_hang_async())
while True:
try:
done, _ = await asyncio.wait({runners, hang_detector}, return_when=asyncio.FIRST_COMPLETED)
for task in done:
# If the runner finish with exception first this will raise an exception
# If the hang detector finishes with exception first this will raise an exception
# The only case in which task.result() does not raise an exception is when
# the runners finish first without raising an exception. In that case
# get the results from the runners and cancel the hang detector task
# before returning
result = task.result()
hang_detector.cancel()
return result
except Exception as e:
self.log(f"Encountered exception {type(e).__name__}")
self.log(traceback.format_exc())

self.log("Cancelling all tasks in event loop...")
runners.cancel()
hang_detector.cancel()
self.log("Done cancelling all tasks in event loop")

self.log("Killing all ray actors...")
for w in self.workers:
ray.kill(w)
self.workers_initialized = False
del self.workers
self.log("Done killing all ray actors")

# Restart workers and reinitialize tasks
self.log("Restarting all actors")
worker_node_info, self.num_physical_nodes = self._get_schedulable_worker_info()
self.workers = [self.worker_cls.options(num_gpus=1,
num_cpus=16,
resources={"worker_units": 1},
scheduling_strategy=NASS(node_id=worker_node_info[i][0], soft=False)).remote(i,
worker_node_info[i][1],
worker_node_info[i][2])
for i in range(self.num_workers)]
self.jax_coordinator_ip = worker_node_info[0][2]
self._init_sync_dict()
self.initialize_workers(**self.worker_init_kwargs)

self.log("Reinitializing tasks")
runners = asyncio.create_task(self._run_workers_async(*args, **kwargs))
hang_detector = asyncio.create_task(self._detect_worker_hang_async())

class ResilientWorker:
def __init__(self, process_id, physical_node_id, physical_node_ip):
self.process_id = process_id
self.physical_node_id = physical_node_id
self.host_ip = physical_node_ip

self.redis_addr = os.environ.get('REDIS_ADDR').split(':')
self.logical_gpu_id = int(os.environ.get('CUDA_VISIBLE_DEVICES'))
self.redis = redis.Redis(host=self.redis_addr[0], port=int(self.redis_addr[1]), decode_responses=True, password=None)

def get_process_id(self):
return self.process_id

def get_host_ip(self):
return self.host_ip

def get_logical_gpu_id(self):
return self.logical_gpu_id

def get_physical_node_id(self):
return self.physical_node_id

def initialize(self, coordinator_addr, num_processes):
jax.distributed.initialize(coordinator_address=coordinator_addr, num_processes=num_processes, process_id=self.process_id, local_device_ids=0)

@contextmanager
def EnableHeartbeat(self):
try:
yield
finally:
current_time = datetime.datetime.now().isoformat()
self.redis.set(self.process_id, current_time)

def run(self, *args, **kwargs):
raise NotImplementedError
Loading