diff --git a/init2winit/base_callback.py b/init2winit/base_callback.py index 09f435d7..65588751 100644 --- a/init2winit/base_callback.py +++ b/init2winit/base_callback.py @@ -20,7 +20,7 @@ callback_builder = callbacks.get_callback(config['callback_name']) callback = callback_builder(model, params, batch_stats, optimizer_state, - dataset, hps, config, train_dir, rng) + dataset, hps, config, train_dir, rng, mesh) callback_metrics = callback.run_eval(params, batch_stats, optimizer_state, global_step). @@ -39,7 +39,7 @@ class BaseCallBack: def __init__(self, model, params, batch_stats, optimizer_state, optimizer_update_fn, dataset, hps, callback_config, train_dir, - rng): + rng, mesh): """Defines the API for callback construction.""" pass diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index 46500fc1..f9bc0c8a 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -20,12 +20,11 @@ """ import os import sys -from typing import Sequence from absl import flags from absl import logging -from flax import jax_utils from flax.training import checkpoints as flax_checkpoints +from init2winit.dataset_lib import data_utils import jax FLAGS = flags.FLAGS @@ -44,47 +43,12 @@ def load_pytree(pytree_file, orbax_checkpointer=None): return None -def replicate_checkpoint( - latest, - pytree_keys: Sequence[str], - replicate=True): - """Restores from the provided checkpoint. - - Args: - latest: A dict representing the state of the - checkpoint we want to restore. - pytree_keys: A sequence of keys into `latest` that are pytrees, which will - be replicated if replicate=True. - replicate: If set, replicate the state across devices. - - Returns: - Tuple of (pytree, extra_dict) where pytree is a JAX pytree holding the - arrays that need to be replicated/unreplicated and extra_dict holds any - additional python state. We expect extra_dict to have the keys of - 'global_step', 'preemption_count', 'sum_train_cost', but old checkpoints - might be missing something. - """ - logging.info('Loaded model parameters from latest checkpoint.') - # Old checkpoints without 'sum_train_cost' can still be restored, but the - # train() function will break. Evals and curvature stuff should be fine, - # however. - expected = ['global_step', 'preemption_count', 'sum_train_cost'] - if any(k not in latest for k in expected): - logging.warn('Checkpoint state missing keys, obtained %s expected %s', - list(latest.keys()), expected) - - pytree = {k: latest[k] for k in pytree_keys} - if replicate: - pytree = jax_utils.replicate(pytree) - extra_dict = {k: latest[k] for k in latest.keys() if k not in pytree_keys} - return pytree, extra_dict - - def replicate_and_maybe_restore_checkpoint( unreplicated_optimizer_state, unreplicated_params, unreplicated_batch_stats, unreplicated_training_metrics_state, + mesh, train_dir, external_checkpoint_path=None, orbax_checkpointer=None): @@ -104,6 +68,7 @@ def replicate_and_maybe_restore_checkpoint( unreplicated_params: unreplicated params unreplicated_batch_stats: unreplicated batch stats unreplicated_training_metrics_state: unreplicated metrics state + mesh: Mesh specification to use for sharding. train_dir: (str) The training directory where we will look for a checkpoint. external_checkpoint_path: (str) If this argument is set, then we will load the external checkpoint stored there. @@ -165,43 +130,34 @@ def replicate_and_maybe_restore_checkpoint( # Handle failure to load from external_checkpoint_path. if ckpt_to_return['global_step'] == -1: return ( - jax_utils.replicate(unreplicated_optimizer_state), - jax_utils.replicate(unreplicated_params), - jax_utils.replicate(unreplicated_batch_stats), - jax_utils.replicate(unreplicated_training_metrics_state), + data_utils.shard_pytree(unreplicated_optimizer_state, mesh), + data_utils.shard_pytree(unreplicated_params, mesh), + data_utils.shard_pytree(unreplicated_batch_stats, mesh), + data_utils.shard_pytree(unreplicated_training_metrics_state, mesh), 0, # global_step - jax_utils.replicate(0), # sum_train_cost + 0, # sum_train_cost 0, # preemption_count False) # is_restored else: # Else, don't restore from any checkpoint. return ( - jax_utils.replicate(unreplicated_optimizer_state), - jax_utils.replicate(unreplicated_params), - jax_utils.replicate(unreplicated_batch_stats), - jax_utils.replicate(unreplicated_training_metrics_state), + data_utils.shard_pytree(unreplicated_optimizer_state, mesh), + data_utils.shard_pytree(unreplicated_params, mesh), + data_utils.shard_pytree(unreplicated_batch_stats, mesh), + data_utils.shard_pytree(unreplicated_training_metrics_state, mesh), 0, # global_step - jax_utils.replicate(0), # sum_train_cost + 0, # sum_train_cost 0, # preemption_count False) # is_restored - pytree_dict, extra_state = replicate_checkpoint( - ckpt_to_return, - pytree_keys=[ - 'optimizer_state', - 'params', - 'batch_stats', - 'training_metrics_grabber', - 'sum_train_cost', - ]) return ( - pytree_dict['optimizer_state'], - pytree_dict['params'], - pytree_dict['batch_stats'], - pytree_dict['training_metrics_grabber'], - extra_state['global_step'], - pytree_dict['sum_train_cost'], - extra_state['preemption_count'], - is_restored) + data_utils.shard_pytree(ckpt_to_return['optimizer_state'], mesh), + data_utils.shard_pytree(ckpt_to_return['params'], mesh), + data_utils.shard_pytree(ckpt_to_return['batch_stats'], mesh), + data_utils.shard_pytree(ckpt_to_return['training_metrics_grabber'], mesh), + ckpt_to_return['global_step'], # global_step + ckpt_to_return['sum_train_cost'], + ckpt_to_return['preemption_count'], # preemption_count + is_restored) # is_restored def save_unreplicated_checkpoint( @@ -217,14 +173,12 @@ def save_unreplicated_checkpoint( max_to_keep=1): """Saves pytree, step, preemption_count, and sum_train_cost to train_dir.""" logging.info('Saving checkpoint to ckpt_%d', global_step) - unreplicated_optimizer_state = jax.device_get( - jax_utils.unreplicate(optimizer_state)) - unreplicated_params = jax.device_get(jax_utils.unreplicate(params)) - unreplicated_batch_stats = jax.device_get(jax_utils.unreplicate(batch_stats)) + unreplicated_optimizer_state = jax.device_get(optimizer_state) + unreplicated_params = jax.device_get(params) + unreplicated_batch_stats = jax.device_get(batch_stats) unreplicated_training_metrics_state = jax.device_get( - jax_utils.unreplicate(training_metrics_state)) - unreplicated_sum_train_cost = jax.device_get( - jax_utils.unreplicate(sum_train_cost)) + training_metrics_state) + unreplicated_sum_train_cost = jax.device_get(sum_train_cost) state = dict(global_step=global_step, preemption_count=preemption_count, sum_train_cost=unreplicated_sum_train_cost, diff --git a/init2winit/dataset_lib/data_utils.py b/init2winit/dataset_lib/data_utils.py index 53ecaf42..53fabc7b 100644 --- a/init2winit/dataset_lib/data_utils.py +++ b/init2winit/dataset_lib/data_utils.py @@ -16,10 +16,14 @@ """Common code used by different models.""" import collections + +import flax.linen as nn import jax from jax.nn import one_hot +from jax.sharding import PartitionSpec as P import numpy as np + Dataset = collections.namedtuple('Dataset', [ 'train_iterator_fn', 'eval_train_epoch', @@ -143,40 +147,6 @@ def zero_pad(ar, pad_axis): return padded_batch -def shard(batch, n_devices=None): - """Prepares the batch for pmap by adding a leading n_devices dimension. - - If all the entries are lists, assume they are already divided into n_devices - smaller arrays and stack them for pmapping. If all the entries are arrays, - assume they have leading dimension divisible by n_devices and reshape. - - Args: - batch: A dict of arrays or lists of arrays - n_devices: If None, this will be set to jax.local_device_count(). - - Returns: - Sharded data. - """ - if n_devices is None: - n_devices = jax.local_device_count() - - # TODO(mbadura): Specify a sharding function per dataset instead - # If entries in the batch dict are lists, then the data is already divided - # into n_devices chunks, so we need to stack them. - if all((isinstance(v, list) for v in batch.values())): - assert all(len(v) == n_devices for v in batch.values()) - # transpose a dict of lists to a list of dicts - shards = [{k: v[i] for (k, v) in batch.items()} for i in range(n_devices)] - return jax.tree.map(lambda *vals: np.stack(vals, axis=0), shards[0], - *shards[1:]) - - # Otherwise, the entries are arrays, so just reshape them. - def _shard_array(array): - return array.reshape((n_devices, -1) + array.shape[1:]) - - return jax.tree.map(_shard_array, batch) - - def tf_to_numpy(tfds_data): # Safe because we won't mutate. Avoids an extra copy from tfds. convert_data = lambda x: x._numpy() # pylint: disable=protected-access @@ -187,4 +157,32 @@ def tf_to_numpy(tfds_data): def convert_jax_to_tf_random_seed(jax_prng_key: jax.random.PRNGKey) -> int: tf_seed = jax.random.bits(jax_prng_key) return tf_seed - \ No newline at end of file + + +def make_global_array(local_data, mesh): + """Util to combine per-host batches into a global batch array. + + Args: + local_data: local data batch on host. + mesh: mesh specification to shard the data. + + Returns: + global_array: global data batch. + """ + global_shape = ( + local_data.shape[0] * jax.process_count(), + *local_data.shape[1:], + ) + sharding = jax.NamedSharding(mesh, P('devices')) + + global_array = jax.make_array_from_process_local_data( + sharding, local_data, global_shape + ) + return global_array + + +def shard_pytree(pytree, mesh): + shardings = nn.get_sharding(pytree, mesh) + pytree = jax.device_put(pytree, shardings) + + return shardings, pytree diff --git a/init2winit/dataset_lib/ogbg_molpcba.py b/init2winit/dataset_lib/ogbg_molpcba.py index edb5748b..9d7e9d64 100644 --- a/init2winit/dataset_lib/ogbg_molpcba.py +++ b/init2winit/dataset_lib/ogbg_molpcba.py @@ -208,8 +208,7 @@ def _get_batch_iterator(dataset_iter, num_shards: How many devices we should be able to shard the batch into. Yields: - Batch in the init2winit format. Each field is a list of num_shards separate - smaller batches. + Batch in the init2winit format. """ if not num_shards: @@ -252,9 +251,9 @@ def _get_batch_iterator(dataset_iter, if count == num_shards: yield { - 'inputs': graphs_shards, - 'targets': labels_shards, - 'weights': weights_shards + 'inputs': jraph.batch(graphs_shards), + 'targets': np.vstack(labels_shards), + 'weights': np.vstack(weights_shards) } count = 0 diff --git a/init2winit/dataset_lib/test_ogbg_molpcba.py b/init2winit/dataset_lib/test_ogbg_molpcba.py index 70d9aef7..6c047b56 100644 --- a/init2winit/dataset_lib/test_ogbg_molpcba.py +++ b/init2winit/dataset_lib/test_ogbg_molpcba.py @@ -117,10 +117,9 @@ def test_get_batch_pads_correctly(self): dataset = _get_dataset(jax.random.PRNGKey(0)) batch = next(dataset.valid_epoch()) - inputs = batch['inputs'][0] + inputs = batch['inputs'] # The first two graphs are in the first batch - self.assertLen(batch['inputs'], 1) self.assertNDArrayNear(inputs.n_node[:2], np.array(NUMS_NODES[:2]), 1e-3) # The graphs are padded to the right size @@ -130,9 +129,9 @@ def test_get_batch_pads_correctly(self): self.assertEqual(np.sum(inputs.n_edge), BATCH_SIZE * EDGES_SIZE_MULTIPLIER) # Weights are zero at NaN labels and in padded examples - self.assertNDArrayNear(batch['weights'][0], + self.assertNDArrayNear(batch['weights'], np.array([[1, 1], [0, 1], [0, 0]]), 1e-3) - self.assertFalse(np.any(np.isnan(batch['targets'][0]))) + self.assertFalse(np.any(np.isnan(batch['targets']))) def test_train_shuffle_is_deterministic(self): """Tests that shuffling of the train split is deterministic.""" @@ -144,19 +143,18 @@ def test_train_shuffle_is_deterministic(self): batch_same = next(dataset_same.train_iterator_fn()) batch_different = next(dataset_different.train_iterator_fn()) - self.assertAllClose(batch['inputs'][0], batch_same['inputs'][0]) - self.assertNotAllClose(batch['inputs'][0], batch_different['inputs'][0]) + self.assertAllClose(batch['inputs'], batch_same['inputs']) + self.assertNotAllClose(batch['inputs'], batch_different['inputs']) def test_add_virtual_node(self): """Tests that adding a virtual node works correctly.""" dataset = _get_dataset(jax.random.PRNGKey(0), {'add_virtual_node': True}) batch = next(dataset.valid_epoch()) - inputs = batch['inputs'][0] + inputs = batch['inputs'] num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) - self.assertLen(batch['inputs'], 1) self.assertNDArrayNear( inputs.n_node[0], np.array(num_nodes + 1), 1e-3) self.assertNDArrayNear( @@ -173,11 +171,10 @@ def test_add_bidirectional_edges(self): jax.random.PRNGKey(0), {'add_bidirectional_edges': True}) batch = next(dataset.valid_epoch()) - inputs = batch['inputs'][0] + inputs = batch['inputs'] num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) - self.assertLen(batch['inputs'], 1) self.assertNDArrayNear( inputs.n_node[0], np.array(num_nodes), 1e-3) self.assertNDArrayNear( @@ -188,11 +185,10 @@ def test_add_self_loops(self): dataset = _get_dataset(jax.random.PRNGKey(0), {'add_self_loops': True}) batch = next(dataset.valid_epoch()) - inputs = batch['inputs'][0] + inputs = batch['inputs'] num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) - self.assertLen(batch['inputs'], 1) self.assertNDArrayNear( inputs.n_node[0], np.array(num_nodes), 1e-3) self.assertNDArrayNear( diff --git a/init2winit/model_lib/base_model.py b/init2winit/model_lib/base_model.py index ea398897..6e78f305 100644 --- a/init2winit/model_lib/base_model.py +++ b/init2winit/model_lib/base_model.py @@ -67,8 +67,8 @@ def _evaluate_batch(flax_module, params, batch_stats, batch, metrics_bundle, # We don't use CLU's `mask` argument here, we handle it ourselves through # `weights`. - return metrics_bundle.gather_from_model_output( - logits=logits, targets=targets, weights=weights, axis_name='batch') + return metrics_bundle.single_from_model_output( + logits=logits, targets=targets, weights=weights) class BaseModel(object): @@ -300,10 +300,6 @@ def training_objective_fn(self, params, logits, targets, weights): logits, targets, weights ) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch' - ) - # epsilon added to handle empty batch case if we encounter one. objective_value = objective_numerator / (objective_denominator + 1e-9) if self.hps.get('l2_decay_factor'): diff --git a/init2winit/model_lib/conformer.py b/init2winit/model_lib/conformer.py index e2353dcd..66115654 100644 --- a/init2winit/model_lib/conformer.py +++ b/init2winit/model_lib/conformer.py @@ -855,19 +855,15 @@ def evaluate_batch(self, params, batch_stats, batch): (objective_numerator, objective_denominator) = self.loss_fn( logits, logit_paddings, labels, label_paddings) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch') - normalized_loss = (objective_numerator / (objective_denominator)) hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings) - return self.metrics_bundle.gather_from_model_output( + return self.metrics_bundle.single_from_model_output( normalized_loss=normalized_loss, hyps=hyps, hyp_paddings=hyp_paddings, targets=labels, - target_paddings=label_paddings, - axis_name='batch') + target_paddings=label_paddings) def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """Return CTC loss.""" @@ -891,9 +887,6 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): (objective_numerator, objective_denominator) = self.loss_fn( outputs, output_paddings, labels, label_paddings) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch') - # epsilon added to handle empty batch case if we encounter one. objective_value = (objective_numerator / (objective_denominator + 1e-9)) return objective_value, new_batch_stats diff --git a/init2winit/model_lib/deepspeech.py b/init2winit/model_lib/deepspeech.py index 0ba1c5a9..73626c41 100644 --- a/init2winit/model_lib/deepspeech.py +++ b/init2winit/model_lib/deepspeech.py @@ -405,10 +405,6 @@ def __call__(self, inputs, input_paddings=None, train=False): count_v = jnp.sum( jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - if self.enable_synced_batchnorm: - sum_v = jax.lax.psum(sum_v, axis_name='batch') - count_v = jax.lax.psum(count_v, axis_name='batch') - count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v @@ -417,9 +413,6 @@ def __call__(self, inputs, input_paddings=None, train=False): axis=reduce_over_dims, keepdims=True) - if self.enable_synced_batchnorm: - sum_vv = jax.lax.psum(sum_vv, axis_name='batch') - var = sum_vv / count_v self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean @@ -959,20 +952,16 @@ def evaluate_batch(self, params, batch_stats, batch): (objective_numerator, objective_denominator) = self.loss_fn( logits, logit_paddings, labels, label_paddings) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch') - # epsilon added to handle empty batch case if we encounter one. normalized_loss = (objective_numerator / (objective_denominator + 1e-9)) hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings) - return self.metrics_bundle.gather_from_model_output( + return self.metrics_bundle.single_from_model_output( normalized_loss=normalized_loss, hyps=hyps, hyp_paddings=hyp_paddings, targets=labels, - target_paddings=label_paddings, - axis_name='batch') + target_paddings=label_paddings) def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """Return CTC loss.""" @@ -996,9 +985,6 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): (objective_numerator, objective_denominator) = self.loss_fn( outputs, output_paddings, labels, label_paddings) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch') - objective_value = (objective_numerator / (objective_denominator)) return objective_value, new_batch_stats diff --git a/init2winit/model_lib/unet.py b/init2winit/model_lib/unet.py index 950c6c4f..ecffab64 100644 --- a/init2winit/model_lib/unet.py +++ b/init2winit/model_lib/unet.py @@ -333,7 +333,7 @@ def evaluate_batch(self, params, batch_stats, batch): # We don't use CLU's `mask` argument here, we handle it ourselves through # `weights`. - return self.metrics_bundle.gather_from_model_output( + return self.metrics_bundle.single_from_model_output( logits=logits, targets=targets, weights=weights, diff --git a/init2winit/model_lib/xformer_translate.py b/init2winit/model_lib/xformer_translate.py index 7508ac0e..c6ae1ce8 100644 --- a/init2winit/model_lib/xformer_translate.py +++ b/init2winit/model_lib/xformer_translate.py @@ -1045,7 +1045,7 @@ def evaluate_batch(self, params, batch_stats, batch): targets = one_hot(batch['targets'], logits.shape[-1]) # Add log-perplexity metric. - return self.metrics_bundle.gather_from_model_output( + return self.metrics_bundle.single_from_model_output( logits=logits, targets=targets, weights=weights, axis_name='batch') def apply_on_batch(self, @@ -1101,8 +1101,8 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): (total_loss, total_weight) = self.loss_fn( logits, targets, weights) - (total_loss, total_weight) = lax.psum( - (total_loss, total_weight), axis_name='batch') + # (total_loss, total_weight) = lax.psum( + # (total_loss, total_weight), axis_name='batch') total_loss = (total_loss / total_weight) diff --git a/init2winit/mt_eval/eval_utils.py b/init2winit/mt_eval/eval_utils.py index 0f1dcfb1..4d0a05c8 100644 --- a/init2winit/mt_eval/eval_utils.py +++ b/init2winit/mt_eval/eval_utils.py @@ -57,7 +57,7 @@ def save_evals(ckpt_dir, ckpt_step, eval_split, bleu_score): f.write(str(bleu_score)) -def _load_checkpoint(checkpoint_path, params, replicate=True): +def _load_checkpoint(checkpoint_path, params): """Load model (and batch stats) from checkpoint.""" target = dict( params=params, @@ -68,12 +68,7 @@ def _load_checkpoint(checkpoint_path, params, replicate=True): checkpoint_path, target=target, ) - results = checkpoint.replicate_checkpoint( - ckpt, - pytree_keys=['params'], - replicate=replicate, - ) - params = results[0]['params'] + params = ckpt['params'] return params @@ -81,12 +76,13 @@ def average_checkpoints(checkpoint_paths, params): """Averages a set of checkpoints in input checkpoints.""" assert len(checkpoint_paths) >= 1 # Sum parameters of separate models together. - params = _load_checkpoint(checkpoint_paths[0], params, replicate=False) + params = _load_checkpoint(checkpoint_paths[0], params) for checkpoint_path in checkpoint_paths[1:]: - params_update = _load_checkpoint(checkpoint_path, params, replicate=False) + params_update = _load_checkpoint( + checkpoint_path, params + ) # TODO(dxin): Make this averaging process more numerically stable. - params = jax.tree.map( - lambda x, y: x + y, params, params_update) + params = jax.tree.map(lambda x, y: x + y, params, params_update) # Average checkpoints. params = jax.tree.map(lambda x: x / float(len(checkpoint_paths)), params) diff --git a/init2winit/mt_eval/inference.py b/init2winit/mt_eval/inference.py index 5f0ad410..6c4003cd 100644 --- a/init2winit/mt_eval/inference.py +++ b/init2winit/mt_eval/inference.py @@ -14,6 +14,7 @@ # limitations under the License. r"""BLEU evaluator container class.""" + import copy import dataclasses import functools @@ -21,9 +22,7 @@ from typing import Any, Sequence from absl import logging -from flax import jax_utils -from flax.training import common_utils -from init2winit import utils +from init2winit.dataset_lib import data_utils as utils from init2winit.dataset_lib import mt_tokenizer from init2winit.mt_eval import decode from init2winit.mt_eval import eval_utils @@ -32,7 +31,6 @@ import numpy as np from tensorflow.io import gfile -glob = gfile.glob DEFAULT_EVAL_CONFIG = { 'eval_batch_size': 16, @@ -68,6 +66,7 @@ def __init__(self, *args, **kwargs): if kwargs['mode'] not in ['offline', 'online']: raise ValueError('BLEU score computation only support online or ' 'offline modes.') + self.mesh = kwargs['mesh'] if kwargs['mode'] == 'offline': self.init_offline_evaluator(*args) else: @@ -168,12 +167,12 @@ def initialize_model(self, model_cls, dataset_meta_data, dropout_rng, params = init_dict['params'] self.flax_module = model.flax_module self.params = params - self.pmapped_init_cache = jax.pmap( + self.init_cache = jax.jit( functools.partial( self.initialize_cache, max_length=self.max_length, params_rng=params_rng, - dropout_rng=dropout_rng), axis_name='gather') + dropout_rng=dropout_rng)) def initialize_cache(self, inputs, max_length, params_rng, dropout_rng): """Initialize a cache for a given input shape and max decode length.""" @@ -217,7 +216,7 @@ def build_predictor(self): eos_id=self.eos_id, beam_size=self.mt_eval_config.get('beam_size'), offset=self.mt_eval_config.get('scan_over_layers_offset', 0)) - self.pmapped_predictor = jax.pmap(decoder, static_broadcasted_argnums=()) + self.predictor = jax.jit(decoder) def translate_and_calculate_bleu(self): """Iterate over all checkpoints and calculate BLEU.""" @@ -233,7 +232,7 @@ def translate_and_calculate_bleu(self): params = eval_utils.average_checkpoints( checkpoint_paths=ckpt_paths, params=self.params) - params_replicated = jax_utils.replicate(params) + _, params_replicated = utils.shard_pytree(params, self.mesh) decoding_output = self.translate_and_calculate_bleu_single_model( params_replicated, self.eval_split) logging.info('Sacre bleu score at step %d: %f', step, @@ -246,20 +245,23 @@ def translate_and_calculate_bleu_single_model(self, params, eval_split): self.build_predictor() decode_output = DecodingOutput() logging.info('Starting decoding..') + + make_global_array_fn = functools.partial( + utils.make_global_array, mesh=self.mesh + ) + for batch in self.get_ds_iter(eval_split): - pred_batch = common_utils.shard(batch) - cache = self.pmapped_init_cache(pred_batch['inputs']) - predicted = utils.data_gather( - self.pmapped_predictor(pred_batch, params, cache), - axis_name='gather') - inputs = utils.data_gather(pred_batch['inputs'], axis_name='gather') - targets = utils.data_gather(pred_batch['targets'], axis_name='gather') - weights = utils.data_gather(pred_batch['weights'], axis_name='gather') + pred_batch = jax.tree_util.tree_map(make_global_array_fn, batch) + cache = self.init_cache(pred_batch['inputs']) + predicted = self.predictor(pred_batch, params, cache) + inputs = pred_batch['inputs'] + targets = pred_batch['targets'] + weights = pred_batch['weights'] - predicted = utils.combine_gathered(np.array(predicted)) - inputs = utils.combine_gathered(np.array(inputs)) - targets = utils.combine_gathered(np.array(targets)) - weights = utils.combine_gathered(np.array(weights)) + predicted = np.array(predicted) + inputs = np.array(inputs) + targets = np.array(targets) + weights = np.array(weights) current_batch_size = int(weights[:, 0].sum()) if self.mt_eval_config.get('decoding_type') == 'beam_search': self.process_beam_search_output(inputs, targets, predicted, diff --git a/init2winit/mt_eval/mt_callback.py b/init2winit/mt_eval/mt_callback.py index 1473a554..a8341023 100644 --- a/init2winit/mt_eval/mt_callback.py +++ b/init2winit/mt_eval/mt_callback.py @@ -40,13 +40,13 @@ 'scan_over_layers_offset' equal to the length of that tuple. """ +import functools + from absl import logging from init2winit import base_callback from init2winit import utils -from init2winit.dataset_lib import data_utils from init2winit.dataset_lib import datasets from init2winit.model_lib import models - from init2winit.mt_eval import inference import jax from ml_collections.config_dict import config_dict @@ -74,7 +74,8 @@ def __init__(self, hps, callback_config, train_dir, - rng): + rng, + mesh): del optimizer_state del optimizer_update_fn del train_dir @@ -87,12 +88,14 @@ def __init__(self, self.callback_config = merged_callback_config self._validate_callback_config() - self.evaluate_batch_pmapped = jax.pmap( - model.evaluate_batch, axis_name='batch') + self.evaluate_batch_pmapped = jax.jit( + model.evaluate_batch, donate_argnums=(2,) + ) self.batch_stats = batch_stats dataset, dataset_metadata = self._get_dataset(hps, rng) self.dataset = dataset + self.mesh = mesh model_class = models.get_model(callback_config['model_name']) self.inference_manager = inference.InferenceManager( @@ -102,7 +105,8 @@ def __init__(self, dataset, dataset_metadata, self.callback_config, - mode='online') + mode='online', + mesh=mesh) def _validate_callback_config(self): assert all(key in self.callback_config for key in _REQUIRED_KEYS), ( @@ -137,7 +141,7 @@ def _evaluate(self, params, batch_stats, batch_iter, - evaluate_batch_pmapped): + evaluate_batch_jitted): """Compute aggregated metrics on the given data iterator. This function is taken as is from trainer.py to avoid circular dependency. @@ -148,19 +152,25 @@ def _evaluate(self, batch_stats: A dict of batch_stats. batch_iter: Generator which yields batches. Must support the API for b in batch_iter: - evaluate_batch_pmapped: A function with API - evaluate_batch_pmapped(params, batch_stats, batch). Returns a dictionary - mapping keys to the metric values across the sharded batch. + evaluate_batch_jitted: A function with API evaluate_batch_jitted(params, + batch_stats, batch). Returns a dictionary mapping keys to the metric + values across the sharded batch. Returns: A dictionary of aggregated metrics. The keys will match the keys returned - by evaluate_batch_pmapped. + by evaluate_batch_jitted. """ metrics = None + make_global_array_fn = functools.partial( + utils.make_global_array, mesh=self.mesh + ) + for batch in batch_iter: - batch = data_utils.shard(batch) - computed_metrics = evaluate_batch_pmapped( - params=params, batch_stats=batch_stats, batch=batch) + batch = utils.maybe_remove_leading_dimension(batch) + batch = jax.tree_util.tree_map(make_global_array_fn, batch) + computed_metrics = evaluate_batch_jitted( + params=params, batch_stats=batch_stats, batch=batch + ) if metrics is None: metrics = computed_metrics else: @@ -169,7 +179,7 @@ def _evaluate(self, # For data splits with no data (e.g. Imagenet no test set) no values # will appear for that split. if metrics is not None: - metrics = metrics.unreplicate().compute() + metrics = metrics.compute() for key, val in metrics.items(): if np.isnan(val): raise utils.TrainingDivergedError('NaN detected in {}'.format(key)) @@ -191,7 +201,8 @@ def _merge_and_apply_prefix(self, d1, d2, prefix): d1[prefix+key] = d2[key] return d1 - def run_eval(self, params, batch_stats, optimizer_state, global_step): + def run_eval( + self, params, batch_stats, optimizer_state, global_step): """Runs the MT models to evals specified by MT model. Args: @@ -230,7 +241,7 @@ def run_eval(self, params, batch_stats, optimizer_state, global_step): self.inference_manager.translate_and_calculate_bleu_single_model( params, split_name)) split_metrics = self._evaluate(params, batch_stats, split_iter, - self.evaluate_batch_pmapped) + self.evaluate_batch_jitted) split_metrics['bleu_score'] = decoding_output.bleu_score metrics = self._merge_and_apply_prefix( diff --git a/init2winit/optimizer_lib/gradient_accumulator.py b/init2winit/optimizer_lib/gradient_accumulator.py index b4190949..3033a7f5 100644 --- a/init2winit/optimizer_lib/gradient_accumulator.py +++ b/init2winit/optimizer_lib/gradient_accumulator.py @@ -64,7 +64,6 @@ def accumulate_gradients( virtual_batch_size: Optional[int], base_opt_init_fn: optax.TransformInitFn, base_opt_update_fn: optax.TransformUpdateFn, - batch_axis_name: Optional[str] = None, ) -> optax.GradientTransformationExtraArgs: """Accumulate gradients. @@ -80,9 +79,6 @@ def accumulate_gradients( generate updates given the total gradient. base_opt_update_fn: The update function for the base optimizer used to generate updates given the total gradient. - batch_axis_name: the name of the axis to pmap over. Used to run a pmean - before applying the optimizer update. - Returns: An (init_fn, update_fn) tuple. """ @@ -124,11 +120,6 @@ def total_batch_update(total_gradients, params, state): # batches. total_gradients = jax.tree.map( lambda x: x / steps_per_update, total_gradients) - if batch_axis_name: - # We only sync gradients when we are about to update the model, in order - # to avoid unnecessary cross replica communications. - total_gradients = jax.lax.pmean( - total_gradients, axis_name=batch_axis_name) updates, updated_base_state = base_opt_update_fn( total_gradients, state.base_state, params=params, **extra_args) diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index 2120fae8..3de50d82 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -218,66 +218,9 @@ def get_optimizer(hps, model=None, batch_axis_name=None): ) elif hps.optimizer == 'distributed_shampoo': - if hps.opt_hparams.get('frequent_directions', False): - statistics_compute_steps = hps.opt_hparams[ - 'preconditioning_compute_steps'] - else: - statistics_compute_steps = hps.opt_hparams['statistics_compute_steps'] - # pylint: disable=line-too-long - opt_init, opt_update = utils.static_inject_hyperparams( - distributed_shampoo.distributed_shampoo - )( - learning_rate=0.0, - block_size=hps.opt_hparams['block_size'], - beta1=hps.opt_hparams['beta1'], - beta2=hps.opt_hparams['beta2'], - diagonal_epsilon=hps.opt_hparams['diagonal_epsilon'], - matrix_epsilon=hps.opt_hparams['matrix_epsilon'], - weight_decay=hps.opt_hparams['weight_decay'], - start_preconditioning_step=hps - .opt_hparams['start_preconditioning_step'], - preconditioning_compute_steps=hps - .opt_hparams['preconditioning_compute_steps'], - decay_preconditioning_compute_steps=hps - .opt_hparams.get('decay_preconditioning_compute_steps', False), - end_preconditioning_compute_steps=hps - .opt_hparams.get('end_preconditioning_compute_steps', None), - statistics_compute_steps=statistics_compute_steps, - best_effort_shape_interpretation=hps - .opt_hparams['best_effort_shape_interpretation'], - nesterov=hps.opt_hparams['nesterov'], - exponent_override=hps.opt_hparams['exponent_override'], - batch_axis_name=batch_axis_name, - graft_type=hps.opt_hparams['graft_type'], - num_devices_for_pjit=hps.opt_hparams['num_devices_for_pjit'], - shard_optimizer_states=hps.opt_hparams['shard_optimizer_states'], - best_effort_memory_usage_reduction=hps - .opt_hparams['best_effort_memory_usage_reduction'], - inverse_failure_threshold=hps.opt_hparams['inverse_failure_threshold'], - moving_average_for_momentum=hps - .opt_hparams['moving_average_for_momentum'], - skip_preconditioning_dim_size_gt=hps - .opt_hparams['skip_preconditioning_dim_size_gt'], - relative_matrix_epsilon=hps.opt_hparams.get('relative_matrix_epsilon', - True), - clip_by_scaled_gradient_norm=hps - .opt_hparams['clip_by_scaled_gradient_norm'], - merge_small_dims_block_size=hps.opt_hparams.get( - 'merge_small_dims_block_size', 4096), - generate_fd_metrics=hps.opt_hparams.get('generate_fd_metrics', False), - compression_rank=hps.opt_hparams.get('compression_rank', 0), - frequent_directions=hps.opt_hparams.get('frequent_directions', False), - average_grad=hps.opt_hparams.get('average_grad', False), - eigh=hps.opt_hparams.get('eigh', False), - skip_preconditioning_rank_lt=hps.opt_hparams.get( - 'skip_preconditioning_rank_lt', 1), - decoupled_learning_rate=hps.opt_hparams.get('decoupled_learning_rate', - True), - decoupled_weight_decay=hps.opt_hparams.get('decoupled_weight_decay', - False), - generate_training_metrics=hps.opt_hparams.get( - 'generate_training_metrics', True), - reuse_preconditioner=hps.opt_hparams.get('reuse_preconditioner', False), + raise ValueError( + 'distributed_shampoo implementation is broken in init2winit after we' + ' migrated to jit, do not use it for the time being.' ) # pylint: enable=line-too-long elif hps.optimizer == 'adam': @@ -414,8 +357,7 @@ def get_optimizer(hps, model=None, batch_axis_name=None): total_batch_size=hps.total_accumulated_batch_size, virtual_batch_size=virtual_batch_size, base_opt_init_fn=opt_init, - base_opt_update_fn=opt_update, - batch_axis_name=batch_axis_name) + base_opt_update_fn=opt_update) if hps.opt_hparams.get('use_sam', False): opt_init, opt_update = sharpness_aware_minimization.sharpness_aware_minimization( diff --git a/init2winit/optimizer_lib/test_optimizers.py b/init2winit/optimizer_lib/test_optimizers.py index 5c47b636..b3ea16e0 100644 --- a/init2winit/optimizer_lib/test_optimizers.py +++ b/init2winit/optimizer_lib/test_optimizers.py @@ -19,20 +19,13 @@ import tempfile from absl.testing import absltest -from init2winit import hyperparameters -from init2winit import utils -from init2winit.dataset_lib import datasets -from init2winit.init_lib import initializers -from init2winit.model_lib import models from init2winit.optimizer_lib import optimizers from init2winit.optimizer_lib import utils as optimizers_utils -from init2winit.trainer_lib import trainer -import jax -from jax import lax from ml_collections import config_dict # import pandas # import tensorflow.compat.v1 as tf +# TODO(b/385225663): add test for nadamw. class OptimizersTrainerTest(absltest.TestCase): @@ -47,169 +40,6 @@ def tearDown(self): shutil.rmtree(self.test_dir) super().tearDown() - def test_shampoo_wrn(self): - """Test distributed shampoo on fake dataset.""" - model_name = 'simple_cnn' - model_cls = models.get_model(model_name) - hparam_overrides = { - 'optimizer': 'distributed_shampoo', - 'batch_size': 1, - 'train_size': 10, - 'valid_size': 10, - 'input_shape': (32, 32, 3), - 'output_shape': (10,), - 'opt_hparams': { - 'block_size': 32, - 'beta1': 0.9, - 'beta2': 0.999, - 'diagonal_epsilon': 1e-10, - 'matrix_epsilon': 1e-6, - 'weight_decay': 0.0, - 'start_preconditioning_step': 5, - 'preconditioning_compute_steps': 1, - 'statistics_compute_steps': 1, - 'best_effort_shape_interpretation': True, - 'graft_type': distributed_shampoo.GraftingType.SGD, - 'nesterov': True, - 'exponent_override': 0, - 'batch_axis_name': 'batch', - 'num_devices_for_pjit': None, - 'shard_optimizer_states': False, - 'inverse_failure_threshold': 0.1, - 'clip_by_scaled_gradient_norm': None, - 'precision': lax.Precision.HIGHEST, - 'moving_average_for_momentum': False, - 'skip_preconditioning_dim_size_gt': 4096, - 'best_effort_memory_usage_reduction': False, - }, - } - input_pipeline_hps = config_dict.ConfigDict(dict( - num_tf_data_prefetches=-1, - num_device_prefetches=0, - num_tf_data_map_parallel_calls=-1, - )) - hps = hyperparameters.build_hparams( - model_name, - initializer_name='noop', - dataset_name='fake', - hparam_file=None, - hparam_overrides=hparam_overrides, - input_pipeline_hps=input_pipeline_hps) - initializer = initializers.get_initializer('noop') - dataset_builder = datasets.get_dataset('fake') - dataset = dataset_builder( - shuffle_rng=jax.random.PRNGKey(0), - batch_size=hps.batch_size, - eval_batch_size=hps.batch_size, - hps=hps) - - loss_name = 'cross_entropy' - metrics_name = 'classification_metrics' - dataset_meta_data = datasets.get_dataset_meta_data('fake') - model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) - - metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) - self.trainer = trainer.Trainer( - train_dir=self.test_dir, - model=model, - dataset_builder=lambda *unused_args, **unused_kwargs: dataset, - initializer=initializer, - num_train_steps=1, - hps=hps, - rng=jax.random.PRNGKey(12), - eval_batch_size=hps.batch_size, - eval_use_ema=False, - eval_num_batches=None, - test_num_batches=0, - eval_train_num_batches=None, - eval_frequency=10, - checkpoint_steps=[], - metrics_logger=metrics_logger, - init_logger=init_logger, - ) - _ = list(self.trainer.train()) - - # TODO(b/373658570) - # NOTE(levskaya): this test is -wildly- sensitive to trainer PRNG key. - # with tf.io.gfile.GFile(os.path.join(self.test_dir, - # 'measurements.csv')) as f: - # df = pandas.read_csv(f) - # valid_ce_loss = df['valid/ce_loss'].values[-1] - # self.assertLess(valid_ce_loss, 1e-3) - - def test_clip_raises_when_no_aggregation(self): - """Test that gradient clipping raises when no gradient aggregation.""" - model_name = 'wide_resnet' - model_cls = models.get_model(model_name) - hparam_overrides = { - 'grad_clip': 0.1, - 'total_accumulated_batch_size': 1024, # Use gradient accumulation. - } - input_pipeline_hps = config_dict.ConfigDict(dict( - num_tf_data_prefetches=-1, - num_device_prefetches=0, - num_tf_data_map_parallel_calls=-1, - )) - hps = hyperparameters.build_hparams( - model_name, - initializer_name='noop', - dataset_name='fake', - hparam_file=None, - hparam_overrides=hparam_overrides, - input_pipeline_hps=input_pipeline_hps) - initializer = initializers.get_initializer('noop') - dataset_builder = datasets.get_dataset('fake') - dataset = dataset_builder( - shuffle_rng=jax.random.PRNGKey(0), - batch_size=hps.batch_size, - eval_batch_size=hps.batch_size, - hps=hps) - - loss_name = 'cross_entropy' - metrics_name = 'classification_metrics' - dataset_meta_data = datasets.get_dataset_meta_data('fake') - model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) - - self.trainer = trainer.Trainer( - train_dir=self.test_dir, - model=model, - dataset_builder=lambda *unused_args, **unused_kwargs: dataset, - initializer=initializer, - num_train_steps=10, - hps=hps, - rng=jax.random.PRNGKey(42), - eval_batch_size=hps.batch_size, - eval_use_ema=False, - eval_num_batches=None, - test_num_batches=0, - eval_train_num_batches=None, - eval_frequency=10, - checkpoint_steps=[], - ) - with self.assertRaises(NotImplementedError): - _ = list(self.trainer.train()) - - -class OptimizersTest(absltest.TestCase): - """Tests for optimizers.py.""" - - def test_no_cross_device_gradient_aggregation(self): - """Test that no_cross_device_gradient_aggregation propagates correctly.""" - _, update_fn = optimizers.get_optimizer( - config_dict.ConfigDict({ - 'optimizer': 'adam', - 'l2_decay_factor': None, - 'batch_size': 50, - 'total_accumulated_batch_size': 100, # Use gradient accumulation. - 'opt_hparams': { - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-7, - 'weight_decay': 0.0, - } - })) - # The gradient accumulation performs gradient aggregation internally. - self.assertFalse(optimizers_utils.requires_gradient_aggregation(update_fn)) if __name__ == '__main__': diff --git a/init2winit/test_checkpoint.py b/init2winit/test_checkpoint.py index a499ddc0..bb528874 100644 --- a/init2winit/test_checkpoint.py +++ b/init2winit/test_checkpoint.py @@ -24,16 +24,17 @@ from absl import flags from absl.testing import absltest from absl.testing import parameterized -from flax import jax_utils from init2winit import checkpoint from init2winit.model_lib import models from init2winit.shared_test_utilities import pytree_equal +from jax.experimental import mesh_utils import jax.numpy as jnp import jax.tree_util import numpy as np import orbax.checkpoint as orbax_checkpoint from tensorflow.io import gfile + FLAGS = flags.FLAGS INPUT_SHAPE = [10, 28, 28, 1] @@ -63,6 +64,12 @@ def setUp(self): orbax_checkpoint.PyTreeCheckpointHandler(), timeout_secs=60) self.params = init_dict['params'] + mesh_shape = (jax.device_count(),) + self.mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, devices=jax.devices()), + axis_names=('devices',), + ) + def tearDown(self): shutil.rmtree(self.test_dir) super(CheckpointTest, self).tearDown() @@ -146,25 +153,39 @@ def test_all_variables_restored(self): orbax_checkpointer=self.orbax_checkpointer, max_to_keep=1) - (ret_state, ret_params, ret_batch_stats, ret_training_metrics, - ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored, - ) = checkpoint.replicate_and_maybe_restore_checkpoint( - initial_optimizer_state, initial_params, initial_batch_stats, - initial_training_metrics, fresh_train_dir, - orbax_checkpointer=self.orbax_checkpointer) + ( + (_, ret_state), + (_, ret_params), + (_, ret_batch_stats), + (_, ret_training_metrics), + ret_global_step, + ret_sum_train_cost, + ret_preemption_count, + ret_is_restored, + ) = checkpoint.replicate_and_maybe_restore_checkpoint( + initial_optimizer_state, + initial_params, + initial_batch_stats, + initial_training_metrics, + self.mesh, + fresh_train_dir, + orbax_checkpointer=self.orbax_checkpointer, + ) assert pytree_equal( - jax.device_get(jax_utils.unreplicate(ret_state)), - saved_optimizer_state) + jax.device_get(ret_state), saved_optimizer_state + ) assert pytree_equal( - jax.device_get(jax_utils.unreplicate(ret_params)), - saved_params) + jax.device_get(ret_params), saved_params + ) assert pytree_equal( - jax.device_get(jax_utils.unreplicate(ret_batch_stats)), - saved_batch_stats) + jax.device_get(ret_batch_stats), + saved_batch_stats, + ) assert pytree_equal( - jax.device_get(jax_utils.unreplicate(ret_training_metrics)), - saved_training_metrics) + jax.device_get(ret_training_metrics), + saved_training_metrics, + ) self.assertEqual(ret_sum_train_cost, sum_train_cost) self.assertEqual(ret_preemption_count, preemption_count) self.assertEqual(ret_global_step, global_step) @@ -217,13 +238,13 @@ def save_checkpoint(train_dir, global_step, preemption_count, def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path): """Helper function to replicate_and_maybe_restore a checkpoint.""" - (_, ret_params, _, _, + (_, (_, ret_params), _, _, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint( - {}, params, {}, {}, train_dir, external_checkpoint_path, + {}, params, {}, {}, self.mesh, train_dir, external_checkpoint_path, orbax_checkpointer=self.orbax_checkpointer) - ret_params_unrep = jax.device_get(jax_utils.unreplicate(ret_params)) + ret_params_unrep = jax.device_get(ret_params) return (ret_params_unrep, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index e2a0700a..112f9b56 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -21,7 +21,6 @@ import time from absl import logging -from flax import jax_utils from init2winit import callbacks from init2winit import checkpoint from init2winit import schedules @@ -31,10 +30,13 @@ from init2winit.trainer_lib import trainer_utils from init2winit.training_metrics_grabber import make_training_metrics import jax +from jax.experimental import mesh_utils import numpy as np import optax import orbax.checkpoint as orbax_checkpoint +NamedSharding = jax.sharding.NamedSharding + class BaseTrainer(metaclass=abc.ABCMeta): """Abstract parent class for all trainers.""" @@ -200,13 +202,20 @@ def __init__( # During eval, we can donate the 'batch' buffer. We don't donate the # 'params' and 'batch_stats' buffers as we don't re-assign those values in # eval, we do that only in train. - self._evaluate_batch_pmapped = jax.pmap( - self._model.evaluate_batch, axis_name='batch', donate_argnums=(2,)) + self._evaluate_batch_jitted = jax.jit( + self._model.evaluate_batch, donate_argnums=(2,)) # Numpy array of range(0, local_device_count) to send to each device to be # folded into the RNG inside each train step to get a unique per-device RNG. self._local_device_indices = np.arange(jax.local_device_count()) + # Creates a 1-d mesh with all devices available globally. + mesh_shape = (jax.device_count(),) + self._mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, devices=jax.devices()), + axis_names=('devices',), + ) + def wait_until_orbax_checkpointer_finished(self): self._orbax_checkpointer.wait_until_finished() @@ -224,7 +233,7 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): data_rng: the jax PRNGKey used for dataset randomness. Should be *different* across hosts! trainer_update_fn: the function for updating the model. If None, this will - skip pmapping the update function. + skip jitting the update function. Returns: A long tuple of the following: @@ -232,10 +241,14 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): optimizer_update_fn: the optax update fn. metrics_update_fn: the optional metrics update fn. metrics_summary_fn: the optional metrics summary fn. - optimizer_state: the replicated optimizer state. - params: the replicated model parameters. - batch_stats: the replicated (optional) model batch statistics. - metrics_state: the replicated metric states. + (optimizer_state_sharding, optimizer_state): the replicated optimizer + state and corresponding sharding annotations. + (params_sharding, params): the replicated model parameters and + corresponding sharding annotations. + (batch_stats_sharding, batch_stats): the replicated (optional) model + batch statistics and corresponding sharding annotations. + (metrics_state_sharding, metrics_state) : the replicated metric states + and corresponding sharding annotations. global_step: the global step to start training at. sum_train_cost: the sum of the train costs. preemption_count: the number of times training has been preempted. @@ -299,10 +312,10 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): unreplicated_batch_stats) ( - optimizer_state, - params, - batch_stats, - metrics_state, + (optimizer_state_sharding, optimizer_state), + (params_sharding, params), + (batch_stats_sharding, batch_stats), + (metrics_state_sharding, metrics_state), global_step, sum_train_cost, preemption_count, @@ -312,6 +325,7 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): unreplicated_params, unreplicated_batch_stats, unreplicated_metrics_state, + self._mesh, train_dir=self._train_dir, external_checkpoint_path=self._external_checkpoint_path, orbax_checkpointer=self._orbax_checkpointer, @@ -360,6 +374,12 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): hps=self._hps, ) + if self._hps.get('grad_clip') and self._hps.get('total_accumulated_batch_size'): # pylint: disable=line-too-long + raise NotImplementedError( + 'Gradient clipping is not supported when gradient accumulation is' + ' performed internally by the optimizer.' + ) + if trainer_update_fn is not None: update_fn = functools.partial( trainer_update_fn, @@ -367,31 +387,32 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): grad_clip=self._hps.get('grad_clip'), optimizer_update_fn=optimizer_update_fn, metrics_update_fn=metrics_update_fn) - # in_axes = ( - # optimizer_state = 0, - # params = 0, - # batch_stats = 0, - # metrics_state = 0, - # batch = 0, - # step = None, - # lr = None, - # rng = None, - # local_device_index = 0, - # running_train_cost = 0, - # training_cost, - # grad_clip, - # optimizer_update_fn, - # metrics_state_update_fn) - # Also, we can donate buffers for 'optimizer', 'batch_stats', - # 'batch' and 'training_metrics_state' for update's pmapped computation. - update_pmapped = jax.pmap( + + # We donate optimizer_state, params and batch_stats in jitted computation. + # This helps reduce memory usage as outputs corresponding to these inputs + # arguments can re-use the memory. + update_jitted = jax.jit( update_fn, - axis_name='batch', - in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0), - donate_argnums=(0, 1, 2, 8), + donate_argnums=(0, 1, 2), + in_shardings=( + optimizer_state_sharding, + params_sharding, + batch_stats_sharding, + metrics_state_sharding, + NamedSharding(self._mesh, jax.sharding.PartitionSpec('devices')), + None, None, None, None + ), + out_shardings=( + optimizer_state_sharding, + params_sharding, + batch_stats_sharding, + None, + metrics_state_sharding, + None + ), ) else: - update_pmapped = None + update_jitted = None return ( lr_fn, @@ -399,14 +420,18 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): metrics_update_fn, metrics_summary_fn, optimizer_state, + optimizer_state_sharding, params, + params_sharding, batch_stats, + batch_stats_sharding, metrics_state, + metrics_state_sharding, global_step, sum_train_cost, preemption_count, dataset, - update_pmapped) + update_jitted) def _setup_and_maybe_restore( self, init_rng, data_rng, callback_rng, trainer_update_fn): @@ -425,7 +450,7 @@ def _setup_and_maybe_restore( - initializing and maybe restoring self._sum_train_cost. - initializing and maybe restoring self._preemption_count. - setting self._dataset - - setting self._update_pmapped + - setting self._update_jitted - setting self._eval_callbacks Args: @@ -436,21 +461,25 @@ def _setup_and_maybe_restore( callback_rng: the jax PRNGKey used for eval callbacks. Should be *different* across hosts! trainer_update_fn: the function for updating the model. If None, this will - skip pmapping the update function. + skip jitting the update function. """ (self._lr_fn, self._optimizer_update_fn, self._metrics_update_fn, self._metrics_summary_fn, self._optimizer_state, + self._optimizer_state_sharding, self._params, + self._params_sharding, self._batch_stats, + self._batch_stats_sharding, self._metrics_state, + self._metrics_state_sharding, self._global_step, self._sum_train_cost, self._preemption_count, self._dataset, - self._update_pmapped) = self.setup_and_maybe_restore( + self._update_jitted) = self.setup_and_maybe_restore( init_rng, data_rng, trainer_update_fn) self._eval_callbacks = self._setup_eval_callbacks(callback_rng) @@ -482,7 +511,7 @@ def _setup_eval_callbacks(self, callback_rng): eval_callback = callbacks.get_callback(config['callback_name'])( self._model, self._params, self._batch_stats, self._optimizer_state, self._optimizer_update_fn, self._dataset, self._hps, config, - self._train_dir, rng) + self._train_dir, rng, self._mesh) eval_callbacks.append(eval_callback) return eval_callbacks @@ -546,9 +575,6 @@ def _eval( """ time_since_last_eval = time.time() - self._time_at_prev_eval_end - self._batch_stats = trainer_utils.maybe_sync_batchnorm_stats( - self._batch_stats - ) if self._eval_use_ema: if isinstance( @@ -574,7 +600,8 @@ def _eval( self._eval_num_batches, self._test_num_batches, self._eval_train_num_batches, - self._evaluate_batch_pmapped) + self._evaluate_batch_jitted, + self._mesh) self._run_eval_callbacks(report) if save: self._save(self._train_dir) @@ -583,10 +610,10 @@ def _eval( run_time = time.time() - self._time_at_prev_eval_end steps_per_sec = steps_since_last_eval / run_time - mean_train_cost = jax.lax.pmean(self._sum_train_cost, axis_name=[])[ - 0 - ].item() / max(1, self._global_step - self._prev_eval_step) - self._sum_train_cost = jax_utils.replicate(0.0) + mean_train_cost = self._sum_train_cost / max( + 1, self._global_step - self._prev_eval_step + ) + self._sum_train_cost = 0.0 epoch = self._global_step * self._hps.batch_size // self._hps.train_size overall_steps_per_sec = self._get_step_frequency( self._global_step, start_step, start_time) diff --git a/init2winit/trainer_lib/test_trainer.py b/init2winit/trainer_lib/test_trainer.py index 7ba594c6..05d91136 100644 --- a/init2winit/trainer_lib/test_trainer.py +++ b/init2winit/trainer_lib/test_trainer.py @@ -16,7 +16,6 @@ """Unit tests for trainer.py.""" import copy -import functools import itertools import os import shutil @@ -25,10 +24,10 @@ from absl import flags from absl.testing import absltest from absl.testing import parameterized -from flax import jax_utils from flax import linen as nn from init2winit import hyperparameters from init2winit import utils +from init2winit.dataset_lib import data_utils from init2winit.dataset_lib import datasets from init2winit.dataset_lib.small_image_datasets import Dataset from init2winit.init_lib import initializers @@ -37,6 +36,8 @@ from init2winit.model_lib import models from init2winit.trainer_lib import trainer from init2winit.trainer_lib import trainer_utils +import jax +from jax.experimental import mesh_utils import jax.numpy as jnp import jax.random import jraph @@ -45,6 +46,8 @@ import pandas import tensorflow.compat.v1 as tf # importing this is needed for tfds mocking. import tensorflow_datasets as tfds + + FLAGS = flags.FLAGS _VOCAB_SIZE = 4 @@ -196,25 +199,25 @@ def _get_fake_dlrm_dataset(batch_size, eval_num_batches, hps): def train_iterator_fn(): while True: - yield batch + yield copy.deepcopy(batch) def eval_train_epoch(num_batches=None): if num_batches is None: num_batches = eval_num_batches for _ in range(num_batches): - yield batch + yield copy.deepcopy(batch) def valid_epoch(num_batches=None): if num_batches is None: num_batches = eval_num_batches for _ in range(num_batches): - yield batch + yield copy.deepcopy(batch) def test_epoch(num_batches=None): if num_batches is None: num_batches = eval_num_batches for _ in range(num_batches): - yield batch + yield copy.deepcopy(batch) meta_data = { 'apply_one_hot_in_loss': False, @@ -303,7 +306,12 @@ def __call__(self, x, train): rngs={'params': params_rng, 'dropout': dropout_rng}, x=None, train=False) - params = jax_utils.replicate(init_dict['params']) + mesh_shape = (jax.device_count(),) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, devices=jax.devices()), + axis_names=('devices',), + ) + _, params = data_utils.shard_pytree(init_dict['params'], mesh) batch_stats = init_dict.get('batch_stats', {}) # 4 evaluation batches of size 4. @@ -336,18 +344,22 @@ def fake_batches_gen(): yield batch # pylint: disable=protected-access - eval_fn = functools.partial( - base_model._evaluate_batch, - flax_module=fake_flax_module, - metrics_bundle=metrics.get_metrics('classification_metrics'), - apply_one_hot_in_loss=True) - evaluate_batch_pmapped = jax.pmap(eval_fn, axis_name='batch') + eval_fn = lambda params, batch_stats, batch: base_model._evaluate_batch( + fake_flax_module, + params, + batch_stats, + batch, + metrics.get_metrics('classification_metrics'), + True) + + evaluate_batch_jitted = jax.jit(eval_fn) # pylint: enable=protected-access evaluated_metrics = trainer_utils.evaluate( params, batch_stats, fake_batches_gen(), - evaluate_batch_pmapped) + evaluate_batch_jitted, + mesh) def batch_ce_loss(logits, targets): one_hot_targets = np.eye(4)[targets] @@ -890,7 +902,7 @@ def mock_evaluate_batch(params, batch_stats, batch): del params, batch_stats metrics_bundle = metrics.get_metrics(metrics_name) - return metrics_bundle.gather_from_model_output( + return metrics_bundle.single_from_model_output( logits=batch.get('logits'), targets=batch.get('targets'), weights=batch.get('weights')) @@ -906,11 +918,18 @@ def mock_evaluate_batch(params, batch_stats, batch): 'weights': ws } for ls, ts, ws in zip(logits, targets, weights)] + mesh_shape = (jax.device_count(),) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, devices=jax.devices()), + axis_names=('devices',), + ) + result = trainer_utils.evaluate( params=None, batch_stats=None, batch_iter=batch_iter, - evaluate_batch_pmapped=jax.pmap(mock_evaluate_batch, axis_name='batch')) + evaluate_batch_jitted=jax.jit(mock_evaluate_batch), + mesh=mesh) for metric, val in zip(test_metric_names, test_metric_vals): self.assertAlmostEqual(result[metric], val, places=5) diff --git a/init2winit/trainer_lib/trainer.py b/init2winit/trainer_lib/trainer.py index e4ddbff3..660ed738 100644 --- a/init2winit/trainer_lib/trainer.py +++ b/init2winit/trainer_lib/trainer.py @@ -14,17 +14,17 @@ # limitations under the License. """Standard trainer for the init2winit project.""" +import functools import itertools import time from absl import logging from init2winit import utils +from init2winit.dataset_lib import data_utils from init2winit.model_lib import model_utils -from init2winit.optimizer_lib import utils as optimizer_utils from init2winit.trainer_lib import base_trainer from init2winit.trainer_lib import trainer_utils import jax -from jax import lax import jax.numpy as jnp import optax @@ -40,16 +40,14 @@ def update( step, lr, rng, - local_device_index, running_train_cost, training_cost, grad_clip, optimizer_update_fn, - metrics_update_fn, - axis_name='batch'): + metrics_update_fn): """Single step of the training loop. - This function will later be pmapped so we keep it outside of the Trainer class + This function will later be jitted so we keep it outside of the Trainer class to avoid the temptation to introduce side-effects. Args: @@ -67,9 +65,6 @@ def update( lr: the floating point learning rate for this step. rng: the RNG used for calling the model. `step` and `local_device_index` will be folded into this to produce a unique per-device, per-step RNG. - local_device_index: an integer that is unique to this device amongst all - devices on this host, usually in the range [0, jax.local_device_count()]. - It is folded in to `rng` to produce a unique per-device, per-step RNG. running_train_cost: the cumulative train cost over some past number of train steps. Reset at evaluation time. training_cost: a function used to calculate the training objective that will @@ -80,7 +75,6 @@ def update( value g / ||g||_2 * grad_clip. If None, then no clipping will be applied. optimizer_update_fn: the optimizer update function. metrics_update_fn: the training metrics update function. - axis_name: axis_name used by pmap. Returns: A tuple of the new optimizer, the new batch stats, the scalar training cost, @@ -89,7 +83,6 @@ def update( # `jax.random.split` is very slow outside the train step, so instead we do a # `jax.random.fold_in` here. rng = jax.random.fold_in(rng, step) - rng = jax.random.fold_in(rng, local_device_index) optimizer_state = trainer_utils.inject_learning_rate(optimizer_state, lr) @@ -101,20 +94,6 @@ def opt_cost(params): (cost_value, new_batch_stats), grad = grad_fn(params) new_batch_stats = new_batch_stats.get('batch_stats', None) - if axis_name is not None: - if optimizer_utils.requires_gradient_aggregation(optimizer_update_fn): - grad = lax.pmean((grad), axis_name=axis_name) - else: - # Skip gradient aggregationas it'll be handled in gradient_accumulator. - if grad_clip: - # Calculating the gradient norm requires cross-device aggregation, - # performed, in this case, inside the optimizer. Calculating it again - # at this point may be inefficient. - raise NotImplementedError( - 'Gradient clipping is not supported when gradient aggregation is' - ' performed internally by the optimizer.' - ) - grad_norm = jnp.sqrt(model_utils.l2_regularization(grad, 0)) # TODO(znado): move to inside optax gradient clipping. if grad_clip: @@ -183,9 +162,6 @@ def train(self): ) - train_iter = trainer_utils.prefetch_input_pipeline( - train_iter, self._hps.num_device_prefetches) - if self._data_selector: train_iter = self._data_selector( train_iter, @@ -207,25 +183,40 @@ def train(self): if self._global_step in self._checkpoint_steps: self._save(self._checkpoint_dir, max_to_keep=None) + make_global_array_fn = functools.partial( + data_utils.make_global_array, mesh=self._mesh + ) + for _ in range(start_step, self._num_train_steps): - with jax.profiler.StepTraceAnnotation('train', - step_num=self._global_step): + with jax.profiler.StepTraceAnnotation( + 'train', step_num=self._global_step + ): # NOTE(dsuo): to properly profile each step, we must include batch # creation in the StepTraceContext (as opposed to putting `train_iter` # directly in the top-level for loop). batch = next(train_iter) + batch = jax.tree_util.tree_map(make_global_array_fn, batch) lr = self._lr_fn(self._global_step) # It looks like we are reusing an rng key, but we aren't. - # TODO(gdahl): Make it more obvious that passing rng is safe. - # TODO(gdahl,gilmer,znado): investigate possibly merging the member - # variable inputs/outputs of this function into a named tuple. - (self._optimizer_state, self._params, self._batch_stats, - self._sum_train_cost, - self._metrics_state, self._grad_norm) = self._update_pmapped( - self._optimizer_state, self._params, self._batch_stats, - self._metrics_state, batch, self._global_step, lr, rng, - self._local_device_indices, self._sum_train_cost) + ( + self._optimizer_state, + self._params, + self._batch_stats, + self._sum_train_cost, + self._metrics_state, + self._grad_norm, + ) = self._update_jitted( + self._optimizer_state, + self._params, + self._batch_stats, + self._metrics_state, + batch, + self._global_step, + lr, + rng, + self._sum_train_cost, + ) self._global_step += 1 if self._global_step in self._checkpoint_steps: self._save(self._checkpoint_dir, max_to_keep=None) diff --git a/init2winit/trainer_lib/trainer_utils.py b/init2winit/trainer_lib/trainer_utils.py index ca3afa09..11c1f830 100644 --- a/init2winit/trainer_lib/trainer_utils.py +++ b/init2winit/trainer_lib/trainer_utils.py @@ -14,14 +14,14 @@ # limitations under the License. """Utility functions related to training.""" + +import functools import time from absl import logging - from flax import jax_utils from init2winit import utils from init2winit.dataset_lib import data_utils -from init2winit.model_lib import model_utils import jax import jax.numpy as jnp import numpy as np @@ -89,20 +89,6 @@ def maybe_log_training_metrics(metrics_state, prefix='metrics_state') -def maybe_sync_batchnorm_stats(batch_stats): - """Sync batch_stats across devices.""" - # We first check that batch_stats is used (pmap will throw an error if - # it's a non batch norm model). If batch norm is not used then - # batch_stats = None. Note that, in the case of using our implementation of - # virtual batch norm, this will also handle synchronizing the multiple moving - # averages on each device before doing a cross-host sync. - if batch_stats: - batch_stats = jax.pmap( - model_utils.sync_batchnorm_stats, axis_name='batch')( - batch_stats) - return batch_stats - - def should_eval(global_step, eval_frequency, eval_steps): on_step = eval_steps and global_step in eval_steps on_freq = (global_step % eval_frequency == 0) @@ -141,30 +127,12 @@ def check_for_early_stopping( ) -def prefetch_input_pipeline(ds, n_prefetch=0, devices=None): - """Modify input pipeline to prefetch from host to device. - - Args: - ds: tf.data pipeline - n_prefetch: number of items to prefetch - devices: devices to prefetch to - - Returns: - prefetching ds - - """ - it = iter(ds) - it = (data_utils.shard(x) for x in it) - if n_prefetch > 0: - it = jax_utils.prefetch_to_device(it, n_prefetch, devices=devices) - return it - - def evaluate( params, batch_stats, batch_iter, - evaluate_batch_pmapped): + evaluate_batch_jitted, + mesh): """Compute aggregated metrics on the given data iterator. WARNING: The caller is responsible for synchronizing the batch norm statistics @@ -184,25 +152,27 @@ def evaluate( {'batch_stats': batch_stats} into flax_module.apply(). batch_iter: Generator which yields batches. Must support the API for b in batch_iter: - evaluate_batch_pmapped: A function with API - evaluate_batch_pmapped(params, batch_stats, batch). Returns a dictionary - mapping keys to the metric values across the sharded batch. + evaluate_batch_jitted: A function with API evaluate_batch_jitted(params, + batch_stats, batch). Returns a dictionary mapping keys to the metric + values across the sharded batch. + mesh: Mesh specification to use for sharding. Returns: A dictionary of aggregated metrics. The keys will match the keys returned by - evaluate_batch_pmapped. + evaluate_batch_jitted. """ metrics = None + make_global_array_fn = functools.partial( + data_utils.make_global_array, mesh=mesh + ) + for batch in batch_iter: - batch = data_utils.shard(batch) + batch = jax.tree_util.tree_map(make_global_array_fn, batch) # Returns a clu.metrics.Collection object. We assume that - # `evaluate_batch_pmpapped` calls CLU's `gather_from_model_outputs`, - # which includes an `all_gather` to replicate the values on all devices. - # We need to `unreplicate` before merging the results across batches to - # accommodate CollectingMetric, which concatenates the values across the - # leading dimension, so we need to remove the leading shard dimension first. - computed_metrics = evaluate_batch_pmapped( - params=params, batch_stats=batch_stats, batch=batch).unreplicate() + # `evaluate_batch_jitted` calls CLU's `single_from_model_outputs`. + computed_metrics = evaluate_batch_jitted( + params=params, batch_stats=batch_stats, batch=batch + ) if metrics is None: metrics = computed_metrics else: @@ -266,7 +236,7 @@ def fetch_learning_rate(optimizer_state): ) if all_equal: lr_array = lrs_with_path[0][1] - return lr_array[0] + return lr_array else: raise ValueError( 'All learning rates in the optimizer state must be the same.' @@ -284,7 +254,7 @@ def _merge_and_apply_prefix(d1, d2, prefix): @utils.timed def eval_metrics(params, batch_stats, dataset, eval_num_batches, test_num_batches, eval_train_num_batches, - evaluate_batch_pmapped): + evaluate_batch_jitted, mesh): """Evaluates the given network on the train, validation, and test sets. WARNING: we assume that `batch_stats` has already been synchronized across @@ -307,7 +277,8 @@ def eval_metrics(params, batch_stats, dataset, eval_num_batches, sets. Set to None to evaluate on the whole test set. eval_train_num_batches: (int) The batch size used for evaluating on train set. Set to None to evaluate on the whole training set. - evaluate_batch_pmapped: Computes the metrics on a sharded batch. + evaluate_batch_jitted: Computes the metrics on a sharded batch. + mesh: Mesh specification to use for sharding. Returns: A dictionary of all computed metrics. @@ -320,7 +291,7 @@ def eval_metrics(params, batch_stats, dataset, eval_num_batches, for split_iter, split_name in zip([train_iter, valid_iter, test_iter], ['train', 'valid', 'test']): split_metrics = evaluate(params, batch_stats, split_iter, - evaluate_batch_pmapped) + evaluate_batch_jitted, mesh) # Metrics are None if the dataset doesn't have that split if split_metrics is not None: metrics = _merge_and_apply_prefix(metrics, split_metrics,