diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index 7f2c08072..eca1c3daa 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -47,6 +47,7 @@ def create_orbax_checkpoint_manager( enable_checkpointing: bool, use_async: bool, save_interval_steps: int, + max_to_keep: Optional[int] = None, dataset_type: Optional[str] = "tfds", orbax_logger: Optional[abstract_logger.AbstractLogger] = None, use_ocdbt: bool = True, @@ -77,6 +78,7 @@ def create_orbax_checkpoint_manager( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, + max_to_keep=max_to_keep ), logger=orbax_logger, ) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 5caaa9ef2..a9500513c 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -41,6 +41,7 @@ load_full_state_path: "" enable_checkpointing: True async_checkpointing: True checkpoint_period: 10_000 +num_checkpoints_to_keep: 5 # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False @@ -464,3 +465,12 @@ ragged_block_size: 256 sa_block_q: 512 sa_block_q_dkv: 512 sa_block_q_dq: 512 + +# Ray +use_ray: False +failure_sim_time: 300 +crash_prob: 0.5 +hang_prob: 0.5 + +# Logging +log_hps: False diff --git a/MaxText/launch_ray_maxtext.py b/MaxText/launch_ray_maxtext.py new file mode 100644 index 000000000..ece80c7fa --- /dev/null +++ b/MaxText/launch_ray_maxtext.py @@ -0,0 +1,37 @@ +import time +import os +from absl import app +from typing import Sequence +from ray.job_submission import JobSubmissionClient, JobStatus + +def main(argv: Sequence[str]) -> None: + client = JobSubmissionClient("http://127.0.0.1:8265") + print("Connected to head!", flush=True) + + maxtext_cmd_args = " ".join(argv[1:]) + job_id = client.submit_job( + entrypoint=f"RAY_DEDUP_LOGS=0 python3 MaxText/resilient_train.py {maxtext_cmd_args}", + runtime_env={"working_dir" : "./", + "excludes" : ["MaxText/test_assets", ".git"]} + ) + + print(f"Launched job: {job_id}", flush=True) + prev_logs = '' + while True: + status = client.get_job_status(job_id) + if status in {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}: + if status in {JobStatus.STOPPED, JobStatus.FAILED}: + logs = client.get_job_logs(job_id) + print(logs, flush=True) + break + time.sleep(5) + if status == JobStatus.RUNNING: + logs = client.get_job_logs(job_id) + print(logs[len(prev_logs):], flush=True) + prev_logs = logs + + + + +if __name__ == "__main__": + app.run(main) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5adc88199..bc114634c 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -312,7 +312,8 @@ def __init__(self, argv: list[str], **kwargs): validate_no_keys_overwritten_twice(keys_from_env_and_command_line, keys_from_model) # We initialize the jax distributed system here because it must be done before device backend is initialized. - max_utils.maybe_initialize_jax_distributed_system(raw_keys) + if not raw_keys["use_ray"]: + max_utils.maybe_initialize_jax_distributed_system(raw_keys) if raw_keys["jax_cache_dir"]: compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"])) @@ -332,10 +333,11 @@ def __init__(self, argv: list[str], **kwargs): raw_keys["tokenizer_path"] = tokenizer_path self.keys = raw_keys - keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension - keys.sort() - for k in keys: - max_logging.log(f"Config param {k}: {raw_keys[k]}") + if raw_keys["log_hps"]: + keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension + keys.sort() + for k in keys: + max_logging.log(f"Config param {k}: {raw_keys[k]}") @staticmethod def user_init(raw_keys): diff --git a/MaxText/ray_cluster.py b/MaxText/ray_cluster.py new file mode 100644 index 000000000..4353d145f --- /dev/null +++ b/MaxText/ray_cluster.py @@ -0,0 +1,189 @@ +import asyncio.selector_events +import ray +import traceback +import os +import jax +import random +import redis +import datetime +import asyncio +from contextlib import contextmanager +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy as NASS + +import max_logging + +class RayClusterCoordinator: + def __init__(self, worker_cls, hang_time_threshold) -> None: + self.worker_cls = worker_cls + self.num_workers = int(os.environ.get('NGPUS')) + self.num_workers_per_node = int(os.environ.get('GPUS_PER_NODE')) + self.workers_initialized = False + self.log = lambda user_str: max_logging.log(f"[RayClusterCoordinator] {user_str}") + self.hang_time_threshold = hang_time_threshold if hang_time_threshold is not None else 300 + + self.redis_addr = os.environ.get('REDIS_ADDR').split(':') + + worker_node_info, self.num_physical_nodes = self._get_schedulable_worker_info() + self.workers = [worker_cls.options(num_gpus=1, + num_cpus=16, + resources={"worker_units": 1}, + scheduling_strategy=NASS(node_id=worker_node_info[i][0], soft=False)).remote(i, + worker_node_info[i][1], + worker_node_info[i][2]) + for i in range(self.num_workers)] + + self.jax_coordinator_ip = worker_node_info[0][2] + self.redis = redis.Redis(host=self.redis_addr[0], port=int(self.redis_addr[1]), decode_responses=True, password=None) + self._init_sync_dict() + + def _get_schedulable_worker_info(self): + worker_node_info = [] + worker_nodes = sorted([node for node in ray.nodes() if (node['Alive'] and 'worker_units' in node['Resources'])], + key=lambda x: x['NodeID']) + + num_nodes_required = self.num_workers // self.num_workers_per_node + num_nodes_available = len(worker_nodes) + assert num_nodes_required <= num_nodes_available + + worker_nodes = worker_nodes[:num_nodes_required] + for worker_node_id, worker_node in enumerate(worker_nodes): + for _ in range(self.num_workers_per_node): + worker_node_info.append((worker_node['NodeID'], worker_node_id, worker_node['NodeName'])) + + return worker_node_info, num_nodes_required + + def _init_sync_dict(self): + self.redis.flushdb() + init_time = datetime.datetime.now().isoformat() + for pid in range(self.num_workers): + self.redis.set(pid, init_time) + + def initialize_workers(self, **kwargs): + self.worker_init_kwargs = kwargs + coordinator_port = random.randint(1, 100000) % 2**12 + (65535 - 2**12 + 1) + self.jax_coordinator_addr = f"{self.jax_coordinator_ip}:{coordinator_port}" + + ray.get([w.initialize.remote(self.jax_coordinator_addr, self.num_workers, **kwargs) for i, w in enumerate(self.workers)]) + self.workers_initialized = True + + async def _run_workers_async(self, *args, **kwargs): + worker_run_futures = [w.run.remote(*args, **kwargs) for w in self.workers] + while True: + completed_worker_results = [] + for _, wf in enumerate(worker_run_futures): + try: + worker_result = ray.get(wf, timeout=0) + completed_worker_results.append(worker_result) + except ray.exceptions.GetTimeoutError: + continue + + if len(completed_worker_results) < len(self.workers): + self.log(f"All workers seem to be alive, but only {len(completed_worker_results)} completed") + await asyncio.sleep(30) + else: + self.log(f"All {len(completed_worker_results)} workers completed. Returning results.") + return completed_worker_results + + async def _detect_worker_hang_async(self): + # Check if processes are hanging + while True: + await asyncio.sleep(30) + for pid in range(self.num_workers): + current_time = datetime.datetime.now() + last_hearbeat_time = datetime.datetime.fromisoformat(self.redis.get(pid)) + elapsed = (current_time - last_hearbeat_time).total_seconds() + if elapsed > self.hang_time_threshold: + self.log(f"Worker {pid} has been hanged for {elapsed / 60} minutes") + raise Exception(f"Worker {pid} appears to have hanged") + + self.log("No hangs detected") + + async def run(self, *args, **kwargs): + if not self.workers_initialized: + raise ValueError("""Cannot run workers without initializing them first. + Please call the initialize_workers method of your cluster coordinator first.""") + + runners = asyncio.create_task(self._run_workers_async(*args, **kwargs)) + hang_detector = asyncio.create_task(self._detect_worker_hang_async()) + while True: + try: + done, _ = await asyncio.wait({runners, hang_detector}, return_when=asyncio.FIRST_COMPLETED) + for task in done: + # If the runner finish with exception first this will raise an exception + # If the hang detector finishes with exception first this will raise an exception + # The only case in which task.result() does not raise an exception is when + # the runners finish first without raising an exception. In that case + # get the results from the runners and cancel the hang detector task + # before returning + result = task.result() + hang_detector.cancel() + return result + except Exception as e: + self.log(f"Encountered exception {type(e).__name__}") + self.log(traceback.format_exc()) + + self.log("Cancelling all tasks in event loop...") + runners.cancel() + hang_detector.cancel() + self.log("Done cancelling all tasks in event loop") + + self.log("Killing all ray actors...") + for w in self.workers: + ray.kill(w) + self.workers_initialized = False + del self.workers + self.log("Done killing all ray actors") + + # Restart workers and reinitialize tasks + self.log("Restarting all actors") + worker_node_info, self.num_physical_nodes = self._get_schedulable_worker_info() + self.workers = [self.worker_cls.options(num_gpus=1, + num_cpus=16, + resources={"worker_units": 1}, + scheduling_strategy=NASS(node_id=worker_node_info[i][0], soft=False)).remote(i, + worker_node_info[i][1], + worker_node_info[i][2]) + for i in range(self.num_workers)] + self.jax_coordinator_ip = worker_node_info[0][2] + self._init_sync_dict() + self.initialize_workers(**self.worker_init_kwargs) + + self.log("Reinitializing tasks") + runners = asyncio.create_task(self._run_workers_async(*args, **kwargs)) + hang_detector = asyncio.create_task(self._detect_worker_hang_async()) + +class ResilientWorker: + def __init__(self, process_id, physical_node_id, physical_node_ip): + self.process_id = process_id + self.physical_node_id = physical_node_id + self.host_ip = physical_node_ip + + self.redis_addr = os.environ.get('REDIS_ADDR').split(':') + self.logical_gpu_id = int(os.environ.get('CUDA_VISIBLE_DEVICES')) + self.redis = redis.Redis(host=self.redis_addr[0], port=int(self.redis_addr[1]), decode_responses=True, password=None) + + def get_process_id(self): + return self.process_id + + def get_host_ip(self): + return self.host_ip + + def get_logical_gpu_id(self): + return self.logical_gpu_id + + def get_physical_node_id(self): + return self.physical_node_id + + def initialize(self, coordinator_addr, num_processes): + jax.distributed.initialize(coordinator_address=coordinator_addr, num_processes=num_processes, process_id=self.process_id, local_device_ids=0) + + @contextmanager + def EnableHeartbeat(self): + try: + yield + finally: + current_time = datetime.datetime.now().isoformat() + self.redis.set(self.process_id, current_time) + + def run(self, *args, **kwargs): + raise NotImplementedError \ No newline at end of file diff --git a/MaxText/resilient_train.py b/MaxText/resilient_train.py new file mode 100644 index 000000000..3a3a01395 --- /dev/null +++ b/MaxText/resilient_train.py @@ -0,0 +1,829 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports +"""Training loop and Decoding of the model.""" + +# Calling jax.device_count here prevents a "TPU platform already registered" error. +# See github.com/google/maxtext/issues/20 for more + +import datetime +import os +import sys +import functools +import time +import ray +import asyncio +import random as py_rand + +from typing import Sequence, Optional +from absl import app +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +import grain.python as grain +import jax +import numpy as np +import orbax.checkpoint +import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager + +import checkpointing +import max_utils +import maxtext_utils +import max_logging +import optimizers +import profiler +import pyconfig +import pathwaysutils # pylint: disable=unused-import +import ray_cluster + +from vertex_tensorboard import VertexTensorboardManager +# Placeholder: internal + +from input_pipeline.input_pipeline_interface import create_data_iterator +from layers import models + +import jax.numpy as jnp +from jax import random +from jax.sharding import Mesh +from jax.experimental import checkify + +from cloud_tpu_diagnostics import diagnostic +from cloud_tpu_diagnostics.configuration import debug_configuration +from cloud_tpu_diagnostics.configuration import diagnostic_configuration +from cloud_tpu_diagnostics.configuration import stack_trace_configuration + +from layers import quantizations + +from ml_goodput_measurement import goodput +from ml_goodput_measurement import monitoring + +# pylint: disable=too-many-positional-arguments + +Transformer = models.Transformer +EPS = 1e-8 +_DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3 + + +def validate_train_config(config): + """Validates the configuration is set correctly for train.py""" + + assert config.run_name, "Erroring out, need a real run_name" + if not config.dataset_path.startswith("gs://"): + max_logging.log("WARNING: 'dataset_path' might be pointing your local file system") + if not config.base_output_directory.startswith("gs://"): + max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") + assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer." + if config.quantization == "fp8": + # pylint: disable=line-too-long + assert ( + config.gradient_accumulation_steps == 1 + ), "fp8 can't be used with gradient_accumulation_steps right now. Please use other quantization or set gradient_accumulation_steps to 1" + + +def get_first_step(state): + with jax.spmd_mode("allow_all"): + return int(state.step) + + +def load_next_batch(train_iter, example_batch, config): + """Loads the next batch. Can keep reusing the same batch for performance reasons""" + + if config.reuse_example_batch and example_batch is not None: + return example_batch + else: + return next(train_iter) + + +def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr, per_device_tokens): + """Records scalar metrics to be written to tensorboard""" + metrics["scalar"].update({"perf/step_time_seconds": step_time_delta.total_seconds()}) + metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops}) + metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()}) + metrics["scalar"].update({"perf/per_device_tokens": per_device_tokens}) + metrics["scalar"].update({"perf/per_device_tokens_per_sec": per_device_tokens / step_time_delta.total_seconds()}) + metrics["scalar"].update({"learning/current_learning_rate": lr}) + + +_buffered_step = None +_buffered_metrics = None + + +def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config, is_training=True): + """Entry point for all metrics writing in Train's Main. + TODO: would be better as a Class in the future (that initialized all state!) + + To avoid introducing an unnecessary dependency, we "double buffer" -- we hold + onto the last metrics and step and only publish when we receive a new metrics and step. + The logic is that this ensures that Jax is able to queues train_steps and we + don't block when turning "lazy" Jax arrays into real Python numbers. + """ + metrics_to_write, steps_to_write = None, None + if is_training: + global _buffered_step, _buffered_metrics + if _buffered_metrics is not None: + if _buffered_step is None: + raise ValueError(f"When writing metrics, {_buffered_step=} was none") + metrics_to_write = _buffered_metrics + steps_to_write = _buffered_step + else: + metrics_to_write = metrics + steps_to_write = step + + if metrics_to_write: + write_metrics_to_tensorboard(writer, metrics_to_write, steps_to_write, config, is_training) + + if config.metrics_file: + max_utils.write_metrics_locally(metrics_to_write, steps_to_write, config, local_metrics_file, is_training) + + if config.gcs_metrics and jax.process_index() == 0: + running_gcs_metrics = max_utils.write_metrics_for_gcs( + metrics_to_write, steps_to_write, config, running_gcs_metrics, is_training + ) + + if is_training: + _buffered_step = step + _buffered_metrics = metrics + + +def write_metrics_to_tensorboard(writer, metrics, step, config, is_training=True): + """Writes metrics to tensorboard""" + with jax.spmd_mode("allow_all"): + if jax.process_index() == 0: + for metric_name in metrics.get("scalar", []): + writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) + for metric_name in metrics.get("scalars", []): + writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) + + if is_training: + full_log = step % config.log_period == 0 + + max_logging.log( + f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " + f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " + f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, " + f"total_weights: {metrics['scalar']['learning/total_weights']}, " + f"loss: {metrics['scalar']['learning/loss']:.3f}" + ) + + if full_log and jax.process_index() == 0: + max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") + writer.flush() + + +def clear_buffered_metrics(): + global _buffered_step + global _buffered_metrics + _buffered_step = None + _buffered_metrics = None + + +def save_checkpoint( + checkpoint_manager, + step, + state, + dataset_type="c4", + data_iterator=None, + config: Optional[pyconfig.config] = None, +) -> bool: + """Wrapper for saving checkpoint.""" + if config and config.enable_checkpointing: + if (step % config.checkpoint_period == 0) or ( + config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0 + ): + blocking_until_ready_start = time.time() + max_logging.log(f"Waiting for step {step} to finish before checkpoint...") + # We block here on the step finishing so that our checkpointing metrics + # measure only checkpointing time, not training time. + jax.block_until_ready(state) + max_logging.log( + f"Waited {time.time() - blocking_until_ready_start} seconds for step " + f"{step} to finish before starting checkpointing." + ) + + # specify chunk_byte_size to force orbax to control maximum file size in checkpoint + chunk_byte_size = _DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE + if config: + chunk_byte_size = config.checkpoint_storage_target_data_file_size_bytes + save_args = jax.tree.map(lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=chunk_byte_size), state) + + if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager): + return checkpoint_manager.save( + step, + args=orbax.checkpoint.args.PyTreeSave(item=state, save_args=save_args, ocdbt_target_data_file_size=chunk_byte_size), + ) + + if dataset_type == "grain": + return checkpoint_manager.save( + step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeSave( + item=state, save_args=save_args, ocdbt_target_data_file_size=chunk_byte_size + ), + iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator), + ), + ) + else: + return checkpoint_manager.save( + step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeSave( + item=state, save_args=save_args, ocdbt_target_data_file_size=chunk_byte_size + ) + ), + ) + + +# ----------------------------------------------------------------------------- +# Top-level Functions +# ----------------------------------------------------------------------------- + + +def record_activation_metrics(output_metrics, intermediate_outputs, config): + """Adds the activation metrics to the metrics dict""" + + if config.scan_layers: + metrics_dict = intermediate_outputs["intermediates"]["decoder"]["decoder"] + + for layer_num in range(config.num_decoder_layers): + output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = metrics_dict["activation_fraction_zero"][0][ + layer_num + ] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = metrics_dict["activation_mean"][0][layer_num] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = metrics_dict["activation_stdev"][0][layer_num] + else: + for layer_num in range(config.num_decoder_layers): + layer = intermediate_outputs["intermediates"]["decoder"][f"layers_{layer_num}"] + output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = layer["activation_fraction_zero"][0] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = layer["activation_mean"][0] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = layer["activation_stdev"][0] + + +def loss_fn(model, config, data, dropout_rng, params, is_train=True): + """loss_fn for both train and eval. + + Args: + model: A nn.Module + config: Config of parameters + data: Batch of data to apply to the model + dropout_rng: A key to use to generate rng for dropout + params: Model params + is_train: True for train_step and False for eval_step + + Returns: + loss: average loss + aux: a dictionary including intermediate_outputs, total_loss, and total_weights + """ + # inputs, targets, segments, positions = apply_args + rng1, aqt_rng = jax.random.split(dropout_rng) + + # decimate proportion of data when per_device_batch_size<1 + if is_train: + for k, v in data.items(): + data[k] = v[: config.micro_batch_size_to_train_on, :] + else: + for k, v in data.items(): + data[k] = v[: config.micro_batch_size_to_eval_on, :] + + logits, intermediate_outputs = model.apply( + params, + data["inputs"], + data["inputs_position"], + decoder_segment_ids=data["inputs_segmentation"], + enable_dropout=config.enable_dropout if is_train else False, + rngs={"dropout": rng1, "params": aqt_rng}, + mutable="intermediates", + ) + one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) + xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) + xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length")) + # Mask out paddings at the end of each example. + xent = xent * (data["targets_segmentation"] != 0) + total_loss = jnp.sum(xent) + total_weights = jnp.sum(data["targets_segmentation"] != 0) + loss = total_loss / (total_weights + EPS) + # get moe load balance loss + moe_lb_loss = 0.0 + if config.num_experts > 1: + nested_key = ("intermediates", "decoder", "layers", "moe_lb_loss") + total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0) + moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss)) + loss += moe_lb_loss + aux = { + "intermediate_outputs": intermediate_outputs, + "total_loss": total_loss, + "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + } + return loss, aux + + +def train_step(model, config, state, data, dropout_rng): + """ + + Args: + model: A nn.Module + state: A pytree of the current state of the model + data: Batch of data to apply to the model + dropout_rng: A key to use to generate rng for dropout + + Returns: + new_state: Same format as state. + metrics: Dictionary of model metrics such as loss, training rate, etc. + rng2: A new rng key that can be used in future calls. + + """ + if config.gradient_accumulation_steps > 1: + + def accumulate_gradient(acc_grad_and_loss, data): + grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) + (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, state.params, is_train=True) + acc_grad_and_loss["loss"] += aux["total_loss"] + acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] + acc_grad_and_loss["grad"] = jax.tree_util.tree_map( + lambda x, y: x * aux["total_weights"] + y, cur_batch_gradient, acc_grad_and_loss["grad"] + ) + acc_grad_and_loss["total_weights"] += aux["total_weights"] + return acc_grad_and_loss, aux + + def reshape_to_microbatch_accumulations(batch_arr): + """Reshape global batch to microbatches, assuming batch axis is leading.""" + microbatches = config.gradient_accumulation_steps + microbatch_shape = (microbatches, batch_arr.shape[0] // microbatches) + batch_arr.shape[1:] + return jnp.reshape(batch_arr, microbatch_shape) + + data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data) + init_grad = jax.tree_util.tree_map(jnp.zeros_like, state.params) + init_grad_and_loss = {"loss": 0.0, "grad": init_grad, "total_weights": 0, "moe_lb_loss": 0.0} + + grad_and_loss, aux = jax.lax.scan( + accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps + ) + loss = ( + grad_and_loss["loss"] / grad_and_loss["total_weights"] + + grad_and_loss["moe_lb_loss"] / config.gradient_accumulation_steps + ) + raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], grad_and_loss["grad"]) + aux = jax.tree_map(lambda x: jnp.sum(x, axis=0), aux) + else: + grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, state.params, is_train=True) + intermediate_outputs = aux["intermediate_outputs"] + total_weights = aux["total_weights"] + moe_lb_loss = aux["moe_lb_loss"] + + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + else: + grads = raw_grads + new_state = state.apply_gradients(grads=grads) + metrics = { + "scalar": { + "learning/loss": loss, + "learning/moe_lb_loss": moe_lb_loss, + "learning/total_weights": total_weights, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + "learning/param_norm": max_utils.l2norm_pytree(new_state.params), + }, + "scalars": {}, + } + + if config.record_internal_nn_metrics: + record_activation_metrics(metrics, intermediate_outputs, config) + + return new_state, metrics + + +def eval_step(model, config, state, data, dropout_rng): + """eval_step no backprop and new state compared with train_step.""" + eval_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(state.params) + total_loss = aux["total_loss"] + total_weights = aux["total_weights"] + moe_lb_loss = aux["moe_lb_loss"] + metrics = { + "scalar": { + "evaluation/loss": loss, + "evaluation/total_loss": total_loss, + "evaluation/total_weights": total_weights, + "evaluation/moe_lb_loss": moe_lb_loss, + }, + } + + return metrics + + +def create_goodput_recorder(config): + if config.enable_goodput_recording: + logger_name = f"goodput_{config.run_name}" + recorder = goodput.GoodputRecorder(config.run_name, logger_name, jax.process_index() == 0) + return recorder + return None + + +def record_goodput( + recorder, + config, + record_func, + *args, +): + """Record data for Goodput and Badput computation.""" + if recorder and config.enable_goodput_recording: + record_func(*args) + + +def check_example_batch(config, example_batch): + if config.max_checkify: + jittable_f = checkify.checkify(lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!")) + # Check if inputs in batch contains bad synthetic data. + # pylint: disable=not-callable + err, _ = jax.jit(jittable_f)(example_batch["inputs"][: config.global_batch_size_to_train_on, :]) + err.throw() + + +def setup_mesh_and_model(config): + """Set up the mesh and the model for training + + Args: + config + + Returns: + init_rng: RNG key + writer: Summary writer for tensorboard + checkpoint_manager: Orbax checkpointer + state_mesh_annotations: the mesh annotations for the train state + model: + mesh: + learning_rate_schedule: + tx: + """ + + init_rng = random.PRNGKey(config.init_weights_seed) + writer = max_utils.initialize_summary_writer(config) + + # Mesh definition + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + # Model and Optimizer definition + quant = quantizations.configure_quantization(config) + model = Transformer(config, mesh, quant=quant) + learning_rate_schedule = max_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule) + logger = checkpointing.setup_checkpoint_logger(config) + if config.enable_emergency_checkpoint: + abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( + config.local_checkpoint_directory, + config.checkpoint_dir, + mesh, + abstract_state, + config.local_checkpoint_period, + config.checkpoint_period, + logger, + ) + else: + # TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend + use_ocdbt = config.checkpoint_storage_use_ocdbt + use_zarr3 = config.checkpoint_storage_use_zarr3 + if config.enable_single_controller: + use_ocdbt, use_zarr3 = False, False + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + config.checkpoint_dir, + config.enable_checkpointing, + config.async_checkpointing, + config.checkpoint_period, + config.num_checkpoints_to_keep, + config.dataset_type, + logger, + use_ocdbt, + use_zarr3, + ) + + return init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx + + +def setup_train_loop(config): + """Set up prerequisites for the training loop - + checkpoint_manager, PRNG keys, Mesh, Model and optimizer. + Set up data iterator and tokenizer, initialize the model. + + Args: + config + + Returns: + init_rng: + writer: Summary writer for tensorboard + checkpoint_manager: Orbax checkpointer + state_mesh_annotations: the mesh annotations for the train state + model: + mesh: + learning_rate_schedule: + data_iterator: + state: the initialized train state + """ + recorder = create_goodput_recorder(config) + record_goodput(recorder, config, recorder.record_tpu_init_start_time if recorder else None) + init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(config) + record_goodput(recorder, config, recorder.record_tpu_init_end_time if recorder else None) + record_goodput(recorder, config, recorder.record_training_preparation_start_time if recorder else None) + data_iterator, eval_data_iterator = create_data_iterator(config, mesh) + + state, state_mesh_annotations, data_iterator = max_utils.setup_training_state( + model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + ) + + if not config.using_pipeline_parallelism: + # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage + maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance=0.02) + record_goodput(recorder, config, recorder.record_training_preparation_end_time if recorder else None) + return ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) + +@ray.remote +class MaxtextTrainer(ray_cluster.ResilientWorker): + def __init__(self, process_id, physical_node_id, physical_node_ip): + super().__init__(process_id, physical_node_id, physical_node_ip) + + def initialize(self, coordinator_addr, num_processes, **kwargs): + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + super().initialize(coordinator_addr, num_processes) + maxtext_args = kwargs['maxtext_args'] + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + pyconfig.initialize(maxtext_args) + max_utils.print_system_information() + self.config = pyconfig.config + validate_train_config(self.config) + os.environ["TFDS_DATA_DIR"] = self.config.dataset_path + self.vertex_tensorboard_manager = VertexTensorboardManager() + if self.config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + self.vertex_tensorboard_manager.configure_vertex_tensorboard(self.config) + + if self.config.monitor_goodput and jax.process_index() == 0: + logger_name = f"goodput_{self.config.run_name}" + self.goodput_monitor = monitoring.GoodputMonitor( + job_name=self.config.run_name, + logger_name=logger_name, + tensorboard_dir=self.config.tensorboard_dir, + upload_interval=self.config.goodput_upload_interval_seconds, + monitoring_enabled=True, + include_badput_breakdown=True, + ) + + debug_config = debug_configuration.DebugConfig( + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=self.config.collect_stack_trace, + stack_trace_to_cloud=self.config.stack_trace_to_cloud, + stack_trace_interval_seconds=self.config.stack_trace_interval_seconds, + ) + ) + self.diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) + + def _fail(self, fail_prob, hang_prob): + if py_rand.random() >= (1 - hang_prob): + max_logging.log("Gonna hang") + time.sleep(3600) + + if py_rand.random() >= (1 - fail_prob): + exception = False if py_rand.random() < 0.5 else True + max_logging.log(f"Failing with exception = {exception}") + if exception: + raise Exception("Failure") + else: + # Cause a seg fault, no graceful exception propagation + eval((lambda:0).__code__.replace(co_consts=())) + + def _train_loop(self, state=None): + """Main Training loop. + Args: + config: + state: + ckpt_path: + Returns: + """ + # Create a GoodputRecorder to log information + recorder = create_goodput_recorder(self.config) + record_goodput(recorder, self.config, recorder.record_job_start_time if recorder else None) + + ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) = setup_train_loop(self.config) + # pylint: disable=line-too-long + ( + functional_train, + in_shard_train, + out_shard_train, + static_argnums_train, + donate_argnums_train, + ) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, self.config) + + if eval_data_iterator: + # pylint: disable=line-too-long + ( + functional_eval, + in_shard_eval, + out_shard_eval, + static_argnums_eval, + donate_argnums_eval, + ) = maxtext_utils.get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, self.config) + + num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) + max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") + per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(self.config) + per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(self.config) + + # Write train config params, num model params, and XLA flags to tensorboard + max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer) + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer) + max_utils.add_config_to_summary_writer(self.config, writer) + + # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit + if self.config.compiled_trainstep_file != "": + print("Loading the compiled function...", flush=True) + # Need to pass train signature and state to determine i/o shapes of train_state for now. + p_train_step = maxtext_utils.load_compiled(self.config, functional_train, state) + # TODO: p_eval_step is not yet supported in load_compiled + p_eval_step = None + print("Loaded compiled function!", flush=True) + else: + p_train_step = jax.jit( + functional_train, + in_shardings=in_shard_train, + out_shardings=out_shard_train, + static_argnums=static_argnums_train, + donate_argnums=donate_argnums_train, + ) + + if eval_data_iterator: + p_eval_step = jax.jit( + functional_eval, + in_shardings=in_shard_eval, + out_shardings=out_shard_eval, + static_argnums=static_argnums_eval, + donate_argnums=donate_argnums_eval, + ) + else: + p_eval_step = None + + local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None + running_gcs_metrics = [] if self.config.gcs_metrics else None + + start_step = get_first_step(state) # this is the start_step for training + first_profiling_step = start_step + self.config.skip_first_n_steps_for_profiler + if self.config.profiler != "" and first_profiling_step >= self.config.steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + last_profiling_step = np.clip(first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.steps - 1) + + example_batch = None + last_step_completion = datetime.datetime.now() + failure_timer_start = datetime.datetime.now() + prof = profiler.Profiler(self.config) + for step in np.arange(start_step, self.config.steps): + with self.EnableHeartbeat(): + if step == first_profiling_step: + if self.config.profile_cleanly: + jax.block_until_ready(state) # Block until previous state finishes to start profile cleanly + prof.activate() + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + record_goodput(recorder, self.config, recorder.record_data_loading_start_time if recorder else None) + example_batch = load_next_batch(data_iterator, example_batch, self.config) + record_goodput(recorder, self.config, recorder.record_data_loading_end_time if recorder else None) + check_example_batch(self.config, example_batch=example_batch) + # pylint: disable=not-callable + nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + record_goodput(recorder, self.config, recorder.record_step_start_time if recorder else None, step) + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + state, metrics = p_train_step(state, example_batch, nextrng) + + new_time = datetime.datetime.now() + record_scalar_metrics( + metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens + ) + last_step_completion = new_time + + if checkpoint_manager is not None: + if save_checkpoint(checkpoint_manager, int(step), state, self.config.dataset_type, data_iterator, self.config): + max_logging.log(f"saved a checkpoint at step {step}") + + # Upon preemption, exit when and only when all ongoing saves are complete. + if checkpoint_manager.reached_preemption(step): + checkpoint_manager.wait_until_finished() + sys.exit() + + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, self.config) + + if self.config.eval_interval > 0 and step > start_step and (step + 1) % self.config.eval_interval == 0: + assert eval_data_iterator + cumulative_eval_metrics = { + "scalar": { + "eval/total_loss": 0.0, + "eval/total_weights": 0.0, + "eval/avg_loss": 0.0, + "eval/moe_lb_loss": 0.0, + } + } + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + if self.config.eval_steps > 0 and eval_step_count >= self.config.eval_steps: + break + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, nextrng) + cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"]) + cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"]) + cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"]) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + eval_loss = ( + cumulative_eval_metrics["scalar"]["eval/total_loss"] + / (cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS) + + cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count + ) + cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss + write_metrics( + writer, local_metrics_file, running_gcs_metrics, cumulative_eval_metrics, step, self.config, is_training=False + ) + max_logging.log( + f"average loss after {step=}: {eval_step_count=}, {eval_loss=}, total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}" + ) + if eval_loss <= self.config.target_eval_loss: + max_logging.log(f"Early stop and exit loop after reaching {self.config.target_eval_loss=}") + prof.deactivate() + break + + if step == last_profiling_step: + if self.config.profile_cleanly: + jax.block_until_ready(state) # Block until current state finishes to end profile cleanly + prof.deactivate() + + current_time = datetime.datetime.now() + time_since_failure_sim = (current_time - failure_timer_start).total_seconds() + if time_since_failure_sim >= self.config.failure_sim_time: + self._fail(self.config.crash_prob, self.config.hang_prob) + + if checkpoint_manager is not None: + checkpoint_manager.wait_until_finished() + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, self.config.steps - 1, self.config) # final step metrics + max_utils.close_summary_writer(writer) + record_goodput(recorder, self.config, recorder.record_job_end_time if recorder else None) + clear_buffered_metrics() + return state + + def run(self): + with diagnostic.diagnose(self.diagnostic_config): + self._train_loop() + +def main(argv: Sequence[str]) -> None: + ray.init(address='auto', logging_level=0) + + hang_time_threshold = None + # Get hang time threshold + for arg in argv: + if arg.startswith('hang_time_threshold='): + hang_time_threshold = int(arg.split('=')[1]) + break + + cluster_coordinator = ray_cluster.RayClusterCoordinator(MaxtextTrainer, hang_time_threshold=hang_time_threshold) + cluster_coordinator.initialize_workers(maxtext_args=argv) + cluster_coordinator.log("Initialized workers") + asyncio.run(cluster_coordinator.run()) + + + +if __name__ == "__main__": + app.run(main)