Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrading init2winit from pmap to jit #687

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion init2winit/base_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class BaseCallBack:

def __init__(self, model, params, batch_stats, optimizer_state,
optimizer_update_fn, dataset, hps, callback_config, train_dir,
rng):
rng, mesh):
"""Defines the API for callback construction."""
pass

Expand Down
55 changes: 25 additions & 30 deletions init2winit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
This is useful for training neural networks with stax, where model parameters
are nested numpy arrays.
"""

import os
import sys
from typing import Sequence
Expand All @@ -26,8 +27,10 @@
from absl import logging
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from init2winit import utils
import jax


FLAGS = flags.FLAGS


Expand Down Expand Up @@ -85,6 +88,7 @@ def replicate_and_maybe_restore_checkpoint(
unreplicated_params,
unreplicated_batch_stats,
unreplicated_training_metrics_state,
mesh,
train_dir,
external_checkpoint_path=None,
orbax_checkpointer=None):
Expand All @@ -104,6 +108,7 @@ def replicate_and_maybe_restore_checkpoint(
unreplicated_params: unreplicated params
unreplicated_batch_stats: unreplicated batch stats
unreplicated_training_metrics_state: unreplicated metrics state
mesh: Mesh specification to use for sharding.
train_dir: (str) The training directory where we will look for a checkpoint.
external_checkpoint_path: (str) If this argument is set, then we will load
the external checkpoint stored there.
Expand Down Expand Up @@ -165,43 +170,34 @@ def replicate_and_maybe_restore_checkpoint(
# Handle failure to load from external_checkpoint_path.
if ckpt_to_return['global_step'] == -1:
return (
jax_utils.replicate(unreplicated_optimizer_state),
jax_utils.replicate(unreplicated_params),
jax_utils.replicate(unreplicated_batch_stats),
jax_utils.replicate(unreplicated_training_metrics_state),
utils.shard_pytree(unreplicated_optimizer_state, mesh),
utils.shard_pytree(unreplicated_params, mesh),
utils.shard_pytree(unreplicated_batch_stats, mesh),
utils.shard_pytree(unreplicated_training_metrics_state, mesh),
0, # global_step
jax_utils.replicate(0), # sum_train_cost
0, # preemption_count
False) # is_restored
else: # Else, don't restore from any checkpoint.
return (
jax_utils.replicate(unreplicated_optimizer_state),
jax_utils.replicate(unreplicated_params),
jax_utils.replicate(unreplicated_batch_stats),
jax_utils.replicate(unreplicated_training_metrics_state),
utils.shard_pytree(unreplicated_optimizer_state, mesh),
utils.shard_pytree(unreplicated_params, mesh),
utils.shard_pytree(unreplicated_batch_stats, mesh),
utils.shard_pytree(unreplicated_training_metrics_state, mesh),
0, # global_step
jax_utils.replicate(0), # sum_train_cost
0, # preemption_count
False) # is_restored

pytree_dict, extra_state = replicate_checkpoint(
ckpt_to_return,
pytree_keys=[
'optimizer_state',
'params',
'batch_stats',
'training_metrics_grabber',
'sum_train_cost',
])
return (
pytree_dict['optimizer_state'],
pytree_dict['params'],
pytree_dict['batch_stats'],
pytree_dict['training_metrics_grabber'],
extra_state['global_step'],
pytree_dict['sum_train_cost'],
extra_state['preemption_count'],
is_restored)
utils.shard_pytree(ckpt_to_return['optimizer_state'], mesh),
utils.shard_pytree(ckpt_to_return['params'], mesh),
utils.shard_pytree(ckpt_to_return['batch_stats'], mesh),
utils.shard_pytree(ckpt_to_return['training_metrics_grabber'], mesh),
ckpt_to_return['global_step'], # global_step
jax_utils.replicate(ckpt_to_return['sum_train_cost']),
ckpt_to_return['preemption_count'], # preemption_count
is_restored) # is_restored


def save_unreplicated_checkpoint(
Expand All @@ -217,12 +213,11 @@ def save_unreplicated_checkpoint(
max_to_keep=1):
"""Saves pytree, step, preemption_count, and sum_train_cost to train_dir."""
logging.info('Saving checkpoint to ckpt_%d', global_step)
unreplicated_optimizer_state = jax.device_get(
jax_utils.unreplicate(optimizer_state))
unreplicated_params = jax.device_get(jax_utils.unreplicate(params))
unreplicated_batch_stats = jax.device_get(jax_utils.unreplicate(batch_stats))
unreplicated_optimizer_state = jax.device_get(optimizer_state)
unreplicated_params = jax.device_get(params)
unreplicated_batch_stats = jax.device_get(batch_stats)
unreplicated_training_metrics_state = jax.device_get(
jax_utils.unreplicate(training_metrics_state))
training_metrics_state)
unreplicated_sum_train_cost = jax.device_get(
jax_utils.unreplicate(sum_train_cost))
state = dict(global_step=global_step,
Expand Down
35 changes: 0 additions & 35 deletions init2winit/dataset_lib/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,40 +143,6 @@ def zero_pad(ar, pad_axis):
return padded_batch


def shard(batch, n_devices=None):
"""Prepares the batch for pmap by adding a leading n_devices dimension.

If all the entries are lists, assume they are already divided into n_devices
smaller arrays and stack them for pmapping. If all the entries are arrays,
assume they have leading dimension divisible by n_devices and reshape.

Args:
batch: A dict of arrays or lists of arrays
n_devices: If None, this will be set to jax.local_device_count().

Returns:
Sharded data.
"""
if n_devices is None:
n_devices = jax.local_device_count()

# TODO(mbadura): Specify a sharding function per dataset instead
# If entries in the batch dict are lists, then the data is already divided
# into n_devices chunks, so we need to stack them.
if all((isinstance(v, list) for v in batch.values())):
assert all(len(v) == n_devices for v in batch.values())
# transpose a dict of lists to a list of dicts
shards = [{k: v[i] for (k, v) in batch.items()} for i in range(n_devices)]
return jax.tree.map(lambda *vals: np.stack(vals, axis=0), shards[0],
*shards[1:])

# Otherwise, the entries are arrays, so just reshape them.
def _shard_array(array):
return array.reshape((n_devices, -1) + array.shape[1:])

return jax.tree.map(_shard_array, batch)


def tf_to_numpy(tfds_data):
# Safe because we won't mutate. Avoids an extra copy from tfds.
convert_data = lambda x: x._numpy() # pylint: disable=protected-access
Expand All @@ -187,4 +153,3 @@ def tf_to_numpy(tfds_data):
def convert_jax_to_tf_random_seed(jax_prng_key: jax.random.PRNGKey) -> int:
tf_seed = jax.random.bits(jax_prng_key)
return tf_seed

44 changes: 9 additions & 35 deletions init2winit/dataset_lib/ogbg_molpcba.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ def _get_batch_iterator(dataset_iter,
edges_per_graph,
add_bidirectional_edges,
add_self_loops,
add_virtual_node,
num_shards=None):
add_virtual_node):
"""Turns a TFDS per-example iterator into a batched iterator in the init2winit format.

Constructs the batch from num_shards smaller batches, so that we can easily
Expand All @@ -205,18 +204,11 @@ def _get_batch_iterator(dataset_iter,
receiver.
add_self_loops: If True, add a self-loop for each node.
add_virtual_node: If True, add a new node connected to all nodes.
num_shards: How many devices we should be able to shard the batch into.

Yields:
Batch in the init2winit format. Each field is a list of num_shards separate
smaller batches.
Batch in the init2winit format.

"""
if not num_shards:
num_shards = jax.local_device_count()

# We will construct num_shards smaller batches and then put them together.
batch_size /= num_shards

max_n_nodes = nodes_per_graph * batch_size
max_n_edges = edges_per_graph * batch_size
Expand All @@ -228,39 +220,21 @@ def _get_batch_iterator(dataset_iter,
add_virtual_node=add_virtual_node,
add_self_loops=add_self_loops)
jraph_iter = map(to_jraph_partial, dataset_iter)
batched_iter = jraph.dynamically_batch(jraph_iter, max_n_nodes + 1,
max_n_edges, max_n_graphs + 1)

count = 0
graphs_shards = []
labels_shards = []
weights_shards = []
batched_iter = jraph.dynamically_batch(jraph_iter, max_n_nodes,
max_n_edges, max_n_graphs)

for batched_graph in batched_iter:
count += 1

# Separate the labels from the graph
labels = batched_graph.globals
graph = batched_graph._replace(globals={})

replaced_labels, weights = _get_weights_by_nan_and_padding(
labels, jraph.get_graph_padding_mask(graph))

graphs_shards.append(graph)
labels_shards.append(replaced_labels)
weights_shards.append(weights)

if count == num_shards:
yield {
'inputs': graphs_shards,
'targets': labels_shards,
'weights': weights_shards
}

count = 0
graphs_shards = []
labels_shards = []
weights_shards = []
yield {
'inputs': graph,
'targets': replaced_labels,
'weights': weights
}


def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None):
Expand Down
10 changes: 5 additions & 5 deletions init2winit/model_lib/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _evaluate_batch(flax_module, params, batch_stats, batch, metrics_bundle,

# We don't use CLU's `mask` argument here, we handle it ourselves through
# `weights`.
return metrics_bundle.gather_from_model_output(
logits=logits, targets=targets, weights=weights, axis_name='batch')
return metrics_bundle.single_from_model_output(
logits=logits, targets=targets, weights=weights)


class BaseModel(object):
Expand Down Expand Up @@ -300,9 +300,9 @@ def training_objective_fn(self, params, logits, targets, weights):
logits, targets, weights
)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch'
)
# (objective_numerator, objective_denominator) = jax.lax.psum(
# (objective_numerator, objective_denominator), axis_name='batch'
# )

# epsilon added to handle empty batch case if we encounter one.
objective_value = objective_numerator / (objective_denominator + 1e-9)
Expand Down
12 changes: 4 additions & 8 deletions init2winit/model_lib/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,19 +855,15 @@ def evaluate_batch(self, params, batch_stats, batch):
(objective_numerator, objective_denominator) = self.loss_fn(
logits, logit_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

normalized_loss = (objective_numerator / (objective_denominator))
hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings)

return self.metrics_bundle.gather_from_model_output(
return self.metrics_bundle.single_from_model_output(
normalized_loss=normalized_loss,
hyps=hyps,
hyp_paddings=hyp_paddings,
targets=labels,
target_paddings=label_paddings,
axis_name='batch')
target_paddings=label_paddings)

def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
"""Return CTC loss."""
Expand All @@ -891,8 +887,8 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
(objective_numerator, objective_denominator) = self.loss_fn(
outputs, output_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')
# (objective_numerator, objective_denominator) = jax.lax.psum(
# (objective_numerator, objective_denominator), axis_name='batch')

# epsilon added to handle empty batch case if we encounter one.
objective_value = (objective_numerator / (objective_denominator + 1e-9))
Expand Down
18 changes: 2 additions & 16 deletions init2winit/model_lib/deepspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,6 @@ def __call__(self, inputs, input_paddings=None, train=False):
count_v = jnp.sum(
jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)

if self.enable_synced_batchnorm:
sum_v = jax.lax.psum(sum_v, axis_name='batch')
count_v = jax.lax.psum(count_v, axis_name='batch')

count_v = jnp.maximum(count_v, 1.0)
mean = sum_v / count_v

Expand All @@ -417,9 +413,6 @@ def __call__(self, inputs, input_paddings=None, train=False):
axis=reduce_over_dims,
keepdims=True)

if self.enable_synced_batchnorm:
sum_vv = jax.lax.psum(sum_vv, axis_name='batch')

var = sum_vv / count_v

self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean
Expand Down Expand Up @@ -959,20 +952,16 @@ def evaluate_batch(self, params, batch_stats, batch):
(objective_numerator, objective_denominator) = self.loss_fn(
logits, logit_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

# epsilon added to handle empty batch case if we encounter one.
normalized_loss = (objective_numerator / (objective_denominator + 1e-9))
hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings)

return self.metrics_bundle.gather_from_model_output(
return self.metrics_bundle.single_from_model_output(
normalized_loss=normalized_loss,
hyps=hyps,
hyp_paddings=hyp_paddings,
targets=labels,
target_paddings=label_paddings,
axis_name='batch')
target_paddings=label_paddings)

def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
"""Return CTC loss."""
Expand All @@ -996,9 +985,6 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
(objective_numerator, objective_denominator) = self.loss_fn(
outputs, output_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

objective_value = (objective_numerator / (objective_denominator))
return objective_value, new_batch_stats

Expand Down
2 changes: 1 addition & 1 deletion init2winit/model_lib/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def evaluate_batch(self, params, batch_stats, batch):

# We don't use CLU's `mask` argument here, we handle it ourselves through
# `weights`.
return self.metrics_bundle.gather_from_model_output(
return self.metrics_bundle.single_from_model_output(
logits=logits,
targets=targets,
weights=weights,
Expand Down
6 changes: 3 additions & 3 deletions init2winit/model_lib/xformer_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def evaluate_batch(self, params, batch_stats, batch):
targets = one_hot(batch['targets'], logits.shape[-1])

# Add log-perplexity metric.
return self.metrics_bundle.gather_from_model_output(
return self.metrics_bundle.single_from_model_output(
logits=logits, targets=targets, weights=weights, axis_name='batch')

def apply_on_batch(self,
Expand Down Expand Up @@ -1101,8 +1101,8 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
(total_loss, total_weight) = self.loss_fn(
logits, targets, weights)

(total_loss, total_weight) = lax.psum(
(total_loss, total_weight), axis_name='batch')
# (total_loss, total_weight) = lax.psum(
# (total_loss, total_weight), axis_name='batch')

total_loss = (total_loss / total_weight)

Expand Down
Loading