diff --git a/trax/models/README.md b/trax/models/README.md new file mode 100644 index 000000000..3da3e0f82 --- /dev/null +++ b/trax/models/README.md @@ -0,0 +1,16 @@ +# Constructing T2T Models. + +This directory contains T2T models, their hyperparameters, and a number +of common layers and hyperparameter settings to help construct new models. +Common building blocks are in `common_layers.py` and `common_attention.py`. +Common hyperparameters are in `common_hparams.py`. Models are imported in +`__init__.py`. + +## Adding a new model. + +To add a model to the built-in set, create a new file (see, e.g., +`neural_gpu.py`) and write your model class inheriting from `T2TModel` there and +decorate it with `registry.register_model`. Import it in `__init__.py`. + +It is now available to use with the trainer binary (`t2t-trainer`) using the +`--model=model_name` flag. diff --git a/trax/models/__init__.py b/trax/models/__init__.py index f827f1d94..62d059134 100644 --- a/trax/models/__init__.py +++ b/trax/models/__init__.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The Trax Authors. +# Copyright 2023 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,80 +13,87 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Models defined in trax.""" -import gin +"""Models defined in T2T. Imports here force registration.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -from trax.models import atari_cnn -from trax.models import mlp -from trax.models import neural_gpu -from trax.models import resnet -from trax.models import rl -from trax.models import rnn -from trax.models import transformer -from trax.models.reformer import reformer -from trax.models.research import bert -from trax.models.research import configurable_transformer -from trax.models.research import hourglass -from trax.models.research import layerdrop_transformer -from trax.models.research import rezero -from trax.models.research import rse -from trax.models.research import terraformer -from trax.models.research import transformer2 +import six +# pylint: disable=unused-import -# Ginify -def model_configure(*args, **kwargs): - kwargs['module'] = 'trax.models' - return gin.external_configurable(*args, **kwargs) +from tensor2tensor.layers import modalities # pylint: disable=g-import-not-at-top +from tensor2tensor.models import basic +from tensor2tensor.models import bytenet +from tensor2tensor.models import distillation +from tensor2tensor.models import evolved_transformer +from tensor2tensor.models import image_transformer +from tensor2tensor.models import image_transformer_2d +from tensor2tensor.models import lstm +from tensor2tensor.models import neural_assistant +from tensor2tensor.models import neural_gpu +from tensor2tensor.models import resnet +from tensor2tensor.models import revnet +from tensor2tensor.models import shake_shake +from tensor2tensor.models import slicenet +from tensor2tensor.models import text_cnn +from tensor2tensor.models import transformer +from tensor2tensor.models import vanilla_gan +from tensor2tensor.models import xception +from tensor2tensor.models.neural_architecture_search import nas_model +from tensor2tensor.models.research import adafactor_experiments +from tensor2tensor.models.research import aligned +from tensor2tensor.models.research import autoencoders +from tensor2tensor.models.research import cycle_gan +from tensor2tensor.models.research import gene_expression +from tensor2tensor.models.research import neural_stack +from tensor2tensor.models.research import residual_shuffle_exchange +from tensor2tensor.models.research import rl +from tensor2tensor.models.research import shuffle_network +from tensor2tensor.models.research import similarity_transformer +from tensor2tensor.models.research import super_lm +from tensor2tensor.models.research import transformer_moe +from tensor2tensor.models.research import transformer_nat +from tensor2tensor.models.research import transformer_parallel +from tensor2tensor.models.research import transformer_revnet +from tensor2tensor.models.research import transformer_seq2edits +from tensor2tensor.models.research import transformer_sketch +from tensor2tensor.models.research import transformer_symshard +from tensor2tensor.models.research import transformer_vae +from tensor2tensor.models.research import universal_transformer +from tensor2tensor.models.video import basic_deterministic +from tensor2tensor.models.video import basic_recurrent +from tensor2tensor.models.video import basic_stochastic +from tensor2tensor.models.video import emily +from tensor2tensor.models.video import savp +from tensor2tensor.models.video import sv2p +from tensor2tensor.utils import contrib +from tensor2tensor.utils import registry +# The following models can't be imported under TF2 +if not contrib.is_tf2: + # pylint: disable=g-import-not-at-top + from tensor2tensor.models.research import attention_lm + from tensor2tensor.models.research import attention_lm_moe + from tensor2tensor.models.research import glow + from tensor2tensor.models.research import lm_experiments + from tensor2tensor.models.research import moe_experiments + from tensor2tensor.models.research import multiquery_paper + from tensor2tensor.models import mtf_image_transformer + from tensor2tensor.models import mtf_resnet + from tensor2tensor.models import mtf_transformer + from tensor2tensor.models import mtf_transformer2 + from tensor2tensor.models.research import vqa_attention + from tensor2tensor.models.research import vqa_recurrent_self_attention + from tensor2tensor.models.research import vqa_self_attention + from tensor2tensor.models.video import epva + from tensor2tensor.models.video import next_frame_glow + # pylint: enable=g-import-not-at-top -# pylint: disable=invalid-name -AtariCnn = model_configure(atari_cnn.AtariCnn) -AtariCnnBody = model_configure(atari_cnn.AtariCnnBody) -FrameStackMLP = model_configure(atari_cnn.FrameStackMLP) -BERT = model_configure(bert.BERT) -BERTClassifierHead = model_configure(bert.BERTClassifierHead) -BERTRegressionHead = model_configure(bert.BERTRegressionHead) -ConfigurableTerraformer = model_configure(terraformer.ConfigurableTerraformer) -ConfigurableTransformer = model_configure( - configurable_transformer.ConfigurableTransformer) -ConfigurableTransformerEncoder = model_configure( - configurable_transformer.ConfigurableTransformerEncoder) -ConfigurableTransformerLM = model_configure( - configurable_transformer.ConfigurableTransformerLM) -MLP = model_configure(mlp.MLP) -NeuralGPU = model_configure(neural_gpu.NeuralGPU) -Reformer = model_configure(reformer.Reformer) -ReformerLM = model_configure(reformer.ReformerLM) -ReformerShortenLM = model_configure(reformer.ReformerShortenLM) -Resnet50 = model_configure(resnet.Resnet50) -ReZeroTransformer = model_configure( - rezero.ReZeroTransformer) -ReZeroTransformerDecoder = model_configure( - rezero.ReZeroTransformerDecoder) -ReZeroTransformerEncoder = model_configure( - rezero.ReZeroTransformerEncoder) -ReZeroTransformerLM = model_configure( - rezero.ReZeroTransformerLM) -SkippingTransformerLM = model_configure( - layerdrop_transformer.SkippingTransformerLM) -LayerDropTransformerLM = model_configure( - layerdrop_transformer.LayerDropTransformerLM) -EveryOtherLayerDropTransformerLM = model_configure( - layerdrop_transformer.EveryOtherLayerDropTransformerLM) -Transformer = model_configure(transformer.Transformer) -TransformerDecoder = model_configure(transformer.TransformerDecoder) -TransformerEncoder = model_configure(transformer.TransformerEncoder) -TransformerLM = model_configure(transformer.TransformerLM) -Transformer2 = model_configure( - transformer2.Transformer2) -WideResnet = model_configure(resnet.WideResnet) -Policy = model_configure(rl.Policy) -PolicyAndValue = model_configure(rl.PolicyAndValue) -Value = model_configure(rl.Value) -Quality = model_configure(rl.Quality) -RNNLM = model_configure(rnn.RNNLM) -GRULM = model_configure(rnn.GRULM) -LSTMSeq2SeqAttn = model_configure(rnn.LSTMSeq2SeqAttn) -ResidualShuffleExchange = model_configure(rse.ResidualShuffleExchange) -HourglassLM = model_configure(hourglass.HourglassLM) +# pylint: disable=unused-import + +# pylint: enable=unused-import + + +def model(name): + return registry.model(name) diff --git a/trax/models/basic.py b/trax/models/basic.py new file mode 100644 index 000000000..4a3209022 --- /dev/null +++ b/trax/models/basic.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Basic models for testing simple tasks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf + + +@registry.register_model +class BasicFcRelu(t2t_model.T2TModel): + """Basic fully-connected + ReLU model.""" + + def body(self, features): + hparams = self.hparams + x = features["inputs"] + shape = common_layers.shape_list(x) + x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]]) + for i in range(hparams.num_hidden_layers): + x = tf.layers.dense(x, hparams.hidden_size, name="layer_%d" % i) + x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout) + x = tf.nn.relu(x) + return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T. + + +@registry.register_hparams +def basic_fc_small(): + """Small fully connected model.""" + hparams = common_hparams.basic_params1() + hparams.learning_rate = 0.1 + hparams.batch_size = 128 + hparams.hidden_size = 256 + hparams.num_hidden_layers = 2 + hparams.initializer = "uniform_unit_scaling" + hparams.initializer_gain = 1.0 + hparams.weight_decay = 0.0 + hparams.dropout = 0.0 + return hparams diff --git a/trax/models/basic_test.py b/trax/models/basic_test.py new file mode 100644 index 000000000..3f6b4affd --- /dev/null +++ b/trax/models/basic_test.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Basic nets tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +from tensor2tensor.data_generators import mnist # pylint: disable=unused-import +from tensor2tensor.models import basic +from tensor2tensor.utils import trainer_lib + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class BasicTest(tf.test.TestCase): + + def testBasicFcRelu(self): + x = np.random.randint(256, size=(1, 28, 28, 1)) + y = np.random.randint(10, size=(1, 1)) + hparams = trainer_lib.create_hparams( + "basic_fc_small", problem_name="image_mnist", data_dir=".") + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + } + model = basic.BasicFcRelu(hparams, tf_estimator.ModeKeys.TRAIN) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (1, 1, 1, 1, 10)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/bytenet.py b/trax/models/bytenet.py new file mode 100644 index 000000000..84594f36a --- /dev/null +++ b/trax/models/bytenet.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ByteNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from six.moves import range # pylint: disable=redefined-builtin + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf + + +def residual_dilated_conv(x, repeat, padding, name, hparams): + """A stack of convolution blocks with residual connections.""" + with tf.variable_scope(name): + k = (hparams.kernel_height, hparams.kernel_width) + dilations_and_kernels = [((2**i, 1), k) + for i in range(hparams.num_hidden_layers)] + for i in range(repeat): + with tf.variable_scope("repeat_%d" % i): + y = common_layers.conv_block( + common_layers.layer_norm(x, hparams.hidden_size, name="lnorm"), + hparams.hidden_size, + dilations_and_kernels, + padding=padding, + name="residual_conv") + y = tf.nn.dropout(y, 1.0 - hparams.dropout) + x += y + return x + + +def bytenet_internal(inputs, targets, hparams): + """ByteNet, main step used for training.""" + with tf.variable_scope("bytenet"): + # Flatten inputs and extend length by 50%. + inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) + extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1])) + inputs_shape = inputs.shape.as_list() + inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]]) + inputs_shape[1] = None + inputs.set_shape(inputs_shape) # Don't lose the other shapes when padding. + # Pad inputs and targets to be the same length, divisible by 50. + inputs, targets = common_layers.pad_to_same_length( + inputs, targets, final_length_divisible_by=50) + final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, + "SAME", "encoder", hparams) + + shifted_targets = common_layers.shift_right(targets) + kernel = (hparams.kernel_height, hparams.kernel_width) + decoder_start = common_layers.conv_block( + tf.concat([final_encoder, shifted_targets], axis=3), + hparams.hidden_size, [((1, 1), kernel)], + padding="LEFT") + + return residual_dilated_conv(decoder_start, hparams.num_block_repeat, + "LEFT", "decoder", hparams) + + +@registry.register_model +class ByteNet(t2t_model.T2TModel): + + def body(self, features): + return bytenet_internal(features["inputs"], features["targets"], + self._hparams) + + +@registry.register_hparams +def bytenet_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.batch_size = 2048 + hparams.hidden_size = 768 + hparams.dropout = 0.2 + hparams.symbol_dropout = 0.2 + hparams.label_smoothing = 0.1 + hparams.clip_grad_norm = 2.0 + hparams.num_hidden_layers = 4 + hparams.kernel_height = 3 + hparams.kernel_width = 1 + hparams.learning_rate_decay_scheme = "exp" + hparams.learning_rate = 0.05 + hparams.learning_rate_warmup_steps = 3000 + hparams.initializer_gain = 1.0 + hparams.weight_decay = 3.0 + hparams.num_sampled_classes = 0 + hparams.sampling_method = "argmax" + hparams.optimizer_adam_epsilon = 1e-6 + hparams.optimizer_adam_beta1 = 0.85 + hparams.optimizer_adam_beta2 = 0.997 + hparams.add_hparam("num_block_repeat", 4) + return hparams diff --git a/trax/models/bytenet_test.py b/trax/models/bytenet_test.py new file mode 100644 index 000000000..204d54bc1 --- /dev/null +++ b/trax/models/bytenet_test.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ByteNet tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models import bytenet + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class ByteNetTest(tf.test.TestCase): + + def testByteNet(self): + vocab_size = 9 + x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1)) + y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1)) + hparams = bytenet.bytenet_base() + p_hparams = problem_hparams.test_problem_hparams(vocab_size, + vocab_size, + hparams) + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + } + model = bytenet.ByteNet( + hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/distillation.py b/trax/models/distillation.py new file mode 100644 index 000000000..9d8ccb849 --- /dev/null +++ b/trax/models/distillation.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Traditional Student-Teacher Distillation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensor2tensor.layers import common_hparams +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +@registry.register_model +class Distillation(t2t_model.T2TModel): + """Distillation from a teacher to student network. + + First, a teacher is trained on a task; Second, a student is trained to perform + the task while matching the teacher's softened outputs. For more details, see + the paper below. + + In the hparams passed to this model include the desired + {teacher/student}_model and {teacher/student}_hparams to be used. Also, + specify the distillation temperature and task-distillation balance. + + Distilling the Knowledge in a Neural Network + Hinton, Vinyals and Dean + https://arxiv.org/abs/1503.02531 + """ + + def __init__(self, + hparams, + mode=tf_estimator.ModeKeys.TRAIN, + problem_hparams=None, + data_parallelism=None, + decode_hparams=None, + **kwargs): + assert hparams.distill_phase in ["train", "distill"] + + if hparams.distill_phase == "train" and hparams.teacher_learning_rate: + hparams.learning_rate = hparams.teacher_learning_rate + elif hparams.distill_phase == "distill" and hparams.student_learning_rate: + hparams.learning_rate = hparams.student_learning_rate + + self.teacher_hparams = registry.hparams(hparams.teacher_hparams) + self.teacher_model = registry.model( + hparams.teacher_model)(self.teacher_hparams, mode, problem_hparams, + data_parallelism, decode_hparams) + self.student_hparams = registry.hparams(hparams.student_hparams) + self.student_model = registry.model( + hparams.student_model)(self.student_hparams, mode, problem_hparams, + data_parallelism, decode_hparams) + super(Distillation, + self).__init__(hparams, mode, problem_hparams, data_parallelism, + decode_hparams, **kwargs) + + def body(self, features): + hp = self.hparams + is_distill = hp.distill_phase == "distill" + + targets = features["targets_raw"] + targets = tf.squeeze(targets, [1, 2, 3]) + one_hot_targets = tf.one_hot(targets, hp.num_classes, dtype=tf.float32) + + # Teacher Network + with tf.variable_scope("teacher"): + teacher_outputs = self.teacher_model.body(features) + tf.logging.info("teacher output shape: %s" % teacher_outputs.get_shape()) + teacher_outputs = tf.reduce_mean(teacher_outputs, axis=[1, 2]) + teacher_logits = tf.layers.dense(teacher_outputs, hp.num_classes) + + teacher_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2( + labels=one_hot_targets, logits=teacher_logits) + outputs = teacher_logits + + if is_distill: + # Load teacher weights + tf.train.init_from_checkpoint(hp.teacher_dir, {"teacher/": "teacher/"}) + # Do not train the teacher + trainable_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) + del trainable_vars[:] + + # Student Network + if is_distill: + with tf.variable_scope("student"): + student_outputs = self.student_model.body(features) + tf.logging.info( + "student output shape: %s" % student_outputs.get_shape()) + student_outputs = tf.reduce_mean(student_outputs, axis=[1, 2]) + student_logits = tf.layers.dense(student_outputs, hp.num_classes) + + student_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2( + labels=one_hot_targets, logits=student_logits) + teacher_targets = tf.nn.softmax(teacher_logits / hp.distill_temperature) + student_distill_xent = tf.nn.softmax_cross_entropy_with_logits_v2( + labels=tf.stop_gradient(teacher_targets), + logits=student_logits / hp.distill_temperature) + # scale soft target obj. to match hard target obj. scale + student_distill_xent *= hp.distill_temperature**2 + + outputs = student_logits + + # Summaries + tf.summary.scalar("distill_xent", student_distill_xent) + + if not is_distill: + phase_loss = teacher_task_xent + else: + phase_loss = hp.task_balance * student_task_xent + phase_loss += (1 - hp.task_balance) * student_distill_xent + + losses = {"training": phase_loss} + outputs = tf.reshape(outputs, [-1, 1, 1, 1, outputs.shape[1]]) + + return outputs, losses + + def top(self, body_output, features): + return body_output + + +def distill_base(): + """Set of hyperparameters.""" + # Base + hparams = common_hparams.basic_params1() + + # teacher/student parameters + hparams.add_hparam("teacher_model", "") + hparams.add_hparam("teacher_hparams", "") + hparams.add_hparam("student_model", "") + hparams.add_hparam("student_hparams", "") + + # Distillation parameters + # WARNING: distill_phase hparam will be overwritten in /bin/t2t_distill.py + hparams.add_hparam("distill_phase", None) + hparams.add_hparam("task_balance", 1.0) + hparams.add_hparam("distill_temperature", 1.0) + hparams.add_hparam("num_classes", 10) + + # Optional Phase-specific hyperparameters + hparams.add_hparam("teacher_learning_rate", None) + hparams.add_hparam("student_learning_rate", None) + + # Training parameters (stolen from ResNet) + hparams.batch_size = 128 + hparams.optimizer = "Momentum" + hparams.optimizer_momentum_momentum = 0.9 + hparams.optimizer_momentum_nesterov = True + hparams.weight_decay = 1e-4 + hparams.clip_grad_norm = 0.0 + # (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.) + hparams.learning_rate = 0.4 + hparams.learning_rate_decay_scheme = "cosine" + # For image_imagenet224, 120k training steps, which effectively makes this a + # cosine decay (i.e. no cycles). + hparams.learning_rate_cosine_cycle_steps = 120000 + hparams.initializer = "normal_unit_scaling" + hparams.initializer_gain = 2. + + return hparams + + +@registry.register_hparams +def distill_resnet_32_to_15_cifar20x5(): + """Set of hyperparameters.""" + hparams = distill_base() + hparams.teacher_model = "resnet" + hparams.teacher_hparams = "resnet_cifar_32" + hparams.student_model = "resnet" + hparams.student_hparams = "resnet_cifar_15" + + hparams.optimizer_momentum_nesterov = True + # (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.) + hparams.teacher_learning_rate = 0.25 * 128. * 8. / 256. + hparams.student_learning_rate = 0.2 * 128. * 8. / 256. + hparams.learning_rate_decay_scheme = "piecewise" + hparams.add_hparam("learning_rate_boundaries", [40000, 60000, 80000]) + hparams.add_hparam("learning_rate_multiples", [0.1, 0.01, 0.001]) + + hparams.task_balance = 0.28 + hparams.distill_temperature = 2.0 + + hparams.num_classes = 20 + + return hparams diff --git a/trax/models/evolved_transformer.py b/trax/models/evolved_transformer.py new file mode 100644 index 000000000..bac01a3cf --- /dev/null +++ b/trax/models/evolved_transformer.py @@ -0,0 +1,833 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evolved Transformer model. + +This implements the model described in arxiv.org/abs/1901.11117 . +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_layers +from tensor2tensor.models import transformer +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops import inplace_ops +# pylint: enable=g-direct-tensorflow-import + +_CONV_BRANCHES_NAME = "conv_branches" +_CONV_BRANCHES_FIRST_LAYER_NAME = _CONV_BRANCHES_NAME + "_first" +_CONV_BRANCHES_SECOND_LAYER_NAME = _CONV_BRANCHES_NAME + "_second" +_FIRST_ATTEND_TO_ENCODER_NAME = "first_attend_to_encoder" +_SECOND_ATTEND_TO_ENCODER_NAME = "second_attend_to_encoder" +_SIXTEEN_HEAD_ATTENTION_NAME = "16_head_self_attention" +_VANILLA_ATTENTION_NAME = "self_attention" + +_DECODER_LEFT_CONV_PADDING = 10 +_DECODER_RIGHT_CONV_PADDING = 6 +_DECODER_FINAL_CONV_PADDING = 6 + + +def _capped_double_heads(num_heads, cap=16): + """Calculate the number of heads for the attention layers with more heads. + + The number of heads will be twice the normal amount (num_heads), until it + reaches |cap| heads. + + Args: + num_heads: the num_heads hparam for the model. + cap: the maximum number of heads |num_heads| will be doubled to. + + Returns: + The number of heads for the attention layers that have more heads. + """ + return max(min(num_heads * 2, cap), num_heads) + + +@registry.register_model +class EvolvedTransformer(transformer.Transformer): + """The Evolved Transformer from arxiv.org/abs/1901.11117 .""" + + def __init__(self, *args, **kwargs): + super(EvolvedTransformer, self).__init__(*args, **kwargs) + self._encoder_function = evolved_transformer_encoder + self._decoder_function = evolved_transformer_decoder + self._init_cache_fn = init_evolved_transformer_cache + + # -1 means train all weights. + if self.hparams.get("num_trainable_top_decoder_layers", -1) < 0: + t2t_model.log_info( + "num_trainable_top_decoder_layers is negative so training all weights." + ) + elif self.hparams.shared_embedding_and_softmax_weights: + t2t_model.log_info( + "Setting hparams.shared_embedding_and_softmax_weights to False, " + "because hparam.num_trainable_top_decoder_layers is being used.") + + # When hparam.num_trainable_top_decoder_layers is set to N >= 0 we will + # freeze (not train) every variable except the N top decoder layers and + # the (pre-)softmax matrix. For any N >= 0 we will freeze the encoder and + # input/target embeddings. This also means we will not share the + # (pre-)softmax matrix with input/target embeddings otherwise they will be + # trained as well. + self.hparams.shared_embedding_and_softmax_weights = False + + # If hparams.shared_embedding_and_softmax_weights was previously True, + # then input and target embeddings were being shared. + # To make sure it they embeddings continue to be shared, we need to set + # hparams.shared_embedding to True. + self.hparams.shared_embedding = True + self._init_cache_fn = init_evolved_transformer_cache + + +def evolved_transformer_encoder(encoder_input, + encoder_self_attention_bias, + hparams, + name="encoder", + nonpadding=None, + save_weights_to=None, + make_image_summary=True, + losses=None, + attn_bias_for_padding=None): + """Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details. + + Note: Pad remover is not supported. + + Args: + encoder_input: a Tensor. + encoder_self_attention_bias: bias Tensor for self-attention (see + common_attention.attention_bias()). + hparams: hyperparameters for model. + name: a string. + nonpadding: optional Tensor with shape [batch_size, encoder_length] + indicating what positions are not padding. This must either be passed in, + which we do for "packed" datasets, or inferred from + encoder_self_attention_bias. The knowledge about padding is used for + pad_remover(efficiency) and to mask out padding in convolutional layers. + save_weights_to: an optional dictionary to capture attention weights for + visualization; the weights tensor will be appended there under a string + key created from the variable scope (including name). + make_image_summary: Whether to make an attention image summary. + losses: Not used. + attn_bias_for_padding: Padded attention bias in case a unidirectional + encoder is being used where future attention is masked. + + Returns: + Tensor encoder output. + """ + del losses + + hidden_state = encoder_input + attention_dropout_broadcast_dims = ( + common_layers.comma_separated_string_to_integer_list( + getattr(hparams, "attention_dropout_broadcast_dims", ""))) + + with tf.variable_scope(name): + if nonpadding is not None: + padding = 1.0 - nonpadding + else: + attention_bias = encoder_self_attention_bias + if attn_bias_for_padding is not None: + attention_bias = attn_bias_for_padding + # Only bfloat16 and float32 supported. + float_type = hparams.get("activation_dtype", "float32") + if float_type == "bfloat16": + cast_fn = tf.to_bfloat16 + else: + assert float_type == "float32" + cast_fn = tf.to_float + padding = common_attention.attention_bias_to_padding( + attention_bias, cast_fn) + nonpadding = 1.0 - padding + + for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): + with tf.variable_scope("layer_%d" % layer): + + with tf.variable_scope("gated_linear_unit"): + + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + values = common_layers.layers().Dense( + hparams.hidden_size)(hidden_state) + gates = common_layers.layers().Dense( + hparams.hidden_size, activation=tf.nn.sigmoid)(hidden_state) + hidden_state = values * gates + + hidden_state = common_layers.layer_postprocess( + residual_state, hidden_state, hparams) + + with tf.variable_scope("conv_branches"): + + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + # Mask padding from conv layers. + mask = tf.tile( + tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size]) + hidden_state *= mask + + left_output_dim = int(hparams.hidden_size * 4) + left_state = common_layers.layers().Dense( + left_output_dim, activation=tf.nn.relu)(hidden_state) + left_state = tf.nn.dropout(left_state, + 1 - hparams.layer_prepostprocess_dropout) + + right_output_dim = int(hparams.hidden_size / 2) + right_state = common_layers.layers().Conv1D( + right_output_dim, + 3, + padding="SAME", + name="standard_conv_3x1", + activation=tf.nn.relu)(hidden_state) + right_state = tf.nn.dropout(right_state, + 1 - hparams.layer_prepostprocess_dropout) + + right_state = tf.pad( + right_state, + [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]], + constant_values=0) + hidden_state = left_state + right_state + + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + # Mask padding from conv layer. + mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, left_output_dim]) + hidden_state *= mask + + separable_conv_9x1 = common_layers.layers().SeparableConv1D( + right_output_dim, 9, padding="SAME", name="separable_conv_9x1") + hidden_state = separable_conv_9x1(hidden_state) + hidden_state = tf.pad( + hidden_state, + [[0, 0], [0, 0], [0, hparams.hidden_size - right_output_dim]], + constant_values=0) + + hidden_state = common_layers.layer_postprocess( + residual_state, hidden_state, hparams) + + if hparams.get("et_encoder_self_attention", True): + with tf.variable_scope("self_attention"): + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + hidden_state = common_attention.multihead_attention( + hidden_state, + None, + encoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type=hparams.self_attention_type, + max_relative_position=hparams.max_relative_position, + heads_share_relative_embedding=( + hparams.heads_share_relative_embedding), + add_relative_to_values=hparams.add_relative_to_values, + save_weights_to=save_weights_to, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length"), + vars_3d=hparams.get("attention_variables_3d"), + activation_dtype=hparams.get("activation_dtype", "float32"), + weight_dtype=hparams.get("weight_dtype", "float32")) + + hidden_state = common_layers.layer_postprocess( + residual_state, hidden_state, hparams) + + with tf.variable_scope("dense_layers"): + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + hidden_state = common_layers.layers().Dense( + int(hparams.hidden_size * 4), activation=tf.nn.relu)(hidden_state) + hidden_state = tf.nn.dropout(hidden_state, + 1 - hparams.layer_prepostprocess_dropout) + + hidden_state = common_layers.layers().Dense( + hparams.hidden_size)(hidden_state) + hidden_state = common_layers.layer_postprocess( + residual_state, hidden_state, hparams) + + # If normalization is done in layer_preprocess, then it should also be done + # on the output, since the output can grow very large, being the sum of + # a whole stack of unnormalized layer outputs. + return common_layers.layer_preprocess(hidden_state, hparams) + + +def evolved_transformer_decoder(decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams, + cache=None, + decode_loop_step=None, + name="decoder", + nonpadding=None, + save_weights_to=None, + make_image_summary=True, + losses=None): + """Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details. + + Args: + decoder_input: a Tensor. + encoder_output: a Tensor. + decoder_self_attention_bias: bias Tensor for self-attention (see + common_attention.attention_bias()). + encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention + (see common_attention.attention_bias()). + hparams: hyperparameters for model. + cache: dict, containing tensors which are the results of previous + layers, used for fast decoding. + decode_loop_step: An integer, step number of the decoding loop. Only used + for inference on TPU. + name: a string. + nonpadding: optional Tensor with shape [batch_size, encoder_length] + indicating what positions are not padding. This is used to mask out + padding in convolutional layers. We generally only need this mask for + "packed" datasets, because for ordinary datasets, no padding is ever + followed by nonpadding. + save_weights_to: an optional dictionary to capture attention weights for + visualization; the weights tensor will be appended there under a string + key created from the variable scope (including name). + make_image_summary: Whether to make an attention image summary. + losses: Not supported. + + Returns: + Decoder output tensor. + """ + del losses + + num_trainable_top_decoder_layers = hparams.get( + "num_trainable_top_decoder_layers", -1) # -1 means train all weights. + + if num_trainable_top_decoder_layers >= 0: + encoder_output = tf.stop_gradient(encoder_output) + + attention_dropout_broadcast_dims = ( + common_layers.comma_separated_string_to_integer_list( + getattr(hparams, "attention_dropout_broadcast_dims", ""))) + + with tf.variable_scope(name): + hidden_state = decoder_input + + num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers + for layer in range(num_layers): + if num_trainable_top_decoder_layers == num_layers - layer: + hidden_state = tf.stop_gradient(hidden_state) + layer_name = "layer_%d" % layer + layer_cache = cache[layer_name] if cache is not None else None + with tf.variable_scope(layer_name): + + with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME): + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + attention_cache = layer_cache[ + _SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None + left_state = common_attention.multihead_attention( + hidden_state, + None, + decoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + _capped_double_heads(hparams.num_heads), + hparams.attention_dropout, + attention_type=hparams.self_attention_type, + max_relative_position=hparams.max_relative_position, + heads_share_relative_embedding=( + hparams.heads_share_relative_embedding), + add_relative_to_values=hparams.add_relative_to_values, + save_weights_to=save_weights_to, + cache=attention_cache, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length"), + decode_loop_step=decode_loop_step, + vars_3d=hparams.get("attention_variables_3d"), + activation_dtype=hparams.get("activation_dtype", "float32"), + weight_dtype=hparams.get("weight_dtype", "float32")) + + if encoder_output is not None: + with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME): + attention_cache = ( + layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME] + if layer_cache is not None else None) + right_state = common_attention.multihead_attention( + hidden_state, + encoder_output, + encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + max_relative_position=hparams.max_relative_position, + heads_share_relative_embedding=( + hparams.heads_share_relative_embedding), + add_relative_to_values=hparams.add_relative_to_values, + save_weights_to=save_weights_to, + cache=attention_cache, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length"), + vars_3d=hparams.get("attention_variables_3d"), + activation_dtype=hparams.get("activation_dtype", "float32"), + weight_dtype=hparams.get("weight_dtype", "float32")) + + left_state = tf.nn.dropout(left_state, + 1 - hparams.layer_prepostprocess_dropout) + right_state = tf.nn.dropout( + right_state, 1 - hparams.layer_prepostprocess_dropout) + + hidden_state = residual_state + left_state + right_state + + else: + hidden_state = common_layers.layer_postprocess( + residual_state, left_state, hparams) + + with tf.variable_scope(_CONV_BRANCHES_NAME): + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + if nonpadding is not None: + # Mask padding from conv layers. + mask = tf.tile( + tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size]) + hidden_state *= mask + + if layer_cache: + if decode_loop_step is None: + hidden_state = layer_cache[ + _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat( + [ + layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], + hidden_state + ], + axis=1)[:, -1 * _DECODER_LEFT_CONV_PADDING - 1:, :] + left_state = hidden_state + right_state = hidden_state[:, _DECODER_LEFT_CONV_PADDING - + _DECODER_RIGHT_CONV_PADDING:, :] + + else: + # Inplace update is required for inference on TPU. + # Inplace_ops only supports inplace_update on the first dimension. + tmp = tf.transpose( + layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], perm=[1, 0, 2]) + tmp = tf.expand_dims(tmp, axis=1) + tmp = inplace_ops.alias_inplace_update( + tmp, + decode_loop_step * tf.shape(hidden_state)[1] + + _DECODER_LEFT_CONV_PADDING, + tf.transpose(hidden_state, perm=[1, 0, 2])) + tmp = tf.squeeze(tmp, axis=1) + hidden_state = layer_cache[ + _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose( + tmp, perm=[1, 0, 2]) + + batch_size = hidden_state.shape.as_list()[0] + left_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [ + batch_size, _DECODER_LEFT_CONV_PADDING + 1, + hparams.hidden_size + ]) + right_state = tf.slice(hidden_state, [ + 0, decode_loop_step + _DECODER_LEFT_CONV_PADDING - + _DECODER_RIGHT_CONV_PADDING, 0 + ], [ + batch_size, _DECODER_RIGHT_CONV_PADDING + 1, + hparams.hidden_size + ]) + + else: # No caching. + left_state = tf.pad( + hidden_state, + paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0], [0, 0]]) + right_state = tf.pad( + hidden_state, + paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0], [0, 0]]) + + left_output_dim = int(hparams.hidden_size * 2) + separable_conv_11x1 = tf.layers.SeparableConv1D( + left_output_dim, + 11, + padding="VALID", + name="separable_conv11x1", + activation=tf.nn.relu) + left_state = separable_conv_11x1.apply(left_state) + left_state = tf.nn.dropout(left_state, + 1 - hparams.layer_prepostprocess_dropout) + + right_output_dim = int(hparams.hidden_size / 2) + separable_conv_7x1_1 = tf.layers.SeparableConv1D( + right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1") + right_state = separable_conv_7x1_1.apply(right_state) + right_state = tf.nn.dropout(right_state, + 1 - hparams.layer_prepostprocess_dropout) + right_state = tf.pad( + right_state, + [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]], + constant_values=0) + + hidden_state = left_state + right_state + + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + if nonpadding is not None: + # Mask padding from conv layers. + mask = tf.tile( + tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2]) + hidden_state *= mask + + if layer_cache: + if decode_loop_step is None: + hidden_state = layer_cache[ + _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat( + [ + layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], + hidden_state + ], + axis=1)[:, -1 * _DECODER_FINAL_CONV_PADDING - 1:, :] + + else: + # Inplace update is required for inference on TPU. + # Inplace_ops only supports inplace_update on the first dimension. + tmp = tf.transpose( + layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], perm=[1, 0, 2]) + tmp = tf.expand_dims(tmp, axis=1) + tmp = inplace_ops.alias_inplace_update( + tmp, (decode_loop_step + _DECODER_FINAL_CONV_PADDING) * + tf.shape(hidden_state)[1], + tf.transpose(hidden_state, perm=[1, 0, 2])) + tmp = tf.squeeze(tmp, axis=1) + hidden_state = layer_cache[ + _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose( + tmp, perm=[1, 0, 2]) + + batch_size = hidden_state.shape.as_list()[0] + hidden_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [ + batch_size, _DECODER_FINAL_CONV_PADDING + 1, + hparams.hidden_size * 2 + ]) + else: + hidden_state = tf.pad( + hidden_state, + paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0], [0, 0]]) + + separable_conv_7x1_2 = tf.layers.SeparableConv1D( + hparams.hidden_size, + 7, + padding="VALID", + name="separable_conv_7x1_2") + hidden_state = separable_conv_7x1_2.apply(hidden_state) + + hidden_state = common_layers.layer_postprocess( + residual_state, hidden_state, hparams) + + with tf.variable_scope(_VANILLA_ATTENTION_NAME): + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + attention_cache = layer_cache[ + _VANILLA_ATTENTION_NAME] if layer_cache is not None else None + hidden_state = common_attention.multihead_attention( + hidden_state, + None, + decoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type=hparams.self_attention_type, + max_relative_position=hparams.max_relative_position, + heads_share_relative_embedding=( + hparams.heads_share_relative_embedding), + add_relative_to_values=hparams.add_relative_to_values, + save_weights_to=save_weights_to, + cache=attention_cache, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length"), + decode_loop_step=decode_loop_step, + vars_3d=hparams.get("attention_variables_3d"), + activation_dtype=hparams.get("activation_dtype", "float32"), + weight_dtype=hparams.get("weight_dtype", "float32")) + hidden_state = common_layers.layer_postprocess( + residual_state, hidden_state, hparams) + + if encoder_output is not None: + with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME): + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + attention_cache = ( + layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME] + if layer_cache is not None else None) + hidden_state = common_attention.multihead_attention( + hidden_state, + encoder_output, + encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + max_relative_position=hparams.max_relative_position, + heads_share_relative_embedding=( + hparams.heads_share_relative_embedding), + add_relative_to_values=hparams.add_relative_to_values, + save_weights_to=save_weights_to, + cache=attention_cache, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length"), + vars_3d=hparams.get("attention_variables_3d"), + activation_dtype=hparams.get("activation_dtype", "float32"), + weight_dtype=hparams.get("weight_dtype", "float32")) + hidden_state = common_layers.layer_postprocess( + residual_state, hidden_state, hparams) + + with tf.variable_scope("dense_layers"): + residual_state = hidden_state + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + hidden_state = tf.layers.dense( + hidden_state, + int(hparams.hidden_size * 4), + activation=tf.nn.swish) + hidden_state = tf.nn.dropout(hidden_state, + 1 - hparams.layer_prepostprocess_dropout) + + hidden_state = common_layers.layer_preprocess(hidden_state, hparams) + + hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size) + hidden_state = common_layers.layer_postprocess( + residual_state, hidden_state, hparams) + + decoder_output = common_layers.layer_preprocess(hidden_state, hparams) + if num_trainable_top_decoder_layers == 0: + decoder_output = tf.stop_gradient(decoder_output) + return decoder_output + + +def _add_attend_to_encoder_cache(cache, attention_name, hparams, num_layers, + key_channels, value_channels, + vars_3d_num_heads, scope_prefix, + encoder_output): + """Add attend-to-encoder layers to cache.""" + for layer in range(num_layers): + layer_name = "layer_%d" % layer + with tf.variable_scope("%sdecoder/%s/%s/multihead_attention" % + (scope_prefix, layer_name, attention_name)): + k_encdec = common_attention.compute_attention_component( + encoder_output, + key_channels, + name="k", + vars_3d_num_heads=vars_3d_num_heads) + k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads) + v_encdec = common_attention.compute_attention_component( + encoder_output, + value_channels, + name="v", + vars_3d_num_heads=vars_3d_num_heads) + v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads) + cache[layer_name][attention_name] = { + "k_encdec": k_encdec, + "v_encdec": v_encdec + } + return cache + + +def init_evolved_transformer_cache(cache, hparams, batch_size, + attention_init_length, encoder_output, + encoder_decoder_attention_bias, + scope_prefix): + """Create the initial cache for Evolved Transformer fast decoding.""" + key_channels = hparams.attention_key_channels or hparams.hidden_size + value_channels = hparams.attention_value_channels or hparams.hidden_size + num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers + vars_3d_num_heads = ( + hparams.num_heads if hparams.get("attention_variables_3d") else 0) + + # Add self-attentions. + if cache is None: + cache = {} + cache.update({ + "layer_%d" % layer: { # pylint: disable=g-complex-comprehension + _SIXTEEN_HEAD_ATTENTION_NAME: { + "k": + common_attention.split_heads( + tf.zeros( + [batch_size, attention_init_length, key_channels]), + _capped_double_heads(hparams.num_heads)), + "v": + common_attention.split_heads( + tf.zeros( + [batch_size, attention_init_length, value_channels]), + _capped_double_heads(hparams.num_heads)), + }, + _VANILLA_ATTENTION_NAME: { + "k": + common_attention.split_heads( + tf.zeros( + [batch_size, attention_init_length, key_channels]), + hparams.num_heads), + "v": + common_attention.split_heads( + tf.zeros( + [batch_size, attention_init_length, value_channels]), + hparams.num_heads), + } + } for layer in range(num_layers) + }) + + # Add branched layers. Pad with additional zeros for causal convolution. + for layer in range(num_layers): + cache["layer_%d" % layer][_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.zeros([ + batch_size, attention_init_length + _DECODER_LEFT_CONV_PADDING, + hparams.hidden_size + ]) + cache["layer_%d" % layer][_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.zeros([ + batch_size, attention_init_length + _DECODER_FINAL_CONV_PADDING, + hparams.hidden_size * 2 + ]) + + # Add encoder embedding attentions. + if encoder_output is not None: + cache = _add_attend_to_encoder_cache( + cache=cache, + attention_name=_FIRST_ATTEND_TO_ENCODER_NAME, + hparams=hparams, + num_layers=num_layers, + key_channels=key_channels, + value_channels=value_channels, + vars_3d_num_heads=vars_3d_num_heads, + scope_prefix=scope_prefix, + encoder_output=encoder_output) + cache = _add_attend_to_encoder_cache( + cache=cache, + attention_name=_SECOND_ATTEND_TO_ENCODER_NAME, + hparams=hparams, + num_layers=num_layers, + key_channels=key_channels, + value_channels=value_channels, + vars_3d_num_heads=vars_3d_num_heads, + scope_prefix=scope_prefix, + encoder_output=encoder_output) + + cache["encoder_output"] = encoder_output + cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias + + return cache + + +# TODO(davidso): Update optimizer, learning rate, and decay to match paper. +def add_evolved_transformer_hparams(hparams): + """Add Evolved Transformer hparams. + + Note: These are for the Adam optimizer, not the Adafactor optimizer used in + the paper. + + Args: + hparams: Current hparams. + + Returns: + hparams updated with Evolved Transformer values. + """ + # Evolved Transformer "layers" are twice as deep as Transformer, so roughly + # halve the number that we use. These numbers are taken from + # arxiv.org/abs/1901.11117 . + hparams.num_encoder_layers = 3 + hparams.num_decoder_layers = 4 + + # Learning rate and decay scheme that mimics the transformer Adam config, + # but with cosine decay instead of rsqrt. + hparams.learning_rate_constant /= hparams.learning_rate_warmup_steps ** 0.5 + hparams.learning_rate_schedule = ( + "constant*linear_warmup*single_cycle_cos_decay*rsqrt_hidden_size") + return hparams + + +@registry.register_hparams +def evolved_transformer_tiny(): + """Base parameters for Evolved Transformer model.""" + hparams = add_evolved_transformer_hparams(transformer.transformer_tiny()) + hparams.learning_rate_schedule = ( + "constant*single_cycle_cos_decay") + return hparams + + +@registry.register_hparams +def evolved_transformer_base(): + """Base parameters for Evolved Transformer model.""" + return add_evolved_transformer_hparams(transformer.transformer_base()) + + +@registry.register_hparams +def evolved_transformer_big(): + """Big parameters for Evolved Transformer model on WMT.""" + return add_evolved_transformer_hparams(transformer.transformer_big()) + + +@registry.register_hparams +def evolved_transformer_deep(): + """Deep parameters for Evolved Transformer model on WMT.""" + hparams = add_evolved_transformer_hparams(transformer.transformer_big()) + hparams.num_encoder_layers = 9 + hparams.num_decoder_layers = 10 + hparams.hidden_size = 640 + return hparams + + +@registry.register_hparams +def evolved_transformer_base_tpu(): + """Base parameters for Evolved Transformer model on TPU.""" + hparams = add_evolved_transformer_hparams(transformer.transformer_tpu()) + hparams.learning_rate_constant = 1 / hparams.learning_rate_warmup_steps ** 0.5 + hparams.learning_rate_schedule = ( + "constant*single_cycle_cos_decay") + return hparams + + +@registry.register_hparams +def evolved_transformer_big_tpu(): + """Big parameters for Evolved Transformer model on TPU.""" + hparams = add_evolved_transformer_hparams(transformer.transformer_big_tpu()) + hparams.learning_rate_constant = 1 / hparams.learning_rate_warmup_steps ** 0.5 + hparams.learning_rate_schedule = ( + "constant*single_cycle_cos_decay") + return hparams + + +@registry.register_hparams +def evolved_transformer_tpu_basic(): + """Basic Seq2Seq TPU hyper-parameters.""" + hparams = transformer.transformer_big_tpu() + hparams.add_hparam("print_vars", False) + hparams.batch_size = 8192 + hparams.max_length = 256 + + # N < 0 means all weights in the model are trainable. + # N >= 0 means all weights are frozen except N top decoder layers + + # (pre-)softmax matrix (that projects from hidden size to vocab size). + hparams.add_hparam("num_trainable_top_decoder_layers", -1) + + return hparams diff --git a/trax/models/evolved_transformer_test.py b/trax/models/evolved_transformer_test.py new file mode 100644 index 000000000..388769918 --- /dev/null +++ b/trax/models/evolved_transformer_test.py @@ -0,0 +1,756 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Evolved Transformer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models import evolved_transformer +from tensor2tensor.models import transformer + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + +BATCH_SIZE = 3 +INPUT_LENGTH = 5 +TARGET_LENGTH = 7 +VOCAB_SIZE = 10 +DECODE_LENGTH = 3 + + +def print_vars(all_vars=None): + """Print info about a list of variables.""" + if not all_vars: + all_vars = tf.trainable_variables() + tf.logging.info("Format: , , <(soft) device placement>") + for var in all_vars: + tf.logging.info(" %s, %s, %s" % + (var.name, str(var.get_shape()), var.op.device)) + + +def get_var(name): + """Get trainable variable by name.""" + variables = [var for var in tf.trainable_variables() if var.name == name] + if len(variables) == 1: + return variables[0] + raise ValueError("`name` must match exactly one variable. '%s' matched %d" % + (name, len(variables))) + + +def get_vars(names): + """Get trainable variables by name.""" + return [get_var(name) for name in names] + + +def assert_with_message(assert_method, a, b, message): + try: + assert_method(a, b) + except AssertionError as e: + tf.logging.error(message) + raise e + + +def get_model(hparams, has_input=True, num_decoder_layers=1): + hparams.layer_prepostprocess_dropout = 0.0 + hparams.hidden_size = 4 + hparams.num_heads = 1 + hparams.num_encoder_layers = 1 + hparams.num_decoder_layers = num_decoder_layers + + p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE, + hparams) + if not has_input: + del p_hparams.modality["inputs"] + hparams.problem_hparams = p_hparams + + inputs = np.random.randint(VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) + targets = np.random.randint( + VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) + features = { + "targets": tf.constant(targets, dtype=tf.int32, name="targets"), + "target_space_id": tf.constant(1, dtype=tf.int32), + } + if has_input: + features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs") + + return (evolved_transformer.EvolvedTransformer(hparams, + tf_estimator.ModeKeys.TRAIN, + p_hparams), features) + + +class EvolvedTransformerTest(tf.test.TestCase): + + def testEvolvedTransformer(self): + model, features = get_model(hparams=transformer.transformer_tiny()) + logits, _ = model(features) + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE)) + + def testSlowVsFast(self): + tf.set_random_seed(1234) + model, features = get_model(transformer.transformer_tiny()) + + decode_length = DECODE_LENGTH + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(10): + apply_grad.run() + + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + greedy_result = model._slow_greedy_infer(features, + decode_length)["outputs"] + greedy_result = tf.squeeze(greedy_result, axis=[2, 3]) + + fast_result = model._greedy_infer(features, decode_length)["outputs"] + + with self.test_session(): + greedy_res = greedy_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(greedy_res, fast_res) + + def testSlowVsFastNoInput(self): + model, features = get_model(transformer.transformer_tiny(), has_input=False) + + decode_length = DECODE_LENGTH + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(10): + apply_grad.run() + + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + slow_result = model._slow_greedy_infer(features, decode_length)["outputs"] + slow_result = tf.squeeze(slow_result, axis=[2, 3]) + + fast_result = model._greedy_infer(features, decode_length)["outputs"] + + with self.test_session(): + slow_res = slow_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(slow_res.shape, (BATCH_SIZE, decode_length)) + self.assertAllClose(slow_res, fast_res) + + def testBeamVsFast(self): + model, features = get_model(transformer.transformer_tiny()) + + decode_length = DECODE_LENGTH + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(10): + apply_grad.run() + + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + beam_result = model._beam_decode_slow( + features, decode_length, beam_size=4, top_beams=1, + alpha=1.0)["outputs"] + + fast_result = model._beam_decode( + features, decode_length, beam_size=4, top_beams=1, + alpha=1.0)["outputs"] + + with self.test_session(): + beam_res = beam_result.eval() + fast_res = fast_result.eval() + + self.assertAllClose(beam_res, fast_res) + + def _create_greedy_infer_model(self): + """Creates model for greedy inference testing. + + Returns: + model: A t2t model. + features: An map of string to tensor. + """ + model, features = get_model(transformer.transformer_tiny()) + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(10): + apply_grad.run() + + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + return model, features + + def testGreedySlowTPUVsNonTPU(self): + decode_length = DECODE_LENGTH + + model, features = self._create_greedy_infer_model() + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + slow_result_non_tpu = model._slow_greedy_infer(features, + decode_length)["outputs"] + slow_result_non_tpu = tf.squeeze(slow_result_non_tpu, axis=[2, 3]) + + slow_result_tpu = model._slow_greedy_infer_tpu(features, + decode_length)["outputs"] + slow_result_tpu = tf.squeeze(slow_result_tpu, axis=[2, 3]) + + with self.test_session(): + slow_non_tpu_res = slow_result_non_tpu.eval() + slow_tpu_res = slow_result_tpu.eval() + + self.assertEqual(slow_tpu_res.shape, + (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(slow_tpu_res, slow_non_tpu_res) + + def testGreedyFastTPUVsNonTPU(self): + tf.set_random_seed(1234) + decode_length = DECODE_LENGTH + + model, features = self._create_greedy_infer_model() + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + fast_result_non_tpu = model._greedy_infer( + features, decode_length, use_tpu=False)["outputs"] + + fast_result_tpu = model._greedy_infer( + features, decode_length, use_tpu=True)["outputs"] + + with self.test_session(): + fast_non_tpu_res = fast_result_non_tpu.eval() + fast_tpu_res = fast_result_tpu.eval() + + self.assertEqual(fast_tpu_res.shape, + (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(fast_tpu_res, fast_non_tpu_res) + + def testGreedyTPUSlowVsFast(self): + tf.set_random_seed(1234) + decode_length = DECODE_LENGTH + + model, features = self._create_greedy_infer_model() + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + slow_result = model._slow_greedy_infer_tpu(features, + decode_length)["outputs"] + slow_result = tf.squeeze(slow_result, axis=[2, 3]) + + fast_result = model._greedy_infer( + features, decode_length, use_tpu=True)["outputs"] + + with self.test_session(): + slow_res = slow_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(fast_res, slow_res) + + def testFrozenWeightsUnchangedByTraining(self): + # Arrange. + hparams = transformer.transformer_tiny() + hparams.add_hparam("num_trainable_top_decoder_layers", 1) + model, features = get_model(hparams, num_decoder_layers=3) + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + frozen_names = [ + "evolved_transformer/symbol_modality_10_4/shared/weights_0:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_1:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_2:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_3:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_4:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_5:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_6:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_7:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_8:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_9:0", + "evolved_transformer/body/target_space_embedding/kernel:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/kernel:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/bias:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/kernel:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/dense/kernel:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/dense/bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/kernel:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/depthwise_kernel:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/pointwise_kernel:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/bias:0", + "evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/dense/kernel:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/dense/bias:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/kernel:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/bias:0", + "evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/bias:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/bias:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/bias:0", + "evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/dense/kernel:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/dense/bias:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/kernel:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/bias:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/bias:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/bias:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/bias:0", + "evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/dense/kernel:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/dense/bias:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/kernel:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/bias:0", + ] + train_names = [ + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/bias:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/bias:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/bias:0", + "evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/dense/kernel:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/dense/bias:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/kernel:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/bias:0", + "evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_1:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_2:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_3:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_4:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_5:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_6:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_7:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_8:0", + "evolved_transformer/symbol_modality_10_4/softmax/weights_9:0", + ] + frozen_vars = get_vars(frozen_names) + train_vars = get_vars(train_names) + print_vars() + + # Act. + with self.test_session() as session: + tf.global_variables_initializer().run() + frozen_values_before = session.run(frozen_vars) + train_values_before = session.run(train_vars) + for _ in range(10): # Arbitrary number of training steps. + apply_grad.run() + frozen_values_after = session.run(frozen_vars) + train_values_after = session.run(train_vars) + + # Assert. + self.assertTrue( + model._original_hparams.shared_embedding_and_softmax_weights) + self.assertFalse(model.hparams.shared_embedding_and_softmax_weights) + self.assertTrue(model.hparams.shared_embedding) + for name, before, after in zip(frozen_names, frozen_values_before, + frozen_values_after): + assert_with_message( + self.assertAllClose, before, after, + "%s should be frozen, but changed after training." % name) + for name, before, after in zip(train_names, train_values_before, + train_values_after): + assert_with_message( + self.assertNotAllClose, before, after, + "%s should be trainable, but did not change after training." % name) + + def testAllWeightsTrainableByDefault(self): + # Arrange. + model, features = get_model( + transformer.transformer_tiny(), num_decoder_layers=3) + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + var_names = [ + "evolved_transformer/symbol_modality_10_4/shared/weights_0:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_1:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_2:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_3:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_4:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_5:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_6:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_7:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_8:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_9:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_10:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_11:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_12:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_13:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_14:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_15:0", + "evolved_transformer/body/target_space_embedding/kernel:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/kernel:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/bias:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/kernel:0", + "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/dense/kernel:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/dense/bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/kernel:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/depthwise_kernel:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/pointwise_kernel:0", + "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/bias:0", + "evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/dense/kernel:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/dense/bias:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/kernel:0", + "evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/bias:0", + "evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/bias:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/bias:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/bias:0", + "evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/dense/kernel:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/dense/bias:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/kernel:0", + "evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/bias:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/bias:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/bias:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/bias:0", + "evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/dense/kernel:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/dense/bias:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/kernel:0", + "evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/bias:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/bias:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/bias:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/depthwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/pointwise_kernel:0", + "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/bias:0", + "evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/q/kernel:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/k/kernel:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/v/kernel:0", + "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/output_transform/kernel:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/dense/kernel:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/dense/bias:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/kernel:0", + "evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/bias:0", + "evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_scale:0", + "evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_bias:0", + ] + variables = get_vars(var_names) + print_vars() + + # Act. + with self.test_session() as session: + tf.global_variables_initializer().run() + values_before = session.run(variables) + for _ in range(10): # Arbitrary number of training steps. + apply_grad.run() + values_after = session.run(variables) + + # Assert. + self.assertTrue( + model._original_hparams.shared_embedding_and_softmax_weights) + self.assertTrue(model.hparams.shared_embedding_and_softmax_weights) + self.assertFalse(model.hparams.shared_embedding) + self.assertSameElements(var_names, + [var.name for var in tf.trainable_variables()]) + empty_vars = { + "evolved_transformer/symbol_modality_10_4/shared/weights_10:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_11:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_12:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_13:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_14:0", + "evolved_transformer/symbol_modality_10_4/shared/weights_15:0" + } + for name, before, after in zip(var_names, values_before, values_after): + if name in empty_vars: + self.assertEqual(before.size, after.size) + self.assertEqual(before.size, 0) + else: + assert_with_message( + self.assertNotAllClose, before, after, + "%s should be trainable, but did not change after training." % name) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/image_transformer.py b/trax/models/image_transformer.py new file mode 100644 index 000000000..dd7c2d882 --- /dev/null +++ b/trax/models/image_transformer.py @@ -0,0 +1,1158 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""image generation with transformer (attention). + +encoder: [Self-Attention, Feed-forward] x n +decoder: [Self-Attention, Source-Target-Attention, Feed-forward] x n + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_image_attention as cia +from tensor2tensor.layers import common_layers +from tensor2tensor.layers import modalities +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +@registry.register_model +class Imagetransformer(t2t_model.T2TModel): + """Conditional image generation with attention. See file docstring. + + The model admits either a Categorical or discretized mixture of logistic + distributions (DMOL) as the likelihood. When using DMOL for training, double + check that the evaluation metrics also use it. + """ + + def body(self, features): + hparams = copy.copy(self._hparams) + targets = features["targets"] + if (hparams.likelihood == cia.DistributionType.DMOL and + hparams.num_channels != 1): + raise ValueError("When using DMOL for the likelihood, bottom function " + " must be identity and num_channels must be 1.") + if (not tf.get_variable_scope().reuse and + hparams.mode != tf_estimator.ModeKeys.PREDICT): + tf.summary.image("targets", tf.to_float(targets), max_outputs=1) + + # Extra losses list if we want to use moe. + losses = [] + # Prepare decoder inputs and bias. + decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) + # Add class label to decoder input. + if not hparams.unconditional: + inputs = features["inputs"] + decoder_input += tf.reshape( + inputs, + [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) + decoder_output = cia.transformer_decoder_layers( + decoder_input, + None, + hparams.num_decoder_layers or hparams.num_hidden_layers, + hparams, + attention_type=hparams.dec_attention_type, + losses=losses, + name="decoder") + output = cia.create_output(decoder_output, rows, cols, targets, hparams) + + if losses: + return output, {"extra_loss": tf.add_n(losses)} + else: + return output + + def loss(self, logits, features): + if self._hparams.likelihood == cia.DistributionType.DMOL: + return common_layers.dml_loss(logits, features["targets"]) + + return super(Imagetransformer, self).loss(logits, features) + + def sample(self, features): + """Run the model and extract samples. + + Args: + features: an map of string to `Tensor`. + + Returns: + samples: an integer `Tensor`. + logits: a list of `Tensor`s, one per datashard. + losses: a dictionary: {loss-name (string): floating point `Scalar`}. + """ + if self._hparams.likelihood == cia.DistributionType.DMOL: + logits, losses = self(features) # pylint: disable=not-callable + samples = common_layers.sample_from_discretized_mix_logistic( + logits, seed=None) + return samples, logits, losses + + return super(Imagetransformer, self).sample(features) + + def _slow_greedy_infer(self, features, decode_length): + """A slow greedy inference method. + + Quadratic time in decode_length. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + + Returns: + samples: an integer `Tensor`. + logits: `Tensor` of shape [batch_size, time, 1, 1, vocab_size]. + losses: a dictionary: {loss-name (string): floating point `Scalar`} + """ + if self._hparams.likelihood == cia.DistributionType.DMOL: + raise NotImplementedError("Decoding is not currently available for DMOL.") + return super(Imagetransformer, self)._slow_greedy_infer(features, + decode_length) + + +@registry.register_model +class ImagetransformerMoe(t2t_model.T2TModel): + """Conditional image generation with attention and MoE.""" + + @staticmethod + def use_body_sharded(): + return True + + def body_sharded(self, sharded_features): + dp = self._data_parallelism + hparams = copy.copy(self._hparams) + inputs = sharded_features["inputs"] + targets = sharded_features["targets"] + + # Determine attention type and padding from hparams. + q_padding, kv_padding = "VALID", "VALID" + if hparams.q_filter_width > 1: + q_padding = "LEFT" + if hparams.kv_filter_width > 1: + kv_padding = "LEFT" + + # Prepare decoder inputs and bias. + decoder_input, rows, cols = dp(cia.prepare_decoder_inputs, + inputs, targets, hparams) + + # Run decoder. + # TODO(nikip): Use q_padding and kv_padding + del q_padding, kv_padding + decoder_output, extra_loss = cia.transformer_layers_sharded( + dp, + self._ps_devices, + decoder_input, + hparams.num_hidden_layers, + hparams, + self_attention_bias=None, + enc_output=None, + attention_type=hparams.dec_attention_type, + name="decoder") + + output = dp(cia.create_output, decoder_output, rows, cols, targets, hparams) + return output, extra_loss + + +@registry.register_hparams +def image_transformer_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.hidden_size = 512 + hparams.batch_size = 4 + hparams.max_length = 3075 + hparams.dropout = 0.0 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 4000 + hparams.initializer_gain = 0.2 + hparams.num_hidden_layers = 6 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.label_smoothing = 0.0 + hparams.bottom["targets"] = modalities.image_channel_embeddings_bottom + hparams.top["targets"] = modalities.identity_top + hparams.norm_type = "layer" + hparams.layer_prepostprocess_dropout = 0.0 + hparams.add_hparam("filter_size", 512) # Add new ones like this. + + # attention-related flags + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("attention_key_channels", 0) + hparams.add_hparam("attention_value_channels", 0) + hparams.add_hparam("ffn_layer", "conv_hidden_relu") + # All hyperparameters ending in "dropout" are automatically set to 0.0 + # when not in training mode. + hparams.add_hparam("attention_dropout", 0.0) + hparams.add_hparam("relu_dropout", 0.0) + hparams.add_hparam("pos", "timing") # timing, none + hparams.add_hparam("nbr_decoder_problems", 1) + hparams.add_hparam("num_output_layers", 3) + hparams.add_hparam("block_size", 1) + + # dilated attention based flags + hparams.add_hparam("gap_sizes", [2, 4, 8, 16, 32, 64, 2, 4, 8, 16, 32, 64]) + + # image size related flags + # assuming that the image has same height and width + hparams.add_hparam("img_len", 32) + hparams.add_hparam("num_channels", 3) + # Local attention params + hparams.add_hparam("local_and_global_att", False) + hparams.add_hparam("block_length", 256) + hparams.add_hparam("block_width", 128) + hparams.add_hparam("num_encoder_layers", 4) + hparams.add_hparam("num_decoder_layers", 12) + hparams.add_hparam("dec_attention_type", cia.AttentionType.LOCAL_1D) + hparams.add_hparam("block_raster_scan", False) + + # multipos attention params + hparams.add_hparam("q_filter_width", 1) + hparams.add_hparam("kv_filter_width", 1) + + hparams.add_hparam("likelihood", cia.DistributionType.CAT) + hparams.add_hparam("unconditional", False) # unconditional generation + + # parameters of discretized mixture of logistics loss from pixel cnn++ + hparams.add_hparam("num_mixtures", 10) + + # These parameters are only used when ffn_layer=="local_moe_tpu" + hparams.add_hparam("moe_overhead_train", 1.0) + hparams.add_hparam("moe_overhead_eval", 2.0) + hparams.moe_num_experts = 8 + hparams.moe_loss_coef = 1e-3 + + # These parameters are for relative attention + hparams.add_hparam("shared_rel", False) # share relative embeddings + return hparams + + +@registry.register_hparams +def imagetransformer_base(): + hparams = image_transformer_base() + return hparams + + +@registry.register_hparams +def imagetransformer_cifar10_base(): + """Best config for 2.90 bits/dim on CIFAR10 using cross entropy.""" + hparams = image_transformer_base() + hparams.batch_size = 4 + hparams.num_heads = 4 + hparams.num_decoder_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.5 + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_cifar10_base_dmol(): + """Best config for 2.90 bits/dim on CIFAR10 using DMOL.""" + hparams = image_transformer_base() + hparams.likelihood = cia.DistributionType.DMOL + hparams.num_channels = 1 + hparams.bottom["targets"] = modalities.image_channel_compress_targets_bottom + hparams.top["targets"] = modalities.identity_top + hparams.num_heads = 8 + hparams.batch_size = 8 + hparams.sampling_method = "random" + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + hparams.summarize_grads = True + hparams.hidden_size = 256 + hparams.filter_size = 512 + hparams.attention_key_channels = 512 + hparams.attention_value_channels = 512 + hparams.num_decoder_layers = 12 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.learning_rate = 0.1 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.pos = "emb" + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_base_tpu(): + """Transformer base params for cifar-10.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 6000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_base_imagenet_tpu(): + """Transformer base params for cifar-10.""" + hparams = imagetransformer_base_tpu() + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_imagenet32_base(): + """Best config for ImageNet-32 with 3.77 bits/dim using cross entropy.""" + hparams = imagetransformer_cifar10_base() + hparams.batch_size = 4 + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_base_rel(): + """Base with relative attention.""" + hparams = imagetransformer_base() + hparams.dec_attention_type = cia.AttentionType.RELATIVE_LOCAL_1D + return hparams + + +@registry.register_hparams +def imagetransformer_sep_channels(): + """separate rgb embeddings.""" + hparams = imagetransformer_base() + hparams.num_heads = 4 + hparams.attention_key_channels = hparams.attention_value_channels = 0 + hparams.hidden_size = 256 + hparams.filter_size = 512 + hparams.num_hidden_layers = 6 + return hparams + + +@registry.register_hparams +def imagetransformer_sep_channels_8l(): + """separate rgb embeddings.""" + hparams = imagetransformer_base() + hparams.num_heads = 4 + hparams.attention_key_channels = hparams.attention_value_channels = 0 + hparams.hidden_size = 256 + hparams.filter_size = 256 + hparams.num_hidden_layers = 8 + hparams.sampling_method = "random" + return hparams + + +@registry.register_hparams +def imagetransformer_sep_channels_8l_multipos3(): + """separate rgb embeddings.""" + hparams = imagetransformer_sep_channels_8l() + hparams.q_filter_width = 3 + hparams.kv_filter_width = 3 + return hparams + + +@registry.register_hparams +def imagetransformer_base_8l_8h_big_cond_dr03_dan(): + """big 1d model for conditional image generation.2.99 on cifar10.""" + hparams = imagetransformer_sep_channels_8l() + hparams.block_width = 256 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.num_heads = 8 + hparams.filter_size = 2048 + hparams.batch_size = 4 + hparams.max_length = 3075 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.num_decoder_layers = 8 + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_base_10l_8h_big_uncond_dr03_dan_64(): + """big 1d model for unconditional generation on imagenet.""" + hparams = imagetransformer_base_10l_8h_big_cond_dr03_dan() + hparams.unconditional = True + hparams.max_length = 14000 + hparams.batch_size = 1 + hparams.img_len = 64 + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformerpp_sep_channels_8l_8h(): + """separate rgb embeddings.""" + hparams = imagetransformer_base() + hparams.likelihood = cia.DistributionType.DMOL + hparams.num_channels = 1 + hparams.bottom["targets"] = modalities.image_channel_compress_targets_bottom + hparams.top["targets"] = modalities.identity_top + hparams.num_heads = 8 + hparams.batch_size = 4 + hparams.attention_key_channels = hparams.attention_value_channels = 0 + hparams.hidden_size = 512 + hparams.filter_size = 512 + hparams.num_hidden_layers = 8 + hparams.sampling_method = "random" + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + hparams.summarize_grads = True + hparams.learning_rate = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_8l_8h_big_cond_dr03_dan(): + """big 1d model for conditional image generation.2.99 on cifar10.""" + hparams = imagetransformerpp_sep_channels_8l_8h() + hparams.hidden_size = 512 + hparams.num_heads = 8 + hparams.filter_size = 2048 + hparams.batch_size = 4 + hparams.max_length = 3075 + hparams.layer_prepostprocess_dropout = 0.3 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.summarize_grads = True + hparams.learning_rate = 0.01 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_8l_8h_big_cond_dr03_dan_a(): + hparams = imagetransformerpp_base_8l_8h_big_cond_dr03_dan() + hparams.learning_rate = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_10l_8h_big_uncond_dr03_dan(): + hparams = imagetransformerpp_base_8l_8h_big_cond_dr03_dan_a() + hparams.unconditional = True + hparams.num_decoder_layers = 10 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_a(): + hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan() + hparams.learning_rate = 0.01 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_b(): + hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan() + hparams.learning_rate = 0.1 + hparams.hidden_size = 256 + hparams.attention_key_channels = 512 + hparams.attention_value_channels = 512 + hparams.filter_size = 1024 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_g(): + hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_b() + hparams.filter_size = 512 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.learning_rate = 0.1 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.pos = "emb" + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_k(): + hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_g() + hparams.num_decoder_layers = 12 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_l(): + hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_g() + hparams.num_decoder_layers = 12 + hparams.clip_grad_norm = 40. + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m(): + hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_k() + hparams.batch_size = 8 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m_rel(): + hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_k() + hparams.batch_size = 8 + hparams.dec_attention_type = cia.AttentionType.RELATIVE_LOCAL_1D + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m_relsh(): + hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m_rel() + hparams.shared_rel = True + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_14l_8h_big_uncond_dr03_dan_p(): + """Gets to 2.92 in just under 4 days on 8 p100s.""" + hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_l() + hparams.num_decoder_layers = 14 + hparams.batch_size = 8 + hparams.layer_prepostprocess_dropout = 0.2 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m_bs1(): + """For 128x128.""" + # TODO(trandustin): why are these running? max_length and img_len not set + # 256x256 was also training without setting max_length + hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m() + hparams.batch_size = 1 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_14l_8h_big_uncond_dr03_dan_p_bs1(): + """For 128x128.""" + hparams = imagetransformerpp_base_14l_8h_big_uncond_dr03_dan_p() + hparams.batch_size = 1 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_5l_8h_big_uncond_dr00_dan_g_bs1(): + """For 256x256.""" + hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_g() + # TODO(trandustin): I forgot to set this in the runs! Maybe it's not used in + # image transformer training implementation? + # hparams.img_len = 256 + hparams.max_length = 66000 # allow for 256x256 + hparams.batch_size = 1 + hparams.num_decoder_layers = 5 + hparams.hidden_size = 128 + hparams.filter_size = 128 + hparams.attention_key_channels = 64 + hparams.attention_value_channels = 64 + hparams.layer_prepostprocess_dropout = 0.0 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_5l_8h_dr00_dan_g_bs1_adafactor(): + """For 256x256.""" + hparams = imagetransformerpp_base_5l_8h_big_uncond_dr00_dan_g_bs1() + # Use Adafactor which uses less memory than Adam, and its recommendations. + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay" + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_6l_8h_dr00_dan_g_bs1_adafactor(): + """For 256x256.""" + hparams = imagetransformerpp_base_5l_8h_dr00_dan_g_bs1_adafactor() + hparams.num_decoder_layers = 6 + return hparams + + +@registry.register_hparams +def imagetransformerpp_base_14l_8h_big_uncond_dr03_dan_eval(): + """Gets to 2.92 in just under 4 days on 8 p100s.""" + hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_l() + hparams.num_decoder_layers = 14 + hparams.batch_size = 8 + # hparams.layer_prepostprocess_dropout = 0.2 + return hparams + + +@registry.register_hparams +def imagetransformer_base_8l_8h_big_cond_dr03_dan_128(): + hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan() + hparams.block_width = 128 + hparams.block_length = 128 + return hparams + + +@registry.register_hparams +def imagetransformer_base_10l_8h_big_cond_dr03_dan(): + """Best conditional Cifar10 gen param.""" + hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan() + hparams.num_decoder_layers = 10 + return hparams + + +@registry.register_hparams +def imagetransformer_base_10l_8h_big_uncond_dr03_dan(): + """Best unconditional Cifar10 gen param.""" + hparams = imagetransformer_base_10l_8h_big_cond_dr03_dan() + hparams.num_decoder_layers = 10 + return hparams + + +@registry.register_hparams +def imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated(): + """Dilated hparams.""" + hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan() + hparams.gap_sizes = [0, 16, 64, 0, 16, 64, 128, 0] + hparams.dec_attention_type = cia.AttentionType.DILATED + hparams.block_length = 128 + hparams.block_width = 128 + hparams.add_hparam("num_memory_blocks", 1) + return hparams + + +@registry.register_hparams +def imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated_b(): + """Dilated hparams.""" + hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated() + hparams.block_width = 64 + hparams.num_memory_blocks = 2 + return hparams + + +@registry.register_hparams +def imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated_c(): + """Dilated hparams.""" + hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated() + hparams.block_width = 32 + hparams.num_memory_blocks = 4 + return hparams + + +@registry.register_hparams +def imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated_d(): + """Dilated hparams.""" + hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated() + hparams.gap_sizes = [0, 16, 64, 16, 64, 128, 256, 0] + return hparams + + +@registry.register_hparams +def imagetransformer_base_12l_8h_big(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_sep_channels_8l_8h() + hparams.filter_size = 1024 + hparams.num_decoder_layers = 12 + hparams.batch_size = 1 + hparams.hidden_size = 512 + hparams.learning_rate_warmup_steps = 4000 + hparams.sampling_method = "random" + hparams.beam_size = 1 + hparams.block_width = 256 + return hparams + + +@registry.register_hparams +def imagetransformer1d_base_8l_64by64(): + """hparams fo 12 layer big 1d model for imagenet 64x64.""" + hparams = image_transformer_base() + hparams.num_heads = 8 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.num_decoder_layers = 8 + hparams.batch_size = 1 + hparams.block_length = 512 + hparams.block_width = 768 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.max_length = 14000 + hparams.unconditional = int(False) + return hparams + + +@registry.register_hparams +def imagetransformer1d_base_12l_64by64(): + """hparams fo 12 layer big 1d model for imagenet 64x64.""" + hparams = image_transformer_base() + hparams.num_heads = 8 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.num_decoder_layers = 12 + hparams.batch_size = 1 + hparams.block_length = 512 + hparams.block_width = 768 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.max_length = 14000 + hparams.unconditional = int(False) + return hparams + + +@registry.register_hparams +def imagetransformer_base_14l_8h_big(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_base_12l_8h_big() + hparams.num_decoder_layers = 14 + return hparams + + +@registry.register_hparams +def imagetransformer_base_14l_8h_big_dr01(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_base_14l_8h_big() + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_base_12l_8h_big_uncond(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_base_12l_8h_big() + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_base_14l_8h_big_uncond(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_base_12l_8h_big_uncond() + hparams.num_decoder_layers = 14 + return hparams + + +@registry.register_hparams +def imagetransformer_sep_channels_12l_16h_imagenet_large(): + """separate rgb embeddings.""" + hparams = imagetransformer_sep_channels_8l_8h() + hparams.num_hidden_layers = 12 + hparams.batch_size = 1 + hparams.filter_size = 2048 + hparams.num_heads = 16 + hparams.learning_rate_warmup_steps = 16000 + hparams.sampling_method = "random" + hparams.learning_rate = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_sep_channels_16l_16h_imgnet_lrg_loc(): + """separate rgb embeddings.""" + hparams = imagetransformer_sep_channels_12l_16h_imagenet_large() + hparams.num_hidden_layers = 16 + hparams.local_attention = True + hparams.batch_size = 1 + hparams.block_length = 256 + return hparams + + +@registry.register_hparams +def imagetransformer_sep_channels_16l_16h_imgnet_lrg_loc_128(): + """separate rgb embeddings.""" + hparams = imagetransformer_sep_channels_12l_16h_imagenet_large() + hparams.num_hidden_layers = 16 + hparams.local_attention = True + hparams.batch_size = 1 + hparams.block_length = 128 + return hparams + + +@registry.register_hparams +def imagetransformer_sep_output_channels_8l_local_and_global_att(): + """separate rgb embeddings.""" + hparams = imagetransformer_sep_channels_8l() + hparams.sampling_method = "random" + hparams.local_and_global_att = True + return hparams + + +@registry.register_hparams +def imagetransformer_base_10l_16h_big_uncond_dr01_imgnet(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_base_14l_8h_big_dr01() + # num_hidden_layers + hparams.num_decoder_layers = 10 + hparams.num_heads = 16 + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + hparams.batch_size = 1 + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_base_10l_16h_big_dr01_imgnet(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_base_14l_8h_big_dr01() + # num_hidden_layers + hparams.num_decoder_layers = 10 + hparams.num_heads = 16 + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + hparams.batch_size = 1 + hparams.unconditional = False + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_sep_channels_8l_8h(): + """separate rgb embeddings.""" + hparams = imagetransformer_base() + hparams.num_heads = 8 + hparams.batch_size = 1 + hparams.attention_key_channels = hparams.attention_value_channels = 0 + hparams.hidden_size = 512 + hparams.filter_size = 512 + hparams.num_hidden_layers = 8 + hparams.sampling_method = "random" + return hparams + + +@registry.register_hparams +def imagetransformer_sep_channels_8l_8h_local_and_global_att(): + """separate rgb embeddings.""" + hparams = imagetransformer_sep_channels_8l_8h() + hparams.num_heads = 8 + hparams.batch_size = 1 + hparams.attention_key_channels = hparams.attention_value_channels = 0 + hparams.hidden_size = 256 + hparams.filter_size = 256 + hparams.num_hidden_layers = 4 + hparams.sampling_method = "random" + hparams.local_and_global_att = True + return hparams + + +@registry.register_hparams +def imagetransformer_bas8l_8h_big_uncond_dr03_imgnet(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_base_14l_8h_big_dr01() + # num_hidden_layers + hparams.num_decoder_layers = 8 + hparams.num_heads = 8 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_tiny(): + hparams = imagetransformer_base() + hparams.num_decoder_layers = 2 + hparams.hidden_size = 64 + hparams.batch_size = 1 + hparams.unconditional = True + hparams.max_length = 66000 # allow for 256x256 + return hparams + + +@registry.register_hparams +def imagetransformerpp_tiny(): + hparams = imagetransformer_tiny() + hparams.likelihood = cia.DistributionType.DMOL + hparams.num_channels = 1 + hparams.bottom["targets"] = modalities.image_channel_compress_targets_bottom + hparams.top["targets"] = modalities.identity_top + return hparams + + +@registry.register_hparams +def imagetransformer_tiny_tpu(): + hparams = imagetransformer_tiny() + update_hparams_for_tpu(hparams) + hparams.num_hidden_layers = 2 + hparams.hidden_size = 16 + hparams.batch_size = 2 + hparams.num_heads = 2 + return hparams + + +@registry.register_hparams +def imagetransformer_base_10l_16h_big_dr01_moe_imgnet(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_base_10l_16h_big_dr01_imgnet() + hparams.initializer = "orthogonal" + hparams.learning_rate_warmup_steps = 16000 + hparams.add_hparam("moe_layers_decoder", "2,7") # Which layer is MoE. + hparams.moe_hidden_sizes = "4096" # Hidden layer sizes (comma-separated). + hparams.moe_num_experts = 64 # Number of experts in each MoE layer. + hparams.moe_k = 4 # How many experts to use per batch element (try 2 or 4). + hparams.moe_loss_coef = 3e-2 # MoE loss coefficient (1e-2 is usually ok). + hparams.scheduled_sampling_prob = 0.1 + hparams.scheduled_sampling_warmup_steps = 200000 + return hparams + + +@registry.register_hparams +def imagetransformer_moe_tiny(): + """Set of hyperparameters for a very small imagetransformer with MoE.""" + hparams = imagetransformer_tiny() + hparams.hidden_size = 64 + hparams.batch_size = 1 + hparams.num_hidden_layers = 3 + hparams.dec_attention_type = cia.AttentionType.MOE_LOCAL_1D + hparams.add_hparam("moe_layers_decoder", "1") # Which layer is MoE. + hparams.moe_hidden_sizes = "1024" # Hidden layer sizes (comma-separated). + hparams.moe_num_experts = 16 # Number of experts in each MoE layer. + hparams.moe_k = 2 # How many experts to use per batch element (try 2 or 4). + hparams.moe_loss_coef = 1e-2 # MoE loss coefficient (1e-2 is usually ok). + return hparams + + +def update_hparams_for_tpu(hparams): + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay" + hparams.learning_rate_warmup_steps = 6000 + hparams.batch_size = 4 + + +@registry.register_hparams +def imagetransformer_sep_channels_8l_tpu(): + """Hparams for training imagetransformer on tpu.""" + hparams = imagetransformer_sep_channels_8l() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.shared_embedding_and_softmax_weights = False + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr03_tpu(): + """Small model for tpu cifar 10.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.block_length = 128 + hparams.hidden_size = 512 + hparams.filter_size = 1024 + hparams.learning_rate = 0.2 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_dr03_moe_tpu(): + """Moe tpu params.""" + hparams = imagetransformer_b10l_4h_big_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.ffn_layer = "local_moe_tpu" + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr03_lr025_tpu(): + """TPU related small model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 10 + hparams.learning_rate = 0.25 + hparams.learning_rate_warmup_steps = 8000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + # hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_big_uncond_dr03_tpu(): + """TPU 12 layer model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 512 + hparams.filter_size = 1024 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_big_uncond_dr03_lr025_tpu(): + hparams = imagetransformer_b12l_4h_big_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.learning_rate = 0.25 + hparams.learning_rate_warmup_steps = 5000 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b256_uncond_dr03_tpu(): + """works very well on 4x4.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.5 + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b256_uncond_dr03_rel_tpu(): + """works very well on 4x4.""" + hparams = imagetransformer_b12l_4h_b256_uncond_dr03_tpu() + hparams.shared_rel = True + hparams.dec_attention_type = cia.AttentionType.RELATIVE_LOCAL_1D + return hparams + + +@registry.register_ranged_hparams +def imagetransformer_cifar_tpu_range(rhp): + """Range of hyperparameters for vizier.""" + # After starting from base, set intervals for some parameters. + rhp.set_float("learning_rate", 0.01, 1.0, scale=rhp.LOG_SCALE) + rhp.set_discrete("num_decoder_layers", [8, 10, 12, 14, 16]) + rhp.set_discrete("hidden_size", [256, 512, 1024]) + rhp.set_discrete("block_length", [128, 256, 512]) + rhp.set_categorical("dec_attention_type", [ + cia.AttentionType.RELATIVE_LOCAL_1D, cia.AttentionType.LOCAL_1D]) + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_h512_uncond_dr03_tpu(): + """TPU related big model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 6000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_h512_uncond_dr01_im(): + """TPU related imagenet model.""" + hparams = imagetransformer_b12l_4h_b256_uncond_dr03_tpu() + update_hparams_for_tpu(hparams) + hparams.batch_size = 4 + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay" + hparams.learning_rate_warmup_steps = 6000 + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_uncond_dr03_tpu(): + """TPU related small model.""" + hparams = imagetransformer_b12l_4h_b256_uncond_dr03_tpu() + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_4h_b128_uncond_dr03_tpu(): + """TPU config for cifar 10.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 2 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 128 + hparams.hidden_size = 256 + hparams.filter_size = 2048 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.1 + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay" + hparams.learning_rate_warmup_steps = 10000 + return hparams + + +@registry.register_hparams +def imagetransformer_b12l_8h_b256_uncond_dr03_tpu(): + """TPU related 12 layer 8 heads model.""" + hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet() + update_hparams_for_tpu(hparams) + hparams.batch_size = 2 + hparams.num_heads = 8 # heads are expensive on tpu + hparams.num_decoder_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def imagetransformer_b10l_4h_big_uncond_dr01_tpu(): + """big 1d model for conditional image generation.""" + hparams = imagetransformer_b12l_4h_big_uncond_dr03_tpu() + # num_hidden_layers + hparams.num_decoder_layers = 10 + hparams.num_heads = 4 + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + hparams.batch_size = 1 + hparams.layer_prepostprocess_dropout = 0.1 + return hparams diff --git a/trax/models/image_transformer_2d.py b/trax/models/image_transformer_2d.py new file mode 100644 index 000000000..32c4aa59a --- /dev/null +++ b/trax/models/image_transformer_2d.py @@ -0,0 +1,908 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""image generation with transformer (attention). + +encoder: [Self-Attention, Feed-forward] x n +decoder: [Self-Attention, Source-Target-Attention, Feed-forward] x n + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import numpy as np +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_image_attention as cia +from tensor2tensor.layers import common_layers +from tensor2tensor.layers import modalities +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +@registry.register_model +class Imagetransformer2d(t2t_model.T2TModel): + """Conditional image generation with attention. See file docstring.""" + + def body(self, features): + hparams = copy.copy(self._hparams) + inputs = features["inputs"] + targets = features["targets"] + targets_shape = common_layers.shape_list(targets) + if not (tf.get_variable_scope().reuse or + hparams.mode == tf_estimator.ModeKeys.PREDICT): + tf.summary.image("targets", targets, max_outputs=1) + + decoder_input, rows, cols = cia.prepare_decoder( + targets, hparams) + # Add class label to decoder input. + if not hparams.unconditional: + decoder_input += tf.reshape(inputs, + [targets_shape[0], 1, 1, hparams.hidden_size]) + + decoder_output = cia.transformer_decoder_layers( + decoder_input, None, + hparams.num_decoder_layers, + hparams, + attention_type=hparams.dec_attention_type, + name="decoder") + + output = cia.create_output(decoder_output, rows, cols, targets, hparams) + return output + + +@registry.register_model +class Img2imgTransformer(t2t_model.T2TModel): + """Image 2 Image transformer net.""" + + def body(self, features): + hparams = copy.copy(self._hparams) + targets = features["targets"] + inputs = features["inputs"] + if not (tf.get_variable_scope().reuse or + hparams.mode == tf_estimator.ModeKeys.PREDICT): + tf.summary.image("inputs", inputs, max_outputs=1) + tf.summary.image("targets", targets, max_outputs=1) + + encoder_input = cia.prepare_encoder(inputs, hparams) + encoder_output = cia.transformer_encoder_layers( + encoder_input, + hparams.num_encoder_layers, + hparams, + attention_type=hparams.enc_attention_type, + name="encoder") + decoder_input, rows, cols = cia.prepare_decoder( + targets, hparams) + decoder_output = cia.transformer_decoder_layers( + decoder_input, + encoder_output, + hparams.num_decoder_layers, + hparams, + attention_type=hparams.dec_attention_type, + name="decoder") + output = cia.create_output(decoder_output, rows, cols, targets, hparams) + return output + + +@registry.register_model +class Img2imgTransformerBlockParallel(t2t_model.T2TModel): + """Image-to-image transformer predicting blocks of the output in parallel.""" + + def body(self, features): + assert self._hparams.block_size > 0 + assert not common_layers.is_xla_compiled() + + hparams = copy.copy(self._hparams) + targets = features["targets"] + inputs = features["inputs"] + if not (tf.get_variable_scope().reuse or + hparams.mode == tf_estimator.ModeKeys.PREDICT): + tf.summary.image("inputs", inputs, max_outputs=1) + tf.summary.image("targets", targets, max_outputs=1) + + encoder_input = cia.prepare_encoder(inputs, hparams) + encoder_output = cia.transformer_encoder_layers( + encoder_input, + hparams.num_encoder_layers, + hparams, + attention_type=hparams.enc_attention_type, + name="encoder") + decoder_input, rows, cols = cia.prepare_decoder( + targets, hparams) + decoder_output = cia.transformer_decoder_layers( + decoder_input, + encoder_output, + hparams.num_decoder_layers, + hparams, + attention_type=hparams.dec_attention_type, + name="decoder") + + assert not isinstance(decoder_output, tuple) + assert len(decoder_output.shape) == 4 + + relu_dropout_broadcast_dims = ( + common_layers.comma_separated_string_to_integer_list( + getattr(self._hparams, "relu_dropout_broadcast_dims", ""))) + + with tf.variable_scope("block_size_%d" % self._hparams.block_size): + tf.logging.info("Using block_size %d", self._hparams.block_size) + block_output = common_layers.dense_relu_dense( + decoder_output, + self._hparams.block_size * self._hparams.filter_size, + self._hparams.block_size * self._hparams.hidden_size, + dropout=self._hparams.relu_dropout, + dropout_broadcast_dims=relu_dropout_broadcast_dims) + + batch_size, rows, cols = common_layers.shape_list(decoder_output)[:3] + decoder_output = tf.reshape(decoder_output, [ + batch_size, + rows, + cols, + 1, + self._hparams.hidden_size + ]) + block_output = tf.reshape(block_output, [ + batch_size, + rows, + cols, + self._hparams.block_size, + self._hparams.hidden_size + ]) + + block_output = common_layers.layer_postprocess( + decoder_output, block_output, self._hparams) + + return block_output + + def top(self, body_output, features): + assert self._hparams.block_size > 0 + + train_or_eval = ( + self._hparams.mode == tf_estimator.ModeKeys.TRAIN or + self._hparams.mode == tf_estimator.ModeKeys.EVAL) + + if train_or_eval: + if self._hparams.mode == tf_estimator.ModeKeys.TRAIN: + features["block_index"] = tf.random_uniform( + shape=[], minval=0, maxval=self._hparams.block_size, dtype=tf.int64) + else: + features["block_index"] = 0 + body_output = body_output[:, :, :, features["block_index"], :] + + decoded_image = tf.layers.dense( + body_output, 256, use_bias=True, activation=None, name="output_conv") + + assert len(features["targets"].shape) == 4 + targets_shape = common_layers.shape_list(features["targets"]) + + if train_or_eval: + output = tf.reshape(decoded_image, targets_shape + [256]) + else: + output = tf.reshape(decoded_image, [ + targets_shape[0], -1, self._hparams.block_size, 1, 256]) + output = output[:, :targets_shape[1], :, :, :] + + return output + + def loss(self, logits, features): + assert self._hparams.block_size > 0 + + if self._hparams.mode == tf_estimator.ModeKeys.PREDICT: + return 0.0 + + def shift_left_2d(x, k): + return tf.pad(x, [[0, 0], [0, k]])[:, k:] + + def shift_left_4d_raster_scan(x, k): + batch_size = common_layers.shape_list(x)[0] + return tf.reshape( + shift_left_2d(tf.reshape(x, [batch_size, -1]), k), tf.shape(x)) + + targets = features["targets"] + assert len(targets.shape) == 4 + + targets = tf.stack([ + shift_left_4d_raster_scan(targets, i) + for i in range(self._hparams.block_size) + ], axis=4) + + if (self._hparams.mode == tf_estimator.ModeKeys.TRAIN or + self._hparams.mode == tf_estimator.ModeKeys.EVAL): + assert "block_index" in features + targets = targets[:, :, :, :, features["block_index"]] + + features["targets"] = targets + + loss = super(Img2imgTransformerBlockParallel, self).loss(logits, features) + + if self._hparams.mode == tf_estimator.ModeKeys.TRAIN: + k = features["block_index"] + loss_num, loss_den = loss + loss_val = loss_num / loss_den + for i in range(self._hparams.block_size): + # Hack: if you report a loss of NaN, TensorBoard will plot a point at + # the previous value without a connecting line. This is used here to + # separate out the training losses by block index. + one_or_nan = tf.cond(tf.equal(k, i), lambda: 1.0, lambda: float("nan")) + tf.summary.scalar( + "block_index_%d" % i, one_or_nan * loss_val, family="losses") + + return loss + + def _greedy_infer(self, features, decode_length, use_tpu=False): + assert not use_tpu + return self._slow_greedy_infer_guess_and_check(features, decode_length) + + def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha): + raise NotImplementedError + + def _slow_greedy_infer_guess_and_check(self, features, decode_length): + assert self._hparams.block_size > 0 + assert self._hparams.force_full_predict + assert self._hparams.sampling_method == "argmax" + assert self._decode_hparams.batch_size == 1 + assert self._decode_hparams.block_size > 0 + assert self._decode_hparams.block_size <= self._hparams.block_size + assert ( + (self._decode_hparams.guess_and_check_top_k > 0) + + (self._decode_hparams.guess_and_check_epsilon >= 0) == 1) + + inputs_old = features["inputs"] + assert "targets" not in features + + assert len(features["inputs"].shape) in [3, 4] + if len(features["inputs"].shape) < 4: + features["inputs"] = tf.expand_dims(features["inputs"], 2) + + block_size = self._decode_hparams.block_size + decode_length += tf.shape(features["inputs"])[1] + + def while_exit_cond(result, length): # pylint: disable=unused-argument + return length < decode_length + + def infer_step(result, length): + """Inference step.""" + + def print_info(samples, result, length, new_length): + tf.logging.info( + "length=%s new_length=%s length_diff=%s samples-result=%s", + length, + new_length, + new_length - length, + np.array_str( + samples[0, -block_size-1:-1, 0, 0] - + result[0, -block_size:, 0, 0] + ).replace("\n", ""), + ) + + features["targets"] = tf.pad(result, [[0, 0], [0, 1], [0, 0], [0, 0]]) + samples, logits, losses = self.sample(features) # pylint: disable=unused-variable + + _, top_k_indices = tf.nn.top_k( + logits[:, :-1, :1, :, :], + k=self._decode_hparams.guess_and_check_top_k) + in_top_k = tf.reduce_any( + tf.equal(tf.to_int64(top_k_indices), tf.expand_dims(result, 4)), + axis=4) + + within_epsilon = tf.less_equal( + tf.abs(result - samples[:, :-1, :1, :]), + self._decode_hparams.guess_and_check_epsilon) + + if self._decode_hparams.guess_and_check_top_k: + tf.logging.info( + "Using guess_and_check_top_k=%s", + self._decode_hparams.guess_and_check_top_k) + correct = in_top_k + else: + tf.logging.info( + "Using guess_and_check_epsilon=%s", + self._decode_hparams.guess_and_check_epsilon) + correct = within_epsilon + + correct_cumsum = tf.cumsum(tf.to_int32(correct), axis=1) + perfect_cumsum = 1 + tf.range(tf.shape(correct)[1]) + for axis in [0, 2, 3]: + perfect_cumsum = tf.expand_dims(perfect_cumsum, axis=axis) + + new_length = tf.reduce_sum( + tf.to_int32(tf.equal(correct_cumsum, perfect_cumsum)), axis=1) + new_length = tf.squeeze(new_length, axis=[0, 1, 2]) + new_length = tf.minimum(new_length, decode_length) + + new_result = tf.concat([ + result[:, :new_length, :, :], + tf.reshape( + samples[:, new_length, :block_size, :], [1, block_size, 1, 1]) + ], axis=1) + + with tf.control_dependencies([ + tf.py_func(print_info, [samples, result, length, new_length], []) + ]): + new_result = tf.identity(new_result) + + return new_result, new_length + + result = tf.zeros((1, 0, 1, 1), dtype=tf.int64) + length = tf.squeeze(tf.zeros(1, dtype=tf.int32)) + + result, length = tf.while_loop( + while_exit_cond, + infer_step, + [result, length], + shape_invariants=[ + tf.TensorShape([1, None, 1, 1]), + tf.TensorShape([]), + ], + back_prop=False, + parallel_iterations=1) + + result = result[:, :length, :, :] + + features["inputs"] = inputs_old + + return { + "outputs": result, + "scores": None, + } + + +@registry.register_hparams +def image_transformer2d_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.hidden_size = 512 + hparams.batch_size = 1 + hparams.max_length = 256 + hparams.dropout = 0.0 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 4000 + hparams.initializer_gain = 0.2 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.label_smoothing = 0.0 + hparams.bottom["targets"] = modalities.make_targets_bottom( + modalities.image_channel_embeddings_bottom) + hparams.top["targets"] = modalities.identity_top + hparams.norm_type = "layer" + hparams.layer_prepostprocess_dropout = 0.0 + hparams.add_hparam("filter_size", 512) # Add new ones like this. + + # attention-related flags + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("attention_key_channels", 0) + hparams.add_hparam("attention_value_channels", 0) + hparams.add_hparam("ffn_layer", "conv_hidden_relu") + # All hyperparameters ending in "dropout" are automatically set to 0.0 + # when not in training mode. + hparams.add_hparam("attention_dropout", 0.0) + hparams.add_hparam("relu_dropout", 0.0) + hparams.add_hparam("pos", "timing") # timing, none + hparams.add_hparam("nbr_decoder_problems", 1) + hparams.add_hparam("num_output_layers", 3) + hparams.add_hparam("block_size", 1) + + # image size related flags + # assuming that the image has same height and width + hparams.add_hparam("img_len", 32) + hparams.add_hparam("num_channels", 3) + # Local attention params + hparams.add_hparam("local_and_global_att", False) + hparams.add_hparam("block_length", 256) + hparams.add_hparam("block_width", 128) + # Local 2D attention params + hparams.add_hparam("query_shape", (16, 16)) + hparams.add_hparam("memory_flange", (16, 32)) + hparams.add_hparam("num_encoder_layers", 4) + hparams.add_hparam("num_decoder_layers", 8) + # attention type related params + hparams.add_hparam("enc_attention_type", cia.AttentionType.GLOBAL) + hparams.add_hparam("dec_attention_type", cia.AttentionType.LOCAL_2D) + hparams.add_hparam("block_raster_scan", False) + + # multipos attention params + hparams.add_hparam("q_filter_width", 1) + hparams.add_hparam("kv_filter_width", 1) + + hparams.add_hparam("unconditional", False) # unconditional generation + + # relative embedding hparams + hparams.add_hparam("shared_rel", False) + return hparams + + +@registry.register_hparams +def imagetransformer2d_base(): + hparams = image_transformer2d_base() + hparams.dec_attention_type = cia.AttentionType.LOCAL_2D + hparams.block_raster_scan = True + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_8l_8_16(): + hparams = image_transformer2d_base() + hparams.num_decoder_layers = 8 + hparams.batch_size = 1 + hparams.memory_flange = (8, 16) + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_8l_8_16_ls(): + hparams = image_transformer2d_base() + hparams.num_decoder_layers = 8 + hparams.label_smoothing = 0.05 + hparams.batch_size = 1 + hparams.memory_flange = (8, 16) + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_8l_8_16_big(): + hparams = image_transformer2d_base() + hparams.filter_size = 1024 + hparams.num_decoder_layers = 8 + hparams.batch_size = 1 + hparams.memory_flange = (8, 16) + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_12l_8_16_big(): + hparams = image_transformer2d_base() + hparams.filter_size = 1024 + hparams.num_decoder_layers = 12 + hparams.batch_size = 1 + hparams.memory_flange = (8, 16) + hparams.sampling_method = "random" + hparams.beam_size = 1 + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_8l_8_32_big(): + """hparams fo 8 layer big 2d model for cifar 10.""" + hparams = image_transformer2d_base() + hparams.num_heads = 16 + hparams.hidden_size = 1024 + hparams.filter_size = 2048 + hparams.num_decoder_layers = 8 + hparams.batch_size = 1 + hparams.layer_prepostprocess_dropout = 0.3 + hparams.query_shape = (8, 16) + hparams.memory_flange = (0, 32) + hparams.unconditional = int(False) + return hparams + + +@registry.register_hparams +def imagetransformer_base_10l_8h_big_uncond_dr03_dan_64_2d(): + """big 1d model for unconditional generation on imagenet.""" + hparams = image_transformer2d_base() + hparams.unconditional = True + hparams.hidden_size = 512 + hparams.batch_size = 1 + hparams.img_len = 64 + hparams.num_heads = 8 + hparams.filter_size = 2048 + hparams.batch_size = 1 + hparams.max_length = 3075 + hparams.max_length = 14000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.1 + hparams.dec_attention_type = cia.AttentionType.LOCAL_2D + hparams.query_shape = (16, 16) + hparams.memory_flange = (8, 8) + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_8l_8_64_64by64(): + """hparams fo 12 layer big 2d model for imagenet 64x64.""" + hparams = image_transformer2d_base() + hparams.num_heads = 8 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.num_decoder_layers = 8 + hparams.batch_size = 1 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.query_shape = (8, 64) + hparams.memory_flange = (4, 32) + hparams.unconditional = int(False) + hparams.max_length = 14000 + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_12l_8_64_64by64(): + """hparams fo 12 layer big 2d model for imagenet 64x64.""" + hparams = image_transformer2d_base() + hparams.num_heads = 8 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.num_decoder_layers = 12 + hparams.batch_size = 1 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.query_shape = (8, 64) + hparams.memory_flange = (4, 32) + hparams.unconditional = int(False) + hparams.max_length = 14000 + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_14l_8_16_big(): + hparams = image_transformer2d_base() + hparams.filter_size = 1024 + hparams.num_decoder_layers = 14 + hparams.batch_size = 1 + hparams.memory_flange = (8, 16) + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_14l_8_16_big_uncond(): + hparams = imagetransformer2d_base_14l_8_16_big() + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def imagetransformer2d_base_8l_8_16_big_16k(): + hparams = image_transformer2d_base() + hparams.filter_size = 1024 + hparams.num_decoder_layers = 8 + hparams.batch_size = 1 + hparams.memory_flange = (8, 16) + hparams.learning_rate_warmup_steps = 16000 + return hparams + + +@registry.register_hparams +def img2img_transformer2d_base(): + """Base params for img2img 2d attention.""" + hparams = image_transformer2d_base() + # learning related flags + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + # This version seems to benefit from a higher learning rate. + hparams.learning_rate = 0.2 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.learning_rate_warmup_steps = 12000 + hparams.filter_size = 2048 + hparams.num_encoder_layers = 4 + hparams.num_decoder_layers = 8 + hparams.bottom["inputs"] = modalities.image_channel_embeddings_bottom + hparams.dec_attention_type = cia.AttentionType.LOCAL_2D + hparams.block_raster_scan = True + return hparams + + +@registry.register_hparams +def img2img_transformer2d_q1(): + hparams = img2img_transformer2d_base() + hparams.batch_size = 2 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.query_shape = (16, 16) + hparams.memory_flange = (16, 64) + return hparams + + +@registry.register_hparams +def img2img_transformer2d_q2(): + hparams = img2img_transformer2d_q1() + hparams.batch_size = 2 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.query_shape = (16, 16) + hparams.memory_flange = (16, 32) + return hparams + + +@registry.register_hparams +def img2img_transformer2d_q3(): + """Current best hparams for local 2d.""" + hparams = img2img_transformer2d_q1() + hparams.batch_size = 2 + hparams.query_shape = (8, 16) + hparams.memory_flange = (8, 32) + return hparams + + +@registry.register_hparams +def img2img_transformer_base(): + """Base params for local1d attention.""" + hparams = image_transformer2d_base() + # learning related flags + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + # This version seems to benefit from a higher learning rate. + hparams.learning_rate = 0.2 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.learning_rate_warmup_steps = 12000 + hparams.filter_size = 2048 + hparams.num_encoder_layers = 4 + hparams.num_decoder_layers = 8 + hparams.block_length = 256 + hparams.block_width = 256 + hparams.dec_attention_type = cia.AttentionType.LOCAL_1D + hparams.block_raster_scan = False + return hparams + + +@registry.register_hparams +def img2img_transformer_b1(): + hparams = img2img_transformer_base() + hparams.batch_size = 2 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.block_length = 512 + return hparams + + +@registry.register_hparams +def img2img_transformer_b2(): + hparams = img2img_transformer_base() + hparams.batch_size = 2 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.block_length = 256 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3(): + """Current best hparams for local 1d.""" + hparams = img2img_transformer_base() + hparams.batch_size = 2 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.block_length = 128 + hparams.sampling_temp = 0.9 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs1(): + hparams = img2img_transformer_b3() + hparams.block_size = 1 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs2(): + hparams = img2img_transformer_b3() + hparams.block_size = 2 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs3(): + hparams = img2img_transformer_b3() + hparams.block_size = 3 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs4(): + hparams = img2img_transformer_b3() + hparams.block_size = 4 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs5(): + hparams = img2img_transformer_b3() + hparams.block_size = 5 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs6(): + hparams = img2img_transformer_b3() + hparams.block_size = 6 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs7(): + hparams = img2img_transformer_b3() + hparams.block_size = 7 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs8(): + hparams = img2img_transformer_b3() + hparams.block_size = 8 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs9(): + hparams = img2img_transformer_b3() + hparams.block_size = 9 + return hparams + + +@registry.register_hparams +def img2img_transformer_b3_bs10(): + hparams = img2img_transformer_b3() + hparams.block_size = 10 + return hparams + + +@registry.register_hparams +def img2img_transformer_dilated(): + """Try dilated.""" + hparams = img2img_transformer_base() + hparams.add_hparam("num_memory_blocks", 1) + hparams.num_heads = 8 + hparams.attention_key_channels = hparams.attention_value_channels = 0 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.num_decoder_layers = 8 + hparams.sampling_method = "random" + hparams.gap_sizes = [0, 16, 64, 0, 16, 64, 128, 0] + hparams.dec_attention_type = cia.AttentionType.DILATED + hparams.img_len = 64 + hparams.block_length = 128 + hparams.block_width = 128 + return hparams + + +@registry.register_hparams +def imagetransformer2d_tiny(): + hparams = imagetransformer2d_base() + hparams.num_decoder_layers = 2 + hparams.hidden_size = 64 + hparams.batch_size = 1 + return hparams + + +def update_hparams_for_tpu(hparams): + hparams.use_pad_remover = False # where op not supported + hparams.optimizer = "true_adam" + hparams.batch_size = 4 + + +@registry.register_hparams +def img2img_transformer_base_tpu(): + """Hparams for training img2img_transformer on tpu.""" + hparams = img2img_transformer_base() + update_hparams_for_tpu(hparams) + hparams.batch_size = 2 + hparams.num_heads = 4 # heads are expensive on tpu + hparams.num_decoder_layers = 8 + hparams.num_encoder_layers = 4 + hparams.shared_embedding_and_softmax_weights = False + return hparams + + +@registry.register_hparams +def img2img_transformer_tiny_tpu(): + hparams = img2img_transformer_base_tpu() + hparams.num_hidden_layers = 2 + hparams.hidden_size = 16 + hparams.batch_size = 2 + hparams.num_heads = 2 + return hparams + + +@registry.register_hparams +def img2img_transformer2d_n3(): + hparams = img2img_transformer2d_base() + hparams.batch_size = 1 + hparams.num_encoder_layers = 4 + hparams.num_decoder_layers = 12 + hparams.query_shape = (16, 32) + hparams.memory_flange = (16, 16) + hparams.layer_prepostprocess_dropout = 0.0 + return hparams + + +@registry.register_hparams +def img2img_transformer2d_n31(): + """Set of hyperparameters.""" + hparams = img2img_transformer2d_base() + hparams.batch_size = 1 + hparams.num_encoder_layers = 6 + hparams.num_decoder_layers = 12 + hparams.num_heads = 8 + hparams.query_shape = (16, 32) + hparams.memory_flange = (16, 32) + return hparams + + +@registry.register_hparams +def img2img_transformer2d_n24(): + """Set of hyperparameters.""" + hparams = img2img_transformer2d_base() + hparams.batch_size = 1 + hparams.hidden_size = 1024 + hparams.filter_size = 2048 + hparams.layer_prepostprocess_dropout = 0.2 + hparams.num_decoder_layers = 8 + hparams.query_shape = (8, 16) + hparams.memory_flange = (8, 32) + return hparams + + +@registry.register_hparams +def img2img_transformer2d_n44(): + hparams = img2img_transformer2d_base() + hparams.batch_size = 1 + hparams.num_decoder_layers = 8 + hparams.query_shape = (8, 16) + hparams.memory_flange = (8, 32) + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def img2img_transformer2d_n103(): + """Best config for img2img.""" + hparams = img2img_transformer2d_base() + hparams.batch_size = 1 + hparams.num_decoder_layers = 12 + hparams.num_encoder_layers = 6 + hparams.query_shape = (8, 32) + hparams.memory_flange = (8, 64) + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def img2img_transformer2d_tiny(): + """Tiny params.""" + hparams = img2img_transformer2d_base() + hparams.num_decoder_layers = 2 + hparams.hidden_size = 128 + hparams.batch_size = 4 + hparams.max_length = 128 + hparams.attention_key_channels = hparams.attention_value_channels = 0 + hparams.filter_size = 128 + hparams.num_heads = 4 + hparams.pos = "timing" + hparams.img_len = 32 + return hparams + + +@registry.register_hparams +def img2img_transformer_tiny(): + """Tiny params.""" + hparams = img2img_transformer2d_base() + hparams.num_hidden_layers = 2 + hparams.hidden_size = 128 + hparams.batch_size = 4 + hparams.max_length = 128 + hparams.attention_key_channels = hparams.attention_value_channels = 0 + hparams.filter_size = 128 + hparams.num_heads = 1 + hparams.pos = "timing" + return hparams diff --git a/trax/models/image_transformer_2d_test.py b/trax/models/image_transformer_2d_test.py new file mode 100644 index 000000000..de3e73837 --- /dev/null +++ b/trax/models/image_transformer_2d_test.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +from tensor2tensor.data_generators import celeba # pylint: disable=unused-import +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models import image_transformer_2d +from tensor2tensor.utils import registry + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class Img2imgTransformerTest(tf.test.TestCase): + + def _test_img2img_transformer(self, net): + batch_size = 3 + hparams = image_transformer_2d.img2img_transformer2d_tiny() + hparams.data_dir = "" + p_hparams = registry.problem("image_celeba").get_hparams(hparams) + inputs = np.random.randint(256, size=(batch_size, 4, 4, 3)) + targets = np.random.randint(256, size=(batch_size, 8, 8, 3)) + with self.test_session() as session: + features = { + "inputs": tf.constant(inputs, dtype=tf.int32), + "targets": tf.constant(targets, dtype=tf.int32), + "target_space_id": tf.constant(1, dtype=tf.int32), + } + model = net(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (batch_size, 8, 8, 3, 256)) + + def testImg2imgTransformer(self): + self._test_img2img_transformer(image_transformer_2d.Img2imgTransformer) + + +class Imagetransformer2dTest(tf.test.TestCase): + + def _test_imagetransformer_2d(self, net): + batch_size = 3 + size = 7 + vocab_size = 256 + hparams = image_transformer_2d.imagetransformer2d_tiny() + p_hparams = problem_hparams.test_problem_hparams(vocab_size, + vocab_size, + hparams) + inputs = np.random.randint( + vocab_size, size=(batch_size, 1, 1, 1)) + targets = np.random.randint( + vocab_size, size=(batch_size, size, size, 3)) + with self.test_session() as session: + features = { + "inputs": tf.constant(inputs, dtype=tf.int32), + "targets": tf.constant(targets, dtype=tf.int32), + "target_space_id": tf.constant(1, dtype=tf.int32), + } + model = net(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (batch_size, size, size, 3, vocab_size)) + + def testImagetransformer2d(self): + self._test_imagetransformer_2d(image_transformer_2d.Imagetransformer2d) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/image_transformer_test.py b/trax/models/image_transformer_test.py new file mode 100644 index 000000000..6dde81d5e --- /dev/null +++ b/trax/models/image_transformer_test.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.layers import common_image_attention +from tensor2tensor.models import image_transformer + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class ImagetransformerTest(parameterized.TestCase, tf.test.TestCase): + + @parameterized.named_parameters( + ("ImageTransformerCat", + image_transformer.Imagetransformer, + image_transformer.imagetransformer_tiny()), + ("ImageTransformerDmol", + image_transformer.Imagetransformer, + image_transformer.imagetransformerpp_tiny()), + ) + def testImagetransformer(self, net, hparams): + batch_size = 3 + size = 7 + vocab_size = 256 + p_hparams = problem_hparams.test_problem_hparams(vocab_size, + vocab_size, + hparams) + inputs = np.random.randint( + vocab_size, size=(batch_size, 1, 1, 1)) + targets = np.random.randint( + vocab_size, size=(batch_size, size, size, 3)) + with self.test_session() as session: + features = { + "inputs": tf.constant(inputs, dtype=tf.int32), + "targets": tf.constant(targets, dtype=tf.int32), + "target_space_id": tf.constant(1, dtype=tf.int32), + } + model = net(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + if hparams.likelihood == common_image_attention.DistributionType.CAT: + expected = (batch_size, size, size, 3, vocab_size) + else: + expected = (batch_size, size, size, hparams.num_mixtures * 10) + self.assertEqual(res.shape, expected) + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/lstm.py b/trax/models/lstm.py new file mode 100644 index 000000000..f59dabb19 --- /dev/null +++ b/trax/models/lstm.py @@ -0,0 +1,524 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RNN LSTM models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +from tensor2tensor.layers import area_attention +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import contrib +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +def _dropout_lstm_cell(hparams, train): + return tf.nn.rnn_cell.DropoutWrapper( + tf.nn.rnn_cell.LSTMCell(hparams.hidden_size), + input_keep_prob=1.0 - hparams.dropout * tf.to_float(train)) + + +def lstm(inputs, sequence_length, hparams, train, name, initial_state=None): + """Adds a stack of LSTM layers on top of input. + + Args: + inputs: The input `Tensor`, shaped `[batch_size, time_steps, hidden_size]`. + sequence_length: Lengths of the actual input sequence, excluding padding; a + `Tensor` shaped `[batch_size]`. + hparams: HParams; hyperparameters. + train: bool; `True` when constructing training graph to enable dropout. + name: string; Create variable names under this scope. + initial_state: tuple of `LSTMStateTuple`s; the initial state of each layer. + + Returns: + A tuple (outputs, states), where: + outputs: The output `Tensor`, shaped `[batch_size, time_steps, + hidden_size]`. + states: A tuple of `LSTMStateTuple`s; the final state of each layer. + Bidirectional LSTM returns a concatenation of last forward and backward + state, reduced to the original dimensionality. + """ + layers = [_dropout_lstm_cell(hparams, train) + for _ in range(hparams.num_hidden_layers)] + with tf.variable_scope(name): + return tf.nn.dynamic_rnn( + tf.nn.rnn_cell.MultiRNNCell(layers), + inputs, + sequence_length, + initial_state=initial_state, + dtype=tf.float32, + time_major=False) + + +def lstm_attention_decoder(inputs, hparams, train, name, initial_state, + encoder_outputs, encoder_output_length, + decoder_input_length): + """Run LSTM cell with attention on inputs of shape [batch x time x size]. + + Args: + inputs: The decoder input `Tensor`, shaped `[batch_size, decoder_steps, + hidden_size]`. + hparams: HParams; hyperparameters. + train: bool; `True` when constructing training graph to enable dropout. + name: string; Create variable names under this scope. + initial_state: Tuple of `LSTMStateTuple`s; the initial state of each layer. + encoder_outputs: Encoder outputs; a `Tensor` shaped `[batch_size, + encoder_steps, hidden_size]`. + encoder_output_length: Lengths of the actual encoder outputs, excluding + padding; a `Tensor` shaped `[batch_size]`. + decoder_input_length: Lengths of the actual decoder inputs, excluding + padding; a `Tensor` shaped `[batch_size]`. + + Raises: + ValueError: If the hparams.attention_mechanism is anything other than + luong or bahdanau. + + Returns: + The decoder output `Tensor`, shaped `[batch_size, decoder_steps, + hidden_size]`. + """ + layers = [_dropout_lstm_cell(hparams, train) + for _ in range(hparams.num_hidden_layers)] + if hparams.attention_mechanism == "luong": + attention_mechanism_class = contrib.seq2seq().LuongAttention + elif hparams.attention_mechanism == "bahdanau": + attention_mechanism_class = contrib.seq2seq().BahdanauAttention + else: + raise ValueError("Unknown hparams.attention_mechanism = %s, must be " + "luong or bahdanau." % hparams.attention_mechanism) + if hparams.get("max_area_width", 1) > 1: + def _area_key_value_fn(keys, values): + """Custom fn for computing area keys and values.""" + tf.logging.info("max_area_width=%d, area_key_mode=%s, area_value_mode=%s", + hparams.get("max_area_width", 1), + hparams.get("area_key_mode", "none"), + hparams.get("area_value_mode", "none")) + keys = area_attention.compute_area_key( + keys, max_area_width=hparams.get("max_area_width", 1), + mode=hparams.get("area_key_mode", "none"), name="decoder_encoder", + training=(hparams.mode == tf_estimator.ModeKeys.TRAIN)) + if hparams.get("area_value_mode", "none") == "sum": + _, _, values, _, _ = area_attention.compute_area_features( + values, max_area_width=hparams.get("max_area_width", 1)) + elif hparams.get("area_value_mode", "none") == "mean": + values, _, _, _, _ = area_attention.compute_area_features( + values, max_area_width=hparams.get("max_area_width", 1)) + else: + raise ValueError( + "Unsupported area_value_mode: %s" % hparams.get( + "area_value_mode", "none")) + return keys, values + area_mask = area_attention.lengths_to_area_mask( + feature_length=encoder_output_length, + length=common_layers.shape_list(encoder_outputs)[1], + max_area_size=hparams.get("max_area_width", "1")) + def _area_prob_fn(score): + alignments = tf.nn.softmax(score) + alignments = tf.where(area_mask, alignments, tf.zeros_like(alignments)) + alignments = tf.div(alignments, tf.reduce_sum( + alignments, axis=-1, keepdims=True)) + return alignments + attention_mechanism = attention_mechanism_class( + hparams.hidden_size, encoder_outputs, + memory_sequence_length=None, + probability_fn=_area_prob_fn, + custom_key_value_fn=_area_key_value_fn) + else: + attention_mechanism = attention_mechanism_class(hparams.hidden_size, + encoder_outputs) + cell = contrib.seq2seq().AttentionWrapper( + tf.nn.rnn_cell.MultiRNNCell(layers), + [attention_mechanism] * hparams.num_heads, + attention_layer_size=[hparams.attention_layer_size] * hparams.num_heads, + output_attention=(hparams.output_attention == 1)) + + batch_size = common_layers.shape_list(inputs)[0] + + initial_state = cell.zero_state(batch_size, tf.float32).clone( + cell_state=initial_state) + + with tf.variable_scope(name): + output, _ = tf.nn.dynamic_rnn( + cell, + inputs, + decoder_input_length, + initial_state=initial_state, + dtype=tf.float32, + time_major=False) + # output is [batch_size, decoder_steps, attention_size], where + # attention_size is either hparams.hidden_size (when + # hparams.output_attention is 0) or hparams.attention_layer_size (when + # hparams.output_attention is 1) times the number of attention heads. + # + # For multi-head attention project output back to hidden size. + if hparams.output_attention == 1 and hparams.num_heads > 1: + output = tf.layers.dense(output, hparams.hidden_size) + + return output + + +def lstm_seq2seq_internal(inputs, targets, hparams, train): + """The basic LSTM seq2seq model, main step used for training.""" + with tf.variable_scope("lstm_seq2seq"): + if inputs is not None: + inputs_length = common_layers.length_from_embedding(inputs) + # Flatten inputs. + inputs = common_layers.flatten4d3d(inputs) + + # LSTM encoder. + inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1) + _, final_encoder_state = lstm(inputs, inputs_length, hparams, train, + "encoder") + else: + final_encoder_state = None + + # LSTM decoder. + shifted_targets = common_layers.shift_right(targets) + # Add 1 to account for the padding added to the left from shift_right + targets_length = common_layers.length_from_embedding(shifted_targets) + 1 + decoder_outputs, _ = lstm( + common_layers.flatten4d3d(shifted_targets), + targets_length, + hparams, + train, + "decoder", + initial_state=final_encoder_state) + return tf.expand_dims(decoder_outputs, axis=2) + + +def lstm_seq2seq_internal_attention(inputs, targets, hparams, train, + inputs_length, targets_length): + """LSTM seq2seq model with attention, main step used for training.""" + with tf.variable_scope("lstm_seq2seq_attention"): + # Flatten inputs. + inputs = common_layers.flatten4d3d(inputs) + + # LSTM encoder. + inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1) + encoder_outputs, final_encoder_state = lstm( + inputs, inputs_length, hparams, train, "encoder") + + # LSTM decoder with attention. + shifted_targets = common_layers.shift_right(targets) + # Add 1 to account for the padding added to the left from shift_right + targets_length = targets_length + 1 + decoder_outputs = lstm_attention_decoder( + common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", + final_encoder_state, encoder_outputs, inputs_length, targets_length) + return tf.expand_dims(decoder_outputs, axis=2) + + +def lstm_bid_encoder(inputs, sequence_length, hparams, train, name): + """Bidirectional LSTM for encoding inputs that are [batch x time x size].""" + + with tf.variable_scope(name): + cell_fw = tf.nn.rnn_cell.MultiRNNCell( + [_dropout_lstm_cell(hparams, train) + for _ in range(hparams.num_hidden_layers)]) + + cell_bw = tf.nn.rnn_cell.MultiRNNCell( + [_dropout_lstm_cell(hparams, train) + for _ in range(hparams.num_hidden_layers)]) + + ((encoder_fw_outputs, encoder_bw_outputs), + (encoder_fw_state, encoder_bw_state)) = tf.nn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + sequence_length, + dtype=tf.float32, + time_major=False) + + encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2) + encoder_states = [] + + for i in range(hparams.num_hidden_layers): + if isinstance(encoder_fw_state[i], tf.nn.rnn_cell.LSTMStateTuple): + encoder_state_c = tf.concat( + values=(encoder_fw_state[i].c, encoder_bw_state[i].c), + axis=1, + name="encoder_fw_state_c") + encoder_state_h = tf.concat( + values=(encoder_fw_state[i].h, encoder_bw_state[i].h), + axis=1, + name="encoder_fw_state_h") + encoder_state = tf.nn.rnn_cell.LSTMStateTuple( + c=encoder_state_c, h=encoder_state_h) + elif isinstance(encoder_fw_state[i], tf.Tensor): + encoder_state = tf.concat( + values=(encoder_fw_state[i], encoder_bw_state[i]), + axis=1, + name="bidirectional_concat") + + encoder_states.append(encoder_state) + + encoder_states = tuple(encoder_states) + return encoder_outputs, encoder_states + + +def lstm_seq2seq_internal_bid_encoder(inputs, targets, hparams, train): + """The basic LSTM seq2seq model with bidirectional encoder.""" + with tf.variable_scope("lstm_seq2seq_bid_encoder"): + if inputs is not None: + inputs_length = common_layers.length_from_embedding(inputs) + # Flatten inputs. + inputs = common_layers.flatten4d3d(inputs) + # LSTM encoder. + _, final_encoder_state = lstm_bid_encoder( + inputs, inputs_length, hparams, train, "encoder") + else: + inputs_length = None + final_encoder_state = None + # LSTM decoder. + shifted_targets = common_layers.shift_right(targets) + # Add 1 to account for the padding added to the left from shift_right + targets_length = common_layers.length_from_embedding(shifted_targets) + 1 + hparams_decoder = copy.copy(hparams) + hparams_decoder.hidden_size = 2 * hparams.hidden_size + decoder_outputs, _ = lstm( + common_layers.flatten4d3d(shifted_targets), + targets_length, + hparams_decoder, + train, + "decoder", + initial_state=final_encoder_state) + return tf.expand_dims(decoder_outputs, axis=2) + + +def lstm_seq2seq_internal_attention_bid_encoder(inputs, targets, hparams, + train): + """LSTM seq2seq model with attention, main step used for training.""" + with tf.variable_scope("lstm_seq2seq_attention_bid_encoder"): + inputs_length = common_layers.length_from_embedding(inputs) + # Flatten inputs. + inputs = common_layers.flatten4d3d(inputs) + # LSTM encoder. + encoder_outputs, final_encoder_state = lstm_bid_encoder( + inputs, inputs_length, hparams, train, "encoder") + # LSTM decoder with attention + shifted_targets = common_layers.shift_right(targets) + # Add 1 to account for the padding added to the left from shift_right + targets_length = common_layers.length_from_embedding(shifted_targets) + 1 + hparams_decoder = copy.copy(hparams) + hparams_decoder.hidden_size = 2 * hparams.hidden_size + decoder_outputs = lstm_attention_decoder( + common_layers.flatten4d3d(shifted_targets), hparams_decoder, train, + "decoder", final_encoder_state, encoder_outputs, + inputs_length, targets_length) + return tf.expand_dims(decoder_outputs, axis=2) + + +@registry.register_model +class LSTMEncoder(t2t_model.T2TModel): + """LSTM encoder only.""" + + def body(self, features): + if self._hparams.initializer == "orthogonal": + raise ValueError("LSTM models fail with orthogonal initializer.") + train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN + inputs = features.get("inputs") + inputs_length = common_layers.length_from_embedding(inputs) + # Flatten inputs. + inputs = common_layers.flatten4d3d(inputs) + # LSTM encoder. + inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1) + encoder_output, _ = lstm(inputs, inputs_length, self._hparams, train, + "encoder") + return tf.expand_dims(encoder_output, axis=2) + + +@registry.register_model +class LSTMSeq2seq(t2t_model.T2TModel): + + def body(self, features): + # TODO(lukaszkaiser): investigate this issue and repair. + if self._hparams.initializer == "orthogonal": + raise ValueError("LSTM models fail with orthogonal initializer.") + train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN + return lstm_seq2seq_internal(features.get("inputs"), features["targets"], + self._hparams, train) + + +@registry.register_model +class LSTMSeq2seqAttention(t2t_model.T2TModel): + """Seq to seq LSTM with attention.""" + + def body(self, features): + # TODO(lukaszkaiser): investigate this issue and repair. + if self._hparams.initializer == "orthogonal": + raise ValueError("LSTM models fail with orthogonal initializer.") + train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN + # This is a temporary fix for varying-length sequences within in a batch. + # A more complete fix should pass a length tensor from outside so that + # all the lstm variants can use it. + input_shape = common_layers.shape_list(features["inputs_raw"]) + flat_input = tf.reshape(features["inputs_raw"], + [input_shape[0], input_shape[1]]) + inputs_length = tf.reduce_sum(tf.minimum(flat_input, 1), -1) + target_shape = common_layers.shape_list(features["targets_raw"]) + flat_target = tf.reshape(features["targets_raw"], + [target_shape[0], target_shape[1]]) + targets_length = tf.reduce_sum(tf.minimum(flat_target, 1), -1) + tf.logging.info(self._hparams) + return lstm_seq2seq_internal_attention( + features["inputs"], features["targets"], self._hparams, train, + inputs_length, targets_length) + + +@registry.register_model +class LSTMSeq2seqBidirectionalEncoder(t2t_model.T2TModel): + + def body(self, features): + # TODO(lukaszkaiser): investigate this issue and repair. + if self._hparams.initializer == "orthogonal": + raise ValueError("LSTM models fail with orthogonal initializer.") + train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN + return lstm_seq2seq_internal_bid_encoder( + features.get("inputs"), features["targets"], self._hparams, train) + + +@registry.register_model +class LSTMSeq2seqAttentionBidirectionalEncoder(t2t_model.T2TModel): + + def body(self, features): + # TODO(lukaszkaiser): investigate this issue and repair. + if self._hparams.initializer == "orthogonal": + raise ValueError("LSTM models fail with orthogonal initializer.") + train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN + return lstm_seq2seq_internal_attention_bid_encoder( + features.get("inputs"), features["targets"], self._hparams, train) + + +@registry.register_hparams +def lstm_seq2seq(): + """hparams for LSTM.""" + hparams = common_hparams.basic_params1() + hparams.daisy_chain_variables = False + hparams.batch_size = 1024 + hparams.hidden_size = 128 + hparams.num_hidden_layers = 2 + hparams.initializer = "uniform_unit_scaling" + hparams.initializer_gain = 1.0 + hparams.weight_decay = 0.0 + return hparams + + +def lstm_attention_base(): + """Base attention params.""" + hparams = lstm_seq2seq() + hparams.add_hparam("attention_layer_size", hparams.hidden_size) + hparams.add_hparam("output_attention", True) + hparams.add_hparam("num_heads", 1) + return hparams + + +@registry.register_hparams +def lstm_bahdanau_attention(): + """Hparams for LSTM with bahdanau attention.""" + hparams = lstm_attention_base() + hparams.add_hparam("attention_mechanism", "bahdanau") + return hparams + + +@registry.register_hparams +def lstm_luong_attention(): + """Hparams for LSTM with luong attention.""" + hparams = lstm_attention_base() + hparams.add_hparam("attention_mechanism", "luong") + return hparams + + +@registry.register_hparams +def lstm_attention(): + """For backwards compatibility, defaults to bahdanau.""" + return lstm_bahdanau_attention() + + +@registry.register_hparams +def lstm_bahdanau_attention_multi(): + """Multi-head Bahdanau attention.""" + hparams = lstm_bahdanau_attention() + hparams.num_heads = 4 + return hparams + + +@registry.register_hparams +def lstm_luong_attention_multi(): + """Multi-head Luong attention.""" + hparams = lstm_luong_attention() + hparams.num_heads = 4 + return hparams + + +@registry.register_hparams +def lstm_asr_v1(): + """Basic LSTM Params.""" + hparams = lstm_bahdanau_attention() + hparams.num_hidden_layers = 2 + hparams.hidden_size = 256 + hparams.batch_size = 36 + hparams.max_input_seq_length = 600000 + hparams.max_target_seq_length = 350 + hparams.max_length = hparams.max_input_seq_length + hparams.min_length_bucket = hparams.max_input_seq_length // 2 + hparams.learning_rate = 0.05 + return hparams + + +@registry.register_hparams +def lstm_area_attention_base(): + """Hparams for LSTM with area attention.""" + hparams = lstm_luong_attention() + hparams.batch_size = 16384 + hparams.num_hidden_layers = 2 + hparams.hidden_size = 1024 + hparams.num_heads = 4 + hparams.dropout = 0.2 + hparams.learning_rate = 0.1 + hparams.max_area_width = 2 + hparams.area_key_mode = "mean" + hparams.area_value_mode = "sum" + return hparams + + +@registry.register_hparams +def lstm_area_attention_enfr(): + """Hparams for LSTM with area attention.""" + hparams = lstm_area_attention_base() + hparams.dropout = 0.1 + return hparams + + +@registry.register_hparams +def lstm_area_attention_char(): + """Hparams for LSTM with area attention.""" + hparams = lstm_area_attention_base() + hparams.batch_size = 20480 + return hparams + + +@registry.register_hparams +def lstm_area_attention_char_enfr(): + """Hparams for LSTM with area attention.""" + hparams = lstm_area_attention_char() + hparams.dropout = 0.1 + return hparams diff --git a/trax/models/lstm_test.py b/trax/models/lstm_test.py new file mode 100644 index 000000000..4723998db --- /dev/null +++ b/trax/models/lstm_test.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LSTMSeq2Seq models tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models import lstm + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class LSTMTest(tf.test.TestCase): + + def testLSTMSeq2Seq(self): + vocab_size = 9 + x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1)) + y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1)) + hparams = lstm.lstm_seq2seq() + p_hparams = problem_hparams.test_problem_hparams(vocab_size, + vocab_size, + hparams) + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + } + model = lstm.LSTMSeq2seq(hparams, tf_estimator.ModeKeys.TRAIN, + p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size)) + + def testLSTMSeq2SeqAttention(self): + vocab_size = 9 + x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1)) + y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1)) + hparams = lstm.lstm_attention() + + p_hparams = problem_hparams.test_problem_hparams(vocab_size, + vocab_size, + hparams) + x = tf.constant(x, dtype=tf.int32) + x = tf.placeholder_with_default(x, shape=[None, None, 1, 1]) + + with self.test_session() as session: + features = { + "inputs": x, + "targets": tf.constant(y, dtype=tf.int32), + } + model = lstm.LSTMSeq2seqAttention( + hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size)) + + def testLSTMSeq2seqBidirectionalEncoder(self): + vocab_size = 9 + x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1)) + y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1)) + hparams = lstm.lstm_seq2seq() + p_hparams = problem_hparams.test_problem_hparams(vocab_size, + vocab_size, + hparams) + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + } + model = lstm.LSTMSeq2seqBidirectionalEncoder( + hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size)) + + def testLSTMSeq2seqAttentionBidirectionalEncoder(self): + vocab_size = 9 + x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1)) + y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1)) + hparams = lstm.lstm_attention() + + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) + x = tf.constant(x, dtype=tf.int32) + x = tf.placeholder_with_default(x, shape=[None, None, 1, 1]) + + with self.test_session() as session: + features = { + "inputs": x, + "targets": tf.constant(y, dtype=tf.int32), + } + model = lstm.LSTMSeq2seqAttentionBidirectionalEncoder( + hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/mtf_image_transformer.py b/trax/models/mtf_image_transformer.py new file mode 100644 index 000000000..dffe8c66b --- /dev/null +++ b/trax/models/mtf_image_transformer.py @@ -0,0 +1,637 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Image Transformer model with model and data parallelism using MTF. + +Integration of Mesh tensorflow with Image Transformer to do model parallelism. +Currently, this supports unconditional image generation. Specify a particular +architecture layout in the hparams that specifies how different dimensions are +split or replicated along the mesh dimensions. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import mesh_tensorflow as mtf + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import mtf_model +from tensor2tensor.utils import registry +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +@registry.register_model +class MtfImageTransformer(mtf_model.MtfModel): + """Image Transformer in mesh_tensorflow.""" + + @property + def inputs_vocab_dim(self): + assert self.has_input + return mtf.Dimension("inputs_vocab", self._hparams.num_classes) + + @property + def targets_vocab_dim(self): + vocab_size = self._problem_hparams.vocab_size["targets"] + if hasattr(self._hparams, "vocab_divisor"): + vocab_size += (-vocab_size) % self._hparams.vocab_divisor + return mtf.Dimension("vocab", vocab_size) + + @property + def outputs_vocab_dim(self): + return mtf.Dimension("output_vocab", 256) + + @property + def pos_dim(self): + return mtf.Dimension("pos", self._hparams.img_len) + + @property + def rows_dim(self): + return mtf.Dimension("rows", self._hparams.img_len) + + @property + def cols_dim(self): + return mtf.Dimension( + "cols", self._hparams.img_len*self._hparams.num_channels) + + @property + def orig_cols_dim(self): + return mtf.Dimension("orig_cols", self._hparams.img_len) + + @property + def channels_dim(self): + return mtf.Dimension("channels", self._hparams.num_channels) + + @property + def model_dim(self): + return mtf.Dimension("d_model", self._hparams.hidden_size) + + @property + def max_length_dim(self): + return mtf.Dimension( + "max_length", + self._hparams.img_len*self._hparams.img_len*self._hparams.num_channels) + + @property + def length_dim(self): + return mtf.Dimension( + "length", + self._hparams.img_len*self._hparams.img_len*self._hparams.num_channels) + + @property + def heads_dim(self): + return mtf.Dimension("heads", self._hparams.num_heads) + + @property + def kv_dim(self): + return mtf.Dimension("d_kv", self._hparams.d_kv) + + @property + def feedforward_dim(self): + return mtf.Dimension("d_ff", self._hparams.d_ff) + + @property + def activation_type(self): + hparams = self._hparams + if hparams.activation_dtype == "float32": + activation_dtype = tf.float32 + elif hparams.activation_dtype == "float16": + activation_dtype = tf.float16 + elif hparams.activation_dtype == "bfloat16": + activation_dtype = tf.bfloat16 + else: + raise ValueError( + "unknown hparams.activation_dtype %s" % hparams.activation_dtype) + return activation_dtype + + def create_positional_emb_2d(self, targets): + """Learned 2d positional embedding for images.""" + mesh = targets.mesh + + positional_emb_rows_var = mtf.get_variable( + mesh, "positional_emb_rows", + mtf.Shape([self.pos_dim, self.model_dim]), + initializer=tf.random_normal_initializer(), + activation_dtype=self.activation_type) + positional_emb_cols_var = mtf.get_variable( + mesh, "positional_emb_cols", + mtf.Shape([self.pos_dim, self.model_dim]), + initializer=tf.random_normal_initializer(), + activation_dtype=self.activation_type) + + targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32) + targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32) + position_x = mtf.broadcast( + mtf.gather(positional_emb_rows_var, targets_position_x, + self.pos_dim), + mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) + + position_y = mtf.broadcast( + mtf.gather(positional_emb_cols_var, targets_position_y, + self.pos_dim), + mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) + return position_x + position_y + + def mtf_model_fn(self, features, mesh): + features = copy.copy(features) + tf.logging.info("features = %s" % features) + hparams = self._hparams + activation_dtype = self.activation_type + + # We assume fixed vocab size for targets + targets = tf.to_int32(features["targets"]) + + # Image preprocessing, reshape into a 1D sequence and shift right. + length = hparams.img_len*hparams.img_len*hparams.num_channels + targets = tf.reshape(targets, [hparams.batch_size, length]) + shifted_targets = common_layers.shift_right_2d(targets) + + # Declare all the dimensions + batch_dim = mtf.Dimension("batch", hparams.batch_size) + + def import_to_batch_by_length(x, name): + return mtf.import_tf_tensor( + mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) + + targets = import_to_batch_by_length(targets, "targets") + shifted_targets = import_to_batch_by_length( + shifted_targets, "shifted_targets") + + extra_losses = [] + + # Create targets content and position embeddings. + # Create embedding var for targets and positions and do a gather. + targets_embedding_var = mtf.get_variable( + mesh, "targets_embedding", + mtf.Shape([self.targets_vocab_dim, self.model_dim]), + initializer=tf.random_normal_initializer(), + activation_dtype=activation_dtype) + + x = mtf.gather(targets_embedding_var, + shifted_targets, self.targets_vocab_dim) + + # Add positional embeddings + x += mtf.reshape(self.create_positional_emb_2d(targets), + [self.length_dim, self.model_dim]) + + # If conditional and input is given, add the input embedding to the target. + # TODO(nikip): Verify conditional. + if self.has_input and not hparams.unconditional: + inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) + inputs = import_to_batch_by_length(inputs, "inputs") + + # Input embeddings + inputs_embedding_var = mtf.layers.embedding( + mesh, "input_embedding", + mtf.Shape([self.inputs_vocab_dim, self.model_dim]), + activation_dtype=activation_dtype) + inputs_emb = mtf.gather( + inputs_embedding_var, inputs, self.inputs_vocab_dim) + x += inputs_emb + + # Image Transformer Decoder + # [ self attention - ffn - residual + dropout] x n + if hparams.attention_type == "local1d_spatial": + decoder_output = local_attention1d_spatial_decoder( + x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) + elif hparams.attention_type == "local2d_spatial": + decoder_output = local_attention2d_spatial_decoder( + x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) + elif hparams.attention_type == "local1d": + decoder_output = local_attention1d_masked_decoder( + x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) + else: + raise ValueError("Invalid attention type.") + + # Calculate the logits and loss. + logits = mtf.layers.dense( + decoder_output, self.outputs_vocab_dim, name="logits") + # Need a reshape for logits + logits = mtf.reshape( + logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim])) + soft_targets = mtf.one_hot( + targets, self.outputs_vocab_dim, dtype=activation_dtype) + loss = mtf.layers.softmax_cross_entropy_with_logits( + logits, soft_targets, self.outputs_vocab_dim) + loss = mtf.reduce_mean(loss) + for l in extra_losses: + loss += l + + # Reshape logits to original target shape. + logits = mtf.reshape( + logits, + mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim, + self.channels_dim, self.outputs_vocab_dim])) + + return logits, loss + + +def layer_prepostprocess_dropout(x, hparams): + batch_dim = x.shape.dims[0] + model_dim = x.shape.dims[-1] + mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) + is_training = mode == tf_estimator.ModeKeys.TRAIN + return mtf.dropout( + x, is_training, + keep_prob=1.0 - hparams.layer_prepostprocess_dropout, + noise_shape=mtf.Shape([batch_dim, model_dim])) + + +def local_attention1d_spatial_decoder(x, kv_dim, heads_dim, + feedforward_dim, hparams): + """Image Transformer decoder with local1D spatial layers.""" + batch_dim, length_dim, model_dim = x.shape.dims + blocks_w_dim = mtf.Dimension("blocksw", hparams.block_length) + num_w_blocks_dim = mtf.Dimension("num_wblocks", + length_dim.size // blocks_w_dim.size) + x = mtf.reshape( + x, mtf.Shape([batch_dim, num_w_blocks_dim, blocks_w_dim, model_dim])) + # [ self attention - ffn - residual + dropout] x n + mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) + is_training = mode == tf_estimator.ModeKeys.TRAIN + for layer in range(hparams.num_decoder_layers): + layer_name = "decoder_layer_%d" % layer + with tf.variable_scope(layer_name): + # Self attention layer + x += layer_prepostprocess_dropout( + mtf.layers.local_self_attention_spatial_blocks( + mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), + kv_dim, + heads_dim, + is_training, + memory_w_dim=blocks_w_dim, + mask_right=True, + name="self_att"), hparams) + # ffn layer + x += layer_prepostprocess_dropout( + mtf.layers.dense_relu_dense( + mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), + feedforward_dim, + is_training, + hparams.dropout, + dropout_broadcast_dims=[length_dim]), hparams) + + output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm") + return output + + +def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, + feedforward_dim, hparams): + """Image Transformer decoder with local2D spatial layers.""" + batch_dim, length_dim, model_dim = x.shape.dims + blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height) + blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width) + num_h_blocks_dim = mtf.Dimension("num_h_blocks", + hparams.img_len // hparams.block_height) + num_w_blocks_dim = mtf.Dimension( + "num_w_blocks", + hparams.img_len * hparams.num_channels // hparams.block_width) + x = mtf.transpose( + mtf.reshape( + x, + mtf.Shape([ + batch_dim, num_h_blocks_dim, blocks_h_dim, + num_w_blocks_dim, blocks_w_dim, model_dim + ])), + mtf.Shape([ + batch_dim, num_h_blocks_dim, num_w_blocks_dim, + blocks_h_dim, blocks_w_dim, model_dim + ])) + mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) + is_training = mode == tf_estimator.ModeKeys.TRAIN + # Image Transformer Decoder + # [ self attention - ffn - residual + dropout] x n + for layer in range(hparams.num_decoder_layers): + layer_name = "decoder_layer_%d" % layer + with tf.variable_scope(layer_name): + # Self attention layer + x += layer_prepostprocess_dropout( + mtf.layers.local_2d_self_attention_spatial_blocks( + mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), + kv_dim, + heads_dim, + is_training, + memory_h_dim=num_h_blocks_dim, + memory_w_dim=num_w_blocks_dim, + name="self_att"), hparams) + # ffn layer + x += layer_prepostprocess_dropout( + mtf.layers.dense_relu_dense( + mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), + feedforward_dim, + hparams.dropout, + dropout_broadcast_dims=[length_dim]), hparams) + + output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm") + return output + + +def local_attention1d_masked_decoder(x, kv_dim, heads_dim, + feedforward_dim, hparams): + """Image Transformer decoder with local1D masked layers.""" + print(x) + _, length_dim, model_dim = x.shape.dims + mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) + is_training = mode == tf_estimator.ModeKeys.TRAIN + for layer in range(hparams.num_decoder_layers): + layer_name = "decoder_layer_%d" % layer + with tf.variable_scope(layer_name): + # Self attention layer + length_per_split = mtf.tensor_dim_to_size_per_split( + hparams.layout, hparams.mesh_shape, length_dim) + x += layer_prepostprocess_dropout( + mtf.layers.masked_local_attention_1d( + mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), + kv_dim, + heads_dim, + is_training, + window_size=hparams.block_length, + length_per_split=length_per_split, + name="self_att"), hparams) + # ffn layer + x += layer_prepostprocess_dropout( + mtf.layers.dense_relu_dense( + mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), + feedforward_dim, + hparams.dropout, + dropout_broadcast_dims=[length_dim]), hparams) + + output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm") + return output + + +@registry.register_hparams +def mtf_image_transformer_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.no_data_parallelism = True + hparams.use_fixed_batch_size = True + hparams.batch_size = 1 + hparams.max_length = 3072 + hparams.hidden_size = 256 + hparams.label_smoothing = 0.0 + # 8-way model-parallelism + hparams.add_hparam("mesh_shape", "batch:8") + hparams.add_hparam("layout", "batch:batch") + hparams.add_hparam("mtf_mode", True) + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("filter_size", 1024) + hparams.add_hparam("num_encoder_layers", 0) + hparams.add_hparam("num_decoder_layers", 6) + hparams.add_hparam("attention_key_size", 256) + hparams.add_hparam("attention_value_size", 256) + # Share weights between input and target embeddings + hparams.shared_embedding = True + + # mixture of experts hparams + hparams.add_hparam("ffn_layer", "dense_relu_dense") + hparams.add_hparam("moe_overhead_train", 1.0) + hparams.add_hparam("moe_overhead_eval", 2.0) + hparams.moe_num_experts = 16 + hparams.moe_loss_coef = 1e-3 + + hparams.shared_embedding_and_softmax_weights = True + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay" + hparams.learning_rate_warmup_steps = 10000 + hparams.add_hparam("d_kv", 64) + hparams.add_hparam("d_ff", 2048) + + # Image related hparams + hparams.add_hparam("img_len", 32) + hparams.add_hparam("num_channels", 3) + hparams.add_hparam("unconditional", True) + + # Local Attention related params + hparams.add_hparam("block_length", 128) + hparams.add_hparam("block_height", 16) + hparams.add_hparam("block_width", 16) + hparams.add_hparam("attention_type", "local1d") + return hparams + + +@registry.register_hparams +def mtf_image_transformer_tiny(): + """Catch bugs locally...""" + hparams = mtf_image_transformer_base() + hparams.hidden_size = 128 + hparams.d_ff = 256 + hparams.batch_size = 4 + hparams.num_encoder_layers = 1 + hparams.num_decoder_layers = 4 + hparams.num_heads = 4 + hparams.attention_key_size = 128 + hparams.attention_value_size = 128 + hparams.block_length = 32 + # data parallelism and model-parallelism + hparams.mesh_shape = "batch:2" + hparams.layout = "batch:batch" + return hparams + + +@registry.register_hparams +def mtf_image_transformer_single(): + """Small single parameters.""" + hparams = mtf_image_transformer_tiny() + hparams.mesh_shape = "" + hparams.layout = "" + hparams.hidden_size = 32 + hparams.filter_size = 32 + hparams.batch_size = 1 + hparams.num_encoder_layers = 1 + hparams.num_decoder_layers = 1 + hparams.num_heads = 2 + hparams.attention_key_size = 32 + hparams.attention_value_size = 32 + hparams.block_length = 16 + return hparams + + +@registry.register_hparams +def mtf_image_transformer_base_single(): + """Small single parameters.""" + hparams = mtf_image_transformer_base() + hparams.num_decoder_layers = 6 + hparams.filter_size = 256 + hparams.block_length = 128 + hparams.mesh_shape = "" + hparams.layout = "" + return hparams + + +@registry.register_hparams +def mtf_image_transformer_tiny_spatial1d(): + """Small single parameters.""" + hparams = mtf_image_transformer_tiny() + hparams.num_decoder_layers = 6 + hparams.filter_size = 128 + hparams.block_height = 8 + hparams.block_width = 8 + hparams.attention_type = "local1d_spatial" + hparams.mesh_shape = "" + hparams.layout = "" + return hparams + + +@registry.register_hparams +def mtf_image_transformer_tiny_spatial2d(): + """Small single parameters.""" + hparams = mtf_image_transformer_tiny() + hparams.num_decoder_layers = 6 + hparams.filter_size = 128 + hparams.block_height = 8 + hparams.block_width = 8 + hparams.attention_type = "local2d_spatial" + hparams.mesh_shape = "b1:2,b2:2" + hparams.layout = "num_h_blocks:b1,num_wblocks:b2" + return hparams + + +@registry.register_hparams +def mtf_image_transformer_base_cifar(): + """Data parallel CIFAR parameters.""" + hparams = mtf_image_transformer_base() + hparams.mesh_shape = "batch:8" + hparams.layout = "batch:batch" + hparams.learning_rate_decay_steps = 13600 # one epoch + hparams.batch_size = 32 + hparams.num_heads = 4 + hparams.num_decoder_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.d_ff = 2048 + hparams.learning_rate = 0.5 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def mtf_image_transformer_cifar_4x(): + """Data parallel CIFAR parameters.""" + hparams = mtf_image_transformer_base_cifar() + hparams.mesh_shape = "batch:32" + hparams.layout = "batch:batch" + hparams.batch_size = 128 + return hparams + + +@registry.register_hparams +def mtf_image_transformer_cifar_mp_4x(): + """Data parallel CIFAR parameters.""" + hparams = mtf_image_transformer_base_cifar() + hparams.mesh_shape = "model:4;batch:8" + hparams.layout = "batch:batch;d_ff:model;heads:model" + hparams.batch_size = 32 + hparams.num_heads = 8 + hparams.d_ff = 8192 + return hparams + + +@registry.register_hparams +def mtf_image_transformer_base_imagenet(): + """Data parallel CIFAR parameters.""" + hparams = mtf_image_transformer_base_cifar() + hparams.mesh_shape = "batch:32" + hparams.layout = "batch:batch" + hparams.batch_size = 128 + hparams.d_ff = 2048 + hparams.hidden_size = 512 + hparams.num_decoder_layers = 12 + hparams.learning_rate = 0.5 + hparams.learning_rate_warmup_steps = 31250 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.1 + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def mtf_image_transformer_base_imagenet_mp(): + """Model parallel ImageNet parameters.""" + hparams = mtf_image_transformer_base_imagenet() + hparams.mesh_shape = "model:4;batch:8" + hparams.layout = "batch:batch;d_ff:model;heads:model" + hparams.batch_size = 32 + hparams.num_heads = 8 + hparams.d_ff = 8192 + hparams.learning_rate_warmup_steps = 31250 + hparams.unconditional = True + return hparams + + +@registry.register_hparams +def mtf_image_transformer_base_imagenet_mp128(): + """Model parallel ImageNet parameters.""" + hparams = mtf_image_transformer_base_imagenet() + hparams.mesh_shape = "model:8;batch:4" + hparams.layout = "batch:batch;d_ff:model;heads:model" + hparams.batch_size = 8 + hparams.img_len = 128 + hparams.block_length = 128 + hparams.num_heads = 8 + hparams.num_decoder_layers = 4 + hparams.d_ff = 4096 + hparams.learning_rate_warmup_steps = 31250 + hparams.unconditional = True + hparams.max_length = 256*256*3 + return hparams + + +@registry.register_hparams +def mtf_image_transformer_base_imagenet_mp_sp(): + """Model parallel ImageNet parameters.""" + hparams = mtf_image_transformer_base_imagenet_mp128() + hparams.mesh_shape = "model:8;batch:4" + hparams.layout = "batch:batch;d_ff:model;num_wblocks:model" + hparams.batch_size = 8 + hparams.img_len = 128 + hparams.block_length = 128 + hparams.attention_type = "local1d_spatial" + return hparams + + +@registry.register_hparams +def mtf_image_transformer_base_imagenet_mp64(): + """Model parallel ImageNet parameters.""" + hparams = mtf_image_transformer_base_imagenet() + hparams.mesh_shape = "model:8;batch:4" + hparams.layout = "batch:batch;d_ff:model;heads:model" + hparams.batch_size = 8 + hparams.img_len = 64 + hparams.num_decoder_layers = 8 + return hparams + + +@registry.register_hparams +def mtf_image_transformer_tiny_8gpu(): + hparams = mtf_image_transformer_tiny() + hparams.mesh_shape = "all:8" + hparams.layout = "vocab:all;filter_size:all;heads:all" + return hparams + + +@registry.register_hparams +def mtf_image_transformer_length_sharded(): + hparams = mtf_image_transformer_tiny() + hparams.mesh_shape = "all:2" + hparams.layout = "length:all" + return hparams diff --git a/trax/models/mtf_image_transformer_test.py b/trax/models/mtf_image_transformer_test.py new file mode 100644 index 000000000..4737d16ea --- /dev/null +++ b/trax/models/mtf_image_transformer_test.py @@ -0,0 +1,142 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Image Transformer on Mesh TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import mesh_tensorflow as mtf + +import numpy as np +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models import mtf_image_transformer + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + +# Constants shared between all functions. +BATCH_SIZE = 8 +IMG_LENGTH = 8 +VOCAB_SIZE = 256 + + +def get_model(hparams=None, + mode=tf_estimator.ModeKeys.TRAIN, + model_cls=mtf_image_transformer.MtfImageTransformer): + if hparams is None: + hparams = mtf_image_transformer.mtf_image_transformer_single() + hparams.max_length = IMG_LENGTH*IMG_LENGTH + hparams.batch_size = BATCH_SIZE + hparams.img_len = IMG_LENGTH + hparams.num_channels = 1 + + p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, + VOCAB_SIZE, + hparams) + del p_hparams.modality["inputs"] + hparams.problem_hparams = p_hparams + + targets = np.random.randint( + VOCAB_SIZE, size=(BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, 1, 1)) + features = { + "targets": tf.constant(targets, dtype=tf.int32, name="targets"), + } + + return model_cls(hparams, mode, p_hparams), features, hparams + + +def get_placement_mesh(hparams): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) + + mesh_devices = [""] * mesh_shape.size + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + mesh_shape, hparams.layout, mesh_devices) + return mesh, mesh_impl + + +class MtfImageTransformerTest(tf.test.TestCase): + + def testMtfImageTransformer(self): + hparams = mtf_image_transformer.mtf_image_transformer_single() + + # need to know layout ahead of time for local attention. + hparams.mesh_shape = "" + hparams.layout = "" + model, features, hparams = get_model(hparams) + mesh, mesh_impl = get_placement_mesh(hparams) + + logits, _ = model.mtf_model_fn(features, mesh) + lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) + tf_group = lowering.copy_masters_to_slices() + tf_logits = lowering.export_to_tf_tensor(logits) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + session.run(tf_group) + res = session.run(tf_logits) + self.assertEqual(res.shape, + (BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, + hparams.num_channels, VOCAB_SIZE)) + + def testMtfImageTransformerDataParallel(self): + hparams = mtf_image_transformer.mtf_image_transformer_single() + + # need to know layout ahead of time for local attention. + hparams.mesh_shape = "all:2" + hparams.layout = "batch:all" + model, features, hparams = get_model(hparams) + mesh, mesh_impl = get_placement_mesh(hparams) + + logits, _ = model.mtf_model_fn(features, mesh) + lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) + tf_group = lowering.copy_masters_to_slices() + tf_logits = lowering.export_to_tf_tensor(logits) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + session.run(tf_group) + res = session.run(tf_logits) + self.assertEqual(res.shape, + (BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, + hparams.num_channels, VOCAB_SIZE)) + + def testMtfImageTransformerModelParallel(self): + hparams = mtf_image_transformer.mtf_image_transformer_single() + + # need to know layout ahead of time for local attention. + hparams.mesh_shape = "all:2" + hparams.layout = "length:all" + model, features, hparams = get_model(hparams) + mesh, mesh_impl = get_placement_mesh(hparams) + + logits, _ = model.mtf_model_fn(features, mesh) + lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) + tf_group = lowering.copy_masters_to_slices() + tf_logits = lowering.export_to_tf_tensor(logits) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + session.run(tf_group) + res = session.run(tf_logits) + self.assertEqual( + res.shape, + (BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, hparams.num_channels, VOCAB_SIZE)) + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/mtf_resnet.py b/trax/models/mtf_resnet.py new file mode 100644 index 000000000..4ad14ee63 --- /dev/null +++ b/trax/models/mtf_resnet.py @@ -0,0 +1,426 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResNet model with model and data parallelism using MTF. + +Integration of Mesh tensorflow with ResNet to do model parallelism. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import mesh_tensorflow as mtf + +from tensor2tensor.layers import common_hparams +from tensor2tensor.utils import mtf_model +from tensor2tensor.utils import registry +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +BATCH_NORM_DECAY = 0.9 +BATCH_NORM_EPSILON = 1e-5 + + +def batch_norm_relu(inputs, is_training, relu=True): + """Block of batch norm and relu.""" + inputs = mtf.layers.batch_norm( + inputs, + is_training, + BATCH_NORM_DECAY, + epsilon=BATCH_NORM_EPSILON, + init_zero=(not relu)) + if relu: + inputs = mtf.relu(inputs) + return inputs + + +def bottleneck_block(inputs, + filters, + is_training, + strides, + projection_shortcut=None, + row_blocks_dim=None, + col_blocks_dim=None): + """Bottleneck block variant for residual networks with BN after convolutions. + + Args: + inputs: a `mtf.Tensor` of shape + `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. + filters: `int` number of filters for the first two convolutions. Note + that the third and final convolution will use 4 times as many filters. + is_training: `bool` for whether the model is in training mode. + strides: `int` block stride. If greater than 1, this block will ultimately + downsample the input. + projection_shortcut: `function` to use for projection shortcuts (typically + a 1x1 convolution to match the filter dimensions). If None, no + projection is used and the input is passed as unchanged through the + shortcut connection. + row_blocks_dim: a mtf.Dimension, row dimension which is + spatially partitioned along mesh axis + col_blocks_dim: a mtf.Dimension, row dimension which is + spatially partitioned along mesh axis + + Returns: + The output `Tensor` of the block. + """ + shortcut = inputs + + if projection_shortcut is not None: + filters_dim = mtf.Dimension("filtersp", filters) + shortcut = projection_shortcut(inputs, filters_dim) + + # First conv block + inputs = mtf.layers.conv2d_with_blocks( + inputs, + mtf.Dimension("filters1", filters), + filter_size=[1, 1], + strides=[1, 1], + padding="SAME", + h_blocks_dim=None, w_blocks_dim=col_blocks_dim, + name="conv0") + + # TODO(nikip): Add Dropout? + inputs = batch_norm_relu(inputs, is_training) + + # Second conv block + inputs = mtf.layers.conv2d_with_blocks( + inputs, + mtf.Dimension("filters2", 4 * filters), + filter_size=[3, 3], + strides=[1, 1], + padding="SAME", + h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, + name="conv1") + + inputs = batch_norm_relu(inputs, is_training) + + # Third wide conv filter block + inputs = mtf.layers.conv2d_with_blocks( + inputs, + mtf.Dimension("filters3", filters), + filter_size=[1, 1], + strides=strides, + padding="SAME", + h_blocks_dim=None, w_blocks_dim=col_blocks_dim, + name="conv2") + + # TODO(nikip): Althought the original resnet code has this batch norm, in our + # setup this is causing no gradients to be passed. Investigate further. + # inputs = batch_norm_relu(inputs, is_training, relu=True) + + # TODO(nikip): Maybe add residual with a projection? + return mtf.relu( + shortcut + mtf.rename_dimension( + inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name)) + + +def block_layer(inputs, + filters, + blocks, + strides, + is_training, + name, + row_blocks_dim=None, + col_blocks_dim=None): + """Creates one layer of blocks for the ResNet model. + + Args: + inputs: `Tensor` of size `[batch, channels, height, width]`. + filters: `int` number of filters for the first convolution of the layer. + blocks: `int` number of blocks contained in the layer. + strides: `int` stride to use for the first convolution of the layer. If + greater than 1, this layer will downsample the input. + is_training: `bool` for whether the model is training. + name: `str`name for the Tensor output of the block layer. + row_blocks_dim: a mtf.Dimension, row dimension which is + spatially partitioned along mesh axis + col_blocks_dim: a mtf.Dimension, row dimension which is + spatially partitioned along mesh axis + + Returns: + The output `Tensor` of the block layer. + """ + with tf.variable_scope(name, default_name="block_layer"): + # Only the first block per block_layer uses projection_shortcut and strides + def projection_shortcut(inputs, output_dim): + """Project identity branch.""" + inputs = mtf.layers.conv2d_with_blocks( + inputs, + output_dim, + filter_size=[1, 1], + strides=strides, + padding="SAME", + h_blocks_dim=None, w_blocks_dim=col_blocks_dim, + name="shortcut0") + return batch_norm_relu( + inputs, is_training, relu=False) + + inputs = bottleneck_block( + inputs, + filters, + is_training, + strides=strides, + projection_shortcut=projection_shortcut, + row_blocks_dim=row_blocks_dim, + col_blocks_dim=col_blocks_dim) + + for i in range(1, blocks): + with tf.variable_scope("bottleneck_%d" % i): + inputs = bottleneck_block( + inputs, + filters, + is_training, + strides=[1, 1, 1, 1], + projection_shortcut=None, + row_blocks_dim=row_blocks_dim, + col_blocks_dim=col_blocks_dim) + + return inputs + + +@registry.register_model +class MtfResNet(mtf_model.MtfModel): + """ResNet in mesh_tensorflow.""" + + def set_activation_type(self): + hparams = self._hparams + if hparams.activation_dtype == "float32": + activation_dtype = tf.float32 + elif hparams.activation_dtype == "float16": + activation_dtype = tf.float16 + elif hparams.activation_dtype == "bfloat16": + activation_dtype = tf.bfloat16 + else: + raise ValueError( + "unknown hparams.activation_dtype %s" % hparams.activation_dtype) + return activation_dtype + + def mtf_model_fn(self, features, mesh): + features = copy.copy(features) + tf.logging.info("features = %s" % features) + hparams = self._hparams + activation_dtype = self.set_activation_type() + is_training = hparams.mode == tf_estimator.ModeKeys.TRAIN + + # Declare all the dimensions + batch_dim = mtf.Dimension("batch", hparams.batch_size) + hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) + filter_dim = mtf.Dimension("filters", hparams.filter_sizes[0]) + rows_dim = mtf.Dimension("rows_size", hparams.rows_size) + cols_dim = mtf.Dimension("cols_size", hparams.cols_size) + row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) + col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) + classes_dim = mtf.Dimension("classes", 10) + channels_dim = mtf.Dimension("channels", 3) + one_channel_dim = mtf.Dimension("one_channel", 1) + + inputs = features["inputs"] + x = mtf.import_tf_tensor( + mesh, tf.reshape(inputs, [ + hparams.batch_size, + hparams.row_blocks, + hparams.rows_size // hparams.row_blocks, + hparams.col_blocks, + hparams.num_channels*hparams.cols_size // hparams.col_blocks, + hparams.num_channels]), + mtf.Shape( + [batch_dim, row_blocks_dim, rows_dim, + col_blocks_dim, cols_dim, channels_dim])) + x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim, + rows_dim, cols_dim, channels_dim]) + + x = mtf.to_float(x) + x = mtf.layers.conv2d_with_blocks( + x, + filter_dim, + filter_size=[3, 3], + strides=[1, 1], + padding="SAME", + h_blocks_dim=None, w_blocks_dim=col_blocks_dim, + name="initial_filter") + + x = batch_norm_relu(x, is_training) + + # Conv blocks + # [block - strided block layer - strided block layer] x n + for layer in range(hparams.num_layers): + layer_name = "block_layer_%d" % layer + with tf.variable_scope(layer_name): + # Residual block layer + x = block_layer( + inputs=x, + filters=hparams.filter_sizes[0], + blocks=hparams.layer_sizes[0], + strides=[1, 1], + is_training=is_training, + name="block_layer1", + row_blocks_dim=None, + col_blocks_dim=None) + x = block_layer( + inputs=x, + filters=hparams.filter_sizes[1], + blocks=hparams.layer_sizes[1], + strides=[1, 1], + is_training=is_training, + name="block_layer2", + row_blocks_dim=None, + col_blocks_dim=None) + x = block_layer( + inputs=x, + filters=hparams.filter_sizes[2], + blocks=hparams.layer_sizes[2], + strides=[1, 1], + is_training=is_training, + name="block_layer3", + row_blocks_dim=None, + col_blocks_dim=None) + + # Calculate the logits and loss. + out = x + outputs = mtf.layers.dense( + out, hidden_dim, + reduced_dims=out.shape.dims[-5:], + activation=mtf.relu, name="dense") + + # We assume fixed vocab size for targets + labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) + labels = mtf.import_tf_tensor( + mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) + + logits = mtf.layers.dense(outputs, classes_dim, name="logits") + soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) + loss = mtf.layers.softmax_cross_entropy_with_logits( + logits, soft_targets, classes_dim) + + # Reshape logits so it doesn't break inside t2t. + logits = mtf.reshape( + logits, + mtf.Shape([batch_dim, one_channel_dim, classes_dim])) + loss = mtf.reduce_mean(loss) + return logits, loss + + +@registry.register_hparams +def mtf_resnet_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.no_data_parallelism = True + hparams.use_fixed_batch_size = True + hparams.batch_size = 32 + hparams.max_length = 3072 + hparams.hidden_size = 256 + hparams.label_smoothing = 0.0 + # 8-way model-parallelism + hparams.add_hparam("mesh_shape", "batch:8") + hparams.add_hparam("layout", "batch:batch") + hparams.add_hparam("filter_size", 1024) + + hparams.add_hparam("num_layers", 6) + # Share weights between input and target embeddings + hparams.shared_embedding = True + + hparams.shared_embedding_and_softmax_weights = True + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay" + hparams.learning_rate_warmup_steps = 10000 + hparams.add_hparam("d_kv", 32) + + # Image related hparams + hparams.add_hparam("img_len", 32) + hparams.add_hparam("num_channels", 3) + hparams.add_hparam("row_blocks", 1) + hparams.add_hparam("col_blocks", 1) + hparams.add_hparam("rows_size", 32) + hparams.add_hparam("cols_size", 32) + + # Model-specific parameters + hparams.add_hparam("layer_sizes", [3, 4, 6, 3]) + hparams.add_hparam("filter_sizes", [64, 64, 128, 256, 512]) + hparams.add_hparam("is_cifar", False) + + # Variable init + hparams.initializer = "normal_unit_scaling" + hparams.initializer_gain = 2. + + # TODO(nikip): Change optimization scheme? + hparams.learning_rate = 0.1 + return hparams + + +@registry.register_hparams +def mtf_resnet_tiny(): + """Catch bugs locally...""" + hparams = mtf_resnet_base() + hparams.num_layers = 2 + hparams.hidden_size = 64 + hparams.filter_size = 64 + hparams.batch_size = 16 + # data parallelism and model-parallelism + hparams.col_blocks = 1 + hparams.mesh_shape = "batch:2" + hparams.layout = "batch:batch" + hparams.layer_sizes = [1, 2, 3] + hparams.filter_sizes = [64, 64, 64] + return hparams + + +@registry.register_hparams +def mtf_resnet_single(): + """Small single parameters.""" + hparams = mtf_resnet_tiny() + hparams.mesh_shape = "" + hparams.layout = "" + hparams.hidden_size = 32 + hparams.filter_size = 32 + hparams.batch_size = 1 + hparams.num_encoder_layers = 1 + hparams.num_layers = 1 + hparams.block_length = 16 + return hparams + + +@registry.register_hparams +def mtf_resnet_base_single(): + """Small single parameters.""" + hparams = mtf_resnet_base() + hparams.num_layers = 6 + hparams.filter_size = 256 + hparams.block_length = 128 + hparams.mesh_shape = "" + hparams.layout = "" + return hparams + + +@registry.register_hparams +def mtf_resnet_base_cifar(): + """Data parallel CIFAR parameters.""" + hparams = mtf_resnet_base() + hparams.mesh_shape = "batch:32" + hparams.layoyt = "batch:batch" + hparams.batch_size = 8 + hparams.num_layers = 12 + hparams.block_length = 256 + hparams.hidden_size = 512 + hparams.filter_size = 2048 + hparams.learning_rate = 0.5 + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_preprocess_sequence = "none" + hparams.layer_postprocess_sequence = "dan" + hparams.layer_prepostprocess_dropout = 0.3 + hparams.unconditional = True + return hparams diff --git a/trax/models/mtf_transformer.py b/trax/models/mtf_transformer.py new file mode 100644 index 000000000..42bb88705 --- /dev/null +++ b/trax/models/mtf_transformer.py @@ -0,0 +1,1195 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import mesh_tensorflow as mtf +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.layers import modalities +from tensor2tensor.models.research import moe +from tensor2tensor.utils import mtf_model +from tensor2tensor.utils import registry + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +@registry.register_model +class MtfTransformer(mtf_model.MtfModel): + """Transformer in mesh_tensorflow.""" + + def __init__(self, + hparams, + mode=tf_estimator.ModeKeys.TRAIN, + problem_hparams=None, + data_parallelism=None, + decode_hparams=None, + **kwargs): + """Init with assignments of hparams.encoder_layers / decoder_layers.""" + # Finalize encoder_layers, decoder_layers + hparams.encoder_layers = ( + hparams.encoder_layers * hparams.encoder_replicate_factor) + hparams.decoder_layers = ( + hparams.decoder_layers * hparams.decoder_replicate_factor) + + super(MtfTransformer, self).__init__(hparams, + mode=mode, + problem_hparams=problem_hparams, + data_parallelism=data_parallelism, + decode_hparams=decode_hparams, + **kwargs) + + @property + def batch_dims(self): + hparams = self._hparams + if hparams.outer_batch_size == 0: + return [mtf.Dimension("batch", hparams.batch_size)] + else: + if hparams.batch_size % hparams.outer_batch_size != 0: + raise ValueError( + "hparams.outer_batch_size must divide hparams.batch_size") + return [ + mtf.Dimension("outer_batch", hparams.outer_batch_size), + mtf.Dimension("inner_batch", + hparams.batch_size // hparams.outer_batch_size)] + + @property + def inputs_vocab_dim(self): + assert self.has_input + return mtf.Dimension("vocab", self._inputs_vocab_size) + + @property + def targets_vocab_dim(self): + return mtf.Dimension("vocab", self._targets_vocab_size) + + @property + def model_dim(self): + return mtf.Dimension("d_model", self._hparams.d_model) + + @property + def max_length_dim(self): + return mtf.Dimension("max_length", self._hparams.max_length) + + @property + def length_dim(self): + return mtf.Dimension("length", self._hparams.max_length) + + @property + def memory_length_dim(self): + return mtf.Dimension("memory_length", self._hparams.max_length) + + @property + def heads_dim(self): + return mtf.Dimension("heads", self._hparams.num_heads) + + @property + def kv_dim(self): + return mtf.Dimension("d_kv", self._hparams.d_kv) + + @property + def feedforward_dim(self): + return mtf.Dimension("d_ff", self._hparams.d_ff) + + @property + def master_dtype(self): + return tf.as_dtype(self._hparams.master_dtype) + + @property + def slice_dtype(self): + return tf.as_dtype(self._hparams.slice_dtype) + + @property + def activation_dtype(self): + return tf.as_dtype(self._hparams.activation_dtype) + + def _import_to_batch_by_length(self, x, name, mesh, hparams): + del hparams + mtf_shape = mtf.Shape(self.batch_dims + [self.length_dim]) + x = tf.reshape(x, mtf_shape.to_integer_list) + return mtf.import_fully_replicated(mesh, x, mtf_shape, name=name) + + def _embedding_and_softmax_vars(self, mesh): + hparams = self._hparams + if hparams.transformer_type == "encoder": + targets_embedding_var = None + else: + targets_embedding_var = mtf.get_variable( + mesh, "targets_embedding", + mtf.Shape([self.targets_vocab_dim, self.model_dim]), + initializer=tf.random_normal_initializer(), + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + activation_dtype=self.activation_dtype) + if hparams.transformer_type == "decoder": + inputs_embedding_var = None + else: + if hparams.shared_embedding and targets_embedding_var: + inputs_embedding_var = targets_embedding_var + else: + inputs_embedding_var = mtf.get_variable( + mesh, "inputs_embedding", + mtf.Shape([self.inputs_vocab_dim, self.model_dim]), + initializer=tf.random_normal_initializer(), + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + activation_dtype=self.activation_dtype) + if hparams.shared_embedding_and_softmax_weights: + softmax_var = (targets_embedding_var or inputs_embedding_var) * ( + self.model_dim.size ** -0.5) + else: + softmax_var = mtf.get_variable( + mesh, + "softmax", + mtf.Shape([self.targets_vocab_dim, self.model_dim]), + initializer=tf.random_normal_initializer( + stddev=self.model_dim.size**-0.5), + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + activation_dtype=self.activation_dtype) + positional_embedding_var = mtf.get_variable( + mesh, "positional_embedding", + mtf.Shape([self.max_length_dim, self.model_dim]), + initializer=tf.random_normal_initializer(), + activation_dtype=self.activation_dtype) + return (inputs_embedding_var, targets_embedding_var, + softmax_var, positional_embedding_var) + + def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): + if noising_spec["type"] == "mask": + # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. + return targets * mtf.cast( + mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), + noising_spec["prob"]), targets.dtype) + elif noising_spec["type"] == "random_zipfian": + # Replace a randomly-chosen noising_spec["prob"] of input tokens. + # Rather than drawing the replacement tokens uniformly, we sample from + # a distribution favoring lower token-ids, assuming that the ids have + # been assigned in frequency order. The probability of choosing an + # id is proportional to 1/(id+10) + logits = mtf.log(1.0 / (mtf.range( + targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) + logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) + r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) + use_noise = mtf.less( + mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) + return mtf.where(use_noise, r, targets) + elif noising_spec["type"] == "transformer": + # Train a small transformer to fill in masked out values, then + # sample from it. + hparams = self._hparams + if hparams.mode != tf_estimator.ModeKeys.TRAIN: + raise NotImplementedError("Not implemented") + noiser_hparams = copy.copy(self._hparams) + noiser_hparams.del_hparam("mode") + noiser_hparams.override_from_dict(noising_spec["overrides"]) + with tf.variable_scope("noiser"): + noiser = MtfTransformer( + noiser_hparams, + mode=hparams.mode, + problem_hparams=self._problem_hparams) + logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access + self._original_features, targets.mesh) + samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) + losses.append(loss) + return samples + else: + raise ValueError("unknown noising spec %s" % noising_spec) + + def _noisy_targets(self, targets, losses=None): + """Generate noisy targets for denoising models. + + Args: + targets: a Tensor + losses: an optional list onto which to append traning losses + Returns: + a Tensor the same dtype and shape as Targets + """ + hparams = self._hparams + if hparams.mode == tf_estimator.ModeKeys.TRAIN: + nt_train = self._noisy_targets_from_spec( + targets, hparams.noising_spec_train, losses=losses) + if hparams.noising_use_eval_during_train > 0: + nt_eval = self._noisy_targets_from_spec( + targets, hparams.noising_spec_eval) + use_eval_noising = mtf.less( + mtf.random_uniform(targets.mesh, targets.shape - self.length_dim), + hparams.noising_use_eval_during_train) + nt_train = mtf.where(use_eval_noising, nt_eval, nt_train) + return nt_train + else: + return self._noisy_targets_from_spec(targets, hparams.noising_spec_eval) + + def _mtf_model_fn(self, features, mesh): + self._original_features = features + features = copy.copy(features) + hparams = self._hparams + extra_losses = [] + targets = tf.to_int32(features["targets"]) + mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) + is_training = mode == tf_estimator.ModeKeys.TRAIN + if len(targets.get_shape()) > 2: + tf.logging.info("targets = %s" % targets) + targets = tf.squeeze(targets, [2, 3]) + # pad targets to max_length + def pad_to_max_length(x): + extra_length = hparams.max_length - tf.shape(x)[1] + x = tf.pad(x, [[0, 0], [0, extra_length]]) + x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) + return x + targets = pad_to_max_length(targets) + targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) + for key in ["targets_segmentation", "targets_position", + "inputs_segmentation", "inputs_position"]: + if key in features: + features[key] = pad_to_max_length(features[key]) + if hparams.decoder_type == "autoregressive": + shifted_targets = mtf.shift( + targets, offset=1, dim=self.length_dim, wrap=False) + elif hparams.decoder_type == "denoising": + shifted_targets = self._noisy_targets(targets, extra_losses) + else: + raise ValueError( + "unknown hparams.decoder_type = %s" % hparams.decoder_type) + + if "targets_segmentation" in features: + # "Packed" dataset - keep the examples from seeing each other. + targets_segmentation = self._import_to_batch_by_length( + features["targets_segmentation"], "targets_segmentation", + mesh, hparams) + targets_position = self._import_to_batch_by_length( + features["targets_position"], "targets_position", + mesh, hparams) + decoder_self_attention_mask = mtf.layers.attention_mask_same_segment( + targets_segmentation, dtype=self.activation_dtype) + if hparams.decoder_type == "autoregressive": + decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive( + targets_position, dtype=self.activation_dtype) + else: + targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) + if hparams.decoder_type == "autoregressive": + decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( + targets_position, dtype=self.activation_dtype) + else: + decoder_self_attention_mask = None + + def layer_prepostprocess_dropout(x): + return mtf.dropout( + x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, + noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) + + (inputs_embedding_var, + targets_embedding_var, + softmax_var, + positional_embedding_var) = self._embedding_and_softmax_vars(mesh) + if hparams.transformer_type == "decoder": + encoder_output = None + encoder_decoder_attention_mask = None + else: + inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) + inputs = pad_to_max_length(inputs) + inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) + if "inputs_segmentation" in features: + # "Packed" dataset - keep the examples from seeing each other. + inputs_segmentation = self._import_to_batch_by_length( + features["inputs_segmentation"], "inputs_segmentation", + mesh, hparams) + inputs_position = self._import_to_batch_by_length( + features["inputs_position"], "inputs_position", + mesh, hparams) + encoder_self_attention_mask = ( + mtf.layers.attention_mask_same_segment( + inputs_segmentation, dtype=self.activation_dtype)) + else: + inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) + encoder_self_attention_mask = ( + mtf.layers.attention_mask_ignore_padding( + inputs, dtype=self.activation_dtype)) + + x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + + mtf.gather(positional_embedding_var, inputs_position, + self.max_length_dim)) + x = layer_prepostprocess_dropout(x) + with tf.variable_scope("encoder"): + x = self._layer_stack(x, + hparams.encoder_layers, + self_attention_mask=encoder_self_attention_mask, + losses=extra_losses) + + if hparams.transformer_type == "encdec": + if "inputs_segmentation" in features: + encoder_decoder_attention_mask = ( + mtf.layers.attention_mask_same_segment( + targets_segmentation, inputs_segmentation, + dtype=self.activation_dtype)) + else: + encoder_decoder_attention_mask = encoder_self_attention_mask + encoder_output = mtf.rename_dimension( + x, self.length_dim.name, self.memory_length_dim.name) + + if hparams.transformer_type != "encoder": + # DECODER + x = (mtf.gather( + targets_embedding_var, shifted_targets, self.targets_vocab_dim) + + mtf.gather( + positional_embedding_var, targets_position, self.max_length_dim)) + x = layer_prepostprocess_dropout(x) + with tf.variable_scope("decoder"): + x = self._layer_stack( + x, + hparams.decoder_layers, + encoder_output=encoder_output, + self_attention_mask=decoder_self_attention_mask, + encdec_attention_mask=encoder_decoder_attention_mask, + losses=extra_losses) + if (hparams.reshape_logits_hack and + hparams.mode == tf_estimator.ModeKeys.TRAIN): + # For some reason, the logits computation is extremely slow on TPU + # in some cases where the batch size per core is 1. Reshape the logits + # and the targets to double the batch size and halve the length. + # TODO(noam): file a bug. + old_dims = self.batch_dims + [self.length_dim] + new_dims = self.batch_dims[:-1] + [ + mtf.Dimension(self.batch_dims[-1].name, + self.batch_dims[-1].size * 2), + mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)] + x = mtf.reshape(x, new_dims + [self.model_dim]) + targets = mtf.reshape(targets, new_dims) + + logits = mtf.matmul(x, softmax_var) + if hparams.mode == tf_estimator.ModeKeys.TRAIN: + logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) + off_value = hparams.label_smoothing / self._targets_vocab_size + on_value = 1.0 - hparams.label_smoothing + off_value + soft_targets = mtf.one_hot( + targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, + dtype=self.activation_dtype) + loss = mtf.layers.softmax_cross_entropy_with_logits( + logits, soft_targets, self.targets_vocab_dim) + weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) + loss = mtf.reduce_mean(loss * weights) + for l in extra_losses: + loss += l + if (hparams.reshape_logits_hack and + hparams.mode == tf_estimator.ModeKeys.TRAIN): + logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim]) + logits = mtf.to_float(logits) + return logits, loss + + def mtf_model_fn(self, features, mesh): + with tf.variable_scope("transformer"): + logits, loss = self._mtf_model_fn(features, mesh) + # combine batch dims + if len(self.batch_dims) > 1: + combined_batch_dim = mtf.Dimension( + self.batch_dims[0].name, mtf.Shape(self.batch_dims).size) + logits = mtf.reshape( + logits, [combined_batch_dim] + logits.shape.dims[-2:]) + return logits, loss + + @property + def _targets_vocab_size(self): + targets_vocab_size = self._problem_hparams.vocab_size["targets"] + targets_vocab_size += (-targets_vocab_size) % self._hparams.vocab_divisor + return targets_vocab_size + + @property + def _inputs_vocab_size(self): + inputs_vocab_size = self._problem_hparams.vocab_size["inputs"] + inputs_vocab_size += (-inputs_vocab_size) % self._hparams.vocab_divisor + return inputs_vocab_size + + def _feedforward_layer(self, x, layer_type, losses=None): + """Feed-forward layer. + + Args: + x: a mtf.Tensor with shape [, length_dim, model_dim] + layer_type: a string + losses: a list to be appended-to + Returns: + a mtf.Tensor with shape [, length_dim, model_dim] + Raises: + ValueError: if hparams make no sense + """ + hparams = self._hparams + mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) + is_training = mode == tf_estimator.ModeKeys.TRAIN + if layer_type == "drd": + return mtf.layers.dense_relu_dense( + x, self.feedforward_dim, is_training, dropout=hparams.relu_dropout, + dropout_broadcast_dims=[self.length_dim], + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype) + elif layer_type == "none": + return x + elif layer_type == "moe": + output, loss = moe.transformer_moe_layer_v1( + x, + self.model_dim, + hparams, + hparams.mode == tf_estimator.ModeKeys.TRAIN, + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype) + if losses is not None: + losses.append(loss) + return output + elif layer_type == "hmoe": + output, loss = moe.transformer_moe_layer_v2( + x, + self.model_dim, + hparams, + hparams.mode == tf_estimator.ModeKeys.TRAIN, + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype) + if losses is not None: + losses.append(loss) + return output + else: + raise ValueError("layer_type not recognized %s" % layer_type) + + def _layer_stack(self, + x, + layers, + encoder_output=None, + self_attention_mask=None, + encdec_attention_mask=None, + losses=None, + step_num=None, + encdec_tensors=None, + states=None): + """Encoder or decoder stack. + + Args: + x: a mtf.Tensor with shape [, length_dim, model_dim] + layers: an list of strings + encoder_output: an optional mtf.Tensor with shape + [, encoder_length_dim, model_dim] + self_attention_mask: an optional mtf.Tensor with shape + [batch, length_dim, memory_length_dim] containing values 0 or -inf. + encdec_attention_mask: an optional mtf.Tensor with shape + [batch, length_dim, encoder_length_dim] containing values 0 or -inf. + losses: a list to be appended-to + step_num: an optional mtf integer Scalar (used in incrmenental mode) + encdec_tensors: an optional list of num_layers tuples, each of the form + (q_var, o_var, k, v), (used in incremental mode) + states: an optional list of Tensors (used in incremental mode) + Returns: + a mtf.Tensor with shape [, length_dim, model_dim] + Raises: + ValueError: if hparams make no sense + """ + hparams = self._hparams + is_incremental = (step_num is not None) + mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) + is_training = mode == tf_estimator.ModeKeys.TRAIN + def layer_prepostprocess_dropout(x): + if is_incremental: + return x + return mtf.dropout( + x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, + noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) + num_layers = len(layers) + num_layer_norms = num_layers + 1 + layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) + layer_norm_combined_var = mtf.get_variable( + x.mesh, + "layer_norm_scale", + mtf.Shape([layer_norms_dim, self.model_dim]), + initializer=tf.ones_initializer(), + activation_dtype=x.dtype) + layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) + def normalize(x): + scale = layer_norm_vars.pop(0) + variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) + return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale + + if is_incremental: + states = list(states) + new_states = [] + tf.logging.info("states = %s" % (states,)) + + for lnum, layer_type in enumerate(layers): + with tf.variable_scope("%s_%d" % (layer_type, lnum)): + if layer_type == "att": + # Self attention layer + if is_incremental: + y, new_k, new_v = mtf.layers.multihead_self_attention_incremental( + normalize(x), + prev_k=states.pop(0), + prev_v=states.pop(0), + step_num=step_num, + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + name="att") + new_states.append(new_k) + new_states.append(new_v) + x += y + else: + x += layer_prepostprocess_dropout( + mtf.layers.multihead_attention( + normalize(x), None, + self_attention_mask, self.kv_dim, self.heads_dim, + is_training, + dropout=hparams.attention_dropout, + dropout_broadcast_dims=[self.length_dim], + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + name="att")) + elif layer_type == "enc_att": + # Encoder-Decoder attention layer + if is_incremental: + # Encoder-Decoder attention layer + q_var, o_var, k, v = encdec_tensors[lnum] + x += mtf.layers.multihead_encdec_attention_incremental( + normalize(x), + q_var, o_var, k, v, + encdec_attention_mask, + name="enc_att") + else: + x += layer_prepostprocess_dropout( + mtf.layers.multihead_attention( + normalize(x), encoder_output, + encdec_attention_mask, self.kv_dim, self.heads_dim, + is_training, + dropout=hparams.attention_dropout, + dropout_broadcast_dims=[self.length_dim], + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + name="enc_att")) + elif layer_type == "local_att": + if is_incremental: + y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( + normalize(x), + prev_k=states.pop(0), + prev_v=states.pop(0), + step_num=step_num, + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + name="local_att") + new_states.append(new_k) + new_states.append(new_v) + x += y + else: + x += layer_prepostprocess_dropout( + mtf.layers.masked_local_attention_1d( + normalize(x), + self.kv_dim, self.heads_dim, is_training, + window_size=hparams.local_attention_window_size, + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + length_per_split=mtf.tensor_dim_to_size_per_split( + hparams.layout, hparams.mesh_shape, + self.max_length_dim), + name="local_att")) + elif layer_type == "compressed_att": + if is_incremental: + raise ValueError("compressed_att incremental not implemented") + else: + x += layer_prepostprocess_dropout( + mtf.layers.multihead_self_attention_memory_compressed( + normalize(x), + mask_right=True, + compression_factor=hparams.compression_factor, + kv_channels=self.kv_dim, + heads=self.heads_dim, + is_training=is_training, + dropout=hparams.attention_dropout, + dropout_broadcast_dims=[self.length_dim], + master_dtype=self.master_dtype, + slice_dtype=self.slice_dtype, + name="compressed_att")) + else: + if is_incremental: + # insert length dimension. + x_shape = x.shape + shape_with_length = mtf.Shape( + x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + + x_shape.dims[-1:]) + x = mtf.reshape(x, shape_with_length) + # ffn layer + x += layer_prepostprocess_dropout( + self._feedforward_layer(normalize(x), layer_type, losses=losses)) + if is_incremental: + # remove length dimension + x = mtf.reshape(x, x_shape) + + x = layer_prepostprocess_dropout(normalize(x)) + assert not layer_norm_vars + if is_incremental: + return x, new_states + else: + return x + + def sample(self, features, mesh): + with tf.variable_scope("transformer"): + return self._sample(features, mesh) + + def _sample(self, features, mesh): + hparams = self._hparams + (inputs_embedding_var, + targets_embedding_var, + softmax_var, + positional_embedding_var) = self._embedding_and_softmax_vars(mesh) + if hparams.transformer_type == "encdec": + inputs = features["inputs"] + while len(inputs.shape.as_list()) > 2: + inputs = tf.squeeze(inputs, axis=2) + actual_batch_size = tf.shape(inputs)[0] + actual_length = tf.shape(inputs)[1] + inputs = tf.pad( + inputs, [[0, hparams.batch_size - actual_batch_size], + [0, hparams.max_length - actual_length]]) + inputs = self._import_to_batch_by_length( + inputs, "inputs", mesh, hparams) + x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + + mtf.reshape(positional_embedding_var, + mtf.Shape([self.length_dim, self.model_dim]))) + encoder_attention_mask = ( + mtf.layers.attention_mask_ignore_padding( + inputs, dtype=self.activation_dtype)) + with tf.variable_scope("encoder"): + x = self._layer_stack(x, + hparams.encoder_layers, + self_attention_mask=encoder_attention_mask) + encoder_output = mtf.rename_dimension( + x, self.length_dim.name, self.memory_length_dim.name) + encdec_tensors = [] + for layer_num, layer_type in enumerate(hparams.decoder_layers): + if layer_type == "enc_att": + with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): + q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( + mesh, self.heads_dim, self.model_dim, + self.kv_dim, self.master_dtype, self.slice_dtype, + self.activation_dtype) + k = mtf.einsum( + [encoder_output, k_var], + mtf.Shape( + self.batch_dims + [self.heads_dim, + self.memory_length_dim, self.kv_dim])) + v = mtf.einsum( + [encoder_output, v_var], + mtf.Shape( + self.batch_dims + [self.heads_dim, + self.memory_length_dim, self.kv_dim])) + encdec_tensors.append((q_var, o_var, k, v)) + else: + encdec_tensors.append(None) + partial_targets = None + elif hparams.transformer_type == "decoder": + encdec_tensors = None + encoder_output = None + encoder_attention_mask = None + # Prepare partial targets. + # In either features["inputs"] or features["targets"]. + # We force the outputs to begin with these sequences. + partial_targets = features.get("inputs", None) + if partial_targets is None: + partial_targets = features.get("targets", None) + if partial_targets is not None: + partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) + partial_targets = tf.to_int32(partial_targets) + partial_targets_batch = tf.shape(partial_targets)[0] + partial_targets_length = tf.shape(partial_targets)[1] + partial_targets = tf.pad( + partial_targets, [[0, hparams.batch_size - partial_targets_batch], + [0, hparams.max_length - partial_targets_length]]) + partial_targets = self._import_to_batch_by_length( + partial_targets, "partial_targets", mesh, hparams) + else: + raise ValueError( + "hparams.model_type = %s not yet supported" + % hparams.transformer_type) + + local_attention_window = mtf.Dimension( + "local_attention_window", hparams.local_attention_window_size) + if hparams.beam_size == 1: + ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) + kv_shape = mtf.Shape(self.batch_dims + + [self.heads_dim, + self.memory_length_dim, self.kv_dim]) + local_kv_shape = mtf.Shape(self.batch_dims + + [self.heads_dim, + local_attention_window, self.kv_dim]) + else: + beam_dim = mtf.Dimension("beam", hparams.beam_size) + ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) + kv_shape = mtf.Shape(self.batch_dims + + [beam_dim, self.heads_dim, + self.memory_length_dim, self.kv_dim]) + local_kv_shape = mtf.Shape(self.batch_dims + + [beam_dim, self.heads_dim, + local_attention_window, self.kv_dim]) + + initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) + initial_states = [] + for layer in hparams.decoder_layers: + if layer == "att": + initial_states.extend( + [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) + elif layer == "local_att": + initial_states.extend( + [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) + + def logits_fn(step_num, ids, states): + """Produce logits for this step, and new states.""" + ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) + x = (mtf.gather(targets_embedding_var, ids_this_step, + self.targets_vocab_dim) + + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) + with tf.variable_scope("decoder"): + x, new_states = self._layer_stack( + x, + hparams.decoder_layers, + encdec_attention_mask=encoder_attention_mask, + step_num=step_num, + encdec_tensors=encdec_tensors, + states=states) + logits = mtf.matmul(x, softmax_var) + return logits, new_states + + if hparams.beam_size == 1: + temperature = (0.0 if hparams.sampling_method == "argmax" + else hparams.sampling_temp) + return mtf.beam_search.greedy_decode( + logits_fn, + initial_ids, + temperature=temperature, + initial_states=initial_states, + forced_ids=partial_targets, + use_tpu=hparams.use_tpu) + else: + if hparams.transformer_type == "encdec": + input_length = mtf.reduce_sum( + mtf.to_float(mtf.cast(inputs, tf.bool)), + reduced_dim=self.length_dim) + max_input_length = mtf.reduce_max(input_length) + decode_length = mtf.cast( + max_input_length * hparams.decode_length_multiplier + + hparams.decode_length_constant, tf.int32) + else: + decode_length = None + beams, unused_scores = mtf.beam_search.beam_search( + logits_fn, + initial_ids, + hparams.alpha, + states=initial_states, + decode_length=decode_length, + use_tpu=hparams.use_tpu, + dtype=self.activation_dtype) + return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim) + + +@registry.register_hparams +def mtf_transformer_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.no_data_parallelism = True + hparams.use_fixed_batch_size = True + hparams.add_hparam("mtf_mode", True) + hparams.batch_size = 64 + hparams.max_length = 256 + hparams.add_hparam("d_model", 512) + hparams.add_hparam("d_kv", 128) + hparams.add_hparam("local_attention_window_size", 128) + hparams.label_smoothing = 0.1 + # 8-way model-parallelism + hparams.add_hparam("mesh_shape", "model:8") + hparams.add_hparam("layout", "batch:batch;vocab:model;d_ff:model;heads:model") + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("d_ff", 2048) + hparams.add_hparam("encoder_replicate_factor", 1) + hparams.add_hparam("decoder_replicate_factor", 1) + hparams.add_hparam("encoder_layers", ["att", "drd"] * 6) + hparams.add_hparam("decoder_layers", ["att", "enc_att", "drd"] * 6) + hparams.add_hparam("attention_dropout", 0.1) + hparams.add_hparam("relu_dropout", 0.1) + hparams.layer_prepostprocess_dropout = 0.1 + + # Describes what model architecture: + # "encdec": encoder + autoregressive decoder + # "decoder": single-stack autoregressive sequence model. + # "encoder": single-stack non-autoregressive model + # with equal-length inputs and outputs. + hparams.add_hparam("transformer_type", "encdec") + + # What does the decoder do: + # "autoregressive": Decoder left to right + # "denoising": Fills in masked-out values simultaneously + hparams.add_hparam("decoder_type", "autoregressive") + + # Parameters describing the noising algorithm for denoising decoders + hparams.add_hparam("noising_spec_train", {"type": "mask", "prob": 0.15}) + hparams.add_hparam("noising_spec_eval", {"type": "mask", "prob": 0.15}) + # during training, we use the eval noiser with this probability + hparams.add_hparam("noising_use_eval_during_train", 0.1) + + # round up vocab sizes to be a multiple of this value + hparams.vocab_divisor = 128 + + # options are dense_relu_dense, moe, hmoe + hparams.add_hparam("feedforward_layer", "drd") + + # If True, then reuse targets_embedding_var * rsqrt(d_model) as softmax_var + # If hparams.transformer_type == "encoder", then there is no targets embedding + # so we reuse the inputs embedding instead. + hparams.shared_embedding_and_softmax_weights = True + # Reuse targets_embedding_var as inputs_embedding_var + # relevant only if hparams.transformer_type == "encdec" + hparams.shared_embedding = True + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "linear_warmup*rsqrt_decay*linear_decay" + hparams.learning_rate_warmup_steps = 10000 + hparams.add_hparam("master_dtype", "bfloat16") + hparams.add_hparam("slice_dtype", "float32") + hparams.activation_dtype = "bfloat16" + + # These parameters make Transformer model compatible with MtfTransformer + # Do not override these, as mtf_transformer does not support other options. + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.bottom = { + "inputs": modalities.identity_bottom, + "targets": modalities.identity_bottom, + } + hparams.top = { + "targets": modalities.identity_top, + } + + # Parameters for computing the maximum decode length in beam search. + # Maximum decode length is: + # min(max_length, + # decode_length_multiplier * input_length + decode_length_constant) + hparams.add_hparam("decode_length_multiplier", 1.5) + hparams.add_hparam("decode_length_constant", 10.0) + + # If nonzero, we split the batch across two tensor-dimensions named + # "outer_batch" and "inner_batch", allowing for splitting across two mesh + # dimensions. This is necessary for hierarchical mixture of experts. + # The two tensor dimensions have sizes hparams.outer_batch_size and + # hparams.batch_size // hparams.outer_batch_size. + hparams.add_hparam("outer_batch_size", 0) + + # TODO(noam): file a bug + hparams.add_hparam("reshape_logits_hack", False) + hparams.add_hparam("compression_factor", 4) + + return hparams + + +@registry.register_hparams +def mtf_transformer_base_lm(): + hparams = mtf_transformer_base() + hparams.decoder_layers = hparams.encoder_layers + hparams.transformer_type = "decoder" + hparams.label_smoothing = 0.0 + hparams.sampling_method = "random" + return hparams + + +@registry.register_hparams +def mtf_transformer_tiny(): + """Catch bugs locally...""" + hparams = mtf_transformer_base() + hparams.d_model = 128 + hparams.d_ff = 512 + hparams.batch_size = 8 + hparams.encoder_layers = ["att", "drd"] * 2 + hparams.decoder_layers = ["att", "enc_att", "drd"] * 2 + hparams.num_heads = 8 + # data parallelism and model-parallelism + hparams.mesh_shape = "batch:2;model:4" + hparams.activation_dtype = "float32" + return hparams + + +@registry.register_hparams +def mtf_transformer_tiny_lm(): + hparams = mtf_transformer_tiny() + hparams.decoder_layers = hparams.encoder_layers + hparams.transformer_type = "decoder" + hparams.label_smoothing = 0.0 + hparams.sampling_method = "random" + return hparams + + +@registry.register_hparams +def mtf_transformer_tiny_denoising(): + hparams = mtf_transformer_tiny_lm() + hparams.decoder_type = "denoising" + hparams.noising_spec_train = ("random_zipfian", 0.3) + hparams.noising_use_eval_during_train = 0.5 + hparams.max_length = 1024 + return hparams + + +@registry.register_hparams +def mtf_transformer_single(): + hparams = mtf_transformer_tiny() + hparams.mesh_shape = "" + return hparams + + +@registry.register_hparams +def mtf_transformer_enc_single(): + hparams = mtf_transformer_single() + hparams.transformer_type = "encoder" + return hparams + + +@registry.register_hparams +def mtf_transformer_tiny_8gpu(): + hparams = mtf_transformer_tiny() + hparams.mesh_shape = "model:8" + return hparams + + +def mtf_transformer_paper_lm(size): + """Config for language-model experiments. + + Train these on languagemodel_lm1b32k_packed for 136000 steps (10 epochs) + + The size parameter is an integer that controls the number of heads and the + size of the size of the feedforward hidden layers. Increasing size by 1 + doubles each of these. + + Results: + size params/10^9 log-ppl(per-token) + -1 0.14 3.209 + 0 0.22 3.119 + 1 0.37 3.037 + 2 0.67 2.969 + 3 1.28 2.912 + 4 2.48 2.874 + 5 4.90 2.871 + + (to get word-level log-ppl, multiply by 1.1078) + + Args: + size: an integer + Returns: + a hparams object + """ + n = 2 ** size + hparams = mtf_transformer_base_lm() + hparams.batch_size = 256 + hparams.d_model = 1024 + hparams.d_ff = int(8192 * n) + hparams.d_kv = 256 + hparams.num_heads = int(8 * n) + hparams.shared_embedding_and_softmax_weights = False + # one epoch for languagemodel_lm1b32k_packed = 13600 steps + hparams.learning_rate_decay_steps = 13600 + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_lm_m1(): + hparams = mtf_transformer_paper_lm(-1) + hparams.mesh_shape = "batch:32" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_lm_0(): + hparams = mtf_transformer_paper_lm(0) + hparams.mesh_shape = "batch:32" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_lm_1(): + hparams = mtf_transformer_paper_lm(1) + hparams.mesh_shape = "model:4;batch:8" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_lm_2(): + hparams = mtf_transformer_paper_lm(2) + hparams.mesh_shape = "model:4;batch:8" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_lm_3(): + hparams = mtf_transformer_paper_lm(3) + hparams.mesh_shape = "model:8;batch:16" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_lm_4(): + hparams = mtf_transformer_paper_lm(4) + hparams.mesh_shape = "batch:16;model:32" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_lm_5(): + hparams = mtf_transformer_paper_lm(5) + hparams.mesh_shape = "batch:16;model:32" + return hparams + + +def mtf_transformer_paper_tr(size): + """Config for translation experiments. + + Train these on translate_enfr_wmt32k_packed for 154000 steps (3 epochs) + + The size parameter is an integer that controls the number of heads and the + size of the size of the feedforward hidden layers. Increasing size by 1 + doubles each of these. + + Args: + size: an integer + Returns: + a hparams object + """ + n = 2 ** size + hparams = mtf_transformer_base() + hparams.label_smoothing = 0.1 + hparams.batch_size = 128 + hparams.d_model = 1024 + hparams.d_ff = int(4096 * n) + hparams.num_heads = int(8 * n) + hparams.shared_embedding_and_softmax_weights = False + # one epoch for translate_enfr_wmt32k_packed = 51400 steps + hparams.learning_rate_decay_steps = 51400 + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_m1(): + hparams = mtf_transformer_paper_tr(-1) + hparams.mesh_shape = "batch:32" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_0(): + hparams = mtf_transformer_paper_tr(0) + hparams.mesh_shape = "batch:32" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_0_a32(): + hparams = mtf_transformer_paper_tr_0() + hparams.activation_dtype = "float32" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_0_nf(): + hparams = mtf_transformer_paper_tr_0() + hparams.optimizer_adafactor_factored = False + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_1(): + hparams = mtf_transformer_paper_tr(1) + hparams.mesh_shape = "model:4;batch:8" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_2(): + hparams = mtf_transformer_paper_tr(2) + hparams.mesh_shape = "model:4;batch:8" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_3(): + hparams = mtf_transformer_paper_tr(3) + hparams.mesh_shape = "model:8;batch:16" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_4(): + hparams = mtf_transformer_paper_tr(4) + hparams.mesh_shape = "model:8;batch:16" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_0_mesh_8(): + hparams = mtf_transformer_paper_tr(0) + hparams.mesh_shape = "batch:8" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_4_mesh_16_8(): + hparams = mtf_transformer_paper_tr(4) + hparams.mesh_shape = "batch:8;model:16" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_6_mesh_64_8(): + # Note: This mesh shape does align well with physical [16, 16, 2] topology. + hparams = mtf_transformer_paper_tr(6) + hparams.mesh_shape = "model:64;batch:8" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_0_mesh_8_v2(): + hparams = mtf_transformer_paper_tr(0) + hparams.batch_size = int(hparams.batch_size / 4) + hparams.mesh_shape = "batch:8" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_0_mesh_128(): + hparams = mtf_transformer_paper_tr(0) + hparams.batch_size = int(hparams.batch_size * 4) + hparams.mesh_shape = "batch:128" + return hparams + + +@registry.register_hparams +def mtf_transformer_paper_tr_0_mesh_512(): + hparams = mtf_transformer_paper_tr(0) + hparams.batch_size = int(hparams.batch_size * 16) + hparams.mesh_shape = "batch:512" + return hparams + + +@registry.register_hparams +def mtf_transformer_lm_baseline(): + """Small language model to run on 1 TPU. + + Run this on 2x2 on languagemodel_lm1b32k_packed for 272000 steps (10 epochs) + Results: + params/10^9 log-ppl(per-token) + 0.14 3.202 + + Returns: + a hparams + """ + hparams = mtf_transformer_paper_lm(-1) + hparams.batch_size = 128 + hparams.learning_rate_decay_steps = 27200 # one epoch on lm1b + hparams.mesh_shape = "batch:8" + return hparams diff --git a/trax/models/mtf_transformer2.py b/trax/models/mtf_transformer2.py new file mode 100644 index 000000000..ed3ffa88d --- /dev/null +++ b/trax/models/mtf_transformer2.py @@ -0,0 +1,868 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import mesh_tensorflow as mtf +from mesh_tensorflow.transformer import moe +from mesh_tensorflow.transformer import transformer +from mesh_tensorflow.transformer import transformer_layers +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.layers import modalities +from tensor2tensor.utils import mtf_model +from tensor2tensor.utils import registry + +import tensorflow.compat.v1 as tf + + +@registry.register_model +class MtfUnitransformer(mtf_model.MtfModel): + """Single-stack Transformer (Transformer Decoder) in mesh_tensorflow. + + Can optionally be autoregressive (language generation) or non-autoregressive + like BERT. + """ + + @property + def batch_dims(self): + hparams = self._hparams + if hparams.outer_batch_size == 0: + return [mtf.Dimension("batch", hparams.batch_size)] + else: + if hparams.batch_size % hparams.outer_batch_size != 0: + raise ValueError( + "hparams.outer_batch_size must divide hparams.batch_size") + return [ + mtf.Dimension("outer_batch", hparams.outer_batch_size), + mtf.Dimension("inner_batch", + hparams.batch_size // hparams.outer_batch_size)] + + def combine_batch_dims(self, x): + if len(self.batch_dims) <= 1: + return x + return mtf.replace_dimensions( + x, self.batch_dims, mtf.combined_dimension(self.batch_dims)) + + @property + def autoregressive(self): + return self._hparams.autoregressive + + @property + def variable_dtype(self): + return mtf.VariableDType( + tf.as_dtype(self._hparams.master_dtype), + tf.as_dtype(self._hparams.slice_dtype), + tf.as_dtype(self._hparams.activation_dtype)) + + @property + def length_dim(self): + return mtf.Dimension( + "length", self._hparams.length or self._hparams.max_length) + + def _import_to_batch_by_length(self, x, name, mesh): + mtf_shape = mtf.Shape(self.batch_dims + [self.length_dim]) + x = tf.reshape(x, mtf_shape.to_integer_list) + return mtf.import_fully_replicated(mesh, x, mtf_shape, name=name) + + def _import_feature(self, features, mesh, key): + """Import a feature from the features dictionary into a mtf.Tensor. + + Args: + features: a features dictionary + mesh: a Mesh + key: a string + + Returns: + a mtf.Tensor with dtype int32 and shape self.batch_dims + self.length_dim + """ + if key not in features: + return None + x = tf.to_int32(features[key]) + x = common_layers.expand_squeeze_to_nd(x, 2) + batch_size = mtf.Shape(self.batch_dims).size + x = x[:, :self.length_dim.size] + extra_length = self.length_dim.size - tf.shape(x)[1] + extra_batch = batch_size - tf.shape(x)[0] + x = tf.pad(x, [[0, extra_batch], [0, extra_length]]) + mtf_shape = mtf.Shape(self.batch_dims + [self.length_dim]) + x = tf.reshape(x, mtf_shape.to_integer_list) + return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + + def model(self): + hparams = self._hparams + if hparams.label_smoothing != 0: + raise NotImplementedError( + "Label smoothing not implemented in unitransformer." + " Do you really want it?") + layer_stack = layer_stack_from_hparams(hparams, "") + if self.autoregressive: + input_vocab_size = self._targets_vocab_size + else: + input_vocab_size = self._inputs_vocab_size + return transformer.Unitransformer( + layer_stack=layer_stack, + d_model=hparams.d_model, + input_vocab_size=input_vocab_size, + output_vocab_size=self._targets_vocab_size, + autoregressive=self.autoregressive, + max_length=hparams.max_length, + shared_embedding_and_softmax_weights=( + hparams.shared_embedding_and_softmax_weights), + z_loss=hparams.z_loss, + layout=hparams.layout, + mesh_shape=hparams.mesh_shape) + + def _mtf_model_fn(self, features, mesh): + self._original_features = features + hparams = self._hparams + def import_feature(key): + return self._import_feature(features, mesh, key) + targets = import_feature("targets") + sequence_id = import_feature("targets_segmentation") + if hparams.use_global_position_in_packed_sequence: + position = None + else: + position = import_feature("targets_position") + if self.autoregressive: + inputs = mtf.shift( + targets, offset=1, dim=self.length_dim, wrap=False) + # We should have a 0 at the beginning of each sequence rather than the + # shifted EOS (1) from the previous sequence. + inputs -= mtf.to_int32(mtf.equal(inputs, 1)) + else: + inputs = import_feature("inputs") + # TODO(noam): options for bert-style masking here? + model = self.model() + logits, loss = model.call_simple( + inputs=inputs, + targets=targets, + compute_loss=True, + mode=hparams.mode, + variable_dtype=self.variable_dtype, + sequence_id=sequence_id, + position=position) + return logits, loss + + def mtf_model_fn(self, features, mesh): + logits, loss = self._mtf_model_fn(features, mesh) + # combine batch dims + logits = self.combine_batch_dims(logits) + return logits, loss + + @property + def _targets_vocab_size(self): + targets_vocab_size = self._problem_hparams.vocab_size["targets"] + targets_vocab_size += (-targets_vocab_size) % self._hparams.vocab_divisor + return targets_vocab_size + + @property + def _inputs_vocab_size(self): + inputs_vocab_size = self._problem_hparams.vocab_size["inputs"] + inputs_vocab_size += (-inputs_vocab_size) % self._hparams.vocab_divisor + return inputs_vocab_size + + def sample(self, features, mesh): + hparams = self._hparams + model = self.model() + def import_feature(key): + return self._import_feature(features, mesh, key) + + if self.autoregressive: + # Prepare partial targets. + # In either features["inputs"] or features["targets"]. + # We force the outputs to begin with these sequences. + partial_targets = import_feature("inputs") + if partial_targets is None: + partial_targets = import_feature("targets") + if partial_targets: + partial_targets *= mtf.cast( + mtf.not_equal(partial_targets, 1), partial_targets.dtype) + else: + ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) + partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) + if hparams.beam_size > 1: + raise NotImplementedError( + "Beam search not implemented for unitransformer.") + ret = model.sample_autoregressive( + partial_targets, + temperature=hparams.sampling_temp, + variable_dtype=self.variable_dtype) + return self.combine_batch_dims(ret) + else: + raise ValueError( + "Don't know how to sample from non-autoregressive unitransformer") + + +@registry.register_model +class MtfBitransformer(MtfUnitransformer): + """Encoder-Decoder Transformer in mesh_tensorflow.""" + + def model(self): + hparams = self._hparams + encoder_layer_stack = layer_stack_from_hparams(hparams, "encoder_") + decoder_layer_stack = layer_stack_from_hparams(hparams, "decoder_") + encoder = transformer.Unitransformer( + layer_stack=encoder_layer_stack, + d_model=hparams.d_model, + input_vocab_size=self._inputs_vocab_size, + output_vocab_size=None, + autoregressive=False, + max_length=hparams.max_length, + name="encoder", + layout=hparams.layout, + mesh_shape=hparams.mesh_shape, + ) + decoder = transformer.Unitransformer( + layer_stack=decoder_layer_stack, + d_model=hparams.d_model, + input_vocab_size=self._targets_vocab_size, + output_vocab_size=self._targets_vocab_size, + autoregressive=True, + max_length=hparams.max_length, + label_smoothing=hparams.label_smoothing, + shared_embedding_and_softmax_weights=( + hparams.shared_embedding_and_softmax_weights), + z_loss=hparams.z_loss, + name="decoder", + layout=hparams.layout, + mesh_shape=hparams.mesh_shape, + ) + return transformer.Bitransformer( + encoder, decoder, shared_embedding=hparams.shared_embedding) + + def _mtf_model_fn(self, features, mesh): + self._original_features = features + hparams = self._hparams + def import_feature(key): + return self._import_feature(features, mesh, key) + targets = import_feature("targets") + inputs = import_feature("inputs") + if not inputs: + raise ValueError("inputs feature is missing") + encoder_sequence_id = import_feature("inputs_segmentation") + if not encoder_sequence_id: + encoder_sequence_id = mtf.to_int32(mtf.not_equal(inputs, 0)) + decoder_sequence_id = import_feature("targets_segmentation") + if decoder_sequence_id is None: + decoder_sequence_id = mtf.to_int32(mtf.not_equal(targets, 0)) + if hparams.use_global_position_in_packed_sequence: + encoder_position = None + decoder_position = None + else: + encoder_position = import_feature("inputs_position") + decoder_position = import_feature("targets_position") + model = self.model() + logits, loss = model.call_simple( + inputs=inputs, + targets=targets, + compute_loss=True, + mode=hparams.mode, + variable_dtype=self.variable_dtype, + encoder_sequence_id=encoder_sequence_id, + decoder_sequence_id=decoder_sequence_id, + encoder_position=encoder_position, + decoder_position=decoder_position) + return logits, loss + + def sample(self, features, mesh): + hparams = self._hparams + model = self.model() + inputs = self._import_feature(features, mesh, "inputs") + ret = model.decode( + inputs, + self.variable_dtype, + beam_size=hparams.beam_size, + alpha=hparams.alpha, + temperature=hparams.sampling_temp if hparams.beam_size == 1 else 0, + decode_length_multiplier=hparams.decode_length_multiplier, + decode_length_constant=hparams.decode_length_constant) + return self.combine_batch_dims(ret) + + +layers_registry = registry.Registries.mtf_layers + + +# The following functions construct layers based on hyperparmeters +def attention_kwargs_from_hparams(hparams): + return { + "dropout_rate": hparams.attention_dropout, + "extra_logit": 0.0 if hparams.extra_logit else None, + } + + +@layers_registry.register("self_att") +def self_attention_layer(hparams, prefix): + """Create self-attention layer based on hyperparameters.""" + return transformer_layers.SelfAttention( + num_heads=hparams.get(prefix + "num_heads"), + num_memory_heads=hparams.get(prefix + "num_memory_heads"), + key_value_size=hparams.d_kv, + shared_kv=hparams.get(prefix + "shared_kv", False), + attention_kwargs=attention_kwargs_from_hparams(hparams)) + + +@layers_registry.register("local_self_att") +def local_self_attention_layer(hparams, prefix): + """Create self-attention layer based on hyperparameters.""" + return transformer_layers.LocalSelfAttention( + num_heads=hparams.get(prefix + "num_heads"), + num_memory_heads=hparams.get(prefix + "num_memory_heads"), + radius=hparams.local_attention_radius, + key_value_size=hparams.d_kv, + shared_kv=hparams.get(prefix + "shared_kv", False), + attention_kwargs=attention_kwargs_from_hparams(hparams)) + + +@layers_registry.register("enc_att") +def enc_dec_attention_layer(hparams, prefix): + return transformer_layers.EncDecAttention( + num_heads=hparams.get(prefix + "num_heads"), + num_memory_heads=hparams.get(prefix + "num_memory_heads"), + key_value_size=hparams.d_kv, + shared_kv=hparams.get(prefix + "shared_kv", False), + attention_kwargs=attention_kwargs_from_hparams(hparams)) + + +@layers_registry.register("drd") +def dense_relu_dense_layer(hparams, prefix): + del prefix + return transformer_layers.DenseReluDense( + hidden_size=hparams.d_ff, + dropout_rate=hparams.relu_dropout) + + +@layers_registry.register("moe_1d") +def moe_1d_layer(hparams, prefix): + del prefix + return moe.MoE1D(num_experts=hparams.moe_num_experts, + hidden_size=hparams.moe_hidden_size) + + +@layers_registry.register("moe_2d") +def moe_2d_layer(hparams, prefix): + del prefix + return moe.MoE2D(expert_x=hparams.moe_expert_x, + expert_y=hparams.moe_expert_y, + hidden_size=hparams.moe_hidden_size) + + +def layer_stack_from_hparams(hparams, prefix): + """Create a layer stack based on the hyperparameter values.""" + layers = hparams.get(prefix + "layers") + return transformer.LayerStack( + [layers_registry[l](hparams, prefix) for l in layers], + dropout_rate=hparams.layer_prepostprocess_dropout, + norm_epsilon=hparams.norm_epsilon) + + +def mtf_transformer2_base(): + """Hyperparameters common to both unitransformer and bitransformer.""" + hparams = common_hparams.basic_params1() + + hparams.add_hparam("d_model", 1024) + hparams.batch_size = 4 + hparams.max_length = 1024 + hparams.label_smoothing = 0.0 + # a small positive value - this seems important for stability when training + # with bfloat16 activations. + hparams.add_hparam("z_loss", 1e-4) + + # hparams applying to both encoder and decoder layer stacks. + hparams.add_hparam("d_ff", 2048) + hparams.add_hparam("d_kv", 128) + hparams.add_hparam("attention_dropout", 0.0) + hparams.add_hparam("relu_dropout", 0.0) + hparams.del_hparam("num_heads") + hparams.del_hparam("num_hidden_layers") + hparams.layer_prepostprocess_dropout = 0.0 + hparams.add_hparam("extra_logit", False) + # number of experts for moe_1d + hparams.moe_num_experts = 32 + # number of experts for moe_2d = moe_expert_x * moe_expert_y + hparams.add_hparam("moe_expert_x", 8) + hparams.add_hparam("moe_expert_y", 4) + hparams.add_hparam("moe_hidden_size", 32768) + + # round up vocab sizes to be a multiple of this value + hparams.vocab_divisor = 128 + + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay*linear_decay" + hparams.learning_rate_warmup_steps = 10000 + hparams.add_hparam("master_dtype", "bfloat16") + hparams.add_hparam("slice_dtype", "float32") + hparams.activation_dtype = "bfloat16" + + # 8-way model-parallelism + hparams.add_hparam("mesh_shape", "model:8") + hparams.add_hparam("layout", "batch:batch;vocab:model;d_ff:model;heads:model") + + # If nonzero, we split the batch across two tensor-dimensions named + # "outer_batch" and "inner_batch", allowing for splitting across two mesh + # dimensions. This is necessary for hierarchical mixture of experts. + # The two tensor dimensions have sizes hparams.outer_batch_size and + # hparams.batch_size // hparams.outer_batch_size. + hparams.add_hparam("outer_batch_size", 0) + + hparams.shared_embedding_and_softmax_weights = False + # length for training or decoding - defaults to max_length + hparams.add_hparam("length", 0) + + # These parameters make Transformer model compatible with mtf + # Do not override these. + hparams.no_data_parallelism = True + hparams.use_fixed_batch_size = True + hparams.add_hparam("mtf_mode", True) + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.bottom = { + "inputs": modalities.identity_bottom, + "targets": modalities.identity_bottom, + } + hparams.top = { + "targets": modalities.identity_top, + } + hparams.add_hparam("beam_size", 1) + + # If this is True, then in a packed dataset (where exaples are concatenated + # to form longer examples) we use the global position (within the concatenated + # sequence) to compute the positional embedding, instead of the position + # within the individual sequence. This is counterintuitive, but for some + # reason, it keeps the model from diverging. + hparams.add_hparam("use_global_position_in_packed_sequence", True) + + return hparams + + +@registry.register_hparams +def mtf_unitransformer_base(): + """Hyperparameters for single-stack Transformer.""" + hparams = mtf_transformer2_base() + hparams.add_hparam("autoregressive", True) + # HYPERPARAMETERS FOR THE SINGLE LAYER STACK + hparams.add_hparam("layers", ["self_att", "drd"] * 6) + # number of heads in multihead attention + hparams.add_hparam("num_heads", 8) + # default of 0 for standard transformer behavior + # 1 means a single set of keys and values that are read by all query heads + hparams.add_hparam("num_memory_heads", 0) + # share attention keys and values + hparams.add_hparam("shared_kv", False) + # if nonzero then use local attention + hparams.add_hparam("local_attention_radius", 128) + return hparams + + +@registry.register_hparams +def mtf_bitransformer_base(): + """Machine translation base configuration.""" + hparams = mtf_transformer2_base() + hparams.max_length = 256 + hparams.shared_embedding = True + # HYPERPARAMETERS FOR THE LAYER STACKS + hparams.add_hparam("encoder_layers", ["self_att", "drd"] * 6) + hparams.add_hparam("decoder_layers", ["self_att", "enc_att", "drd"] * 6) + hparams.add_hparam("encoder_num_layers", 6) + hparams.add_hparam("decoder_num_layers", 6) + # number of heads in multihead attention + hparams.add_hparam("encoder_num_heads", 8) + hparams.add_hparam("decoder_num_heads", 8) + hparams.add_hparam("local_attention_radius", 128) + + # default of 0 for standard transformer behavior + # 1 means a single set of keys and values that are read by all query heads + hparams.add_hparam("encoder_num_memory_heads", 0) + hparams.add_hparam("decoder_num_memory_heads", 0) + # share attention keys and values + hparams.add_hparam("encoder_shared_kv", False) + hparams.add_hparam("decoder_shared_kv", False) + + # Parameters for computing the maximum decode length in beam search. + # Maximum decode length is: + # min(max_length, + # decode_length_multiplier * input_length + decode_length_constant) + hparams.add_hparam("decode_length_multiplier", 1.5) + hparams.add_hparam("decode_length_constant", 10.0) + # used during decoding + hparams.add_hparam("alpha", 0.6) + hparams.sampling_temp = 0.0 + return hparams + + +@registry.register_hparams +def mtf_unitransformer_tiny(): + hparams = mtf_unitransformer_base() + hparams.batch_size = 2 + hparams.mesh_shape = "" + hparams.d_model = 128 + hparams.layers = ["self_att", "drd"] * 2 + hparams.num_heads = 4 + hparams.d_ff = 512 + return hparams + + +@registry.register_hparams +def mtf_bitransformer_tiny(): + """Small encoder-decoder model for testing.""" + hparams = mtf_bitransformer_base() + hparams.batch_size = 2 + hparams.mesh_shape = "" + hparams.d_model = 128 + hparams.encoder_layers = ["self_att", "drd"] * 2 + hparams.decoder_layers = ["self_att", "enc_att", "drd"] * 2 + hparams.num_heads = 4 + hparams.d_ff = 512 + return hparams + + +@registry.register_hparams +def mtf_unitransformer_all_layers_tiny(): + """Test out all the layers on local CPU.""" + hparams = mtf_unitransformer_tiny() + hparams.moe_num_experts = 4 + hparams.moe_expert_x = 4 + hparams.moe_expert_y = 4 + hparams.moe_hidden_size = 512 + hparams.layers = ["self_att", "local_self_att", "moe_1d", "moe_2d", "drd"] + return hparams + + +@registry.register_hparams +def mtf_bitransformer_all_layers_tiny(): + """Test out all the layers on local CPU.""" + hparams = mtf_bitransformer_tiny() + hparams.moe_num_experts = 4 + hparams.moe_expert_x = 4 + hparams.moe_expert_y = 4 + hparams.moe_hidden_size = 512 + hparams.encoder_layers = [ + "self_att", "local_self_att", "moe_1d", "moe_2d", "drd"] + hparams.decoder_layers = [ + "self_att", "local_self_att", "enc_att", "moe_1d", "moe_2d", "drd"] + return hparams + + +@registry.register_hparams +def mtr_lm_dense(sz): + """Series of architectures for language modeling. + + We assume infinite training data, so no dropout necessary. + + You can use languagemodel_wiki_noref_v32k_l1k. + (1 epoch = ~46000 steps). + TODO(noam): find a large enough dataset for these experiments. + + Args: + sz: an integer + + Returns: + a hparams + """ + n = 2 ** sz + hparams = mtf_unitransformer_base() + hparams.d_model = 1024 + hparams.max_length = 1024 + hparams.batch_size = 128 + # Parameters for my_layer_stack() + hparams.num_hidden_layers = 6 + hparams.d_ff = 8192 * n + hparams.d_kv = 256 + hparams.num_heads = 8 * n + hparams.learning_rate_decay_steps = 65536 + hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model" + hparams.mesh_shape = "batch:32" + return hparams + + +@registry.register_hparams +def mtr_lm_dense_0(): + return mtr_lm_dense(0) + + +@registry.register_hparams +def mtr_lm_dense_0_h1_16(): + hparams = mtr_lm_dense_0() + hparams.decoder_num_heads = 16 + hparams.decoder_num_memory_heads = 1 + return hparams + + +@registry.register_hparams +def mtr_lm_dense_1(): + return mtr_lm_dense(1) + + +@registry.register_hparams +def mtr_lm_dense_2(): + hparams = mtr_lm_dense(2) + hparams.mesh_shape = "model:4;batch:8" + return hparams + + +@registry.register_hparams +def mtr_lm_dense_3(): + hparams = mtr_lm_dense(3) + hparams.mesh_shape = "model:4;batch:8" + return hparams + + +@registry.register_hparams +def mtr_lm_v1(): + """Model incorporating mixture-of-experts, local and global attention. + + ~6B parameters + + 32 experts in 3 hierarchichal moe layers. + + Returns: + a hparams + """ + hparams = mtr_lm_dense(0) + hparams.layers = (["local_self_att", "local_self_att", "drd", + "self_att", "drd", "local_self_att", + "local_self_att", "moe_2d"] * 4)[:-1] + hparams.d_kv = 128 + hparams.moe_expert_x = 8 + hparams.moe_expert_y = 4 + hparams.moe_hidden_size = 32768 + hparams.d_ff = 2048 + hparams.num_memory_heads = 0 + hparams.mesh_shape = "b0:4;b1:8" + hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0" + hparams.outer_batch_size = 4 + return hparams + + +@registry.register_hparams +def mtr_lm_v1_h1_8(): + """Version for fast decoding.""" + hparams = mtr_lm_v1() + hparams.num_memory_heads = 1 + return hparams + + +def mtr_tr_dense(sz): + """Series of machine translation models. + + All models are trained on sequences of 256 tokens. + + You can use the dataset translate_enfr_wmt32k_packed. + 154000 steps = 3 epochs. + + Args: + sz: an integer + + Returns: + a hparams + """ + n = 2 ** sz + hparams = mtf_bitransformer_base() + hparams.d_model = 1024 + hparams.max_length = 256 + hparams.batch_size = 128 + hparams.d_ff = int(4096 * n) + hparams.d_kv = 128 + hparams.encoder_num_heads = int(8 * n) + hparams.decoder_num_heads = int(8 * n) + # one epoch for translate_enfr_wmt32k_packed = 51400 steps + hparams.learning_rate_decay_steps = 51400 + hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model" + hparams.mesh_shape = "batch:32" + hparams.label_smoothing = 0.1 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.attention_dropout = 0.1 + hparams.relu_dropout = 0.1 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0(): + return mtr_tr_dense(0) + + +@registry.register_hparams +def mtr_tr_dense_1(): + return mtr_tr_dense(1) + + +@registry.register_hparams +def mtr_tr_dense_2(): + hparams = mtr_tr_dense(2) + hparams.mesh_shape = "model:4;batch:8" + return hparams + + +@registry.register_hparams +def mtr_tr_dense_3(): + hparams = mtr_tr_dense(3) + hparams.mesh_shape = "model:4;batch:8" + return hparams + + +@registry.register_hparams +def mtr_tr_dense_3_88(): + hparams = mtr_tr_dense(3) + hparams.mesh_shape = "model:8;batch:16" + return hparams + + +@registry.register_hparams +def mtr_tr_dense_3_fast(): + hparams = mtr_tr_dense_3() + hparams.local_attention_radius = 32 + hparams.decoder_num_heads = 128 + hparams.decoder_num_memory_heads = 8 + return hparams + + +def mtr_tr_dense_local(sz): + """With local self-attention in the decoder.""" + hparams = mtr_tr_dense(sz) + hparams.decoder_layers = ["local_self_att", "enc_att", "drd"] * 6 + hparams.local_attention_radius = 32 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_local_0(): + return mtr_tr_dense_local(0) + + +@registry.register_hparams +def mtr_tr_dense_local_0_w8(): + hparams = mtr_tr_dense_local_0() + hparams.local_attention_radius = 8 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_local_0_h1_16(): + hparams = mtr_tr_dense_local_0() + hparams.decoder_num_heads = 16 + hparams.decoder_num_memory_heads = 1 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_local_0_h1_16_shared(): + hparams = mtr_tr_dense_local_0_h1_16() + hparams.shared_embedding_and_softmax_weights = True + return hparams + + +@registry.register_hparams +def mtr_tr_dense_local_0_h1_8_kv256(): + hparams = mtr_tr_dense_local_0() + hparams.decoder_num_heads = 8 + hparams.decoder_num_memory_heads = 1 + hparams.d_kv = 256 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_local_0_h1_16_shared_kv(): + hparams = mtr_tr_dense_local_0_h1_16() + hparams.decoder_shared_kv = True + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0_h4(): + hparams = mtr_tr_dense_0() + hparams.decoder_num_heads = 4 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0_h16(): + hparams = mtr_tr_dense_0() + hparams.decoder_num_heads = 16 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0_extra_logit(): + hparams = mtr_tr_dense_0() + hparams.extra_logit = True + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0_h1_8(): + hparams = mtr_tr_dense_0() + hparams.decoder_num_memory_heads = 1 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0_h1_1(): + hparams = mtr_tr_dense_0() + hparams.decoder_num_heads = 1 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0_h1_16(): + hparams = mtr_tr_dense_0() + hparams.decoder_num_heads = 16 + hparams.decoder_num_memory_heads = 1 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0_h2_16(): + hparams = mtr_tr_dense_0() + hparams.decoder_num_heads = 16 + hparams.decoder_num_memory_heads = 2 + return hparams + + +@registry.register_hparams +def mtr_tr_dense_0_shared_kv(): + hparams = mtr_tr_dense_0() + hparams.decoder_shared_kv = True + return hparams + + +@registry.register_hparams +def mtr_tr_enfr_v0(): + # good parameters for wmt-en-fr + hparams = mtr_tr_dense_local_0_h1_16() + return hparams + + +@registry.register_hparams +def mtr_tr_ende_v0(): + # good parameters for wmt-en-de + hparams = mtr_tr_dense_local_0_h1_16() + hparams.learning_rate_decay_steps = 20000 + hparams.shared_embedding_and_softmax_weights = True + hparams.layer_prepostprocess_dropout = 0.2 + return hparams + + +@registry.register_hparams +def mtr_tr_ende_deep(): + hparams = mtr_tr_ende_v0() + hparams.decoder_num_heads = 8 + hparams.encoder_num_heads = 4 + hparams.d_ff = 2048 + hparams.encoder_num_layers = 12 + hparams.decoder_num_layers = 12 + return hparams diff --git a/trax/models/mtf_transformer_test.py b/trax/models/mtf_transformer_test.py new file mode 100644 index 000000000..f411e078b --- /dev/null +++ b/trax/models/mtf_transformer_test.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer on Mesh TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import mesh_tensorflow as mtf +import numpy as np + +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models import mtf_transformer + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + +# Constants shared between all functions. +BATCH_SIZE = 2 +INPUT_LENGTH = 6 +TARGET_LENGTH = 6 +VOCAB_SIZE = 128 + + +def get_model(hparams=None, mode=tf_estimator.ModeKeys.TRAIN, + has_input=True, model_cls=mtf_transformer.MtfTransformer): + if hparams is None: + hparams = mtf_transformer.mtf_transformer_single() + hparams.max_length = INPUT_LENGTH + hparams.batch_size = BATCH_SIZE + + p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, + VOCAB_SIZE, + hparams) + if not has_input: + del p_hparams.modality["inputs"] + hparams.problem_hparams = p_hparams + + inputs = np.random.randint( + VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) + targets = np.random.randint( + VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) + features = { + "targets": tf.constant(targets, dtype=tf.int32, name="targets"), + "target_space_id": tf.constant(1, dtype=tf.int32) + } + if has_input: + features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs") + + return model_cls(hparams, mode, p_hparams), features, hparams + + +def get_placement_mesh(hparams): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) + + mesh_devices = [""] * mesh_shape.size + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + mesh_shape, hparams.layout, mesh_devices) + return mesh, mesh_impl + + +class MtfTransformerTest(tf.test.TestCase): + + def testMtfTransformer(self): + hparams = mtf_transformer.mtf_transformer_single() + + model, features, hparams = get_model(hparams) + hparams.mesh_shape = "" + hparams.layout = "" + mesh, mesh_impl = get_placement_mesh(hparams) + + logits, _ = model.mtf_model_fn(features, mesh) + lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) + tf_group = lowering.copy_masters_to_slices() + tf_logits = lowering.export_to_tf_tensor(logits) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + session.run(tf_group) + res = session.run(tf_logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE)) + + def testMtfTransformerDataParallel(self): + hparams = mtf_transformer.mtf_transformer_single() + + model, features, hparams = get_model(hparams) + hparams.mesh_shape = "all:2" + hparams.layout = "batch:all" + mesh, mesh_impl = get_placement_mesh(hparams) + + logits, _ = model.mtf_model_fn(features, mesh) + lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) + tf_group = lowering.copy_masters_to_slices() + tf_logits = lowering.export_to_tf_tensor(logits) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + session.run(tf_group) + res = session.run(tf_logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE)) + + def testMtfTransformerModelParallel(self): + hparams = mtf_transformer.mtf_transformer_single() + + model, features, hparams = get_model(hparams) + hparams.mesh_shape = "all:2" + hparams.layout = "length:all" + mesh, mesh_impl = get_placement_mesh(hparams) + + logits, _ = model.mtf_model_fn(features, mesh) + lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) + tf_group = lowering.copy_masters_to_slices() + tf_logits = lowering.export_to_tf_tensor(logits) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + session.run(tf_group) + res = session.run(tf_logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE)) + + def testMtfTransformerDataModelParallel(self): + hparams = mtf_transformer.mtf_transformer_single() + + model, features, hparams = get_model(hparams) + hparams.mesh_shape = "batch:2;model:2" + hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model" + mesh, mesh_impl = get_placement_mesh(hparams) + + logits, _ = model.mtf_model_fn(features, mesh) + lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) + tf_group = lowering.copy_masters_to_slices() + tf_logits = lowering.export_to_tf_tensor(logits) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + session.run(tf_group) + res = session.run(tf_logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE)) + + def testMtfTransformerEncoderDataModelParallel(self): + hparams = mtf_transformer.mtf_transformer_enc_single() + + model, features, hparams = get_model(hparams) + hparams.mesh_shape = "batch:2;model:2" + hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model" + mesh, mesh_impl = get_placement_mesh(hparams) + + logits, _ = model.mtf_model_fn(features, mesh) + lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) + tf_group = lowering.copy_masters_to_slices() + tf_logits = lowering.export_to_tf_tensor(logits) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + session.run(tf_group) + res = session.run(tf_logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/neural_assistant.py b/trax/models/neural_assistant.py new file mode 100644 index 000000000..53f87eb1d --- /dev/null +++ b/trax/models/neural_assistant.py @@ -0,0 +1,564 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neural Assistant.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import six +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_layers +from tensor2tensor.models import transformer +from tensor2tensor.utils import registry +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +@registry.register_model +class NeuralAssistant(transformer.Transformer): + """Attention net. See file docstring.""" + + def __init__(self, *args, **kwargs): + super(NeuralAssistant, self).__init__(*args, **kwargs) + self.attention_weights = dict() # For visualizing attention heads. + + # Loss scheduling. + hparams = self._hparams + self.triple_num = hparams.train_triple_num + + def model_fn(self, features): + with tf.variable_scope(tf.get_variable_scope(), use_resource=True) as vs: + self._add_variable_scope("model_fn", vs) + transformed_features = self.bottom(features) + + if self.hparams.activation_dtype == "bfloat16": + for k, v in sorted(six.iteritems(transformed_features)): + if v.dtype == tf.float32: + transformed_features[k] = tf.cast(v, tf.bfloat16) + + with tf.variable_scope("body") as body_vs: + self._add_variable_scope("body", body_vs) + body_out = self.body(transformed_features) + output, losses = self._normalize_body_output(body_out) + + if "training" in losses: + tf.logging.info( + "Skipping T2TModel top and loss because training loss returned from body" + ) + logits = output + else: + tf.logging.warn("The loss will be computed in model_fn now.") + logits = self.top(output, features) + losses["training"] = 0.0 + cur_kb_loss = losses["kb_loss"] + cur_knowledge_training_loss = losses["transe_loss"] + cur_kb_loss_weight = self._hparams.kb_loss_weight + kb_train_weight = self._hparams.kb_train_weight + cur_lm_loss_weight = 1.0 - cur_kb_loss_weight + # Finalize loss + if (self._hparams.mode != tf_estimator.ModeKeys.PREDICT and + self._hparams.mode != "attack"): + lm_loss_num, lm_loss_denom = self.loss(logits, features) + total_loss = (kb_train_weight) * cur_knowledge_training_loss + ( + 1 - kb_train_weight) * ( + cur_kb_loss * cur_kb_loss_weight + + (lm_loss_num / lm_loss_denom) * cur_lm_loss_weight) + tf.summary.scalar("kb_loss", cur_kb_loss) + tf.summary.scalar("transe_loss", cur_knowledge_training_loss) + tf.summary.scalar("lm_loss", (lm_loss_num / lm_loss_denom)) + tf.summary.scalar("cur_kb_loss_weight", + tf.reshape(cur_kb_loss_weight, [])) + tf.logging.info("Loss computed " + str(total_loss)) + losses = {"training": total_loss} + + return logits, losses + + def encode_knowledge_bottom(self, features): + tf.logging.info("Encoding knowledge " + str(self.triple_num)) + # Make sure this is embeddings for triples + # [batch_size, triple_num*max_triple_length, 1, emb_dim] + fact_embedding = features["encoded_triples"] + # [batch_size, triple_num*max_triple_length, emb_dim] + fact_embedding = tf.squeeze(fact_embedding, 2) + + kb_shape = common_layers.shape_list(fact_embedding) + batch_size = kb_shape[0] + embed_dim = kb_shape[2] + # [batch_size*triple_num, max_triple_length, emb_dim] + re_fact_embedding = tf.reshape( + fact_embedding, [batch_size * self.triple_num, -1, embed_dim], + name="reshape_fact_embedding") + + # [batch_size, triple_num] + input_fact_lengths = features["triple_lens"] + # Stack the fact lengths. + # [batch_size*max_triple_num] + re_fact_lengths = tf.reshape( + input_fact_lengths, [batch_size * self.triple_num, 1], + name="reshape_fact_lengths") + + return re_fact_embedding, re_fact_lengths + + def compute_knowledge_selection_and_loss(self, features, encoder_output, + fact_embedding, fact_lengths, margin, + num_negative_samples): + """Compute knowledge selection and loss. + + Args: + features: features. + encoder_output: [batch_size, input_length, hidden_dim] + fact_embedding: [batch_size*triple_num, max_triple_length, + emb_dim] + fact_lengths: # [batch_size*triple_num] + margin: integer value for max margin in TransE loss, + num_negative_samples: shuffle and sample multiple negative examples for + the TransE loss + + Returns: + knowledge_weights: + knowledge_loss: + """ + hparams = self._hparams + encoder_output_shape = common_layers.shape_list(encoder_output) + encoder_hidden_dim = encoder_output_shape[-1] + inputs = features["inputs"] + # [batch_size, input_length, emb_dim] + inputs = tf.squeeze(inputs, 2) + # [batch_size, input_length] + context_padding = common_attention.embedding_to_padding(inputs) + # [batch_size] + context_lens = tf.to_float( + common_attention.padding_to_length(context_padding)) + # [batch_size, 1] + context_lens = tf.expand_dims(context_lens, -1) + # Compute context vector summary. + # [batch_size, hidden_dim] + context_vector_summary = compute_summary_embedding(encoder_output, + context_lens, hparams) + knowledge_encoder_output = compute_average_embedding( + fact_embedding, fact_lengths) + # [batch_size, triple_num, emb_dim] + knowledge_encoder_output = tf.reshape( + knowledge_encoder_output, [-1, self.triple_num, encoder_hidden_dim]) + original_knowledge_encoder_output = knowledge_encoder_output + if hparams.similarity_fuction == "dot_product": + triple_logits = tf.squeeze( + tf.matmul(knowledge_encoder_output, + tf.expand_dims(context_vector_summary, 2)), -1) + elif hparams.similarity_fuction == "bilinear": + # Tile the context vector summary. + # [batch_size, triple_num*hidden_dim] + tiled_context_vector = tf.tile(context_vector_summary, + [1, self.triple_num]) + # [batch_size, triple_num, hidden_dim] + context_vector = tf.reshape(tiled_context_vector, + [-1, self.triple_num, encoder_hidden_dim]) + # compute outer product + context_vector = tf.expand_dims(context_vector, -1) + knowledge_encoder_output = tf.expand_dims(knowledge_encoder_output, 2) + # [batch_size, triple_num, hidden_dim, hidden_dim] + outer_product = tf.matmul(context_vector, knowledge_encoder_output) + outer_product = tf.reshape( + outer_product, + [-1, self.triple_num, encoder_hidden_dim * encoder_hidden_dim]) + triple_logits = tf.squeeze( + tf.layers.dense(outer_product, 1, name="knolwedge_final_mlp"), -1) + + avg_triple_loss = 0.0 + triple_labels = features["triple_labels"] + + subject_mask = tf.reshape(features["subject_mask"], + [-1, self.triple_num, hparams.max_triple_length]) + subject_mask = tf.reshape(subject_mask, [-1, hparams.max_triple_length]) + + predicate_mask = tf.reshape( + features["predicate_mask"], + [-1, self.triple_num, hparams.max_triple_length]) + predicate_mask = tf.reshape(predicate_mask, [-1, hparams.max_triple_length]) + + object_mask = tf.reshape(features["object_mask"], + [-1, self.triple_num, hparams.max_triple_length]) + object_mask = tf.reshape(object_mask, [-1, hparams.max_triple_length]) + + # mask : [bs, max_seq_len, triple_num] + # the below operation will result in [bs*triple_num,emb_dim] + subject_length = tf.cast( + tf.expand_dims(tf.reduce_sum(subject_mask, -1), 1), + tf.float32) # [bs*tn] + object_length = tf.cast( + tf.expand_dims(tf.reduce_sum(object_mask, -1), 1), tf.float32) + predicate_length = tf.cast( + tf.expand_dims(tf.reduce_sum(predicate_mask, -1), 1), tf.float32) + + # expand dimension 2 to be able to broadcast + subject_mask = tf.cast(tf.expand_dims(subject_mask, 2), tf.float32) + predicate_mask = tf.cast(tf.expand_dims(predicate_mask, 2), tf.float32) + object_mask = tf.cast(tf.expand_dims(object_mask, 2), tf.float32) + + subject_vect = tf.reduce_sum(tf.multiply( + fact_embedding, subject_mask), 1) / ( + subject_length + + tf.broadcast_to(tf.constant([1e-5]), tf.shape(subject_length))) + object_vect = tf.reduce_sum(tf.multiply(fact_embedding, object_mask), 1) / ( + object_length + + tf.broadcast_to(tf.constant([1e-5]), tf.shape(object_length))) + predicate_vect = tf.reduce_sum( + tf.multiply(fact_embedding, predicate_mask), 1) / ( + predicate_length + + tf.broadcast_to(tf.constant([1e-5]), tf.shape(predicate_length))) + + # Shuffled rows to generate adversarial samples + shuffled_subject_vect = [] + shuffled_object_vect = [] + + for _ in range(num_negative_samples): + shuffled_subject_vect += [ + tf.gather(subject_vect, + tf.random.shuffle(tf.range(tf.shape(subject_vect)[0]))) + ] # [bs*tn,d] + shuffled_object_vect += [ + tf.gather(object_vect, + tf.random.shuffle(tf.range(tf.shape(object_vect)[0]))) + ] # [bs*tn,d] + + # KB pretraining loss + + positive_loss = tf.reduce_mean( + tf.squared_difference(subject_vect + predicate_vect, object_vect)) + negative_loss = 0 + for n_adv in range(num_negative_samples): + negative_loss += tf.reduce_mean( + tf.squared_difference(shuffled_subject_vect[n_adv] + predicate_vect, + object_vect)) + negative_loss += tf.reduce_mean( + tf.squared_difference(subject_vect + predicate_vect, + shuffled_object_vect[n_adv])) + + # TransE Loss + + negative_loss = negative_loss / (2 * num_negative_samples) + + transe_loss = tf.clip_by_value( + margin + positive_loss - negative_loss, + clip_value_min=0, + clip_value_max=100) + if hparams.mode != tf_estimator.ModeKeys.PREDICT: + triple_losses = tf.nn.weighted_cross_entropy_with_logits( + labels=triple_labels, + logits=triple_logits, + pos_weight=hparams.pos_weight) + avg_triple_loss = tf.reduce_mean(triple_losses) + tf.summary.scalar("triple_loss", avg_triple_loss) + + return triple_logits, avg_triple_loss, original_knowledge_encoder_output, transe_loss + + def body(self, features): + """Transformer main model_fn. + + Args: + features: Map of features to the model. Should contain the following: + "inputs": Transformer inputs [batch_size, input_length, hidden_dim] + "targets": Target decoder outputs. [batch_size, decoder_length, + hidden_dim] + "target_space_id": A scalar int from data_generators.problem.SpaceID. + + Returns: + Final decoder representation. [batch_size, decoder_length, hidden_dim] + """ + tf.logging.info("Using PgScratch BODY function.") + hparams = self._hparams + + losses = {} + inputs = features["inputs"] + target_space = features["target_space_id"] + # encoder_output: [batch_size, input_length, hidden_dim] + # encoder_decoder_attention_bias: [batch_size, input_length] + encoder_output, encoder_decoder_attention_bias = self.encode( + inputs, target_space, hparams, features=features, losses=losses) + + with tf.variable_scope("knowledge"): + with tf.name_scope("knowledge_encoding"): + # Encode knowledge. + # [batch_size, triple_num, emb_dim] + fact_embedding, fact_lengths = self.encode_knowledge_bottom(features) + tf.logging.info("Encoded knowledge") + + with tf.name_scope("knowledge_selection_and_loss"): + # Compute knowledge selection and loss. + triple_logits, avg_triple_selection_loss, knowledge_encoder_output, transe_loss = self.compute_knowledge_selection_and_loss( + features, encoder_output, fact_embedding, fact_lengths, + hparams.margin, hparams.num_negative_samples) + losses["kb_loss"] = avg_triple_selection_loss + losses["transe_loss"] = transe_loss + + if hparams.attend_kb: + tf.logging.info("ATTEND_KB is ACTIVE") + with tf.name_scope("knowledge_attention"): + + knowledge_padding = tf.zeros_like(triple_logits, dtype=tf.float32) + knowledge_attention_bias = common_attention.attention_bias_ignore_padding( + knowledge_padding) + encoder_output = tf.concat([knowledge_encoder_output, encoder_output], + 1) + encoder_decoder_attention_bias = tf.concat( + [knowledge_attention_bias, encoder_decoder_attention_bias], -1) + + else: + tf.logging.info("ATTEND_KB is INACTIVE") + + targets = features["targets"] + targets_shape = common_layers.shape_list(targets) + targets = common_layers.flatten4d3d(targets) + + (decoder_input, + decoder_self_attention_bias) = transformer.transformer_prepare_decoder( + targets, hparams, features=features) + + decode_kwargs = {} + decoder_output = self.decode( + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + nonpadding=transformer.features_to_nonpadding(features, "targets"), + losses=losses, + **decode_kwargs) + + expected_attentions = features.get("expected_attentions") + if expected_attentions is not None: + attention_loss = common_attention.encoder_decoder_attention_loss( + expected_attentions, self.attention_weights, + hparams.expected_attention_loss_type, + hparams.expected_attention_loss_multiplier) + return decoder_output, {"attention_loss": attention_loss} + + ret = tf.reshape(decoder_output, targets_shape) + if losses: + return ret, losses + else: + return ret + + def _normalize_body_output(self, body_out): + if len(body_out) == 2: + output, losses = body_out + if not isinstance(losses, dict): + losses = {"extra": tf.reduce_mean(losses)} + else: + output = body_out + losses = {"extra": 0.0} + + return output, losses + + def _beam_decode(self, + features, + decode_length, + beam_size, + top_beams, + alpha, + use_tpu=False): + """Beam search decoding. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + beam_size: number of beams. + top_beams: an integer. How many of the beams to return. + alpha: Float that controls the length penalty. larger the alpha, stronger + the preference for longer translations. + use_tpu: A bool, whether to do beam decode on TPU. + + Returns: + A dict of decoding results { + "outputs": integer `Tensor` of decoded ids of shape + [batch_size, <= decode_length] if beam_size == 1 or + [batch_size, top_beams, <= decode_length] + "scores": decoding log probs from the beam search, + None if using greedy decoding (beam_size=1) + } + """ + return super(transformer.Transformer, + self)._beam_decode_slow(features, decode_length, beam_size, + top_beams, alpha, use_tpu) + + def _greedy_infer(self, features, decode_length, use_tpu=False): + """Fast version of greedy decoding. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + use_tpu: A bool. Whether to build the inference graph for TPU. + + Returns: + A dict of decoding results { + "outputs": integer `Tensor` of decoded ids of shape + [batch_size, <= decode_length] if beam_size == 1 or + [batch_size, top_beams, <= decode_length] + "scores": decoding log probs from the beam search, + None if using greedy decoding (beam_size=1) + } + + Raises: + NotImplementedError: If there are multiple data shards. + """ + return super(transformer.Transformer, + self)._greedy_infer(features, decode_length) + + +def compute_last_embedding(input_embeddings, input_lengths, hparams): + """Computes average of last K embedding. + + Args: + input_embeddings: [bs, max_seq_len, emb_dim] + input_lengths: [bs, 1] + hparams: model hparams + + Returns: + last_k_embedding: [bs, emb_dim] + """ + max_seq_len = tf.shape(input_embeddings)[1] + # [bs, 1, max_seq_len] + mask = tf.sequence_mask(input_lengths, max_seq_len, dtype=tf.float32) + del_mask = tf.sequence_mask( + input_lengths - hparams.last_k, max_seq_len, dtype=tf.float32) + final_mask = mask - del_mask + # [bs, 1, emb_dim] + sum_embedding = tf.matmul(final_mask, input_embeddings) + # [bs, 1, emb_dim] + last_k_embedding = sum_embedding / tf.to_float( + tf.expand_dims( + tf.ones([tf.shape(input_embeddings)[0], 1]) * hparams.last_k, 2)) + # [bs, dim] + return tf.squeeze(last_k_embedding, 1) + + +def compute_max_pool_embedding(input_embeddings, input_lengths): + """Computes max pool embedding. + + Args: + input_embeddings: [bs, max_seq_len, emb_dim] + input_lengths: [bs, 1] + + Returns: + max_pool_embedding: [bs, emb_dim] + """ + max_seq_len = tf.shape(input_embeddings)[1] + # [bs, max_seq_len] + mask = 1.0 - tf.sequence_mask(input_lengths, max_seq_len, dtype=tf.float32) + mask = tf.squeeze(mask * (-1e-6), 1) + mask = tf.expand_dims(mask, 2) + # [bs, emb_dim] + max_pool_embedding = tf.reduce_max(input_embeddings + mask, 1) + # [bs, dim] + return max_pool_embedding + + +def compute_average_embedding(input_embeddings, input_lengths): + """Computes bag-of-words embedding. + + Args: + input_embeddings: [bs, max_seq_len, emb_dim] + input_lengths: [bs, 1] + + Returns: + bow_embedding: [bs, emb_dim] + """ + max_seq_len = tf.shape(input_embeddings)[1] + # [bs, 1, max_seq_len] + mask = tf.sequence_mask(input_lengths, max_seq_len, dtype=tf.float32) + # [bs, 1, emb_dim] + sum_embedding = tf.matmul(mask, input_embeddings) + # [bs, 1, emb_dim] + avg_embedding = sum_embedding / tf.to_float(tf.expand_dims(input_lengths, 2)) + # [bs, dim] + return tf.squeeze(avg_embedding, 1) + + +def compute_summary_embedding(input_embeddings, input_lengths, hparams): + """Convert list of embedding to single embedding. + + Args: + input_embeddings: [bs, max_seq_len, emb_dim] + input_lengths: [bs, 1] + hparams: model hparams + + Returns: + embedding: [bs, emb_dim] + """ + if hparams.pool_technique == "average": + return compute_average_embedding(input_embeddings, input_lengths) + elif hparams.pool_technique == "max_pool": + return compute_max_pool_embedding(input_embeddings, input_lengths) + elif hparams.pool_technique == "last": + return compute_last_embedding(input_embeddings, input_lengths, hparams) + + +@registry.register_hparams +def neural_assistant_base(): + """HParams for a base neural_assistant model.""" + hparams = transformer.transformer_tpu() + hparams.add_hparam("pos_weight", 1.0) # weight for positive triples + hparams.add_hparam("similarity_fuction", + "bilinear") # dot_product or bilinear + hparams.add_hparam("pool_technique", "average") # avg or max pool or last + hparams.add_hparam("last_k", 1) # number of last indices for averaging + hparams.add_hparam("max_triple_length", 30) # max length of every triple + hparams.add_hparam("train_triple_num", + 5000) # max number of triples during training + hparams.add_hparam("attend_kb", True) # if False, it's a transformer model + hparams.add_hparam("kb_loss_weight", 0.0) # weight for distant supervision + hparams.add_hparam("test_triple_num", + 28483) # max triples of KB + hparams.add_hparam("margin", 0.0) # KB training max-margin loss + hparams.add_hparam( + "num_negative_samples", + 1) # Sampling number of different adversarial training examples + hparams.add_hparam("kb_train_weight", 0.0) + # KB_training loss weight which combines Language model and KB selection loss + return hparams + + +@registry.register_hparams +def neural_assistant_tiny(): + """HParams for tiny neural_assistant model.""" + hparams = transformer.transformer_tiny_tpu() + hparams.add_hparam("pos_weight", 1.0) # weight for positive triples + hparams.add_hparam("similarity_fuction", + "bilinear") # dot_product or bilinear + hparams.add_hparam("pool_technique", "average") # avg or max pool or last + hparams.add_hparam("last_k", 1) # number of last indices for averaging + hparams.add_hparam("max_triple_length", 30) # max length of every triple + hparams.add_hparam("train_triple_num", + 5000) # max number of triples during training + hparams.add_hparam("attend_kb", True) # if False, it's a transformer model + hparams.add_hparam("kb_loss_weight", 0.0) # weight for distant supervision + hparams.add_hparam("test_triple_num", + 28483) # max triples of KB + hparams.add_hparam("margin", 1.0) # KB training max-margin loss + hparams.add_hparam( + "num_negative_samples", + 1) # Sampling number of different adversarial training examples + hparams.add_hparam("kb_train_weight", 0.0) + # KB_training loss weight which combines Language model and KB selection loss + return hparams + + +@registry.register_hparams +def neural_assistant_tiny_ds(): + """HParams for tiny neural_assistant model with distant supervision loss.""" + hparams = neural_assistant_tiny() + hparams.kb_loss_weight = 0.2 + return hparams diff --git a/trax/models/neural_gpu.py b/trax/models/neural_gpu.py index 76abf7673..953855172 100644 --- a/trax/models/neural_gpu.py +++ b/trax/models/neural_gpu.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The Trax Authors. +# Copyright 2023 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,65 +13,109 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the improved Neural GPU (NGPU).""" - -from trax import layers as tl -from trax.fastmath import numpy as jnp - - -# TODO(ddohan): Combinator to add saturation costs to loss -def SaturationCost(x, limit=0.9): - return jnp.minimum(0, jnp.abs(x) - limit) - - -def DiagonalGate(): - """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right.""" - - def f(x): # pylint: disable=invalid-name - # x : [batch, 1, length, depth] - x = jnp.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)], - mode='constant', constant_values=0.0) - depth = x.shape[-1] // 3 - assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, - x.shape) - xs = [ - x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], - x[:, :, 2:, 2 * depth:3 * depth] - ] - return jnp.concatenate(xs, axis=3) - return tl.Fn('DiagonalGate', f) - - -def ConvDiagonalGRU(units, kernel_size=(3, 3)): - """Build convolutional GRU with diagonal gating as in ImprovedNGPU.""" - - def BuildConv(): - return tl.Conv(filters=units, kernel_size=kernel_size, padding='SAME') - - return tl.GeneralGRUCell( - candidate_transform=BuildConv, - memory_transform_fn=DiagonalGate, - gate_nonlinearity=tl.HardSigmoid, - candidate_nonlinearity=tl.HardTanh) - - -def NeuralGPU(d_feature=96, steps=16, vocab_size=2, mode='train'): - """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727. - - Args: - d_feature: Number of memory channels (dimensionality of feature embedding). - steps: Number of times depthwise recurrence steps. - vocab_size: Vocabulary size. - mode: Whether we are training or evaluating or doing inference. - - Returns: - A NeuralGPU Stax model. - """ - del mode - - core = ConvDiagonalGRU(units=d_feature) - return tl.Serial( - tl.Embedding(vocab_size=vocab_size, d_feature=d_feature), - [core] * steps, - tl.Dense(vocab_size), - ) +"""The Neural GPU model and its variants.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from six.moves import range # pylint: disable=redefined-builtin + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf + + +def neural_gpu_body(inputs, hparams, name=None): + """The core Neural GPU.""" + with tf.variable_scope(name, "neural_gpu"): + + def step(state, inp): # pylint: disable=missing-docstring + x = tf.nn.dropout(state, 1.0 - hparams.dropout) + for layer in range(hparams.num_hidden_layers): + x = common_layers.conv_gru( + x, (hparams.kernel_height, hparams.kernel_width), + hparams.hidden_size, + name="cgru_%d" % layer) + # Padding input is zeroed-out in the modality, we check this by summing. + padding_inp = tf.less(tf.reduce_sum(tf.abs(inp), axis=[1, 2]), 0.00001) + new_state = tf.where(padding_inp, state, x) # No-op where inp is padding. + return new_state + + return tf.foldl( + step, + tf.transpose(inputs, [1, 0, 2, 3]), + initializer=inputs, + parallel_iterations=1, + swap_memory=True) + + +@registry.register_model +class NeuralGPU(t2t_model.T2TModel): + + def body(self, features): + return neural_gpu_body(features["inputs"], self._hparams) + + +def diagonal_neural_gpu(inputs, hparams, name=None): + """Improved Neural GPU as in https://arxiv.org/abs/1702.08727.""" + with tf.variable_scope(name, "diagonal_neural_gpu"): + + def step(state_tup, inp): + """Single step of the improved Neural GPU.""" + state, _ = state_tup + x = state + for layer in range(hparams.num_hidden_layers): + x, new_loss = common_layers.diagonal_conv_gru( + x, (hparams.kernel_height, hparams.kernel_width), + hparams.hidden_size, + dropout=hparams.dropout, + name="dcgru_%d" % layer) + # Padding input is zeroed-out in the modality, we check this by summing. + padding_inp = tf.less(tf.reduce_sum(tf.abs(inp), axis=[1, 2]), 0.00001) + new_state = tf.where(padding_inp, state, x) # No-op where inp is padding. + return new_state, new_loss + + final_state, losses = tf.scan( + step, + tf.transpose(inputs, [1, 0, 2, 3]), + initializer=(inputs, tf.constant(0.0)), + parallel_iterations=1, + swap_memory=True) + return final_state[0, :, :, :, :], 2.0 * tf.reduce_mean(losses) + + +@registry.register_model +class DiagonalNeuralGPU(t2t_model.T2TModel): + + def body(self, features): + return diagonal_neural_gpu(features["inputs"], self._hparams) + + +@registry.register_hparams +def neural_gpu(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.daisy_chain_variables = False + hparams.batch_size = 1024 + hparams.num_hidden_layers = 1 + hparams.hidden_size = 256 + hparams.dropout = 0.1 + hparams.label_smoothing = 0.0 + hparams.clip_grad_norm = 10.0 + hparams.num_hidden_layers = 1 + hparams.kernel_height = 3 + hparams.kernel_width = 1 + hparams.learning_rate_decay_scheme = "exp" + hparams.learning_rate = 0.02 + hparams.learning_rate_warmup_steps = 3000 + hparams.initializer_gain = 1.0 + hparams.weight_decay = 0.0 + hparams.num_sampled_classes = 0 + hparams.sampling_method = "argmax" + hparams.optimizer_adam_epsilon = 1e-6 + hparams.optimizer_adam_beta1 = 0.85 + hparams.optimizer_adam_beta2 = 0.997 + return hparams diff --git a/trax/models/neural_gpu_test.py b/trax/models/neural_gpu_test.py index 0eaa77dbf..57a4a1f36 100644 --- a/trax/models/neural_gpu_test.py +++ b/trax/models/neural_gpu_test.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The Trax Authors. +# Copyright 2023 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,24 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for trax.models.neural_gpu.""" +"""Tests for Neural GPU.""" -from absl.testing import absltest +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import numpy as np -from trax import shapes -from trax.models import neural_gpu - - -class NeuralGPUTest(absltest.TestCase): - - def test_ngpu(self): - model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=22) - x = np.ones((3, 5, 7)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 5, 7, 22)) - - -if __name__ == '__main__': - absltest.main() +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.layers import common_hparams +from tensor2tensor.models import neural_gpu + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class NeuralGPUTest(tf.test.TestCase): + + def testNeuralGPU(self): + hparams = common_hparams.basic_params1() + batch_size = 3 + input_length = 5 + target_length = input_length + input_vocab_size = 9 + target_vocab_size = 11 + p_hparams = problem_hparams.test_problem_hparams(input_vocab_size, + target_vocab_size, + hparams) + inputs = np.random.randint( + input_vocab_size, size=(batch_size, input_length, 1, 1)) + targets = np.random.randint( + target_vocab_size, size=(batch_size, target_length, 1, 1)) + with self.test_session() as session: + features = { + "inputs": tf.constant(inputs, dtype=tf.int32), + "targets": tf.constant(targets, dtype=tf.int32) + } + model = neural_gpu.NeuralGPU(hparams, tf_estimator.ModeKeys.TRAIN, + p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (batch_size, target_length, 1, 1, + target_vocab_size)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/resnet.py b/trax/models/resnet.py index d7ea0f2f4..5eeb4792f 100644 --- a/trax/models/resnet.py +++ b/trax/models/resnet.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The Trax Authors. +# Copyright 2023 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,160 +13,849 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ResNet.""" - -from trax import layers as tl - - -def ConvBlock(kernel_size, filters, strides, norm, non_linearity, - mode='train'): - """ResNet convolutional striding block.""" - ks = kernel_size - filters1, filters2, filters3 = filters - main = [ - tl.Conv(filters1, (1, 1), strides), - norm(mode=mode), - non_linearity(), - tl.Conv(filters2, (ks, ks), padding='SAME'), - norm(mode=mode), - non_linearity(), - tl.Conv(filters3, (1, 1)), - norm(mode=mode), - ] - shortcut = [ - tl.Conv(filters3, (1, 1), strides), - norm(mode=mode), - ] - return [ - tl.Residual(main, shortcut=shortcut), - non_linearity() - ] - - -def IdentityBlock(kernel_size, filters, norm, non_linearity, - mode='train'): - """ResNet identical size block.""" - ks = kernel_size - filters1, filters2, filters3 = filters - main = [ - tl.Conv(filters1, (1, 1)), - norm(mode=mode), - non_linearity(), - tl.Conv(filters2, (ks, ks), padding='SAME'), - norm(mode=mode), - non_linearity(), - tl.Conv(filters3, (1, 1)), - norm(mode=mode), - ] - return [ - tl.Residual(main), - non_linearity(), - ] - - -def Resnet50(d_hidden=64, n_output_classes=1001, mode='train', - norm=tl.BatchNorm, - non_linearity=tl.Relu): - """ResNet. +"""Resnets.""" +# Copied from cloud_tpu/models/resnet/resnet_model.py and modified + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import hparam +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +BATCH_NORM_DECAY = 0.9 +BATCH_NORM_EPSILON = 1e-5 + + +# TODO(lukaszkaiser): remove or simplify after V2 work is done. +def layers(): + return common_layers.layers() + + +def batch_norm_relu(inputs, + is_training, + relu=True, + init_zero=False, + data_format="channels_first"): + """Performs a batch normalization followed by a ReLU. + + Args: + inputs: `Tensor` of shape `[batch, channels, ...]`. + is_training: `bool` for whether the model is training. + relu: `bool` if False, omits the ReLU operation. + init_zero: `bool` if True, initializes scale parameter of batch + normalization with 0 instead of 1 (default). + data_format: `str` either "channels_first" for `[batch, channels, height, + width]` or "channels_last for `[batch, height, width, channels]`. + + Returns: + A normalized `Tensor` with the same `data_format`. + """ + if init_zero: + gamma_initializer = tf.zeros_initializer() + else: + gamma_initializer = tf.ones_initializer() + + if data_format == "channels_first": + axis = 1 + else: + axis = 3 + + inputs = layers().BatchNormalization( + axis=axis, + momentum=BATCH_NORM_DECAY, + epsilon=BATCH_NORM_EPSILON, + center=True, + scale=True, + fused=True, + gamma_initializer=gamma_initializer)(inputs, training=is_training) + + if relu: + inputs = tf.nn.relu(inputs) + return inputs + + +def fixed_padding(inputs, kernel_size, data_format="channels_first"): + """Pads the input along the spatial dimensions independently of input size. + + Args: + inputs: `Tensor` of size `[batch, channels, height, width]` or + `[batch, height, width, channels]` depending on `data_format`. + kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d` + operations. Should be a positive integer. + data_format: `str` either "channels_first" for `[batch, channels, height, + width]` or "channels_last for `[batch, height, width, channels]`. + + Returns: + A padded `Tensor` of the same `data_format` with size either intact + (if `kernel_size == 1`) or padded (if `kernel_size > 1`). + """ + pad_total = kernel_size - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + if data_format == "channels_first": + padded_inputs = tf.pad( + inputs, [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]]) + else: + padded_inputs = tf.pad( + inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) + + return padded_inputs + + +def conv2d_fixed_padding(inputs, + filters, + kernel_size, + strides, + data_format="channels_first", + use_td=False, + targeting_rate=None, + keep_prob=None, + is_training=None): + """Strided 2-D convolution with explicit padding. + + The padding is consistent and is based only on `kernel_size`, not on the + dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). + + Args: + inputs: `Tensor` of size `[batch, channels, height_in, width_in]`. + filters: `int` number of filters in the convolution. + kernel_size: `int` size of the kernel to be used in the convolution. + strides: `int` strides of the convolution. + data_format: `str` either "channels_first" for `[batch, channels, height, + width]` or "channels_last for `[batch, height, width, channels]`. + use_td: `str` one of "weight" or "unit". Set to False or "" to disable + targeted dropout. + targeting_rate: `float` proportion of weights to target with targeted + dropout. + keep_prob: `float` keep probability for targeted dropout. + is_training: `bool` for whether the model is in training. + + Returns: + A `Tensor` of shape `[batch, filters, height_out, width_out]`. + + Raises: + Exception: if use_td is not valid. + """ + if strides > 1: + inputs = fixed_padding(inputs, kernel_size, data_format=data_format) + + if use_td: + inputs_shape = common_layers.shape_list(inputs) + if use_td == "weight": + if data_format == "channels_last": + size = kernel_size * kernel_size * inputs_shape[-1] + else: + size = kernel_size * kernel_size * inputs_shape[1] + targeting_count = targeting_rate * tf.to_float(size) + targeting_fn = common_layers.weight_targeting + elif use_td == "unit": + targeting_count = targeting_rate * filters + targeting_fn = common_layers.unit_targeting + else: + raise Exception("Unrecognized targeted dropout type: %s" % use_td) + + y = common_layers.td_conv( + inputs, + filters, + kernel_size, + targeting_count, + targeting_fn, + keep_prob, + is_training, + do_prune=True, + strides=strides, + padding=("SAME" if strides == 1 else "VALID"), + data_format=data_format, + use_bias=False, + kernel_initializer=tf.variance_scaling_initializer()) + else: + y = layers().Conv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=("SAME" if strides == 1 else "VALID"), + use_bias=False, + kernel_initializer=tf.variance_scaling_initializer(), + data_format=data_format)(inputs) + + return y + + +def residual_block(inputs, + filters, + is_training, + projection_shortcut, + strides, + final_block, + data_format="channels_first", + use_td=False, + targeting_rate=None, + keep_prob=None, + bottleneck_ratio=None): + """Standard building block for residual networks with BN before convolutions. Args: - d_hidden: Dimensionality of the first hidden layer (multiplied later). - n_output_classes: Number of distinct output classes. - mode: Whether we are training or evaluating or doing inference. - norm: `Layer` used for normalization, Ex: BatchNorm or - FilterResponseNorm. - non_linearity: `Layer` used as a non-linearity, Ex: If norm is - BatchNorm then this is a Relu, otherwise for FilterResponseNorm this - should be ThresholdedLinearUnit. + inputs: `Tensor` of size `[batch, channels, height, width]`. + filters: `int` number of filters for the first two convolutions. Note that + the third and final convolution will use 4 times as many filters. + is_training: `bool` for whether the model is in training. + projection_shortcut: `function` to use for projection shortcuts (typically + a 1x1 convolution to match the filter dimensions). If None, no + projection is used and the input is passed as unchanged through the + shortcut connection. + strides: `int` block stride. If greater than 1, this block will ultimately + downsample the input. + final_block: unused parameter to keep the same function signature as + `bottleneck_block`. + data_format: `str` either "channels_first" for `[batch, channels, height, + width]` or "channels_last for `[batch, height, width, channels]`. + use_td: `str` one of "weight" or "unit". Set to False or "" to disable + targeted dropout. + targeting_rate: `float` proportion of weights to target with targeted + dropout. + keep_prob: `float` keep probability for targeted dropout. + bottleneck_ratio: unused parameter to keep the same function signature as + `bottleneck_block`. Returns: - The list of layers comprising a ResNet model with the given parameters. + The output `Tensor` of the block. """ + del final_block + del bottleneck_ratio + shortcut = inputs + inputs = batch_norm_relu(inputs, is_training, data_format=data_format) + + if projection_shortcut is not None: + shortcut = projection_shortcut(inputs) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=strides, + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + is_training=is_training) + + inputs = batch_norm_relu(inputs, is_training, data_format=data_format) + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=1, + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + is_training=is_training) + + return inputs + shortcut + - # A ConvBlock configured with the given norm, non-linearity and mode. - def Resnet50ConvBlock(filter_multiplier=1, strides=(2, 2)): - filters = ( - [filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden]]) - return ConvBlock(3, filters, strides, norm, non_linearity, mode) - - # Same as above for IdentityBlock. - def Resnet50IdentityBlock(filter_multiplier=1): - filters = ( - [filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden]]) - return IdentityBlock(3, filters, norm, non_linearity, mode) - - return tl.Serial( - tl.ToFloat(), - tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), - norm(mode=mode), - non_linearity(), - tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), - Resnet50ConvBlock(strides=(1, 1)), - [Resnet50IdentityBlock() for _ in range(2)], - Resnet50ConvBlock(2), - [Resnet50IdentityBlock(2) for _ in range(3)], - Resnet50ConvBlock(4), - [Resnet50IdentityBlock(4) for _ in range(5)], - Resnet50ConvBlock(8), - [Resnet50IdentityBlock(8) for _ in range(2)], - tl.AvgPool(pool_size=(7, 7)), - tl.Flatten(), - tl.Dense(n_output_classes), - ) - - -def WideResnetBlock(channels, strides=(1, 1), bn_momentum=0.9, mode='train'): - """WideResnet convolutional block.""" - return [ - tl.BatchNorm(momentum=bn_momentum, mode=mode), - tl.Relu(), - tl.Conv(channels, (3, 3), strides, padding='SAME'), - tl.BatchNorm(momentum=bn_momentum, mode=mode), - tl.Relu(), - tl.Conv(channels, (3, 3), padding='SAME'), - ] - - -def WideResnetGroup(n, channels, strides=(1, 1), bn_momentum=0.9, mode='train'): - shortcut = [ - tl.Conv(channels, (3, 3), strides, padding='SAME'), - ] - return [ - tl.Residual(WideResnetBlock(channels, strides, bn_momentum=bn_momentum, - mode=mode), - shortcut=shortcut), - tl.Residual([WideResnetBlock(channels, (1, 1), bn_momentum=bn_momentum, - mode=mode) - for _ in range(n - 1)]), - ] - - -def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, bn_momentum=0.9, - mode='train'): - """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. +def bottleneck_block(inputs, + filters, + is_training, + projection_shortcut, + strides, + final_block, + data_format="channels_first", + use_td=False, + targeting_rate=None, + keep_prob=None, + bottleneck_ratio=4): + """Bottleneck block variant for residual networks with BN after convolutions. Args: - n_blocks: int, number of blocks in a group. total layers = 6n + 4. - widen_factor: int, widening factor of each group. k=1 is vanilla resnet. - n_output_classes: int, number of distinct output classes. - bn_momentum: float, momentum in BatchNorm. - mode: Whether we are training or evaluating or doing inference. + inputs: `Tensor` of size `[batch, channels, height, width]`. + filters: `int` number of filters for the first two convolutions. Note that + the third and final convolution will use 4 times as many filters. + is_training: `bool` for whether the model is in training. + projection_shortcut: `function` to use for projection shortcuts (typically + a 1x1 convolution to match the filter dimensions). If None, no + projection is used and the input is passed as unchanged through the + shortcut connection. + strides: `int` block stride. If greater than 1, this block will ultimately + downsample the input. + final_block: `bool` set to True if it is this the final block in the group. + This is changes the behavior of batch normalization initialization for + the final batch norm in a block. + data_format: `str` either "channels_first" for `[batch, channels, height, + width]` or "channels_last for `[batch, height, width, channels]`. + use_td: `str` one of "weight" or "unit". Set to False or "" to disable + targeted dropout. + targeting_rate: `float` proportion of weights to target with targeted + dropout. + keep_prob: `float` keep probability for targeted dropout. + bottleneck_ratio: `int`, how much we scale up filters. + Returns: - The list of layers comprising a WideResnet model with the given parameters. + The output `Tensor` of the block. """ - return tl.Serial( - tl.ToFloat(), - tl.Conv(16, (3, 3), padding='SAME'), - WideResnetGroup(n_blocks, 16 * widen_factor, bn_momentum=bn_momentum, - mode=mode), - WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), - bn_momentum=bn_momentum, mode=mode), - WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), - bn_momentum=bn_momentum, mode=mode), - tl.BatchNorm(momentum=bn_momentum, mode=mode), - tl.Relu(), - tl.AvgPool(pool_size=(8, 8)), - tl.Flatten(), - tl.Dense(n_output_classes), - ) + # TODO(chrisying): this block is technically the post-activation resnet-v1 + # bottleneck unit. Test with v2 (pre-activation) and replace if there is no + # difference for consistency. + shortcut = inputs + if projection_shortcut is not None: + shortcut = projection_shortcut(inputs) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=1, + strides=1, + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + is_training=is_training) + + inputs = batch_norm_relu(inputs, is_training, data_format=data_format) + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=strides, + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + is_training=is_training) + + inputs = batch_norm_relu(inputs, is_training, data_format=data_format) + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=bottleneck_ratio * filters, + kernel_size=1, + strides=1, + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + is_training=is_training) + inputs = batch_norm_relu( + inputs, + is_training, + relu=False, + init_zero=final_block, + data_format=data_format) + + return tf.nn.relu(inputs + shortcut) + + +def block_layer(inputs, + filters, + block_fn, + blocks, + strides, + is_training, + name, + data_format="channels_first", + use_td=False, + targeting_rate=None, + keep_prob=None, + bottleneck_ratio=4): + """Creates one layer of blocks for the ResNet model. + + Args: + inputs: `Tensor` of size `[batch, channels, height, width]`. + filters: `int` number of filters for the first convolution of the layer. + block_fn: `function` for the block to use within the model + blocks: `int` number of blocks contained in the layer. + strides: `int` stride to use for the first convolution of the layer. If + greater than 1, this layer will downsample the input. + is_training: `bool` for whether the model is training. + name: `str`name for the Tensor output of the block layer. + data_format: `str` either "channels_first" for `[batch, channels, height, + width]` or "channels_last for `[batch, height, width, channels]`. + use_td: `str` one of "weight" or "unit". Set to False or "" to disable + targeted dropout. + targeting_rate: `float` proportion of weights to target with targeted + dropout. + keep_prob: `float` keep probability for targeted dropout. + bottleneck_ratio: `int`, how much we scale up filters in bottleneck block. + + Returns: + The output `Tensor` of the block layer. + """ + # Bottleneck blocks end with bottleneck_ratio x the number of filters + filters_out = filters + if block_fn is bottleneck_block: + filters_out = bottleneck_ratio * filters + + def projection_shortcut(inputs): + """Project identity branch.""" + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters_out, + kernel_size=1, + strides=strides, + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + is_training=is_training) + return batch_norm_relu( + inputs, is_training, relu=False, data_format=data_format) + + # Only the first block per block_layer uses projection_shortcut and strides + inputs = block_fn( + inputs, + filters, + is_training, + projection_shortcut, + strides, + False, + data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + bottleneck_ratio=bottleneck_ratio) + + for i in range(1, blocks): + inputs = block_fn( + inputs, + filters, + is_training, + None, + 1, (i + 1 == blocks), + data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + bottleneck_ratio=bottleneck_ratio) + + return tf.identity(inputs, name) + + +def resnet_v2(inputs, + block_fn, + layer_blocks, + filters, + data_format="channels_first", + is_training=False, + is_cifar=False, + use_td=False, + targeting_rate=None, + keep_prob=None, + bottleneck_ratios=None): + """Resnet model. + + Args: + inputs: `Tensor` images. + block_fn: `function` for the block to use within the model. Either + `residual_block` or `bottleneck_block`. + layer_blocks: list of 3 or 4 `int`s denoting the number of blocks to include + in each of the 3 or 4 block groups. Each group consists of blocks that + take inputs of the same resolution. + filters: list of 4 or 5 `int`s denoting the number of filter to include in + block. + data_format: `str`, "channels_first" `[batch, channels, height, + width]` or "channels_last" `[batch, height, width, channels]`. + is_training: bool, build in training mode or not. + is_cifar: bool, whether the data is CIFAR or not. + use_td: `str` one of "weight" or "unit". Set to False or "" to disable + targeted dropout. + targeting_rate: `float` proportion of weights to target with targeted + dropout. + keep_prob: `float` keep probability for targeted dropout. + bottleneck_ratios: list of `int`s, how much we scale up filters in + bottleneck blocks. + + Returns: + Pre-logit activations. + """ + inputs = block_layer( + inputs=inputs, + filters=filters[1], + block_fn=block_fn, + blocks=layer_blocks[0], + strides=1, + is_training=is_training, + name="block_layer1", + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + bottleneck_ratio=bottleneck_ratios[0]) + inputs = block_layer( + inputs=inputs, + filters=filters[2], + block_fn=block_fn, + blocks=layer_blocks[1], + strides=2, + is_training=is_training, + name="block_layer2", + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + bottleneck_ratio=bottleneck_ratios[1]) + inputs = block_layer( + inputs=inputs, + filters=filters[3], + block_fn=block_fn, + blocks=layer_blocks[2], + strides=2, + is_training=is_training, + name="block_layer3", + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + bottleneck_ratio=bottleneck_ratios[2]) + if not is_cifar: + inputs = block_layer( + inputs=inputs, + filters=filters[4], + block_fn=block_fn, + blocks=layer_blocks[3], + strides=2, + is_training=is_training, + name="block_layer4", + data_format=data_format, + use_td=use_td, + targeting_rate=targeting_rate, + keep_prob=keep_prob, + bottleneck_ratio=bottleneck_ratios[3]) + + return inputs + + +@registry.register_model +class Resnet(t2t_model.T2TModel): + """Residual Network.""" + + def body(self, features): + hp = self.hparams + block_fns = { + "residual": residual_block, + "bottleneck": bottleneck_block, + } + assert hp.block_fn in block_fns + is_training = hp.mode == tf_estimator.ModeKeys.TRAIN + if is_training: + targets = features["targets_raw"] + + inputs = features["inputs"] + + data_format = "channels_last" + if hp.use_nchw: + # Convert from channels_last (NHWC) to channels_first (NCHW). This + # provides a large performance boost on GPU. + inputs = tf.transpose(inputs, [0, 3, 1, 2]) + data_format = "channels_first" + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=hp.filter_sizes[0], + kernel_size=7, + strides=1 if hp.is_cifar else 2, + data_format=data_format) + inputs = tf.identity(inputs, "initial_conv") + inputs = batch_norm_relu(inputs, is_training, data_format=data_format) + + if not hp.is_cifar: + inputs = layers().MaxPooling2D( + pool_size=3, + strides=2, + padding="SAME", + data_format=data_format)(inputs) + inputs = tf.identity(inputs, "initial_max_pool") + + out = resnet_v2( + inputs, + block_fns[hp.block_fn], + hp.layer_sizes, + hp.filter_sizes, + data_format, + is_training=is_training, + is_cifar=hp.is_cifar, + use_td=hp.use_td, + targeting_rate=hp.targeting_rate, + keep_prob=hp.keep_prob, + bottleneck_ratios=hp.bottleneck_ratios) + + if hp.use_nchw: + out = tf.transpose(out, [0, 2, 3, 1]) + + if not hp.is_cifar: + return out + + out = tf.reduce_mean(out, [1, 2]) + num_classes = self._problem_hparams.vocab_size["targets"] + if hasattr(self._hparams, "vocab_divisor"): + num_classes += (-num_classes) % self._hparams.vocab_divisor + logits = layers().Dense(num_classes, name="logits")(out) + + losses = {"training": 0.0} + if is_training: + loss = tf.losses.sparse_softmax_cross_entropy( + labels=tf.squeeze(targets), logits=logits) + loss = tf.reduce_mean(loss) + + losses = {"training": loss} + + logits = tf.reshape(logits, [-1, 1, 1, 1, logits.shape[1]]) + + return logits, losses + + def infer(self, + features=None, + decode_length=50, + beam_size=1, + top_beams=1, + alpha=0.0, + use_tpu=False): + """Predict.""" + del decode_length, beam_size, top_beams, alpha, use_tpu + assert features is not None + logits, _ = self(features) # pylint: disable=not-callable + assert len(logits.get_shape()) == 5 + logits = tf.squeeze(logits, [1, 2, 3]) + log_probs = common_layers.log_prob_from_logits(logits) + predictions, scores = common_layers.argmax_with_score(log_probs) + return { + "outputs": predictions, + "scores": scores, + } + + +def resnet_base(): + """Set of hyperparameters.""" + # For imagenet on TPU: + # Set train_steps=120000 + # Set eval_steps=48 + + # Base + hparams = common_hparams.basic_params1() + + # Model-specific parameters + hparams.add_hparam("layer_sizes", [3, 4, 6, 3]) + hparams.add_hparam("bottleneck_ratios", [4, 4, 4, 4]) + hparams.add_hparam("filter_sizes", [64, 64, 128, 256, 512]) + hparams.add_hparam("block_fn", "bottleneck") + hparams.add_hparam("use_nchw", True) + hparams.add_hparam("is_cifar", False) + + # Targeted dropout + hparams.add_hparam("use_td", False) + hparams.add_hparam("targeting_rate", None) + hparams.add_hparam("keep_prob", None) + + # Variable init + hparams.initializer = "normal_unit_scaling" + hparams.initializer_gain = 2. + + # Optimization + hparams.optimizer = "Momentum" + hparams.optimizer_momentum_momentum = 0.9 + hparams.optimizer_momentum_nesterov = True + hparams.weight_decay = 1e-4 + hparams.clip_grad_norm = 0.0 + # (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.) + hparams.learning_rate = 0.4 + hparams.learning_rate_decay_scheme = "cosine" + # For image_imagenet224, 120k training steps, which effectively makes this a + # cosine decay (i.e. no cycles). + hparams.learning_rate_cosine_cycle_steps = 120000 + + hparams.batch_size = 128 + return hparams + + +@registry.register_hparams +def resnet_50(): + hp = resnet_base() + return hp + + +@registry.register_hparams +def resnet_18(): + hp = resnet_base() + hp.block_fn = "residual" + hp.layer_sizes = [2, 2, 2, 2] + return hp + + +@registry.register_hparams +def resnet_imagenet_34(): + """Set of hyperparameters.""" + hp = resnet_base() + hp.block_fn = "residual" + hp.layer_sizes = [2, 4, 8, 2] + + return hp + + +@registry.register_hparams +def resnet_imagenet_34_td_weight_05_05(): + """Set of hyperparameters.""" + hp = resnet_imagenet_34() + hp.use_td = "weight" + hp.targeting_rate = 0.5 + hp.keep_prob = 0.5 + + return hp + + +@registry.register_hparams +def resnet_imagenet_34_td_unit_05_05(): + """Set of hyperparameters.""" + hp = resnet_imagenet_34() + hp.use_td = "unit" + hp.targeting_rate = 0.5 + hp.keep_prob = 0.5 + + return hp + + +@registry.register_hparams +def resnet_imagenet_34_td_unit_no_drop(): + """Set of hyperparameters.""" + hp = resnet_imagenet_34() + hp.use_td = "unit" + hp.targeting_rate = 0.0 + hp.keep_prob = 1.0 + + return hp + + +@registry.register_hparams +def resnet_imagenet_102(): + hp = resnet_imagenet_34() + hp.layer_sizes = [3, 8, 36, 3] + return hp + + +@registry.register_hparams +def resnet_cifar_15(): + """Set of hyperparameters.""" + hp = resnet_base() + hp.block_fn = "residual" + hp.is_cifar = True + hp.layer_sizes = [2, 2, 2] + hp.filter_sizes = [16, 32, 64, 128] + + return hp + + +@registry.register_hparams +def resnet_cifar_32(): + hp = resnet_cifar_15() + hp.layer_sizes = [5, 5, 5] + return hp + + +@registry.register_hparams +def resnet_cifar_32_td_weight_05_05(): + hp = resnet_cifar_32() + hp.use_td = "weight" + hp.targeting_rate = 0.5 + hp.keep_prob = 0.5 + return hp + + +@registry.register_hparams +def resnet_cifar_32_td_unit_05_05(): + hp = resnet_cifar_32() + hp.use_td = "unit" + hp.targeting_rate = 0.5 + hp.keep_prob = 0.5 + return hp + + +@registry.register_hparams +def resnet_cifar_32_td_unit_no_drop(): + hp = resnet_cifar_32() + hp.use_td = "unit" + hp.targeting_rate = 0.0 + hp.keep_prob = 1.0 + return hp + + +@registry.register_hparams +def resnet_34(): + hp = resnet_base() + hp.block_fn = "residual" + return hp + + +@registry.register_hparams +def resnet_101(): + hp = resnet_base() + hp.layer_sizes = [3, 4, 23, 3] + return hp + + +@registry.register_hparams +def resnet_152(): + hp = resnet_base() + hp.layer_sizes = [3, 8, 36, 3] + return hp + + +@registry.register_hparams +def resnet_200(): + hp = resnet_base() + hp.layer_sizes = [3, 24, 36, 3] + return hp + + +# Pruning parameters +@registry.register_pruning_params +def resnet_weight(): + hp = hparam.HParams() + hp.add_hparam("strategy", "weight") + hp.add_hparam("black_list", ["logits", "bias"]) + hp.add_hparam("white_list", ["td_conv"]) + hp.add_hparam("sparsities", [0.1 * i for i in range(10)]) + return hp + + +@registry.register_pruning_params +def resnet_unit(): + hp = resnet_weight() + hp.strategy = "unit" + return hp + + +# Adversarial attack parameters +@registry.register_attack_params +def resnet_fgsm(): + aparams = hparam.HParams() + aparams.attack = "fgsm" + aparams.epsilon_name = "eps" + aparams.attack_epsilons = [i * 0.8 for i in range(20)] + aparams.add_hparam("clip_min", 0.0) + aparams.add_hparam("clip_max", 255.0) + return aparams + + +@registry.register_attack_params +def resnet_madry(): + aparams = resnet_fgsm() + aparams.attack = "madry" + aparams.add_hparam("nb_iter", 40) + aparams.add_hparam("eps_iter", 1.0) + return aparams + + +@registry.register_attack_params +def resnet_random(): + aparams = resnet_fgsm() + aparams.attack = "random" + aparams.epsilon_name = "eps" + aparams.add_hparam("num_samples", 10) + aparams.add_hparam("num_batches", 100) + return aparams diff --git a/trax/models/resnet_test.py b/trax/models/resnet_test.py index 3742d67ae..3b629fa48 100644 --- a/trax/models/resnet_test.py +++ b/trax/models/resnet_test.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The Trax Authors. +# Copyright 2023 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,33 +13,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Resnet models.""" +"""Resnet tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -from absl.testing import absltest import numpy as np -from trax import fastmath -from trax import shapes -from trax.models import resnet +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.layers import modalities +from tensor2tensor.models import resnet + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + +def resnet_tiny_cpu(): + hparams = resnet.resnet_base() + hparams.layer_sizes = [2, 2, 2, 2] + hparams.use_nchw = False + return hparams -class ResnetTest(absltest.TestCase): - def test_resnet(self): - model = resnet.Resnet50(d_hidden=8, n_output_classes=10) - x = np.ones((3, 256, 256, 3)).astype(np.float32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 10)) +class ResnetTest(tf.test.TestCase): - def test_wide_resnet(self): - model = resnet.WideResnet(n_blocks=1, n_output_classes=10) - x = np.ones((3, 32, 32, 3)).astype(np.float32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 10)) + def _test_resnet(self, img_size, output_size): + vocab_size = 9 + batch_size = 2 + x = np.random.randint( + 256, size=(batch_size, img_size, img_size, 3)) + y = np.random.randint( + 1, high=vocab_size, size=(batch_size, 1, 1, 1)) + hparams = resnet_tiny_cpu() + p_hparams = problem_hparams.test_problem_hparams(vocab_size, + vocab_size, + hparams) + p_hparams.modality["inputs"] = modalities.ModalityType.IMAGE + p_hparams.modality["targets"] = modalities.ModalityType.CLASS_LABEL + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + } + model = resnet.Resnet(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (batch_size,) + output_size + (1, vocab_size)) + def testResnetLarge(self): + self._test_resnet(img_size=224, output_size=(1, 1)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/revnet.py b/trax/models/revnet.py new file mode 100644 index 000000000..e841652af --- /dev/null +++ b/trax/models/revnet.py @@ -0,0 +1,439 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Creates a RevNet with the bottleneck residual function. + +Implements the following equations described in the RevNet paper: +y1 = x1 + f(x2) +y2 = x2 + g(y1) + +However, in practice, the authors use the following equations to downsample +tensors inside a RevNet block: + +y1 = h(x1) + f(x2) +y2 = h(x2) + g(y1) + +In this case, h is the downsampling function used to change number of channels. + +These modified equations are evident in the authors' code online: +https://github.com/renmengye/revnet-public + +For reference, the original paper can be found here: +https://arxiv.org/pdf/1707.04585.pdf +""" + +import functools +from tensor2tensor.layers import common_hparams +from tensor2tensor.utils import contrib +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +def wrapped_partial(fn, *args, **kwargs): + partial = functools.partial(fn, *args, **kwargs) + wrapped = functools.update_wrapper(partial, fn) + return wrapped + + +conv_initializer = tf.initializers.variance_scaling( + scale=2.0, mode='fan_out') + +CONFIG = {'2d': {'conv': wrapped_partial( + tf.layers.conv2d, kernel_initializer=conv_initializer), + 'max_pool': tf.layers.max_pooling2d, + 'avg_pool': tf.layers.average_pooling2d, + 'split_axis': 3, + 'reduction_dimensions': [1, 2] + }, + '3d': {'conv': wrapped_partial( + tf.layers.conv3d, kernel_initializer=conv_initializer), + 'max_pool': tf.layers.max_pooling3d, + 'avg_pool': tf.layers.average_pooling2d, + 'split_axis': 4, + 'reduction_dimensions': [1, 2, 3] + } + } + + +def f(x, depth1, depth2, dim='2d', first_batch_norm=True, stride=1, + training=True, bottleneck=True, padding='SAME'): + """Applies residual function for RevNet. + + Args: + x: input tensor + depth1: Number of output channels for the first and second conv layers. + depth2: Number of output channels for the third conv layer. + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. + first_batch_norm: Whether to keep the first batch norm layer or not. + Typically used in the first RevNet block. + stride: Stride for the first conv filter. Note that this particular + RevNet architecture only varies the stride for the first conv + filter. The stride for the second conv filter is always set to 1. + training: True for train phase, False for eval phase. + bottleneck: If true, apply bottleneck 1x1 down/up sampling. + padding: Padding for each conv layer. + + Returns: + Output tensor after applying residual function for RevNet. + """ + conv = CONFIG[dim]['conv'] + with tf.variable_scope('f', reuse=tf.AUTO_REUSE): + if first_batch_norm: + net = tf.layers.batch_normalization(x, training=training) + net = tf.nn.relu(net) + else: + net = x + + if bottleneck: + net = conv(net, depth1, 1, strides=stride, + padding=padding, activation=None) + + net = tf.layers.batch_normalization(net, training=training) + net = tf.nn.relu(net) + net = conv(net, depth1, 3, strides=1, + padding=padding, activation=None) + + net = tf.layers.batch_normalization(net, training=training) + net = tf.nn.relu(net) + net = conv(net, depth2, 1, strides=1, + padding=padding, activation=None) + else: + net = conv(net, depth2, 3, strides=stride, + padding=padding, activation=None) + net = tf.layers.batch_normalization(x, training=training) + net = tf.nn.relu(net) + net = conv(net, depth2, 3, strides=stride, + padding=padding, activation=None) + + return net + + +def downsample_bottleneck(x, output_channels, dim='2d', stride=1, scope='h'): + """Downsamples 'x' by `stride` using a 1x1 convolution filter. + + Args: + x: input tensor of size [N, H, W, C] + output_channels: Desired number of output channels. + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. + stride: What stride to use. Usually 1 or 2. + scope: Optional variable scope. + + Returns: + A downsampled tensor of size [N, H/2, W/2, output_channels] if stride + is 2, else returns a tensor of size [N, H, W, output_channels] if + stride is 1. + """ + conv = CONFIG[dim]['conv'] + with tf.variable_scope(scope): + x = conv(x, output_channels, 1, strides=stride, padding='SAME', + activation=None) + return x + + +def downsample_residual(x, output_channels, dim='2d', stride=1, scope='h'): + """Downsamples 'x' by `stride` using average pooling. + + Args: + x: input tensor of size [N, H, W, C] + output_channels: Desired number of output channels. + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. + stride: What stride to use. Usually 1 or 2. + scope: Optional variable scope. + + Returns: + A downsampled tensor of size [N, H/2, W/2, output_channels] if stride + is 2, else returns a tensor of size [N, H, W, output_channels] if + stride is 1. + """ + with tf.variable_scope(scope): + if stride > 1: + avg_pool = CONFIG[dim]['avg_pool'] + x = avg_pool(x, + pool_size=(stride, stride), + strides=(stride, stride), + padding='VALID') + + input_channels = tf.shape(x)[3] + diff = output_channels - input_channels + x = tf.pad( + x, [[0, 0], [0, 0], [0, 0], + [diff // 2, diff // 2]]) + return x + + +def init(images, num_channels, dim='2d', stride=2, + kernel_size=7, maxpool=True, training=True, scope='init'): + """Standard ResNet initial block used as first RevNet block. + + Args: + images: [N, H, W, 3] tensor of input images to the model. + num_channels: Output depth of convolutional layer in initial block. + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. + stride: stride for the convolution and pool layer. + kernel_size: Size of the initial convolution filter + maxpool: If true, apply a maxpool after the convolution + training: True for train phase, False for eval phase. + scope: Optional scope for the init block. + + Returns: + Two [N, H, W, C] output activations from input images. + """ + conv = CONFIG[dim]['conv'] + pool = CONFIG[dim]['max_pool'] + with tf.variable_scope(scope): + net = conv(images, num_channels, kernel_size, strides=stride, + padding='SAME', activation=None) + net = tf.layers.batch_normalization(net, training=training) + net = tf.nn.relu(net) + if maxpool: + net = pool(net, pool_size=3, strides=stride) + x1, x2 = tf.split(net, 2, axis=CONFIG[dim]['split_axis']) + return x1, x2 + + +def unit(x1, x2, block_num, depth, num_layers, dim='2d', + bottleneck=True, first_batch_norm=True, stride=1, training=True): + """Implements bottleneck RevNet unit from authors' RevNet architecture. + + Args: + x1: [N, H, W, C] tensor of network activations. + x2: [N, H, W, C] tensor of network activations. + block_num: integer ID of block + depth: First depth in bottleneck residual unit. + num_layers: Number of layers in the RevNet block. + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. + bottleneck: Should a bottleneck layer be used. + first_batch_norm: Whether to keep the first batch norm layer or not. + Typically used in the first RevNet block. + stride: Stride for the residual function. + training: True for train phase, False for eval phase. + + Returns: + Two [N, H, W, C] output activation tensors. + """ + scope_name = 'unit_%d' % block_num + if bottleneck: + depth1 = depth + depth2 = depth * 4 + else: + depth1 = depth2 = depth + + residual = wrapped_partial(f, + depth1=depth1, depth2=depth2, dim=dim, + training=training, bottleneck=bottleneck) + + with tf.variable_scope(scope_name): + downsample = downsample_bottleneck if bottleneck else downsample_residual + # Manual implementation of downsampling + with tf.variable_scope('downsampling'): + with tf.variable_scope('x1'): + hx1 = downsample(x1, depth2, dim=dim, stride=stride) + fx2 = residual(x2, stride=stride, first_batch_norm=first_batch_norm) + x1 = hx1 + fx2 + with tf.variable_scope('x2'): + hx2 = downsample(x2, depth2, dim=dim, stride=stride) + fx1 = residual(x1) + x2 = hx2 + fx1 + + # Full block using memory-efficient rev_block implementation. + with tf.variable_scope('full_block'): + x1, x2 = contrib.layers().rev_block( + x1, x2, residual, residual, num_layers=num_layers) + return x1, x2 + + +def final_block(x1, x2, dim='2d', training=True, scope='final_block'): + """Converts activations from last RevNet block to pre-logits. + + Args: + x1: [NxHxWxC] tensor of network activations. + x2: [NxHxWxC] tensor of network activations. + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. + training: True for train phase, False for eval phase. + scope: Optional variable scope for the final block. + + Returns: + [N, hidden_dim] pre-logits tensor from activations x1 and x2. + """ + + # Final batch norm and relu + with tf.variable_scope(scope): + y = tf.concat([x1, x2], axis=CONFIG[dim]['split_axis']) + y = tf.layers.batch_normalization(y, training=training) + y = tf.nn.relu(y) + + # Global average pooling + net = tf.reduce_mean(y, CONFIG[dim]['reduction_dimensions'], + name='final_pool', keep_dims=True) + + return net + + +def revnet(inputs, hparams, reuse=None): + """Uses Tensor2Tensor memory optimized RevNet block to build a RevNet. + + Args: + inputs: [NxHxWx3] tensor of input images to the model. + hparams: HParams object that contains the following parameters, + in addition to the parameters contained in the basic_params1() object in + the common_hparams module: + num_channels_first - A Python list where each element represents the + depth of the first and third convolutional layers in the bottleneck + residual unit for a given block. + num_channels_second - A Python list where each element represents the + depth of the second convolutional layer in the bottleneck residual + unit for a given block. + num_layers_per_block - A Python list containing the number of RevNet + layers for each block. + first_batch_norm - A Python list containing booleans representing the + presence of a batch norm layer at the beginning of a given block. + strides - A Python list containing integers representing the stride of + the residual function for each block. + num_channels_init_block - An integer representing the number of channels + for the convolutional layer in the initial block. + dimension - A string (either "2d" or "3d") that decides if the RevNet is + 2-dimensional or 3-dimensional. + reuse: Whether to reuse the default variable scope. + + Returns: + [batch_size, hidden_dim] pre-logits tensor from the bottleneck RevNet. + """ + training = hparams.mode == tf_estimator.ModeKeys.TRAIN + with tf.variable_scope('RevNet', reuse=reuse): + x1, x2 = init(inputs, + num_channels=hparams.num_channels_init_block, + dim=hparams.dim, + kernel_size=hparams.init_kernel_size, + maxpool=hparams.init_maxpool, + stride=hparams.init_stride, + training=training) + for block_num in range(len(hparams.num_layers_per_block)): + block = {'depth': hparams.num_channels[block_num], + 'num_layers': hparams.num_layers_per_block[block_num], + 'first_batch_norm': hparams.first_batch_norm[block_num], + 'stride': hparams.strides[block_num], + 'bottleneck': hparams.bottleneck} + x1, x2 = unit(x1, x2, block_num, dim=hparams.dim, training=training, + **block) + pre_logits = final_block(x1, x2, dim=hparams.dim, training=training) + return pre_logits + + +@registry.register_model +class Revnet(t2t_model.T2TModel): + + def body(self, features): + return revnet(features['inputs'], self.hparams) + + +def revnet_base(): + """Default hparams for Revnet.""" + hparams = common_hparams.basic_params1() + hparams.add_hparam('num_channels', [64, 128, 256, 416]) + hparams.add_hparam('num_layers_per_block', [1, 1, 10, 1]) + hparams.add_hparam('bottleneck', True) + hparams.add_hparam('first_batch_norm', [False, True, True, True]) + hparams.add_hparam('init_stride', 2) + hparams.add_hparam('init_kernel_size', 7) + hparams.add_hparam('init_maxpool', True) + hparams.add_hparam('strides', [1, 2, 2, 2]) + hparams.add_hparam('num_channels_init_block', 64) + hparams.add_hparam('dim', '2d') + + # Variable init + hparams.initializer = 'normal_unit_scaling' + hparams.initializer_gain = 2. + + # Optimization + hparams.optimizer = 'Momentum' + hparams.optimizer_momentum_momentum = 0.9 + hparams.optimizer_momentum_nesterov = True + hparams.weight_decay = 1e-4 + hparams.clip_grad_norm = 0.0 + # (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.) + hparams.learning_rate = 0.4 + hparams.learning_rate_decay_scheme = 'cosine' + # For image_imagenet224, 120k training steps, which effectively makes this a + # cosine decay (i.e. no cycles). + hparams.learning_rate_cosine_cycle_steps = 120000 + + # Can run with a batch size of 128 with Problem ImageImagenet224 + hparams.batch_size = 128 + return hparams + + +@registry.register_hparams +def revnet_104(): + return revnet_base() + + +def revnet_cifar_base(): + """Tiny hparams suitable for CIFAR/etc.""" + hparams = revnet_base() + hparams.num_channels_init_block = 32 + hparams.first_batch_norm = [False, True, True] + hparams.init_stride = 1 + hparams.init_kernel_size = 3 + hparams.init_maxpool = False + hparams.strides = [1, 2, 2] + hparams.batch_size = 128 + hparams.weight_decay = 1e-4 + + hparams.learning_rate = 0.1 + hparams.learning_rate_cosine_cycle_steps = 5000 + return hparams + + +@registry.register_hparams +def revnet_38_cifar(): + hparams = revnet_cifar_base() + hparams.bottleneck = False + hparams.num_channels = [16, 32, 56] + hparams.num_layers_per_block = [2, 2, 2] + hparams.initializer = 'normal_unit_scaling' + hparams.initializer_gain = 1.5 + return hparams + + +@registry.register_hparams +def revnet_110_cifar(): + """Tiny hparams suitable for CIFAR/etc.""" + hparams = revnet_cifar_base() + hparams.bottleneck = False + hparams.num_channels = [16, 32, 64] + hparams.num_layers_per_block = [8, 8, 8] + return hparams + + +@registry.register_hparams +def revnet_164_cifar(): + """Tiny hparams suitable for CIFAR/etc.""" + hparams = revnet_cifar_base() + hparams.bottleneck = True + hparams.num_channels = [16, 32, 64] + hparams.num_layers_per_block = [8, 8, 8] + return hparams + + +@registry.register_ranged_hparams +def revnet_range(rhp): + """Hyperparameters for tuning revnet.""" + rhp.set_float('learning_rate', 0.05, 0.2, scale=rhp.LOG_SCALE) + rhp.set_float('weight_decay', 1e-5, 1e-3, scale=rhp.LOG_SCALE) + rhp.set_discrete('num_channels_init_block', [64, 128]) + return rhp diff --git a/trax/models/revnet_test.py b/trax/models/revnet_test.py new file mode 100644 index 000000000..234752514 --- /dev/null +++ b/trax/models/revnet_test.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Revnet.""" + +from tensor2tensor.models import revnet +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class RevnetTest(tf.test.TestCase): + + def testH(self): + rev_block_input = tf.random_uniform([1, 299, 299, 3]) + rev_block_output = revnet.downsample_bottleneck(rev_block_input, 256) + self.assertEqual(rev_block_output.get_shape().as_list(), [1, 299, 299, 256]) + + def testHStride(self): + rev_block_input = tf.random_uniform([2, 299, 299, 256]) + rev_block_output = revnet.downsample_bottleneck( + rev_block_input, 512, stride=2, scope='HStride') + self.assertEqual(rev_block_output.get_shape().as_list(), [2, 150, 150, 512]) + + def testInit(self): + images = tf.random_uniform([1, 299, 299, 3]) + x1, x2 = revnet.init(images, 32) + self.assertEqual(x1.get_shape().as_list(), [1, 74, 74, 16]) + self.assertEqual(x2.get_shape().as_list(), [1, 74, 74, 16]) + + def testInit3D(self): + images = tf.random_uniform([1, 299, 299, 299, 3]) + x1, x2 = revnet.init(images, 32, dim='3d', scope='init3d') + self.assertEqual(x1.get_shape().as_list(), [1, 74, 74, 74, 16]) + self.assertEqual(x2.get_shape().as_list(), [1, 74, 74, 74, 16]) + + def testUnit1(self): + x1 = tf.random_uniform([4, 74, 74, 256]) + x2 = tf.random_uniform([4, 74, 74, 256]) + x1, x2 = revnet.unit(x1, x2, block_num=1, depth=64, + first_batch_norm=True, num_layers=1) + self.assertEqual(x1.get_shape().as_list(), [4, 74, 74, 256]) + self.assertEqual(x2.get_shape().as_list(), [4, 74, 74, 256]) + + def testUnit2(self): + x1 = tf.random_uniform([4, 74, 74, 256]) + x2 = tf.random_uniform([4, 74, 74, 256]) + x1, x2 = revnet.unit(x1, x2, block_num=2, depth=128, + num_layers=1, stride=2) + self.assertEqual(x1.get_shape().as_list(), [4, 37, 37, 512]) + self.assertEqual(x2.get_shape().as_list(), [4, 37, 37, 512]) + + def testUnit3(self): + x1 = tf.random_uniform([1, 37, 37, 512]) + x2 = tf.random_uniform([1, 37, 37, 512]) + x1, x2 = revnet.unit(x1, x2, block_num=3, depth=256, + num_layers=10, stride=2) + self.assertEqual(x1.get_shape().as_list(), [1, 19, 19, 1024]) + self.assertEqual(x2.get_shape().as_list(), [1, 19, 19, 1024]) + + def testUnit4(self): + x1 = tf.random_uniform([1, 19, 19, 1024]) + x2 = tf.random_uniform([1, 19, 19, 1024]) + x1, x2 = revnet.unit(x1, x2, block_num=4, depth=416, + num_layers=1, stride=2) + self.assertEqual(x1.get_shape().as_list(), [1, 10, 10, 1664]) + self.assertEqual(x2.get_shape().as_list(), [1, 10, 10, 1664]) + + def testUnit3D(self): + x1 = tf.random_uniform([4, 74, 74, 74, 256]) + x2 = tf.random_uniform([4, 74, 74, 74, 256]) + x1, x2 = revnet.unit(x1, x2, block_num=5, depth=128, + num_layers=1, dim='3d', stride=2) + self.assertEqual(x1.get_shape().as_list(), [4, 37, 37, 37, 512]) + self.assertEqual(x2.get_shape().as_list(), [4, 37, 37, 37, 512]) + + def testFinalBlock(self): + x1 = tf.random_uniform([5, 10, 10, 1024]) + x2 = tf.random_uniform([5, 10, 10, 1024]) + logits = revnet.final_block(x1, x2) + self.assertEqual(logits.shape, [5, 1, 1, 2048]) + + def testFinalBlock3D(self): + x1 = tf.random_uniform([5, 10, 10, 10, 1024]) + x2 = tf.random_uniform([5, 10, 10, 10, 1024]) + logits = revnet.final_block(x1, x2, dim='3d', scope='FinalBlock3D') + self.assertEqual(logits.shape, [5, 1, 1, 1, 2048]) + + def testEndToEnd(self): + images = tf.random_uniform([1, 299, 299, 3]) + hparams = revnet.revnet_base() + hparams.mode = tf_estimator.ModeKeys.TRAIN + logits = revnet.revnet(images, hparams) + self.assertEqual(logits.shape, [1, 1, 1, 3328]) + + def testEndToEnd3D(self): + images = tf.random_uniform([1, 299, 299, 299, 3]) + hparams = revnet.revnet_base() + hparams.dim = '3d' + hparams.mode = tf_estimator.ModeKeys.TRAIN + logits = revnet.revnet(images, hparams) + self.assertEqual(logits.shape, [1, 1, 1, 1, 3328]) + +if __name__ == '__main__': + tf.test.main() diff --git a/trax/models/shake_shake.py b/trax/models/shake_shake.py new file mode 100644 index 000000000..378f86c97 --- /dev/null +++ b/trax/models/shake_shake.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shake-shake model for CIFAR.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import hparam +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +def shake_shake_skip_connection(x, output_filters, stride, is_training): + """Adds a residual connection to the filter x for the shake-shake model.""" + curr_filters = common_layers.shape_list(x)[-1] + if curr_filters == output_filters: + return x + stride_spec = [1, stride, stride, 1] + # Skip path 1. + path1 = tf.nn.avg_pool(x, [1, 1, 1, 1], stride_spec, "VALID") + path1 = tf.layers.conv2d( + path1, int(output_filters / 2), (1, 1), padding="SAME", name="path1_conv") + + # Skip path 2. + pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] # First pad with 0's then crop. + path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :] + path2 = tf.nn.avg_pool(path2, [1, 1, 1, 1], stride_spec, "VALID") + path2 = tf.layers.conv2d( + path2, int(output_filters / 2), (1, 1), padding="SAME", name="path2_conv") + + # Concat and apply BN. + final_path = tf.concat(values=[path1, path2], axis=-1) + final_path = tf.layers.batch_normalization( + final_path, training=is_training, name="final_path_bn") + return final_path + + +def shake_shake_branch(x, output_filters, stride, rand_forward, rand_backward, + hparams): + """Building a 2 branching convnet.""" + is_training = hparams.mode == tf_estimator.ModeKeys.TRAIN + x = tf.nn.relu(x) + x = tf.layers.conv2d( + x, + output_filters, (3, 3), + strides=(stride, stride), + padding="SAME", + name="conv1") + x = tf.layers.batch_normalization(x, training=is_training, name="bn1") + x = tf.nn.relu(x) + x = tf.layers.conv2d(x, output_filters, (3, 3), padding="SAME", name="conv2") + x = tf.layers.batch_normalization(x, training=is_training, name="bn2") + if is_training: + x = x * rand_backward + tf.stop_gradient(x * rand_forward - + x * rand_backward) + else: + x *= 1.0 / hparams.shake_shake_num_branches + return x + + +def shake_shake_block(x, output_filters, stride, hparams): + """Builds a full shake-shake sub layer.""" + is_training = hparams.mode == tf_estimator.ModeKeys.TRAIN + batch_size = common_layers.shape_list(x)[0] + + # Generate random numbers for scaling the branches. + rand_forward = [ + tf.random_uniform( + [batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32) + for _ in range(hparams.shake_shake_num_branches) + ] + rand_backward = [ + tf.random_uniform( + [batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32) + for _ in range(hparams.shake_shake_num_branches) + ] + # Normalize so that all sum to 1. + total_forward = tf.add_n(rand_forward) + total_backward = tf.add_n(rand_backward) + rand_forward = [samp / total_forward for samp in rand_forward] + rand_backward = [samp / total_backward for samp in rand_backward] + zipped_rand = zip(rand_forward, rand_backward) + + branches = [] + for branch, (r_forward, r_backward) in enumerate(zipped_rand): + with tf.variable_scope("branch_{}".format(branch)): + b = shake_shake_branch(x, output_filters, stride, r_forward, r_backward, + hparams) + b = tf.nn.dropout(b, 1.0 - hparams.layer_prepostprocess_dropout) + branches.append(b) + res = shake_shake_skip_connection(x, output_filters, stride, is_training) + if hparams.shake_shake_concat: + concat_values = [res] + branches + concat_output = tf.concat(values=concat_values, axis=-1) + concat_output = tf.nn.relu(concat_output) + concat_output = tf.layers.conv2d( + concat_output, output_filters, (1, 1), name="concat_1x1") + concat_output = tf.layers.batch_normalization( + concat_output, training=is_training, name="concat_bn") + return concat_output + else: + return res + tf.add_n(branches) + + +def shake_shake_layer(x, output_filters, num_blocks, stride, hparams): + """Builds many sub layers into one full layer.""" + for block_num in range(num_blocks): + curr_stride = stride if (block_num == 0) else 1 + with tf.variable_scope("layer_{}".format(block_num)): + x = shake_shake_block(x, output_filters, curr_stride, hparams) + return x + + +@registry.register_model +class ShakeShake(t2t_model.T2TModel): + """Implements the Shake-Shake architecture. + + From + This is intended to match the CIFAR-10 version, and correspond to + "Shake-Shake-Batch" in Table 1. + """ + + def body(self, features): + hparams = self._hparams + is_training = hparams.mode == tf_estimator.ModeKeys.TRAIN + inputs = features["inputs"] + assert (hparams.num_hidden_layers - 2) % 6 == 0 + assert hparams.hidden_size % 16 == 0 + k = hparams.hidden_size // 16 + n = (hparams.num_hidden_layers - 2) // 6 + x = inputs + + x = tf.layers.conv2d(x, 16, (3, 3), padding="SAME", name="init_conv") + x = tf.layers.batch_normalization(x, training=is_training, name="init_bn") + with tf.variable_scope("L1"): + x = shake_shake_layer(x, 16 * k, n, 1, hparams) + with tf.variable_scope("L2"): + x = shake_shake_layer(x, 32 * k, n, 2, hparams) + with tf.variable_scope("L3"): + x = shake_shake_layer(x, 64 * k, n, 2, hparams) + x = tf.nn.relu(x) + + # Global avg on [1, 2] (we're nhwc) and dense to num_classes done by top. + return x + + +@registry.register_hparams +def shakeshake_small(): + """Parameters for CIFAR-10. Gets to about 96% accuracy@700K steps, 1 GPU.""" + hparams = common_hparams.basic_params1() + hparams.batch_size = 128 + hparams.hidden_size = 32 + hparams.layer_prepostprocess_dropout = 0.0 + hparams.dropout = 0 + hparams.label_smoothing = 0.0 + hparams.clip_grad_norm = 0.0 # No clipping for now, one can also try 2.0. + hparams.num_hidden_layers = 26 + hparams.learning_rate_decay_scheme = "cosine" + # Model should be run for 700000 steps with batch size 128 (~1800 epochs) + hparams.learning_rate_cosine_cycle_steps = 700000 + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 100 # That's basically unused. + hparams.initializer = "uniform_unit_scaling" + hparams.initializer_gain = 1.0 + hparams.weight_decay = 1e-4 + hparams.optimizer = "Momentum" + hparams.optimizer_momentum_momentum = 0.9 + hparams.add_hparam("shake_shake_num_branches", 2) + hparams.add_hparam("shake_shake_concat", int(False)) + return hparams + + +@registry.register_hparams +def shake_shake_quick(): + hparams = shakeshake_small() + hparams.optimizer = "adam" + hparams.learning_rate_cosine_cycle_steps = 1000 + hparams.learning_rate = 0.5 + hparams.batch_size = 100 + return hparams + + +@registry.register_hparams +def shakeshake_big(): + hparams = shakeshake_small() + hparams.layer_prepostprocess_dropout = 0.0 + hparams.hidden_size = 96 + return hparams + + +@registry.register_hparams +def shakeshake_tpu(): + hparams = shakeshake_big() + hparams.learning_rate_cosine_cycle_steps = 180000 + hparams.learning_rate = 0.6 + return hparams + + +@registry.register_attack_params +def shake_shake_fgsm(): + aparams = hparam.HParams() + aparams.attack = "fgsm" + aparams.attack_epsilons = [(i+1) * 0.1 for i in range(12)] + aparams.add_hparam("clip_min", 0.0) + aparams.add_hparam("clip_max", 255.0) + return aparams diff --git a/trax/models/slicenet.py b/trax/models/slicenet.py new file mode 100644 index 000000000..e20786f31 --- /dev/null +++ b/trax/models/slicenet.py @@ -0,0 +1,372 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SliceNet.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from six.moves import range # pylint: disable=redefined-builtin +from six.moves import zip # pylint: disable=redefined-builtin + +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.layers import modalities +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf + + +# pylint: disable=unused-argument +def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): + """Complete attention layer with preprocessing.""" + separabilities = [hparams.separability, hparams.separability] + if hparams.separability < 0: + separabilities = [hparams.separability - 1, hparams.separability] + targets_timed = common_layers.subseparable_conv_block( + common_layers.add_timing_signal(targets_shifted), + hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], + normalizer_fn=norm_fn, + padding="LEFT", + separabilities=separabilities, + name="targets_time") + if hparams.attention_type == "transformer": + targets_timed = tf.squeeze(targets_timed, 2) + target_shape = tf.shape(targets_timed) + targets_segment = tf.zeros([target_shape[0], target_shape[1]]) + target_attention_bias = common_attention.attention_bias_lower_triangle( + target_shape[1]) + inputs_encoded = common_layers.flatten4d3d(inputs_encoded) + # TODO(jbaccash): use input bias parameter. This code seems to assume fixed + # size inputs. + inputs_attention_bias = tf.zeros([ + tf.shape(inputs_encoded)[0], hparams.num_heads, + tf.shape(targets_segment)[1], + tf.shape(inputs_encoded)[1] + ]) + + qv = common_attention.multihead_attention( + targets_timed, + None, + target_attention_bias, + hparams.hidden_size, + hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + name="self_attention") + qv = common_attention.multihead_attention( + qv, + inputs_encoded, + inputs_attention_bias, + hparams.hidden_size, + hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + name="encdec_attention") + return tf.expand_dims(qv, 2) + else: + raise ValueError("Unsupported attention_type: %s" % hparams.attention_type) + + +def multi_conv_res(x, padding, name, layers, hparams, mask=None, source=None): + """A stack of separable convolution blocks with residual connections.""" + with tf.variable_scope(name): + padding_bias = None + if mask is not None: + padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. + if padding == "LEFT": # Do not mask anything when left-padding. + mask = None + if (hparams.kernel_scheme in _KERNEL_SCHEMES and + hparams.dilation_scheme in _DILATION_SCHEMES): + kernels = _KERNEL_SCHEMES[hparams.kernel_scheme] + dilations = _DILATION_SCHEMES[hparams.dilation_scheme] + dilations_and_kernels = list(zip(dilations, kernels)) + dilations_and_kernels1 = dilations_and_kernels[:2] + dilations_and_kernels2 = dilations_and_kernels[2:] + else: + k = (hparams.kernel_height, hparams.kernel_width) + k2 = (hparams.large_kernel_size, 1) + dilations_and_kernels1 = [((1, 1), k), ((1, 1), k)] + dilations_and_kernels2 = [((1, 1), k2), ((4, 4), k2)] + separabilities1 = [hparams.separability, hparams.separability] + separabilities2 = [hparams.separability] * len(dilations_and_kernels2) + if hparams.separability < 0: + separabilities1 = [hparams.separability - 1, hparams.separability] + separabilities2 = [ + hparams.separability - i + for i in reversed(range(len(dilations_and_kernels2))) + ] + + def norm_fn(x, name): + with tf.variable_scope(name, default_name="norm"): + return common_layers.apply_norm( + x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) + + for layer in range(layers): + with tf.variable_scope("layer_%d" % layer): + y = common_layers.subseparable_conv_block( + x, + hparams.hidden_size, + dilations_and_kernels1, + normalizer_fn=norm_fn, + padding=padding, + mask=mask, + separabilities=separabilities1, + name="residual1") + x += common_layers.subseparable_conv_block( + x + y, + hparams.hidden_size, + dilations_and_kernels2, + normalizer_fn=norm_fn, + padding=padding, + mask=mask, + separabilities=separabilities2, + name="residual2") + y + if source is not None and hparams.attention_type != "none": + x += attention(x, source, norm_fn, hparams, bias=padding_bias) + if mask is not None: + x *= mask + return tf.nn.dropout(x, 1.0 - hparams.dropout) + + +def rank_loss(sentence_emb, image_emb, margin=0.2): + """Experimental rank loss, thanks to kkurach@ for the code.""" + with tf.name_scope("rank_loss"): + # Normalize first as this is assumed in cosine similarity later. + sentence_emb = tf.nn.l2_normalize(sentence_emb, 1) + image_emb = tf.nn.l2_normalize(image_emb, 1) + # Both sentence_emb and image_emb have size [batch, depth]. + scores = tf.matmul(image_emb, tf.transpose(sentence_emb)) # [batch, batch] + diagonal = tf.diag_part(scores) # [batch] + cost_s = tf.maximum(0.0, margin - diagonal + scores) # [batch, batch] + cost_im = tf.maximum( + 0.0, margin - tf.reshape(diagonal, [-1, 1]) + scores) # [batch, batch] + # Clear diagonals. + batch_size = tf.shape(sentence_emb)[0] + empty_diagonal_mat = tf.ones_like(cost_s) - tf.eye(batch_size) + cost_s *= empty_diagonal_mat + cost_im *= empty_diagonal_mat + return tf.reduce_mean(cost_s) + tf.reduce_mean(cost_im) + + +def similarity_cost(inputs_encoded, targets_encoded): + """Loss telling to be more similar to your own targets than to others.""" + # This is a first very simple version: handle variable-length by padding + # to same length and putting everything into batch. In need of a better way. + x, y = common_layers.pad_to_same_length(inputs_encoded, targets_encoded) + depth = tf.shape(inputs_encoded)[3] + x, y = tf.reshape(x, [-1, depth]), tf.reshape(y, [-1, depth]) + return rank_loss(x, y) + + +def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): + """Middle part of slicenet, connecting encoder and decoder.""" + + def norm_fn(x, name): + with tf.variable_scope(name, default_name="norm"): + return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size, + hparams.norm_epsilon) + + # Flatten targets and embed target_space_id. + targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) + target_space_emb = tf.tile(target_space_emb, + [tf.shape(targets_flat)[0], 1, 1, 1]) + + # Use attention from each target to look at input and retrieve. + targets_shifted = common_layers.shift_right( + targets_flat, pad_value=target_space_emb) + if hparams.attention_type == "none": + targets_with_attention = tf.zeros_like(targets_shifted) + else: + inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. + targets_with_attention = attention( + targets_shifted, + inputs_encoded, + norm_fn, + hparams, + bias=inputs_padding_bias) + + # Positional targets: merge attention and raw. + kernel = (hparams.kernel_height, hparams.kernel_width) + targets_merged = common_layers.subseparable_conv_block( + tf.concat([targets_with_attention, targets_shifted], axis=3), + hparams.hidden_size, [((1, 1), kernel)], + normalizer_fn=norm_fn, + padding="LEFT", + separability=4, + name="targets_merge") + + return targets_merged, 0.0 + + +def embed_target_space(target_space_id, hidden_size): + target_space_emb = common_layers.embedding( + target_space_id, 32, hidden_size, name="target_space_embedding") + return tf.reshape(target_space_emb, [1, 1, 1, -1]) + + +def embedding_to_padding(emb): + """Input embeddings -> is_padding.""" + emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1, keep_dims=True) + return tf.to_float(tf.equal(emb_sum, 0.0)) + + +def slicenet_internal(inputs, targets, target_space, hparams, run_decoder=True): + """The slicenet model, main step used for training.""" + with tf.variable_scope("slicenet"): + # Project to hidden size if necessary + if inputs.get_shape().as_list()[-1] != hparams.hidden_size: + inputs = common_layers.conv_block( + inputs, + hparams.hidden_size, [((1, 1), (3, 3))], + first_relu=False, + padding="SAME", + force2d=True) + + # Flatten inputs and encode. + inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) + inputs_mask = 1.0 - embedding_to_padding(inputs) + inputs = common_layers.add_timing_signal(inputs) # Add position info. + target_space_emb = embed_target_space(target_space, hparams.hidden_size) + extra_layers = int(hparams.num_hidden_layers * 1.5) + inputs_encoded = multi_conv_res( + inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask) + if not run_decoder: + return inputs_encoded + # Do the middle part. + decoder_start, similarity_loss = slicenet_middle( + inputs_encoded, targets, target_space_emb, inputs_mask, hparams) + # Decode. + decoder_final = multi_conv_res( + decoder_start, + "LEFT", + "decoder", + hparams.num_hidden_layers, + hparams, + mask=inputs_mask, + source=inputs_encoded) + return decoder_final, tf.reduce_mean(similarity_loss) + + +@registry.register_model +class SliceNet(t2t_model.T2TModel): + + def body(self, features): + target_modality = self._problem_hparams.modality["targets"] + # If we're just predicting a class, there is no use for a decoder. + run_decoder = target_modality != modalities.ModalityType.CLASS_LABEL + return slicenet_internal( + features["inputs"], + features["targets"], + features["target_space_id"], + self._hparams, + run_decoder=run_decoder) + + +_KERNEL_SCHEMES = { + "3.3.3.3": [(3, 1), (3, 1), (3, 1), (3, 1)], + "3.7.7.7": [(3, 1), (7, 1), (7, 1), (7, 1)], + "3.7.15.15": [(3, 1), (7, 1), (15, 1), (15, 1)], + "3.7.15.31": [(3, 1), (7, 1), (15, 1), (31, 1)], + "3.7.15.31.63": [(3, 1), (7, 1), (15, 1), (31, 1), (63, 1)], +} +_DILATION_SCHEMES = { + "1.1.1.1.1": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)], + "1.1.1.1": [(1, 1), (1, 1), (1, 1), (1, 1)], + "1.1.1.2": [(1, 1), (1, 1), (1, 1), (2, 1)], + "1.1.2.4": [(1, 1), (1, 1), (2, 1), (4, 1)], + "1.2.4.8": [(1, 1), (2, 1), (4, 1), (8, 1)], +} + + +@registry.register_hparams("slicenet_1") +def slicenet_params1(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.batch_size = 1024 + hparams.hidden_size = 768 + hparams.dropout = 0.5 + hparams.symbol_dropout = 0.2 + hparams.label_smoothing = 0.1 + hparams.clip_grad_norm = 2.0 + hparams.num_hidden_layers = 4 + hparams.kernel_height = 3 + hparams.kernel_width = 1 + hparams.norm_type = "layer" + hparams.learning_rate_decay_scheme = "exp" + hparams.learning_rate = 0.05 + hparams.learning_rate_warmup_steps = 3000 + hparams.initializer_gain = 1.0 + hparams.weight_decay = 3.0 + hparams.num_sampled_classes = 0 + hparams.sampling_method = "argmax" + hparams.optimizer_adam_epsilon = 1e-6 + hparams.optimizer_adam_beta1 = 0.85 + hparams.optimizer_adam_beta2 = 0.997 + hparams.add_hparam("large_kernel_size", 15) # New ones are added like this. + hparams.add_hparam("separability", -2) + # A dilation scheme, one of _DILATION_SCHEMES. + hparams.add_hparam("dilation_scheme", "1.1.1.1") + # A kernel scheme, one of _KERNEL_SCHEMES; overrides large_kernel_size. + hparams.add_hparam("kernel_scheme", "3.7.15.31") + hparams.add_hparam("audio_compression", 8) + # attention-related flags + hparams.add_hparam("attention_type", "transformer") + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("attention_key_channels", 0) + hparams.add_hparam("attention_value_channels", 0) + hparams.add_hparam("sim_loss_mult", 0.0) # Try 10.0 for experiments. + hparams.add_hparam("attention_dropout", 0.2) + hparams.shared_embedding_and_softmax_weights = True + return hparams + + +@registry.register_hparams("slicenet_1noam") +def slicenet_params1_noam(): + """Version with Noam's decay scheme.""" + hparams = slicenet_params1() + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 1.0 + hparams.learning_rate_warmup_steps = 4000 + hparams.initializer = "uniform_unit_scaling" + hparams.optimizer_adam_epsilon = 1e-9 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + return hparams + + +@registry.register_hparams("slicenet_1tiny") +def slicenet_params1_tiny(): + """Version for fast local runs.""" + hparams = slicenet_params1() + hparams.separability = 0 + hparams.hidden_size = 128 + hparams.num_hidden_layers = 2 + hparams.batch_size = 512 + hparams.learning_rate_warmup_steps = 200 + return hparams + + +@registry.register_ranged_hparams("slicenet1") +def slicenet_range1(ranged_hparams): + """Small range of hyperparameters.""" + rhp = ranged_hparams + rhp.set_float("clip_grad_norm", 1.0, 10.0, scale=rhp.LOG_SCALE) + rhp.set_float("learning_rate", 0.02, 1.0, scale=rhp.LOG_SCALE) + rhp.set_float("optimizer_adam_beta2", 0.995, 0.998) + rhp.set_float("weight_decay", 1.0, 5.0) diff --git a/trax/models/slicenet_test.py b/trax/models/slicenet_test.py new file mode 100644 index 000000000..944a78234 --- /dev/null +++ b/trax/models/slicenet_test.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SliceNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +from tensor2tensor.data_generators import cifar # pylint: disable=unused-import +from tensor2tensor.data_generators import mscoco # pylint: disable=unused-import +from tensor2tensor.layers import modalities # pylint: disable=unused-import +from tensor2tensor.models import slicenet +from tensor2tensor.utils import registry + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class SliceNetTest(tf.test.TestCase): + + def testSliceNet(self): + x = np.random.randint(256, size=(3, 5, 5, 3)) + y = np.random.randint(10, size=(3, 5, 1, 1)) + hparams = slicenet.slicenet_params1_tiny() + hparams.add_hparam("data_dir", "") + problem = registry.problem("image_cifar10") + p_hparams = problem.get_hparams(hparams) + hparams.problem_hparams = p_hparams + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + "target_space_id": tf.constant(1, dtype=tf.int32), + } + model = slicenet.SliceNet(hparams, tf_estimator.ModeKeys.TRAIN, + p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 1, 1, 1, 10)) + + def testSliceNetImageToText(self): + x = np.random.randint(256, size=(3, 5, 5, 3)) + y = np.random.randint(10, size=(3, 5, 1, 1)) + hparams = slicenet.slicenet_params1_tiny() + hparams.add_hparam("data_dir", "") + problem = registry.problem("image_ms_coco_characters") + p_hparams = problem.get_hparams(hparams) + hparams.problem_hparams = p_hparams + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + "target_space_id": tf.constant(1, dtype=tf.int32), + } + model = slicenet.SliceNet(hparams, tf_estimator.ModeKeys.TRAIN, + p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (3, 5, 1, 1, 258)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/text_cnn.py b/trax/models/text_cnn.py new file mode 100644 index 000000000..ee6434d3e --- /dev/null +++ b/trax/models/text_cnn.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TextCNN (see Convolutional Neural Networks for Sentence Classification).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf + + +@registry.register_model +class TextCNN(t2t_model.T2TModel): + """Text CNN.""" + + def body(self, features): + """TextCNN main model_fn. + + Args: + features: Map of features to the model. Should contain the following: + "inputs": Text inputs. + [batch_size, input_length, 1, hidden_dim]. + "targets": Target encoder outputs. + [batch_size, 1, 1, hidden_dim] + Returns: + Final encoder representation. [batch_size, 1, 1, hidden_dim] + """ + hparams = self._hparams + inputs = features["inputs"] + + xshape = common_layers.shape_list(inputs) + + vocab_size = xshape[3] + inputs = tf.reshape(inputs, [xshape[0], xshape[1], xshape[3], xshape[2]]) + + pooled_outputs = [] + for _, filter_size in enumerate(hparams.filter_sizes): + with tf.name_scope("conv-maxpool-%s" % filter_size): + filter_shape = [filter_size, vocab_size, 1, hparams.num_filters] + filter_var = tf.Variable( + tf.truncated_normal(filter_shape, stddev=0.1), name="W") + filter_bias = tf.Variable( + tf.constant(0.1, shape=[hparams.num_filters]), name="b") + conv = tf.nn.conv2d( + inputs, + filter_var, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + conv_outputs = tf.nn.relu( + tf.nn.bias_add(conv, filter_bias), name="relu") + pooled = tf.math.reduce_max( + conv_outputs, axis=1, keepdims=True, name="max") + pooled_outputs.append(pooled) + + num_filters_total = hparams.num_filters * len(hparams.filter_sizes) + h_pool = tf.concat(pooled_outputs, 3) + h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total]) + + # Add dropout + output = tf.nn.dropout(h_pool_flat, 1 - hparams.output_dropout) + output = tf.reshape(output, [-1, 1, 1, num_filters_total]) + + return output + + +@registry.register_hparams +def text_cnn_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.batch_size = 4096 + hparams.max_length = 256 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_schedule = "legacy" + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 4000 + hparams.initializer_gain = 1.0 + hparams.num_hidden_layers = 6 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.num_sampled_classes = 0 + hparams.label_smoothing = 0.1 + hparams.shared_embedding_and_softmax_weights = True + hparams.symbol_modality_num_shards = 16 + + # Add new ones like this. + hparams.add_hparam("filter_sizes", [2, 3, 4, 5]) + hparams.add_hparam("num_filters", 128) + hparams.add_hparam("output_dropout", 0.4) + return hparams diff --git a/trax/models/transformer.py b/trax/models/transformer.py index ed9917baa..2bc8f33d1 100644 --- a/trax/models/transformer.py +++ b/trax/models/transformer.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The Trax Authors. +# Copyright 2023 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,604 +13,2966 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Transformer models: encoder, decoder, language model, and encoder-decoder. +"""Transformer model from "Attention Is All You Need". -The "Transformer" name and network architecture were introduced in the paper -[Attention Is All You Need](https://arxiv.org/abs/1706.03762). -""" - -from trax import layers as tl - - -# Defaults used across Transformer variants. -MODE = 'train' -D_MODEL = 512 -D_FF = 2048 -N_LAYERS = 6 -N_HEADS = 8 -MAX_SEQUENCE_LENGTH = 2048 -DROPOUT_RATE = .1 -DROPOUT_SHARED_AXES = None -FF_ACTIVATION_TYPE = tl.Relu - - -def TransformerEncoder(vocab_size, - n_classes=10, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer encoder suitable for N-way classification. - - This model maps tokenized text to N-way (``n_classes``) activations: +The Transformer model consists of an encoder and a decoder. Both are stacks +of self-attention layers followed by feed-forward layers. This model yields +good results on a number of problems, especially in NLP and machine translation. - - input: Array representing a batch of text strings via token IDs plus - padding markers; shape is (batch_size, sequence_length), where - sequence_length <= ``max_len``. Array elements are integers in - ``range(vocab_size)``, and 0 values mark padding positions. +See "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) for the full +description of the model and the results obtained with its early version. +""" - - output: Array representing a batch of raw (non-normalized) activations - over ``n_classes`` categories; shape is (batch_size, ``n_classes``). +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from six.moves import range # pylint: disable=redefined-builtin + +from tensor2tensor.data_generators import librispeech +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.layers import modalities +from tensor2tensor.layers import transformer_layers +from tensor2tensor.layers import transformer_memory +from tensor2tensor.utils import beam_search +from tensor2tensor.utils import expert_utils +from tensor2tensor.utils import mlperf_log +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops import inplace_ops +from tensorflow.python.util import nest +# pylint: enable=g-direct-tensorflow-import + +# Alias some commonly reused layers, here and elsewhere. +transformer_prepare_encoder = transformer_layers.transformer_prepare_encoder +transformer_encoder = transformer_layers.transformer_encoder +transformer_ffn_layer = transformer_layers.transformer_ffn_layer + + +def transformer_encode(encoder_function, inputs, target_space, hparams, + attention_weights=None, features=None, losses=None, + prepare_encoder_fn=None, **kwargs): + """Encode transformer inputs. Args: - vocab_size: Input vocabulary size -- each element of the input array - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - n_classes: Last/innermost dimension of output arrays, suitable for N-way - classification. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of encoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder blocks. The same rate is also - used for attention dropout in encoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each encoder block will include dropout; else, it - will pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. + encoder_function: the encoder function + inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which + will be flattened along the two spatial dimensions. + target_space: scalar, target space ID. + hparams: hyperparameters for model. + attention_weights: weight to store attention to. + features: optionally pass the entire features dictionary as well. This is + needed now for "packed" datasets. + losses: optional list onto which to append extra training losses + prepare_encoder_fn: optional, alternative to transformer_prepare_encoder. + **kwargs: additional arguments to pass to encoder_function Returns: - A Transformer model that maps strings (conveyed by token IDs) to - raw (non-normalized) activations over a range of output classes. + Tuple of: + encoder_output: Encoder representation. + [batch_size, input_length, hidden_dim] + encoder_decoder_attention_bias: Bias and mask weights for + encoder-decoder attention. [batch_size, input_length] """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _EncBlock(): - return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - tl.Branch([], tl.PaddingMask()), # Creates masks from copy of the tokens. - tl.Embedding(vocab_size, d_model), - _Dropout(), - tl.PositionalEncoding(max_len=max_len), - [_EncBlock() for _ in range(n_layers)], - tl.Select([0], n_in=2), # Drops the masks. - tl.LayerNorm(), - tl.Mean(axis=1), - tl.Dense(n_classes), - ) - - -def TransformerDecoder(vocab_size=None, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer decoder. - - This model maps sequential inputs to sequential outputs: - - - input if ``vocab_size`` is specified: array representing a batch - of text strings via token IDs plus padding markers; shape is - (batch_size, sequence_length). The tensor elements are integers in - ``range(vocab_size)``, and 0 values mark padding positions. - - - input if ``vocab_size`` is ``None``: 3-D array representing a batch of - sequences of activation vectors; shape is (batch_size, sequence_length, - ``d_model``). - - - output: 3-D array with shape (batch_size, sequence_length, ``d_model``). - - The model uses causal attention and does *not* shift the input to the right. - Thus, the output for position `t` is based on inputs up to and including - position `t`. + inputs = common_layers.flatten4d3d(inputs) + + if not prepare_encoder_fn: + prepare_encoder_fn = transformer_prepare_encoder + encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( + prepare_encoder_fn( + inputs, target_space, hparams, features=features)) + + mlperf_log.transformer_print( + key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, + value=hparams.layer_prepostprocess_dropout, + hparams=hparams) + + encoder_input = tf.nn.dropout(encoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + attn_bias_for_padding = None + # Otherwise the encoder will just use encoder_self_attention_bias. + if hparams.unidirectional_encoder: + attn_bias_for_padding = encoder_decoder_attention_bias + + encoder_output = encoder_function( + encoder_input, + self_attention_bias, + hparams, + nonpadding=features_to_nonpadding(features, "inputs"), + save_weights_to=attention_weights, + make_image_summary=not common_layers.is_xla_compiled(), + losses=losses, + attn_bias_for_padding=attn_bias_for_padding, + **kwargs) + + return encoder_output, encoder_decoder_attention_bias + + +def transformer_decode(decoder_function, + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + attention_weights=None, + cache=None, + decode_loop_step=None, + nonpadding=None, + losses=None, + **kwargs): + """Decode Transformer outputs from encoder representation. Args: - vocab_size: If specified, gives the input vocabulary size -- each element - of the input tensor should be an integer in ``range(vocab_size)``. - If ``None``, indicates that the model expects as input sequences of - floating point vectors, each with ``d_model`` components. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of decoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also - used for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each encoder block will include dropout; else, it - will pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. + decoder_function: the decoder function + decoder_input: inputs to bottom of the model. [batch_size, decoder_length, + hidden_dim] + encoder_output: Encoder representation. [batch_size, input_length, + hidden_dim] + encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder + attention. [batch_size, input_length] + decoder_self_attention_bias: Bias and mask weights for decoder + self-attention. [batch_size, decoder_length] + hparams: hyperparameters for model. + attention_weights: weight to store attention to. + cache: dict, containing tensors which are the results of previous + attentions, used for fast decoding. + decode_loop_step: An integer, step number of the decoding loop. Only used + for inference on TPU. + nonpadding: optional Tensor with shape [batch_size, decoder_length] + losses: optional list onto which to append extra training losses + **kwargs: additional arguments to pass to decoder_function Returns: - If ``vocab_size`` is defined: a Transformer model that maps strings - (conveyed by token IDs) to sequences of activation vectors. - - If ``vocab_size`` is ``None``: a Transformer model that maps sequences of - activation vectors to sequences of activation vectors. + Final decoder representation. [batch_size, decoder_length, hidden_dim] """ - def _EmbeddingOrDense(): - return (tl.Embedding(vocab_size, d_model) if vocab_size is not None - else tl.Dense(d_model)) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _DecBlock(): - return _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - _EmbeddingOrDense(), - _Dropout(), - tl.PositionalEncoding(max_len=max_len), - [_DecBlock() for _ in range(n_layers)], - tl.LayerNorm(), - ) - - -def TransformerLM(vocab_size, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer language model. - - This model performs autoregressive language modeling: - - - input: Array representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). Array - elements are integers in ``range(vocab_size)``, and 0 values mark padding - positions. - - - output: 3-D array of raw activations with last/innermost dimension of - ``vocab_size``, suitable for decoding into a batch of token strings; - shape is (batch_size, sequence_length, ``vocab_size``). - - This model uses only the decoder part of the overall Transformer. + mlperf_log.transformer_print( + key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, + value=hparams.layer_prepostprocess_dropout, + hparams=hparams) + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + decoder_output = decoder_function( + decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams, + cache=cache, + decode_loop_step=decode_loop_step, + nonpadding=nonpadding, + save_weights_to=attention_weights, + losses=losses, + **kwargs) + + if (common_layers.is_xla_compiled() and + hparams.mode == tf_estimator.ModeKeys.TRAIN): + # TPU does not react kindly to extra dimensions. + # TODO(noam): remove this once TPU is more forgiving of extra dims. + return decoder_output + else: + # Expand since t2t expects 4d tensors. + return tf.expand_dims(decoder_output, axis=2) + + +@registry.register_model +class Transformer(t2t_model.T2TModel): + """Attention net. See file docstring.""" + + def __init__(self, *args, **kwargs): + super(Transformer, self).__init__(*args, **kwargs) + self.attention_weights = {} # For visualizing attention heads. + self.recurrent_memory_by_layer = None # Override to enable recurrent memory + self._encoder_function = transformer_encoder + self._decoder_function = transformer_decoder + self._init_cache_fn = _init_transformer_cache + self._prepare_encoder_fn = transformer_prepare_encoder + self._prepare_decoder_fn = transformer_prepare_decoder + + def encode(self, inputs, target_space, hparams, features=None, losses=None): + """Encode transformer inputs, see transformer_encode.""" + return transformer_encode( + self._encoder_function, inputs, target_space, hparams, + attention_weights=self.attention_weights, + features=features, losses=losses, + prepare_encoder_fn=self._prepare_encoder_fn) + + def decode(self, + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + cache=None, + decode_loop_step=None, + nonpadding=None, + losses=None, + **kwargs): + """Decode Transformer outputs, see transformer_decode.""" + return transformer_decode( + self._decoder_function, decoder_input, encoder_output, + encoder_decoder_attention_bias, decoder_self_attention_bias, + hparams, attention_weights=self.attention_weights, cache=cache, + decode_loop_step=decode_loop_step, nonpadding=nonpadding, losses=losses, + **kwargs) + + def body(self, features): + """Transformer main model_fn. + + Args: + features: Map of features to the model. Should contain the following: + "inputs": Transformer inputs. [batch_size, input_length, 1, + hidden_dim]. + "targets": Target decoder outputs. [batch_size, decoder_length, 1, + hidden_dim] + "target_space_id": A scalar int from data_generators.problem.SpaceID. + + Returns: + Final decoder representation. [batch_size, decoder_length, hidden_dim] + """ + hparams = self._hparams + + losses = [] + + if self.has_input: + inputs = self._prepare_inputs_for_body(features) + target_space = features["target_space_id"] + encoder_output, encoder_decoder_attention_bias = self.encode( + inputs, target_space, hparams, features=features, losses=losses) + else: + encoder_output, encoder_decoder_attention_bias = (None, None) + + targets = features["targets"] + targets_shape = common_layers.shape_list(targets) + targets = common_layers.flatten4d3d(targets) + decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn( + targets, hparams, features=features) + + # Not all subclasses of Transformer support keyword arguments related to + # recurrent memory, so only pass these arguments if memory is enabled. + decode_kwargs = {} + if self.recurrent_memory_by_layer is not None: + # TODO(kitaev): The chunk_number feature currently has the same shape as + # "targets", but this is only for the purposes of sharing sharding code. + # In fact every token within an example must have the same chunk number. + chunk_number_each_token = tf.squeeze(features["chunk_number"], (-1, -2)) + chunk_number_each_example = chunk_number_each_token[:, 0] + # Uncomment the code below to verify that tokens within a batch share the + # same chunk number: + # with tf.control_dependencies([ + # tf.assert_equal(chunk_number_each_token, + # chunk_number_each_example[:, None]) + # ]): + # chunk_number_each_example = tf.identity(chunk_number_each_example) + decode_kwargs = dict( + recurrent_memory_by_layer=self.recurrent_memory_by_layer, + chunk_number=chunk_number_each_example, + ) + decoder_output = self.decode( + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + nonpadding=features_to_nonpadding(features, "targets"), + losses=losses, + **decode_kwargs + ) + expected_attentions = features.get("expected_attentions") + if expected_attentions is not None: + attention_loss = common_attention.encoder_decoder_attention_loss( + expected_attentions, self.attention_weights, + hparams.expected_attention_loss_type, + hparams.expected_attention_loss_multiplier) + return decoder_output, {"attention_loss": attention_loss} + + ret = tf.reshape(decoder_output, targets_shape) + if losses: + return ret, {"extra_loss": tf.add_n(losses)} + else: + return ret + + def _prepare_inputs_for_body(self, features): + """Prepare inputs for body. + + Args: + features: Map of string to model features. Should contain + "inputs": Transformer inputs. [batch_size, input_length, 1, + hidden_dim]. + + Returns: + Inputs which will be passed to the model. [batch_size, input_length, 1, + hidden_dim] + """ + return features["inputs"] + + def _greedy_infer(self, features, decode_length, use_tpu=False): + """Fast version of greedy decoding. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + use_tpu: A bool. Whether to build the inference graph for TPU. + + Returns: + A dict of decoding results { + "outputs": integer `Tensor` of decoded ids of shape + [batch_size, <= decode_length] if beam_size == 1 or + [batch_size, top_beams, <= decode_length] + "scores": decoding log probs from the beam search, + None if using greedy decoding (beam_size=1) + } + + Raises: + NotImplementedError: If there are multiple data shards. + """ + # For real-valued modalities use the slow decode path for now. + if (self._target_modality_is_real or + self._hparams.self_attention_type != "dot_product"): + return super(Transformer, self)._greedy_infer(features, decode_length) + with tf.variable_scope(self.name): + if use_tpu: + return self._fast_decode_tpu(features, decode_length) + return self._fast_decode(features, decode_length) + + def _beam_decode(self, + features, + decode_length, + beam_size, + top_beams, + alpha, + use_tpu=False): + """Beam search decoding. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + beam_size: number of beams. + top_beams: an integer. How many of the beams to return. + alpha: Float that controls the length penalty. larger the alpha, stronger + the preference for longer translations. + use_tpu: A bool, whether to do beam decode on TPU. + + Returns: + A dict of decoding results { + "outputs": integer `Tensor` of decoded ids of shape + [batch_size, <= decode_length] if beam_size == 1 or + [batch_size, top_beams, <= decode_length] + "scores": decoding log probs from the beam search, + None if using greedy decoding (beam_size=1) + } + """ + if (self._hparams.self_attention_type not in [ + "dot_product", "dot_product_relative" + ]): + # Caching is not guaranteed to work with attention types other than + # dot_product and dot_product_relative. + return self._beam_decode_slow(features, decode_length, beam_size, + top_beams, alpha, use_tpu) + with tf.variable_scope(self.name): + if use_tpu: + return self._fast_decode_tpu(features, decode_length, beam_size, + top_beams, alpha) + return self._fast_decode(features, decode_length, beam_size, top_beams, + alpha) + + def _prepare_inputs_for_decode(self, features): + """Prepare inputs for decoding. + + Args: + features: A map of string to model features. + + Returns: + Inputs after fixing shape and applying modality. + """ + dp = self._data_parallelism + hparams = self._hparams + inputs = features["inputs"] + # TODO(llion): Clean up this reshaping logic. + inputs = tf.expand_dims(inputs, axis=1) + if len(inputs.shape) < 5: + inputs = tf.expand_dims(inputs, axis=4) + s = common_layers.shape_list(inputs) + inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) + # _shard_features called to ensure that the variable names match + inputs = self._shard_features({"inputs": inputs})["inputs"] + input_modality = self._problem_hparams.modality["inputs"] + input_vocab_size = self._problem_hparams.vocab_size["inputs"] + if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"): + input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor + modality_name = hparams.name.get("inputs", + modalities.get_name(input_modality))( + hparams, input_vocab_size) + with tf.variable_scope(modality_name): + bottom = hparams.bottom.get("inputs", + modalities.get_bottom(input_modality)) + inputs = dp(bottom, inputs, hparams, input_vocab_size) + return inputs + + def _fast_decode_tpu(self, + features, + decode_length, + beam_size=1, + top_beams=1, + alpha=1.0): + """Fast decoding. + + Implements both greedy and beam search decoding on TPU, uses beam search + iff beam_size > 1, otherwise beam search related arguments are ignored. + + Args: + features: A map of string to model features. + decode_length: An integer, how many additional timesteps to decode. + beam_size: An integer, number of beams. + top_beams: An integer, how many of the beams to return. + alpha: A float that controls the length penalty. Larger the alpha, + stronger the preference for longer translations. + + Returns: + A dict of decoding results { + "outputs": integer `Tensor` of decoded ids of shape + [batch_size, <= decode_length] if beam_size == 1 or + [batch_size, top_beams, <= decode_length] + "scores": decoding log probs from the beam search, + None if using greedy decoding (beam_size=1) + }. + + Raises: + NotImplementedError: If there are multiple data shards. + """ + if self._num_datashards != 1: + raise NotImplementedError("Fast decoding only supports a single shard.") + if "targets_segmentation" in features: + raise NotImplementedError( + "Decoding not supported on packed datasets " + " If you want to decode from a dataset, use the non-packed version" + " of the dataset when decoding.") + dp = self._data_parallelism + hparams = self._hparams + target_modality = self._problem_hparams.modality["targets"] + target_vocab_size = self._problem_hparams.vocab_size["targets"] + if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"): + target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor + + if self.has_input: + inputs_shape = common_layers.shape_list(features["inputs"]) + if (target_modality == modalities.ModalityType.CLASS_LABEL or + self._problem_hparams.get("regression_targets")): + decode_length = 1 + else: + decode_length = ( + inputs_shape[1] + features.get("decode_length", decode_length)) + batch_size = inputs_shape[0] + inputs = self._prepare_inputs_for_decode(features) + with tf.variable_scope("body"): + encoder_output, encoder_decoder_attention_bias = dp( + self.encode, + inputs, + features["target_space_id"], + hparams, + features=features) + encoder_output = encoder_output[0] + encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] + partial_targets = None + else: + # The problem has no inputs. + encoder_output = None + encoder_decoder_attention_bias = None + + # Prepare partial targets. + # In either features["inputs"] or features["targets"]. + # We force the outputs to begin with these sequences. + partial_targets = features.get("inputs") + if partial_targets is None: + partial_targets = features["targets"] + assert partial_targets is not None + partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) + partial_targets = tf.to_int64(partial_targets) + partial_targets_shape = common_layers.shape_list(partial_targets) + partial_targets_length = partial_targets_shape[1] + decode_length = ( + partial_targets_length + features.get("decode_length", decode_length)) + batch_size = partial_targets_shape[0] + + if hparams.pos == "timing": + positional_encoding = common_attention.get_timing_signal_1d( + decode_length + 1, hparams.hidden_size) + elif hparams.pos == "timing_from_features": + positional_encoding = common_attention.add_timing_signals_from_features( + tf.zeros([1, decode_length + 1, hparams.hidden_size]), features, + hparams.position_features) + elif hparams.pos == "emb": + positional_encoding = common_attention.add_positional_embedding( + tf.zeros([1, decode_length + 1, hparams.hidden_size]), + hparams.max_length, "body/targets_positional_embedding", None) + else: + positional_encoding = None + + def preprocess_targets(targets, i): + """Performs preprocessing steps on the targets to prepare for the decoder. + + This includes: + - Embedding the ids. + - Flattening to 3D tensor. + - Optionally adding timing signals. + + Args: + targets: A tensor, inputs ids to the decoder. [batch_size, 1]. + i: An integer, Step number of the decoding loop. + + Returns: + A tensor, processed targets [batch_size, 1, hidden_dim]. + """ + # _shard_features called to ensure that the variable names match + targets = self._shard_features({"targets": targets})["targets"] + modality_name = hparams.name.get( + "targets", + modalities.get_name(target_modality))(hparams, target_vocab_size) + with tf.variable_scope(modality_name): + bottom = hparams.bottom.get( + "targets", modalities.get_targets_bottom(target_modality)) + targets = dp(bottom, targets, hparams, target_vocab_size)[0] + targets = common_layers.flatten4d3d(targets) + + # GO embeddings are all zero, this is because transformer_prepare_decoder + # Shifts the targets along by one for the input which pads with zeros. + # If the modality already maps GO to the zero embeddings this is not + # needed. + targets = tf.cond( + tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) + + if positional_encoding is not None: + positional_encoding_shape = positional_encoding.shape.as_list() + targets += tf.slice( + positional_encoding, [0, i, 0], + [positional_encoding_shape[0], 1, positional_encoding_shape[2]]) + return targets + + decoder_self_attention_bias = ( + common_attention.attention_bias_lower_triangle(decode_length)) + if hparams.proximity_bias: + decoder_self_attention_bias += common_attention.attention_bias_proximal( + decode_length) + + def symbols_to_logits_tpu_fn(ids, i, cache): + """Go from ids to logits for next symbol on TPU. + + Args: + ids: A tensor, symbol IDs. + i: An integer, step number of the decoding loop. Only used for inference + on TPU. + cache: A dict, containing tensors which are the results of previous + attentions, used for fast decoding. + + Returns: + ret: A tensor, computed logits. + cache: A dict, containing tensors which are the results of previous + attentions, used for fast decoding. + """ + ids = ids[:, -1:] + targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) + targets = preprocess_targets(targets, i) + + bias_shape = decoder_self_attention_bias.shape.as_list() + bias = tf.slice(decoder_self_attention_bias, [0, 0, i, 0], + [bias_shape[0], bias_shape[1], 1, bias_shape[3]]) + + with tf.variable_scope("body"): + body_outputs = dp( + self.decode, + targets, + cache.get("encoder_output"), + cache.get("encoder_decoder_attention_bias"), + bias, + hparams, + cache, + i, + nonpadding=features_to_nonpadding(features, "targets")) + modality_name = hparams.name.get( + "targets", + modalities.get_name(target_modality))(hparams, target_vocab_size) + with tf.variable_scope(modality_name): + top = hparams.top.get("targets", + modalities.get_top(target_modality)) + logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] + + ret = tf.squeeze(logits, axis=[1, 2, 3]) + if partial_targets is not None: + # If the position is within the given partial targets, we alter the + # logits to always return those values. + # A faster approach would be to process the partial targets in one + # iteration in order to fill the corresponding parts of the cache. + # This would require broader changes, though. + vocab_size = tf.shape(ret)[1] + + def forced_logits(): + return tf.one_hot( + tf.tile( + tf.slice(partial_targets, [0, i], + [partial_targets.shape.as_list()[0], 1]), + [beam_size]), vocab_size, 0.0, -1e9) + + ret = tf.cond( + tf.less(i, partial_targets_length), forced_logits, lambda: ret) + return ret, cache + + eos_id = self.get_decode_end_id() or beam_search.EOS_ID + temperature = features.get("sampling_temp", + getattr(hparams, "sampling_temp", 0.0)) + top_k = features.get("sampling_keep_top_k", + getattr(hparams, "sampling_keep_top_k", -1)) + + ret = fast_decode_tpu( + encoder_output=encoder_output, + encoder_decoder_attention_bias=encoder_decoder_attention_bias, + symbols_to_logits_fn=symbols_to_logits_tpu_fn, + hparams=hparams, + decode_length=decode_length, + vocab_size=target_vocab_size, + init_cache_fn=self._init_cache_fn, + beam_size=beam_size, + top_beams=top_beams, + alpha=alpha, + batch_size=batch_size, + force_decode_length=self._decode_hparams.force_decode_length, + eos_id=eos_id, + sampling_temperature=temperature, + top_k=top_k) + if partial_targets is not None: + if beam_size <= 1 or top_beams <= 1: + ret["outputs"] = ret["outputs"][:, partial_targets_length:] + else: + ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] + return ret + + def get_decode_start_id(self): + """Returns the id of the first decoder input symbol. + + The default case maps None to a vector of 0's for transformer. This method + can be overridden to return a different id by a model wanting to use a + different decoder start symbol. The id returned by this method is used to + index the embedding matrix, and retrieve the vector that will be used as the + first input to the decoder + """ + return None + + def get_decode_end_id(self): + """Returns the id of the output symbol that terminates decoding. + + This method can be overridden by a different model. The id returned by this + method is used to check if the generation is complete during decoding. + """ + return None + + def _fast_decode(self, + features, + decode_length, + beam_size=1, + top_beams=1, + alpha=1.0, + preprocess_targets_method=None): + """Fast decoding. + + Implements both greedy and beam search decoding, uses beam search iff + beam_size > 1, otherwise beam search related arguments are ignored. + + Args: + features: a map of string to model features. + decode_length: an integer. How many additional timesteps to decode. + beam_size: number of beams. + top_beams: an integer. How many of the beams to return. + alpha: Float that controls the length penalty. larger the alpha, stronger + the preference for longer translations. + preprocess_targets_method: method used to preprocess targets. If None, + uses method "preprocess_targets" defined inside this method. + + Returns: + A dict of decoding results { + "outputs": integer `Tensor` of decoded ids of shape + [batch_size, <= decode_length] if beam_size == 1 or + [batch_size, top_beams, <= decode_length] + "scores": decoding log probs from the beam search, + None if using greedy decoding (beam_size=1) + } + + Raises: + NotImplementedError: If there are multiple data shards. + """ + if self._num_datashards != 1: + raise NotImplementedError("Fast decoding only supports a single shard.") + dp = self._data_parallelism + hparams = self._hparams + target_modality = self._problem_hparams.modality["targets"] + target_vocab_size = self._problem_hparams.vocab_size["targets"] + if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"): + target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor + if "targets_segmentation" in features: + raise NotImplementedError( + "Decoding not supported on packed datasets " + " If you want to decode from a dataset, use the non-packed version" + " of the dataset when decoding.") + if self.has_input: + inputs_shape = common_layers.shape_list(features["inputs"]) + if (target_modality == modalities.ModalityType.CLASS_LABEL or + self._problem_hparams.get("regression_targets")): + decode_length = 1 + else: + decode_length = ( + inputs_shape[1] + features.get("decode_length", decode_length)) + batch_size = inputs_shape[0] + inputs = self._prepare_inputs_for_decode(features) + with tf.variable_scope("body"): + encoder_output, encoder_decoder_attention_bias = dp( + self.encode, + inputs, + features["target_space_id"], + hparams, + features=features) + encoder_output = encoder_output[0] + encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] + partial_targets = features.get("partial_targets") + else: + # The problem has no inputs. + encoder_output = None + encoder_decoder_attention_bias = None + + # Prepare partial targets. + # In either features["inputs"] or features["targets"]. + # We force the outputs to begin with these sequences. + partial_targets = features.get("inputs") + if partial_targets is None: + partial_targets = features["targets"] + assert partial_targets is not None + + if partial_targets is not None: + partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) + partial_targets = tf.to_int64(partial_targets) + partial_targets_shape = common_layers.shape_list(partial_targets) + partial_targets_length = partial_targets_shape[1] + decode_length = ( + partial_targets_length + features.get("decode_length", decode_length)) + batch_size = partial_targets_shape[0] + + if hparams.pos == "timing": + positional_encoding = common_attention.get_timing_signal_1d( + decode_length + 1, hparams.hidden_size) + elif hparams.pos == "timing_from_features": + positional_encoding = common_attention.add_timing_signals_from_features( + tf.zeros([1, decode_length, hparams.hidden_size]), features, + hparams.position_features) + elif hparams.pos == "emb": + positional_encoding = common_attention.add_positional_embedding( + tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, + "body/targets_positional_embedding", None) + else: + positional_encoding = None + + def preprocess_targets(targets, i): + """Performs preprocessing steps on the targets to prepare for the decoder. + + This includes: + - Embedding the ids. + - Flattening to 3D tensor. + - Optionally adding timing signals. + + Args: + targets: inputs ids to the decoder. [batch_size, 1] + i: scalar, Step number of the decoding loop. + + Returns: + Processed targets [batch_size, 1, hidden_dim] + """ + # _shard_features called to ensure that the variable names match + targets = self._shard_features({"targets": targets})["targets"] + modality_name = hparams.name.get( + "targets", + modalities.get_name(target_modality))(hparams, target_vocab_size) + with tf.variable_scope(modality_name): + bottom = hparams.bottom.get( + "targets", modalities.get_targets_bottom(target_modality)) + targets = dp(bottom, targets, hparams, target_vocab_size)[0] + targets = common_layers.flatten4d3d(targets) + + # GO embeddings are all zero, this is because transformer_prepare_decoder + # Shifts the targets along by one for the input which pads with zeros. + # If the modality already maps GO to the zero embeddings this is not + # needed. + if not self.get_decode_start_id(): + targets = tf.cond( + tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) + + if positional_encoding is not None: + targets += positional_encoding[:, i:i + 1] + return targets + + decoder_self_attention_bias = ( + common_attention.attention_bias_lower_triangle(decode_length)) + if hparams.proximity_bias: + decoder_self_attention_bias += common_attention.attention_bias_proximal( + decode_length) + + # Create tensors for encoder-decoder attention history + att_cache = {"attention_history": {}} + num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers + if encoder_output is not None: + att_batch_size, enc_seq_length = common_layers.shape_list( + encoder_output)[0:2] + for layer in range(num_layers): + att_cache["attention_history"]["layer_%d" % layer] = tf.zeros( + [att_batch_size, hparams.num_heads, 0, enc_seq_length]) + + def update_decoder_attention_history(cache): + """Save attention weights in cache, e.g., for vizualization.""" + for k in [x for x in self.attention_weights + if "decoder" in x and "self" not in x and "logits" not in x]: + idx = k.find("layer_") + if idx < 0: + continue + # Get layer number from the string name. + layer_nbr = k[idx + 6:] + idx = 0 + while idx + 1 < len(layer_nbr) and layer_nbr[:idx + 1].isdigit(): + idx += 1 + layer_nbr = "layer_%d" % int(layer_nbr[:idx]) + if layer_nbr in cache["attention_history"]: + cache["attention_history"][layer_nbr] = tf.concat( + [cache["attention_history"][layer_nbr], + self.attention_weights[k]], + axis=2) + if not preprocess_targets_method: + preprocess_targets_method = preprocess_targets + + def symbols_to_logits_fn(ids, i, cache): + """Go from ids to logits for next symbol.""" + ids = ids[:, -1:] + targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) + targets = preprocess_targets_method(targets, i) + + bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] + with tf.variable_scope("body"): + body_outputs = dp( + self.decode, + targets, + cache.get("encoder_output"), + cache.get("encoder_decoder_attention_bias"), + bias, + hparams, + cache, + nonpadding=features_to_nonpadding(features, "targets")) + + update_decoder_attention_history(cache) + + modality_name = hparams.name.get( + "targets", + modalities.get_name(target_modality))(hparams, target_vocab_size) + with tf.variable_scope(modality_name): + top = hparams.top.get("targets", modalities.get_top(target_modality)) + logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] + + ret = tf.squeeze(logits, axis=[1, 2, 3]) + if partial_targets is not None: + # If the position is within the given partial targets, we alter the + # logits to always return those values. + # A faster approach would be to process the partial targets in one + # iteration in order to fill the corresponding parts of the cache. + # This would require broader changes, though. + vocab_size = tf.shape(ret)[1] + + def forced_logits(): + return tf.one_hot( + tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, + -1e9) + + ret = tf.cond( + tf.less(i, partial_targets_length), forced_logits, lambda: ret) + return ret, cache + + sos_id = self.get_decode_start_id() or 0 + eos_id = self.get_decode_end_id() or beam_search.EOS_ID + temperature = features.get("sampling_temp", + getattr(hparams, "sampling_temp", 0.0)) + top_k = features.get("sampling_keep_top_k", + getattr(hparams, "sampling_keep_top_k", -1)) + + ret = fast_decode( + encoder_output=encoder_output, + encoder_decoder_attention_bias=encoder_decoder_attention_bias, + symbols_to_logits_fn=symbols_to_logits_fn, + hparams=hparams, + decode_length=decode_length, + vocab_size=target_vocab_size, + init_cache_fn=self._init_cache_fn, + beam_size=beam_size, + top_beams=top_beams, + alpha=alpha, + batch_size=batch_size, + force_decode_length=self._decode_hparams.force_decode_length, + sos_id=sos_id, + eos_id=eos_id, + sampling_temperature=temperature, + top_k=top_k, + cache=att_cache) + if partial_targets is not None: + if beam_size <= 1 or top_beams <= 1: + ret["outputs"] = ret["outputs"][:, partial_targets_length:] + else: + ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] + return ret + + +def _init_transformer_cache(cache, hparams, batch_size, attention_init_length, + encoder_output, encoder_decoder_attention_bias, + scope_prefix): + """Create the initial cache for Transformer fast decoding.""" + key_channels = hparams.attention_key_channels or hparams.hidden_size + value_channels = hparams.attention_value_channels or hparams.hidden_size + num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers + vars_3d_num_heads = ( + hparams.num_heads if hparams.get("attention_variables_3d") else 0) + + if cache is None: + cache = {} + cache.update({ + "layer_%d" % layer: { # pylint: disable=g-complex-comprehension + "k": + common_attention.split_heads( + tf.zeros([batch_size, + attention_init_length, + key_channels]), hparams.num_heads), + "v": + common_attention.split_heads( + tf.zeros([batch_size, + attention_init_length, + value_channels]), hparams.num_heads), + } for layer in range(num_layers) + }) + + # If `ffn_layer` is in `["dense_relu_dense" or "conv_hidden_relu"]`, then the + # cache key "f" won't be used, which means that the` shape of cache["f"]` + # won't be changed to + # `[beamsize*batch_size, decode_length, hparams.hidden_size]` and may cause + # error when applying `nest.map reshape function` on it. + if hparams.ffn_layer not in ["dense_relu_dense", "conv_hidden_relu"]: + for layer in range(num_layers): + cache["layer_%d" % layer]["f"] = tf.zeros( + [batch_size, 0, hparams.hidden_size]) + + if encoder_output is not None: + for layer in range(num_layers): + layer_name = "layer_%d" % layer + with tf.variable_scope( + "%sdecoder/%s/encdec_attention/multihead_attention" % + (scope_prefix, layer_name)): + k_encdec = common_attention.compute_attention_component( + encoder_output, + key_channels, + name="k", + vars_3d_num_heads=vars_3d_num_heads) + k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads) + v_encdec = common_attention.compute_attention_component( + encoder_output, + value_channels, + name="v", + vars_3d_num_heads=vars_3d_num_heads) + v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads) + cache[layer_name]["k_encdec"] = k_encdec + cache[layer_name]["v_encdec"] = v_encdec + + cache["encoder_output"] = encoder_output + cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias + return cache + + +def fast_decode_tpu(encoder_output, + encoder_decoder_attention_bias, + symbols_to_logits_fn, + hparams, + decode_length, + vocab_size, + init_cache_fn=_init_transformer_cache, + beam_size=1, + top_beams=1, + alpha=1.0, + sos_id=0, + eos_id=beam_search.EOS_ID, + batch_size=None, + force_decode_length=False, + scope_prefix="body/", + use_top_k_with_unique=True, + sampling_temperature=0.0, + top_k=-1): + """Given encoder output and a symbols to logits function, does fast decoding. + + Implements both greedy and beam search decoding for TPU, uses beam search iff + beam_size > 1, otherwise beam search related arguments are ignored. Args: - vocab_size: Input vocabulary size -- each element of the input array - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of decoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also - used for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'predict'``, use fast inference. If ``'train'``, each decoder - block will include dropout; else, it will pass all values through - unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. + encoder_output: A tensor, output from encoder. + encoder_decoder_attention_bias: A tensor, bias for use in encoder-decoder + attention. + symbols_to_logits_fn: Incremental decoding, function mapping triple `(ids, + step, cache)` to symbol logits. + hparams: Run hyperparameters. + decode_length: An integer, how many additional timesteps to decode. + vocab_size: Output vocabulary size. + init_cache_fn: Function that returns the initial cache dict. + beam_size: An integer, number of beams. + top_beams: An integer, how many of the beams to return. + alpha: A float that controls the length penalty. Larger the alpha, stronger + the preference for longer translations. + sos_id: Start-of-sequence symbol. + eos_id: End-of-sequence symbol. + batch_size: An integer, must be passed if there is no input. + force_decode_length: A bool, whether to force the full decode length, or if + False, stop when all beams hit eos_id. + scope_prefix: str, prefix for decoder layer variable scopes. + use_top_k_with_unique: bool, whether to use a fast (but decreased precision) + top_k during beam search. + sampling_temperature: scalar, temperature with which to sample. + top_k: scalar, sample only top k. Returns: - A Transformer language model that maps strings (represented as token ID - sequences) to sequences of raw (non-normalized) activation vectors; each - vector in the sequence can be mapped (e.g., by `argmax`) to a token ID. + A dict of decoding results { + "outputs": integer `Tensor` of decoded ids of shape + [batch_size, <= decode_length] if top_beams == 1 or + [batch_size, top_beams, <= decode_length] otherwise + "scores": decoding log probs from the beam search, + None if using greedy decoding (beam_size=1) + }. + + Raises: + NotImplementedError: If beam size > 1 with partial targets. """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _DecBlock(): - return _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(vocab_size, d_model), - _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=mode), - [_DecBlock() for _ in range(n_layers)], - tl.LayerNorm(), - tl.Dense(vocab_size), - ) - - -def Transformer(input_vocab_size, - output_vocab_size=None, - d_model=D_MODEL, - d_ff=D_FF, - n_encoder_layers=N_LAYERS, - n_decoder_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a full Transformer model. - - This model is an encoder-decoder that performs tokenized string-to-string - ("source"-to-"target") transduction: - - - inputs (2): - - - source: Array representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length), - where sequence_length <= ``max_len``. Array elements are integers in - ``range(input_vocab_size)``, and 0 values mark padding positions. - - - target: Array representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length), - where sequence_length <= ``max_len``. Array elements are integers in - ``range(output_vocab_size)``, and 0 values mark padding positions. - - - output: 3-D array of raw activations with last/innermost dimension of - ``output_vocab_size``, suitable for decoding into a batch of token - strings; shape is (batch_size, sequence_length, ``vocab_size``). - - An example use would be to translate (tokenized) sentences from English to - German. + if encoder_output is not None: + batch_size = common_layers.shape_list(encoder_output)[0] + + cache = init_cache_fn(None, hparams, batch_size, decode_length, + encoder_output, encoder_decoder_attention_bias, + scope_prefix) + + mlperf_log.transformer_print( + key=mlperf_log.MODEL_HP_SEQ_BEAM_SEARCH, + value={ + "vocab_size": vocab_size, + "batch_size": batch_size, + "beam_size": beam_size, + "alpha": alpha, + "max_decode_length": decode_length + }, + hparams=hparams) + if beam_size > 1: # Beam Search + initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32) + decoded_ids, scores, _ = beam_search.beam_search( + symbols_to_logits_fn, + initial_ids, + beam_size, + decode_length, + vocab_size, + alpha, + states=cache, + eos_id=eos_id, + stop_early=(top_beams == 1), + use_tpu=True, + use_top_k_with_unique=use_top_k_with_unique) + + if top_beams == 1: + decoded_ids = decoded_ids[:, 0, 1:] + scores = scores[:, 0] + else: + decoded_ids = decoded_ids[:, :top_beams, 1:] + scores = scores[:, :top_beams] + else: # Greedy + + def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob): + """One step of greedy decoding.""" + logits, cache = symbols_to_logits_fn(next_id, i, cache) + log_probs = common_layers.log_prob_from_logits(logits) + temperature = sampling_temperature + if hparams.sampling_method == "random_per_example": + next_id = common_layers.sample_temperature_per_example( + logits, temperature, top_k) + else: + if hparams.sampling_method == "argmax": + temperature = 0.0 + next_id = common_layers.sample_with_temperature(logits, temperature, + top_k) + + log_prob_indices = tf.stack([tf.range(tf.to_int64(batch_size)), next_id], + axis=1) + log_prob += tf.gather_nd( + log_probs, log_prob_indices) * (1 - tf.to_float(hit_eos)) + # Note(thangluong): we purposely update hit_eos after aggregating log_prob + # There is a subtle detail here that we want to include log_probs up to + # (and inclusive of) the first eos generated, but not subsequent tokens. + hit_eos |= tf.equal(next_id, eos_id) + + next_id = tf.expand_dims(next_id, axis=1) + decoded_ids = tf.transpose(decoded_ids) + decoded_ids = inplace_ops.alias_inplace_update( + decoded_ids, i, tf.squeeze(next_id, axis=1)) + decoded_ids = tf.transpose(decoded_ids) + return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob + + def is_not_finished(i, hit_eos, *_): + finished = i >= decode_length + if not force_decode_length: + finished |= tf.reduce_all(hit_eos) + return tf.logical_not(finished) + + decoded_ids = tf.zeros([batch_size, decode_length], dtype=tf.int64) + hit_eos = tf.fill([batch_size], False) + next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64) + initial_log_prob = tf.zeros([batch_size], dtype=tf.float32) + + def compute_cache_shape_invariants(tensor): + return tf.TensorShape(tensor.shape.as_list()) + + _, _, _, decoded_ids, _, log_prob = tf.while_loop( + is_not_finished, + inner_loop, [ + tf.constant(0), hit_eos, next_id, decoded_ids, cache, + initial_log_prob + ], + shape_invariants=[ + tf.TensorShape([]), + tf.TensorShape([batch_size]), + tf.TensorShape([batch_size, 1]), + tf.TensorShape([batch_size, decode_length]), + nest.map_structure(compute_cache_shape_invariants, cache), + tf.TensorShape([batch_size]), + ]) + scores = log_prob + + return {"outputs": decoded_ids, "scores": scores} + + +def fast_decode(encoder_output, + encoder_decoder_attention_bias, + symbols_to_logits_fn, + hparams, + decode_length, + vocab_size, + init_cache_fn=_init_transformer_cache, + beam_size=1, + top_beams=1, + alpha=1.0, + sos_id=0, + eos_id=beam_search.EOS_ID, + batch_size=None, + force_decode_length=False, + scope_prefix="body/", + sampling_temperature=0.0, + top_k=-1, + cache=None): + """Given encoder output and a symbols to logits function, does fast decoding. + + Implements both greedy and beam search decoding, uses beam search iff + beam_size > 1, otherwise beam search related arguments are ignored. Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if ``None``, then input and target integers (token IDs) are assumed to - come from the same vocabulary. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_encoder_layers: Number of encoder blocks. - n_decoder_layers: Number of decoder blocks. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder/decoder blocks. The same rate is - also used for attention dropout in encoder/decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'predict'``, use fast inference. If ``'train'``, each - encoder/decoder block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each - encoder/decoder block; must be an activation-type subclass of - :py:class:`Layer`. + encoder_output: Output from encoder. + encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder + attention + symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, + step, cache)` to symbol logits. + hparams: run hyperparameters + decode_length: an integer. How many additional timesteps to decode. + vocab_size: Output vocabulary size. + init_cache_fn: Function that returns the initial cache dict. + beam_size: number of beams. + top_beams: an integer. How many of the beams to return. + alpha: Float that controls the length penalty. larger the alpha, stronger + the preference for longer translations. + sos_id: End-of-sequence symbol in beam search. + eos_id: End-of-sequence symbol in beam search. + batch_size: an integer scalar - must be passed if there is no input + force_decode_length: bool, whether to force the full decode length, or if + False, stop when all beams hit eos_id. + scope_prefix: str, prefix for decoder layer variable scopes. + sampling_temperature: scalar, temperature with which to sample. + top_k: scalar, sample only top k. + cache: cache dictionary for additional predictions. Returns: - A Transformer model as a layer that maps from a source-target tokenized - text pair to activations over a vocab set. + A dict of decoding results { + "outputs": integer `Tensor` of decoded ids of shape + [batch_size, <= decode_length] if top_beams == 1 or + [batch_size, top_beams, <= decode_length] otherwise + "scores": decoding log probs from the beam search, + None if using greedy decoding (beam_size=1) + } + """ + if encoder_output is not None: + batch_size = common_layers.shape_list(encoder_output)[0] + + cache = init_cache_fn( + cache=cache, + hparams=hparams, + batch_size=batch_size, + attention_init_length=0, + encoder_output=encoder_output, + encoder_decoder_attention_bias=encoder_decoder_attention_bias, + scope_prefix=scope_prefix) + + if beam_size > 1: # Beam Search + initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32) + decoded_ids, scores, cache = beam_search.beam_search( + symbols_to_logits_fn, + initial_ids, + beam_size, + decode_length, + vocab_size, + alpha, + states=cache, + eos_id=eos_id, + stop_early=(top_beams == 1)) + + if top_beams == 1: + decoded_ids = decoded_ids[:, 0, 1:] + scores = scores[:, 0] + else: + decoded_ids = decoded_ids[:, :top_beams, 1:] + scores = scores[:, :top_beams] + else: # Greedy + + def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob): + """One step of greedy decoding.""" + logits, cache = symbols_to_logits_fn(next_id, i, cache) + log_probs = common_layers.log_prob_from_logits(logits) + temperature = sampling_temperature + if hparams.sampling_method == "random_per_example": + next_id = common_layers.sample_temperature_per_example( + logits, temperature, top_k) + else: + if hparams.sampling_method == "argmax": + temperature = 0.0 + next_id = common_layers.sample_with_temperature(logits, temperature, + top_k) + + log_prob_indices = tf.stack([tf.range(tf.to_int64(batch_size)), next_id], + axis=1) + log_prob += tf.gather_nd( + log_probs, log_prob_indices) * (1 - tf.to_float(hit_eos)) + # Note(thangluong): we purposely update hit_eos after aggregating log_prob + # There is a subtle detail here that we want to include log_probs up to + # (and inclusive of) the first eos generated, but not subsequent tokens. + hit_eos |= tf.equal(next_id, eos_id) + + next_id = tf.expand_dims(next_id, axis=1) + decoded_ids = tf.concat([decoded_ids, next_id], axis=1) + + return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob + + def is_not_finished(i, hit_eos, *_): + finished = i >= decode_length + if not force_decode_length: + finished |= tf.reduce_all(hit_eos) + return tf.logical_not(finished) + + decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) + hit_eos = tf.fill([batch_size], False) + next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64) + initial_log_prob = tf.zeros([batch_size], dtype=tf.float32) + _, _, _, decoded_ids, cache, log_prob = tf.while_loop( + is_not_finished, + inner_loop, [ + tf.constant(0), hit_eos, next_id, decoded_ids, cache, + initial_log_prob + ], + shape_invariants=[ + tf.TensorShape([]), + tf.TensorShape([None]), + tf.TensorShape([None, None]), + tf.TensorShape([None, None]), + nest.map_structure(beam_search.get_state_shape_invariants, cache), + tf.TensorShape([None]), + ]) + scores = log_prob + + return {"outputs": decoded_ids, "scores": scores, "cache": cache} + + +@registry.register_model +class TransformerScorer(Transformer): + """Transformer model, but only scores in PREDICT mode. + + Checkpoints between Transformer and TransformerScorer are interchangeable. """ - # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise. - encoder_mode = 'eval' if mode == 'predict' else mode - - # Share embedding weights if no separate output vocab size. - in_embedder = tl.Embedding(input_vocab_size, d_model) - if output_vocab_size is None: - out_embedder = in_embedder - output_vocab_size = input_vocab_size - else: - out_embedder = tl.Embedding(output_vocab_size, d_model) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _EncBlock(): - return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - def _Encoder(): - encoder = tl.Serial( - in_embedder, - _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=encoder_mode), - [_EncBlock() for _ in range(n_encoder_layers)], - tl.LayerNorm(), - ) - return tl.Cache(encoder) if mode == 'predict' else encoder - - def _EncDecBlock(): - return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation) - - # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e - # Model output is decoder-side vectors and decoder-side tokens: vec_d tok_d - return tl.Serial( - tl.Select([0, 1, 1]), # Copies decoder tokens for use in loss. - - # Encode. - tl.Branch([], tl.PaddingMask()), # tok_e masks tok_d tok_d - _Encoder(), - - # Decode. - tl.Select([2, 1, 0]), # Re-orders inputs: tok_d masks vec_e ..... - tl.ShiftRight(mode=mode), - out_embedder, - _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=mode), - tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... - [_EncDecBlock() for _ in range(n_decoder_layers)], - tl.LayerNorm(), - tl.Select([0], n_in=3), # Drops masks and encoding vectors. - - # Map vectors to match output vocab size. - tl.Dense(output_vocab_size), - ) - - -def _EncoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers that implements a Transformer encoder block. - - The input to the block is a pair (activations, mask) where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. The block's outputs are the same type/shape as its inputs, - so that multiple blocks can be chained together. - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder blocks. The same rate is also used - for attention dropout in encoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. + def __init__(self, *args, **kwargs): + super(TransformerScorer, self).__init__(*args, **kwargs) + self._name = "transformer" + self._base_name = "transformer" - Returns: - A list of layers that act in series as a (repeatable) encoder block. + def infer(self, + features=None, + decode_length=50, + beam_size=1, + top_beams=1, + alpha=0.0, + use_tpu=False): + """Returns the targets and their log probabilities.""" + del decode_length, beam_size, top_beams, alpha, use_tpu + assert features is not None + + # Run the model + self.hparams.force_full_predict = True + with tf.variable_scope(self.name): + logits, _ = self.model_fn(features) + assert len(logits.shape) == 5 # [batch, time, 1, 1, vocab] + logits = tf.squeeze(logits, [2, 3]) + + # Compute the log probabilities + log_probs = common_layers.log_prob_from_logits(logits) + + targets = features["targets"] + assert len(targets.shape) == 4 # [batch, time, 1, 1] + targets = tf.squeeze(targets, [2, 3]) + + # Slice out the log_probs of the targets + log_probs = common_layers.index_last_dim_with_indices(log_probs, targets) + + # Sum over time to get the log_prob of the sequence + scores = tf.reduce_sum(log_probs, axis=1) + + return {"outputs": targets, "scores": scores} + + +@registry.register_model +class TransformerEncoder(t2t_model.T2TModel): + """Transformer, encoder only.""" + + def body(self, features): + hparams = self._hparams + inputs = features["inputs"] + target_space = features["target_space_id"] + + inputs = common_layers.flatten4d3d(inputs) + + (encoder_input, encoder_self_attention_bias, _) = ( + transformer_prepare_encoder(inputs, target_space, hparams)) + + encoder_input = tf.nn.dropout(encoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + encoder_output = transformer_encoder( + encoder_input, + encoder_self_attention_bias, + hparams, + nonpadding=features_to_nonpadding(features, "inputs")) + encoder_output = tf.expand_dims(encoder_output, 2) + + return encoder_output + + +@registry.register_model +class TransformerRegressor(TransformerEncoder): + """Transformer inheriting from Encoder, for the regression problem. + + Final result is a tensor that has a shape of (?, 1, 1, 1). """ - def _Attention(): - return tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - return [ - tl.Residual( - tl.LayerNorm(), - _Attention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _DecoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers that implements a Transformer decoder block. - - The input to the block is a pair (activations, mask) where the mask encodes - causal connections, preventing attention to future positions in the sequence. - The block's outputs are the same type/shape as its inputs, so that multiple - blocks can be chained together. + + def top(self, body_output, features): + """Computes single scalar value from body_output.""" + + with tf.variable_scope("reg_top_ffn"): + x = body_output + x = tf.reduce_mean(x, axis=[1, 2], keepdims=True) + res = tf.layers.dense(x, 1, name="model_top") + return res + + +def features_to_nonpadding(features, inputs_or_targets="inputs"): + key = inputs_or_targets + "_segmentation" + if features and key in features: + return tf.minimum(tf.to_float(features[key]), 1.0) + return None + + +def transformer_prepare_decoder(targets, hparams, features=None, pad=None): + """Prepare one shard of the model for the decoder. Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also used - for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. + targets: a Tensor. + hparams: run hyperparameters + features: optionally pass the entire features dictionary as well. This is + needed now for "packed" datasets. + pad: vector to use for padding when shifting targets right Returns: - A list of layers that act in series as a (repeatable) decoder block. + decoder_input: a Tensor, bottom of decoder stack + decoder_self_attention_bias: a bias tensor for use in decoder self-attention """ - def _CausalAttention(): - return tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, - mode=mode), - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - return [ - tl.Residual( - tl.LayerNorm(), - _CausalAttention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _EncoderDecoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers implementing a Transformer encoder-decoder block. - - The block input is a triple (decoder_activations, mask, encoder_activations) - where the mask was created from the original input token IDs to prevent - attending to padding positions for that input. + if hparams.causal_decoder_self_attention: + # Causal attention. + if hparams.prepend_mode == "prepend_inputs_full_attention": + decoder_self_attention_bias = ( + common_attention.attention_bias_prepend_inputs_full_attention( + common_attention.embedding_to_padding(targets))) + else: + decoder_self_attention_bias = ( + common_attention.attention_bias_lower_triangle( + common_layers.shape_list(targets)[1])) + else: + # Full attention. + decoder_padding = common_attention.embedding_to_padding(targets) + decoder_self_attention_bias = ( + common_attention.attention_bias_ignore_padding(decoder_padding)) + + if features and "targets_segmentation" in features: + # "Packed" dataset - keep the examples from seeing each other. + targets_segmentation = features["targets_segmentation"] + targets_position = features["targets_position"] + decoder_self_attention_bias += common_attention.attention_bias_same_segment( + targets_segmentation, targets_segmentation) + else: + targets_position = None + if hparams.proximity_bias: + decoder_self_attention_bias += common_attention.attention_bias_proximal( + common_layers.shape_list(targets)[1]) + decoder_input = common_layers.shift_right_3d(targets, pad) + if hparams.pos == "timing": + if targets_position is not None: + decoder_input = common_attention.add_timing_signal_1d_given_position( + decoder_input, targets_position) + else: + decoder_input = common_attention.add_timing_signal_1d(decoder_input) + elif hparams.pos == "timing_from_features": + decoder_input = common_attention.add_timing_signals_from_features( + decoder_input, features, hparams.position_features) + elif hparams.pos == "emb": + decoder_input = common_attention.add_positional_embedding( + decoder_input, hparams.max_length, "targets_positional_embedding", + targets_position) + + if hparams.activation_dtype == "bfloat16": + decoder_self_attention_bias = tf.cast(decoder_self_attention_bias, + tf.bfloat16) + return (decoder_input, decoder_self_attention_bias) + + +def transformer_self_attention_layer(decoder_input, + decoder_self_attention_bias, + layer_idx, + hparams, + encoder_output=None, + encoder_decoder_attention_bias=None, + cache=None, + decode_loop_step=None, + save_weights_to=None, + make_image_summary=False, + layer_collection=None, + recurrent_memory_by_layer=None, + chunk_number=None): + """A single transformer self-attention layer.""" + x = decoder_input + layer = layer_idx + layer_name = "layer_%d" % layer + layer_cache = cache[layer_name] if cache is not None else None + + attention_dropout_broadcast_dims = ( + common_layers.comma_separated_string_to_integer_list( + getattr(hparams, "attention_dropout_broadcast_dims", ""))) + + if recurrent_memory_by_layer is not None: + recurrent_memory = recurrent_memory_by_layer[layer_name] + else: + recurrent_memory = None + + if layer < hparams.get("num_area_layers", 0): + max_area_width = hparams.get("max_area_width", 1) + max_area_height = hparams.get("max_area_height", 1) + memory_height = hparams.get("max_area_height", 1) + else: + max_area_width = 1 + max_area_height = 1 + memory_height = 1 + with tf.variable_scope(layer_name): + with tf.variable_scope("self_attention"): + y = common_attention.multihead_attention( + common_layers.layer_preprocess( + x, hparams, layer_collection=layer_collection), + None, + decoder_self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type=hparams.self_attention_type, + max_relative_position=hparams.max_relative_position, + heads_share_relative_embedding=( + hparams.heads_share_relative_embedding), + add_relative_to_values=hparams.add_relative_to_values, + save_weights_to=save_weights_to, + cache=layer_cache, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length"), + decode_loop_step=decode_loop_step, + vars_3d=hparams.get("attention_variables_3d"), + activation_dtype=hparams.get("activation_dtype", "float32"), + weight_dtype=hparams.get("weight_dtype", "float32"), + layer_collection=layer_collection, + recurrent_memory=recurrent_memory, + chunk_number=chunk_number, + hard_attention_k=hparams.get("hard_attention_k", 0), + gumbel_noise_weight=hparams.get("gumbel_noise_weight", 0.0), + max_area_width=max_area_width, + max_area_height=max_area_height, + memory_height=memory_height, + area_key_mode=hparams.get("area_key_mode", "none"), + area_value_mode=hparams.get("area_value_mode", "none"), + training=(hparams.get( + "mode", + tf_estimator.ModeKeys.TRAIN) == tf_estimator.ModeKeys.TRAIN)) + x = common_layers.layer_postprocess(x, y, hparams) + if encoder_output is not None: + if not isinstance(encoder_output, (list,)): + encoder_output = [encoder_output] + with tf.variable_scope("encdec_attention"): + for enc_output in encoder_output: + y = common_attention.multihead_attention( + common_layers.layer_preprocess( + x, hparams, layer_collection=layer_collection), + enc_output, + encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + max_relative_position=hparams.max_relative_position, + heads_share_relative_embedding=( + hparams.heads_share_relative_embedding), + add_relative_to_values=hparams.add_relative_to_values, + save_weights_to=save_weights_to, + cache=layer_cache, + make_image_summary=make_image_summary, + dropout_broadcast_dims=attention_dropout_broadcast_dims, + max_length=hparams.get("max_length"), + vars_3d=hparams.get("attention_variables_3d"), + activation_dtype=hparams.get("activation_dtype", "float32"), + weight_dtype=hparams.get("weight_dtype", "float32"), + layer_collection=layer_collection, + hard_attention_k=hparams.get("hard_attention_k", 0), + gumbel_noise_weight=hparams.get("gumbel_noise_weight", 0.0), + max_area_width=max_area_width, + max_area_height=max_area_height, + memory_height=memory_height, + area_key_mode=hparams.get("area_key_mode", "none"), + area_value_mode=hparams.get("area_value_mode", "none"), + training=(hparams.get( + "mode", + tf_estimator.ModeKeys.TRAIN) == tf_estimator.ModeKeys.TRAIN)) + x = common_layers.layer_postprocess(x, y, hparams) + return x, layer_cache + + +def transformer_decoder_layer(decoder_input, + decoder_self_attention_bias, + layer_idx, + hparams, + encoder_output=None, + encoder_decoder_attention_bias=None, + cache=None, + decode_loop_step=None, + nonpadding=None, + save_weights_to=None, + make_image_summary=False, + losses=None, + layer_collection=None, + recurrent_memory_by_layer=None, + chunk_number=None): + """A single transformer decoder layer.""" + x, layer_cache = transformer_self_attention_layer( + decoder_input=decoder_input, + decoder_self_attention_bias=decoder_self_attention_bias, + layer_idx=layer_idx, + hparams=hparams, + encoder_output=encoder_output, + encoder_decoder_attention_bias=encoder_decoder_attention_bias, + cache=cache, + decode_loop_step=decode_loop_step, + save_weights_to=save_weights_to, + make_image_summary=make_image_summary, + layer_collection=layer_collection, + recurrent_memory_by_layer=recurrent_memory_by_layer, + chunk_number=chunk_number) + + layer = layer_idx + layer_name = "layer_%d" % layer + with tf.variable_scope(layer_name): + with tf.variable_scope("ffn"): + y = transformer_ffn_layer( + common_layers.layer_preprocess( + x, hparams, layer_collection=layer_collection), + hparams, + conv_padding="LEFT", + nonpadding_mask=nonpadding, + losses=losses, + cache=layer_cache, + decode_loop_step=decode_loop_step, + layer_collection=layer_collection) + x = common_layers.layer_postprocess(x, y, hparams) + return x + + +def transformer_decoder(decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams, + cache=None, + decode_loop_step=None, + name="decoder", + nonpadding=None, + save_weights_to=None, + make_image_summary=True, + losses=None, + layer_collection=None, + recurrent_memory_by_layer=None, + chunk_number=None): + """A stack of transformer layers. Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder/decoder blocks. The same rate is - also used for attention dropout in encoder/decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. + decoder_input: a Tensor + encoder_output: a Tensor + decoder_self_attention_bias: bias Tensor for self-attention (see + common_attention.attention_bias()) + encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention + (see common_attention.attention_bias()) + hparams: hyperparameters for model + cache: dict, containing tensors which are the results of previous + attentions, used for fast decoding. + decode_loop_step: An integer, step number of the decoding loop. Only used + for inference on TPU. + name: a string + nonpadding: optional Tensor with shape [batch_size, encoder_length] + indicating what positions are not padding. This is used to mask out + padding in convolutional layers. We generally only need this mask for + "packed" datasets, because for ordinary datasets, no padding is ever + followed by nonpadding. + save_weights_to: an optional dictionary to capture attention weights for + visualization; the weights tensor will be appended there under a string + key created from the variable scope (including name). + make_image_summary: Whether to make an attention image summary. + losses: optional list onto which to append extra training losses + layer_collection: A tensorflow_kfac.LayerCollection. Only used by the KFAC + optimizer. Default is None. + recurrent_memory_by_layer: Optional dict, mapping layer names to instances + of transformer_memory.RecurrentMemory. Default is None. + chunk_number: an optional integer Tensor with shape [batch] used to operate + the recurrent_memory. Returns: - A list of layers that act in series as a (repeatable) encoder-decoder - block. + y: a Tensors """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _AttentionQKV(): - return tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, - mode=mode, cache_KV_in_predict=True) - - def _CausalAttention(): - return tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - return [ # vec_d masks vec_e - tl.Residual( - tl.LayerNorm(), - _CausalAttention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e - _AttentionQKV(), # vec_d masks vec_e - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _FeedForwardBlock(d_model, - d_ff, - dropout, - dropout_shared_axes, - mode, - activation): - """Returns a list of layers that implements a feedforward block. + x = decoder_input + + mlperf_log.transformer_print( + key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS, + value=hparams.num_decoder_layers or hparams.num_hidden_layers, + hparams=hparams) + mlperf_log.transformer_print( + key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT, + value=hparams.attention_dropout, + hparams=hparams) + mlperf_log.transformer_print( + key=mlperf_log.MODEL_HP_ATTENTION_DENSE, + value={ + "use_bias": "false", + "num_heads": hparams.num_heads, + "hidden_size": hparams.hidden_size + }, + hparams=hparams) + + with tf.variable_scope(name): + for layer_idx in range(hparams.num_decoder_layers or + hparams.num_hidden_layers): + x = transformer_decoder_layer( + x, + decoder_self_attention_bias, + layer_idx, + hparams, + encoder_decoder_attention_bias=encoder_decoder_attention_bias, + encoder_output=encoder_output, + cache=cache, + decode_loop_step=decode_loop_step, + nonpadding=nonpadding, + save_weights_to=save_weights_to, + make_image_summary=make_image_summary, + losses=losses, + layer_collection=layer_collection, + recurrent_memory_by_layer=recurrent_memory_by_layer, + chunk_number=chunk_number + ) + + # if normalization is done in layer_preprocess, then it should also be done + # on the output, since the output can grow very large, being the sum of + # a whole stack of unnormalized layer outputs. + mlperf_log.transformer_print( + key=mlperf_log.MODEL_HP_NORM, + value={"hidden_size": hparams.hidden_size}) + return common_layers.layer_preprocess( + x, hparams, layer_collection=layer_collection) + + +@registry.register_model +class TransformerMemory(Transformer): + """Transformer language model with memory across chunks.""" + + # TODO(kitaev): consider overriding set_mode to swap out recurrent memory when + # switching between training and evaluation. + + def __init__(self, *args, **kwargs): + super(TransformerMemory, self).__init__(*args, **kwargs) + + hparams = self._hparams + self.recurrent_memory_by_layer = {} + for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): + layer_name = "layer_%d" % layer + if hparams.memory_type == "neural_memory": + memory = transformer_memory.TransformerMemory( + batch_size=int(hparams.batch_size / hparams.max_length), + key_depth=hparams.hidden_size, + val_depth=hparams.hidden_size, + memory_size=hparams.split_targets_chunk_length, + sharpen_factor=1., + name=layer_name + "/recurrent_memory") + elif hparams.memory_type == "transformer_xl": + memory = transformer_memory.RecentTokensMemory( + layer_name + "/recurrent_memory", hparams) + else: + raise ValueError("Unsupported memory type: %s" % hparams.memory_type) + self.recurrent_memory_by_layer[layer_name] = memory + + @property + def has_input(self): + if hasattr(self._hparams, "unconditional") and self._hparams.unconditional: + return False + return super(TransformerMemory, self).has_input + + def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha, + use_tpu=False): + """Overriding beam search because for now only the slow version works with + memory + """ + return self._beam_decode_slow(features, decode_length, beam_size, + top_beams, alpha, use_tpu) + + +@registry.register_hparams +def transformer_base_v1(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.norm_type = "layer" + hparams.hidden_size = 512 + hparams.batch_size = 4096 + hparams.max_length = 256 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_schedule = "legacy" + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 4000 + hparams.initializer_gain = 1.0 + hparams.num_hidden_layers = 6 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.num_sampled_classes = 0 + hparams.label_smoothing = 0.1 + hparams.shared_embedding_and_softmax_weights = True + hparams.symbol_modality_num_shards = 16 + + # Add new ones like this. + hparams.add_hparam("filter_size", 2048) + # Layer-related flags. If zero, these fall back on hparams.num_hidden_layers. + hparams.add_hparam("num_encoder_layers", 0) + hparams.add_hparam("num_decoder_layers", 0) + # Attention-related flags. + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("attention_key_channels", 0) + hparams.add_hparam("attention_value_channels", 0) + hparams.add_hparam("ffn_layer", "dense_relu_dense") + hparams.add_hparam("parameter_attention_key_channels", 0) + hparams.add_hparam("parameter_attention_value_channels", 0) + # All hyperparameters ending in "dropout" are automatically set to 0.0 + # when not in training mode. + hparams.add_hparam("attention_dropout", 0.0) + hparams.add_hparam("attention_dropout_broadcast_dims", "") + hparams.add_hparam("relu_dropout", 0.0) + hparams.add_hparam("relu_dropout_broadcast_dims", "") + hparams.add_hparam("pos", "timing") # timing, none + hparams.add_hparam("position_features", "") + hparams.add_hparam("nbr_decoder_problems", 1) + hparams.add_hparam("proximity_bias", False) + hparams.add_hparam("causal_decoder_self_attention", True) + hparams.add_hparam("use_pad_remover", True) + hparams.add_hparam("self_attention_type", "dot_product") + hparams.add_hparam("conv_first_kernel", 3) + hparams.add_hparam("attention_variables_3d", False) + hparams.add_hparam("use_target_space_embedding", True) + # These parameters are only used when ffn_layer=="local_moe_tpu" + hparams.add_hparam("moe_overhead_train", 1.0) + hparams.add_hparam("moe_overhead_eval", 2.0) + hparams.moe_num_experts = 16 + hparams.moe_loss_coef = 1e-3 + # If specified, use this value instead of problem name in metrics.py. + # This is useful for programs that can automatically compare experiments side + # by side based on the same metric names. + hparams.add_hparam("overload_eval_metric_name", "") + # For making a transformer encoder unidirectional by using masked + # attention. + hparams.add_hparam("unidirectional_encoder", False) + # For hard attention. + hparams.add_hparam("hard_attention_k", 0) + hparams.add_hparam("gumbel_noise_weight", 0.0) + return hparams + + +@registry.register_hparams +def transformer_base_v2(): + """Set of hyperparameters.""" + hparams = transformer_base_v1() + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + hparams.layer_prepostprocess_dropout = 0.1 + hparams.attention_dropout = 0.1 + hparams.relu_dropout = 0.1 + hparams.learning_rate_warmup_steps = 8000 + hparams.learning_rate = 0.2 + return hparams + + +@registry.register_hparams +def transformer_base_vq_ada_32ex_packed(): + """Set of hyperparameters for lm1b packed following tpu params.""" + hparams = transformer_base_v2() + expert_utils.update_hparams_for_vq_gating(hparams) + hparams.moe_num_experts = 32 + hparams.gating_type = "vq" + # this gives us a batch size of 16 because each seq is len 256 + hparams.batch_size = 5072 + hparams.ffn_layer = "local_moe" + hparams.shared_embedding_and_softmax_weights = False + hparams.learning_rate_warmup_steps = 10000 + # one epoch for languagemodel_lm1b32k_packed = 27200 steps w/ bsize 128 + hparams.learning_rate_decay_steps = 27200 + hparams.num_heads = 4 + hparams.num_blocks = 1 + hparams.moe_k = 1 + hparams.num_decoder_layers = 6 + hparams.label_smoothing = 0. + hparams.layer_prepostprocess_dropout = 0.1 + hparams.layer_postprocess_sequence = "dan" + hparams.layer_preprocess_sequence = "none" + hparams.weight_decay = 1e-06 + hparams.attention_dropout = 0.1 + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "linear_warmup*rsqrt_decay*linear_decay" + hparams.activation_dtype = "float32" + hparams.learning_rate = 0.1 + hparams.learning_rate_constant = 1.0 + return hparams + + +@registry.register_hparams +def transformer_topk_16_packed(): + hparams = transformer_base_vq_ada_32ex_packed() + hparams.gating_type = "topk" + hparams.moe_num_experts = 16 + hparams.moe_k = 2 + return hparams + + +@registry.register_hparams +def transformer_base_vq1_16_nb1_packed_nda_b01_scales(): + """Set of hyperparameters.""" + hparams = transformer_base_vq_ada_32ex_packed() + hparams.use_scales = int(True) + hparams.moe_num_experts = 16 + hparams.moe_k = 1 + hparams.beta = 0.1 + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + hparams.ema = False + return hparams + + +@registry.register_hparams +def transformer_base_vq1_16_nb1_packed_dan_b01_scales(): + """Set of hyperparameters.""" + hparams = transformer_base_vq_ada_32ex_packed() + hparams.use_scales = int(True) + hparams.moe_num_experts = 16 + hparams.moe_k = 1 + hparams.beta = 0.1 + hparams.ema = False + return hparams + + +@registry.register_hparams +def transformer_base_vq1_16_nb1_packed_nda_b01_scales_dialog(): + """Set of hyperparameters.""" + hparams = transformer_base_vq1_16_nb1_packed_nda_b01_scales() + hparams.batch_size = 2048 + hparams.max_length = 1024 + hparams.filter_size = 3072 + return hparams + + +@registry.register_hparams +def transformer_ada_lmpackedbase(): + """Set of hyperparameters.""" + hparams = transformer_base_vq_ada_32ex_packed() + hparams.ffn_layer = "dense_relu_dense" + return hparams + + +@registry.register_hparams +def transformer_ada_lmpackedbase_dialog(): + """Set of hyperparameters.""" + hparams = transformer_base_vq_ada_32ex_packed() + hparams.max_length = 1024 + hparams.ffn_layer = "dense_relu_dense" + hparams.batch_size = 4096 + return hparams + + +@registry.register_hparams +def transformer_ada_lmpackedbase_relative(): + """Set of hyperparameters.""" + hparams = transformer_base_vq_ada_32ex_packed() + hparams.ffn_layer = "dense_relu_dense" + return hparams + + +@registry.register_hparams +def transformer_base_v3(): + """Base parameters for Transformer model.""" + # Update parameters here, then occasionally cut a versioned set, e.g. + # transformer_base_v2. + hparams = transformer_base_v2() + hparams.optimizer_adam_beta2 = 0.997 + # New way of specifying learning rate schedule. + # Equivalent to previous version. + hparams.learning_rate_schedule = ( + "constant*linear_warmup*rsqrt_decay*rsqrt_hidden_size") + hparams.learning_rate_constant = 2.0 + return hparams + + +@registry.register_hparams +def transformer_base(): + """Base parameters for Transformer model.""" + hparams = transformer_base_v3() + return hparams + + +@registry.register_hparams +def transformer_big(): + """HParams for transformer big model on WMT.""" + hparams = transformer_base() + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + # Reduce batch size to 2048 from 4096 to be able to train the model on a GPU + # with 12 GB memory. For example, NVIDIA TITAN V GPU. + hparams.batch_size = 2048 + hparams.num_heads = 16 + hparams.layer_prepostprocess_dropout = 0.3 + return hparams + + +@registry.register_hparams +def transformer_tall(): + """Hparams for transformer on LM for pretraining/finetuning/mixing.""" + hparams = transformer_base() + hparams.batch_size = 2048 + hparams.hidden_size = 768 + hparams.filter_size = 3072 + hparams.num_hidden_layers = 12 + hparams.num_heads = 12 + hparams.label_smoothing = 0.0 + hparams.max_length = 1024 + hparams.eval_drop_long_sequences = True + hparams.multiproblem_mixing_schedule = "pretrain" + hparams.multiproblem_vocab_size = 65536 + hparams.clip_grad_norm = 1.0 + return hparams + + +@registry.register_hparams +def transformer_tall_finetune_tied(): + """Tied means fine-tune CNN/DM summarization as LM.""" + hparams = transformer_tall() + hparams.multiproblem_max_input_length = 750 + hparams.multiproblem_max_target_length = 100 + hparams.multiproblem_schedule_max_examples = 0 + hparams.learning_rate_schedule = ("linear_warmup*constant*cosdecay") + hparams.learning_rate_constant = 5e-5 + hparams.learning_rate_warmup_steps = 100 + # Set train steps to learning_rate_decay_steps or less + hparams.learning_rate_decay_steps = 80000 + hparams.multiproblem_target_eval_only = True + hparams.multiproblem_reweight_label_loss = True + hparams.multiproblem_label_weight = 1.0 + hparams.optimizer = "true_adam" + return hparams + + +@registry.register_hparams +def transformer_tall_train_tied(): + """Tied means train CNN/DM summarization as LM.""" + hparams = transformer_tall() + hparams.multiproblem_max_input_length = 750 + hparams.multiproblem_max_target_length = 100 + hparams.multiproblem_schedule_max_examples = 0 + hparams.learning_rate_schedule = ("linear_warmup*constant*cosdecay") + hparams.learning_rate_constant = 2e-4 + hparams.learning_rate_warmup_steps = 8000 + # Set train steps to learning_rate_decay_steps or less + hparams.learning_rate_decay_steps = 150000 + hparams.multiproblem_target_eval_only = True + hparams.multiproblem_reweight_label_loss = True + hparams.multiproblem_label_weight = 1.0 + hparams.optimizer = "true_adam" + return hparams + + +@registry.register_hparams +def transformer_tall_finetune_uniencdec(): + """Fine-tune CNN/DM with a unidirectional encoder and decoder.""" + hparams = transformer_tall() + hparams.max_input_seq_length = 750 + hparams.max_target_seq_length = 100 + hparams.optimizer = "true_adam" + hparams.learning_rate_schedule = ("linear_warmup*constant*cosdecay") + hparams.learning_rate_decay_steps = 80000 + hparams.learning_rate_constant = 5e-5 + hparams.learning_rate_warmup_steps = 100 + hparams.unidirectional_encoder = True + return hparams + + +@registry.register_hparams +def transformer_tall_train_uniencdec(): + """Train CNN/DM with a unidirectional encoder and decoder.""" + hparams = transformer_tall() + hparams.max_input_seq_length = 750 + hparams.max_target_seq_length = 100 + hparams.optimizer = "true_adam" + hparams.learning_rate_schedule = ("linear_warmup*constant*cosdecay") + hparams.learning_rate_decay_steps = 150000 + hparams.learning_rate_constant = 2e-4 + hparams.unidirectional_encoder = True + return hparams + + +@registry.register_hparams +def transformer_tall_finetune_textclass(): + """Hparams for transformer on LM for finetuning on text class problems.""" + hparams = transformer_tall() + hparams.learning_rate_constant = 6.25e-5 + hparams.learning_rate_schedule = ("linear_warmup*constant*linear_decay") + hparams.multiproblem_schedule_max_examples = 0 + hparams.multiproblem_target_eval_only = True + hparams.learning_rate_warmup_steps = 50 + # Set train steps to learning_rate_decay_steps or less + hparams.learning_rate_decay_steps = 25000 + hparams.multiproblem_reweight_label_loss = True + hparams.multiproblem_label_weight = 0.95 + return hparams + + +@registry.register_hparams +def transformer_tall_pretrain_lm(): + """Hparams for transformer on LM pretraining (with 64k vocab).""" + hparams = transformer_tall() + hparams.learning_rate_constant = 2e-4 + hparams.learning_rate_schedule = ("linear_warmup*constant*cosdecay") + hparams.optimizer = "adam_w" + hparams.weight_decay = 0.01 * hparams.learning_rate_constant + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.999 + hparams.optimizer_adam_epsilon = 1e-8 + # Set max examples to something big when pretraining only the LM, definitely + # something an order of magnitude bigger than number of train steps. + hparams.multiproblem_schedule_max_examples = 5e8 + # Set train steps to learning_rate_decay_steps or less + hparams.learning_rate_decay_steps = 5000000 + return hparams + + +@registry.register_hparams +def transformer_tall_pretrain_lm_tpu_adafactor(): + """Hparams for transformer on LM pretraining (with 64k vocab) on TPU.""" + hparams = transformer_tall_pretrain_lm() + update_hparams_for_tpu(hparams) + hparams.max_length = 1024 + # For multi-problem on TPU we need it in absolute examples. + hparams.batch_size = 8 + hparams.multiproblem_vocab_size = 2**16 + return hparams + + +@registry.register_hparams +def transformer_tall_pretrain_lm_tpu_adafactor_large(): + """Hparams for transformer on LM pretraining on TPU, large model.""" + hparams = transformer_tall_pretrain_lm_tpu_adafactor() + hparams.hidden_size = 1024 + hparams.num_heads = 16 + hparams.filter_size = 32768 # max fitting in 16G memory is 49152, batch 2 + hparams.batch_size = 4 + hparams.multiproblem_mixing_schedule = "constant" + # Task order: lm/en-de/en-fr/en-ro/de-en/fr-en/ro-en/cnndm/mnli/squad. + hparams.multiproblem_per_task_threshold = "320,80,160,1,80,160,2,20,10,5" + return hparams + + +@registry.register_hparams +def transformer_tall_pretrain_lm_tpu(): + """Hparams for transformer on LM pretraining on TPU with AdamW.""" + hparams = transformer_tall_pretrain_lm_tpu_adafactor() + # Optimizer gets reset in update_hparams_for_tpu so we set it again here. + hparams.learning_rate_constant = 2e-4 + hparams.learning_rate_schedule = ("linear_warmup * constant * cosdecay") + hparams.optimizer = "adam_w" + hparams.weight_decay = 0.01 * hparams.learning_rate_constant + return hparams + + +@registry.register_hparams +def transformer_tall_big(): + """Hparams for transformer on LM+MNLI.""" + hparams = transformer_tall() + hparams.num_hidden_layers = 18 + return hparams + + +@registry.register_hparams +def transformer_big_single_gpu(): + """HParams for transformer big model for single GPU.""" + hparams = transformer_big() + hparams.layer_prepostprocess_dropout = 0.1 + hparams.learning_rate_warmup_steps = 16000 + return hparams + + +@registry.register_hparams +def transformer_base_single_gpu(): + """HParams for transformer base model for single GPU.""" + hparams = transformer_base() + hparams.batch_size = 1024 + hparams.learning_rate_schedule = "constant*linear_warmup*rsqrt_decay" + hparams.learning_rate_constant = 0.1 + hparams.learning_rate_warmup_steps = 16000 + return hparams + + +@registry.register_hparams +def transformer_base_multistep8(): + """HParams for simulating 8 GPUs with MultistepAdam optimizer.""" + hparams = transformer_base() + hparams.optimizer = "multistep_adam" + hparams.optimizer_multistep_accumulate_steps = 8 + return hparams + + +@registry.register_hparams +def transformer_cubbitt(): + """Transformer hyperparameters used in CUBBITT experiments.""" + hparams = transformer_big_single_gpu() + hparams.learning_rate_schedule = "rsqrt_decay" + hparams.batch_size = 2900 + hparams.learning_rate_warmup_steps = 8000 + hparams.max_length = 150 + hparams.layer_prepostprocess_dropout = 0 + hparams.optimizer = "Adafactor" + return hparams + + +@registry.register_hparams +def transformer_parsing_base(): + """HParams for parsing on WSJ only.""" + hparams = transformer_base() + hparams.attention_dropout = 0.2 + hparams.layer_prepostprocess_dropout = 0.2 + hparams.max_length = 512 + hparams.learning_rate_warmup_steps = 16000 + hparams.hidden_size = 1024 + hparams.learning_rate = 0.05 + hparams.shared_embedding_and_softmax_weights = False + return hparams + + +@registry.register_hparams +def transformer_parsing_big(): + """HParams for parsing on WSJ semi-supervised.""" + hparams = transformer_big() + hparams.max_length = 512 + hparams.shared_source_target_embedding = False + hparams.learning_rate_warmup_steps = 4000 + hparams.layer_prepostprocess_dropout = 0.1 + hparams.batch_size = 2048 + hparams.learning_rate = 0.05 + return hparams + + +@registry.register_hparams +def transformer_parsing_ice(): + """HParams for parsing and tagging Icelandic text.""" + hparams = transformer_base_single_gpu() + hparams.batch_size = 4096 + hparams.shared_embedding_and_softmax_weights = False + return hparams + + +@registry.register_hparams +def transformer_tiny(): + hparams = transformer_base() + hparams.num_hidden_layers = 2 + hparams.hidden_size = 128 + hparams.filter_size = 512 + hparams.num_heads = 4 + return hparams + + +@registry.register_hparams +def transformer_test(): + hparams = transformer_base() + hparams.num_hidden_layers = 2 + hparams.hidden_size = 16 + hparams.filter_size = 8 + hparams.num_heads = 2 + return hparams + + +@registry.register_hparams +def transformer_small(): + hparams = transformer_base() + hparams.num_hidden_layers = 2 + hparams.hidden_size = 256 + hparams.filter_size = 1024 + hparams.num_heads = 4 + return hparams + + +@registry.register_hparams +def transformer_l2(): + hparams = transformer_base() + hparams.num_hidden_layers = 2 + return hparams + + +@registry.register_hparams +def transformer_l4(): + hparams = transformer_base() + hparams.num_hidden_layers = 4 + return hparams + + +@registry.register_hparams +def transformer_l8(): + hparams = transformer_base() + hparams.num_hidden_layers = 8 + return hparams + + +@registry.register_hparams +def transformer_l10(): + hparams = transformer_base() + hparams.num_hidden_layers = 10 + return hparams + + +@registry.register_hparams +def transformer_h1(): + hparams = transformer_base() + hparams.num_heads = 1 + return hparams + + +@registry.register_hparams +def transformer_h4(): + hparams = transformer_base() + hparams.num_heads = 4 + return hparams + + +@registry.register_hparams +def transformer_h16(): + hparams = transformer_base() + hparams.num_heads = 16 + return hparams + + +@registry.register_hparams +def transformer_h32(): + hparams = transformer_base() + hparams.num_heads = 32 + return hparams + + +@registry.register_hparams +def transformer_k128(): + hparams = transformer_base() + hparams.attention_key_channels = 128 + return hparams + + +@registry.register_hparams +def transformer_k256(): + hparams = transformer_base() + hparams.attention_key_channels = 256 + return hparams - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. + +@registry.register_hparams +def transformer_ff1024(): + hparams = transformer_base() + hparams.filter_size = 1024 + return hparams + + +@registry.register_hparams +def transformer_ff4096(): + hparams = transformer_base() + hparams.filter_size = 4096 + return hparams + + +@registry.register_hparams +def transformer_dr0(): + hparams = transformer_base() + hparams.layer_prepostprocess_dropout = 0.0 + return hparams + + +@registry.register_hparams +def transformer_dr2(): + hparams = transformer_base() + hparams.layer_prepostprocess_dropout = 0.2 + return hparams + + +@registry.register_hparams +def transformer_ls0(): + hparams = transformer_base() + hparams.label_smoothing = 0.0 + return hparams + + +@registry.register_hparams +def transformer_ls2(): + hparams = transformer_base() + hparams.label_smoothing = 0.2 + return hparams + + +@registry.register_hparams +def transformer_hs256(): + hparams = transformer_base() + hparams.hidden_size = 256 + return hparams + + +@registry.register_hparams +def transformer_hs1024(): + hparams = transformer_base() + hparams.hidden_size = 1024 + return hparams + + +@registry.register_hparams +def transformer_big_dr1(): + hparams = transformer_base() + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + hparams.num_heads = 16 + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def transformer_big_enfr(): + hparams = transformer_big_dr1() + hparams.shared_embedding_and_softmax_weights = False + hparams.filter_size = 8192 + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def transformer_big_enfr_tpu(): + hparams = transformer_big_enfr() + # For performance, use fewer heads so that matrix dimensions are at least 128 + hparams.num_heads = 8 + update_hparams_for_tpu(hparams) + return hparams + + +@registry.register_hparams +def transformer_big_dr2(): + hparams = transformer_big_dr1() + hparams.layer_prepostprocess_dropout = 0.2 + return hparams + + +@registry.register_hparams +def transformer_parameter_attention_a(): + hparams = transformer_base() + hparams.ffn_layer = "parameter_attention" + hparams.filter_size = 1536 + return hparams + + +@registry.register_hparams +def transformer_parameter_attention_b(): + hparams = transformer_base() + hparams.ffn_layer = "parameter_attention" + hparams.filter_size = 512 + hparams.parameter_attention_key_channels = 1024 + hparams.parameter_attention_value_channels = 1024 + hparams.num_heads = 16 + return hparams + + +@registry.register_hparams +def transformer_prepend_v2(): + hparams = transformer_base_v2() + hparams.prepend_mode = "prepend_inputs_masked_attention" + hparams.max_length = 0 + return hparams + + +@registry.register_hparams +def transformer_prepend_v1(): + hparams = transformer_base_v1() + hparams.prepend_mode = "prepend_inputs_masked_attention" + hparams.max_length = 0 + return hparams + + +@registry.register_hparams +def transformer_prepend(): + return transformer_prepend_v2() + + +@registry.register_ranged_hparams +def transformer_base_range(rhp): + """Small range of hyperparameters.""" + # After starting from base, set intervals for some parameters. + rhp.set_float("learning_rate", 0.3, 3.0, scale=rhp.LOG_SCALE) + rhp.set_discrete("learning_rate_warmup_steps", + [1000, 2000, 4000, 8000, 16000]) + rhp.set_float("initializer_gain", 0.5, 2.0) + rhp.set_float("optimizer_adam_beta1", 0.85, 0.95) + rhp.set_float("optimizer_adam_beta2", 0.97, 0.99) + rhp.set_float("weight_decay", 0.0, 1e-4) + + +@registry.register_hparams +def transformer_relative(): + """Use relative position embeddings instead of absolute position encodings.""" + hparams = transformer_base() + hparams.pos = None + hparams.self_attention_type = "dot_product_relative" + hparams.max_relative_position = 20 + return hparams + + +@registry.register_hparams +def transformer_relative_tiny(): + hparams = transformer_relative() + hparams.num_hidden_layers = 2 + hparams.hidden_size = 128 + hparams.filter_size = 512 + hparams.num_heads = 4 + return hparams + + +@registry.register_hparams +def transformer_relative_big(): + hparams = transformer_big() + hparams.pos = None + hparams.self_attention_type = "dot_product_relative" + hparams.max_relative_position = 20 + return hparams + + +@registry.register_hparams +def transformer_timeseries(): + hparams = transformer_small() + hparams.batch_size = 256 + hparams.learning_rate_warmup_steps = 2000 + return hparams + + +@registry.register_hparams +def transformer_mlperf_tpu(): + """HParams for Transformer model on TPU for MLPerf on TPU 2x2.""" + hparams = transformer_base_v3() + hparams.mlperf_mode = True + hparams.symbol_modality_num_shards = 1 + hparams.max_length = 256 # ignored when using "_packed" problems + hparams.batch_size = 2048 # per-chip batch size matches the reference model + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + hparams.num_heads = 16 + hparams.attention_dropout_broadcast_dims = "0,1" # batch, heads + hparams.relu_dropout_broadcast_dims = "1" # length + hparams.layer_prepostprocess_dropout_broadcast_dims = "1" # length + return hparams + + +def update_hparams_for_tpu(hparams): + """Change hparams to be compatible with TPU training.""" + + # Adafactor uses less memory than Adam. + # switch to Adafactor with its recommended learning rate scheme. + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay" + hparams.learning_rate_warmup_steps = 10000 + + # Avoid an expensive concat on TPU. + # >1 shards helps with faster parameter distribution on multi-GPU machines + hparams.symbol_modality_num_shards = 1 + + # Adaptive batch sizes and sequence lengths are not supported on TPU. + # Instead, every batch has the same sequence length and the same batch size. + # Longer sequences are dropped and shorter ones are padded. + # + # It is therefore suggested to use a problem where examples have been combined + # to a longer length, e.g. the "_packed" problems. + # + # For problems with variable sequence lengths, this parameter controls the + # maximum sequence length. Longer sequences are dropped and shorter ones + # are padded. + # + # For problems with fixed sequence lengths - e.g. the "_packed" problems, + # this hyperparameter is ignored. + hparams.max_length = 64 + + # TPUs have less memory than GPUs, so decrease the batch size if it's too high + if hparams.batch_size > 2048: + hparams.batch_size = 2048 + + # Using noise broadcast in the dropout layers saves memory during training. + hparams.attention_dropout_broadcast_dims = "0,1" # batch, heads + hparams.relu_dropout_broadcast_dims = "1" # length + hparams.layer_prepostprocess_dropout_broadcast_dims = "1" # length + return hparams + + +@registry.register_hparams +def transformer_tpu(): + """HParams for Transformer model on TPU.""" + hparams = transformer_base() + update_hparams_for_tpu(hparams) + return hparams + + +@registry.register_hparams +def transformer_timeseries_tpu(): + """HParams for running Transformer model on timeseries on TPU.""" + hparams = transformer_timeseries() + update_hparams_for_tpu(hparams) + hparams.batch_size = 256 # revert to value set in transformer_timeseries + return hparams + + +@registry.register_hparams +def transformer_tpu_bf16_activation(): + """HParams for Transformer model with BF16 activation on TPU.""" + hparams = transformer_tpu() + hparams.activation_dtype = "bfloat16" + return hparams + + +@registry.register_hparams +def transformer_fairseq_fp16_activation_big(): + """Hparams intended to mirror those used in arxiv.org/pdf/1806.00187.pdf.""" + hparams = transformer_big() + hparams.activation_dtype = "float16" + hparams.batch_size = 3584 + return hparams + + +@registry.register_hparams +def transformer_packed_tpu(): + """Deprecated alias for transformer_tpu().""" + return transformer_tpu() + + +@registry.register_hparams +def transformer_big_tpu(): + hparams = transformer_big() + update_hparams_for_tpu(hparams) + return hparams + + +@registry.register_hparams +def transformer_tiny_tpu(): + hparams = transformer_tiny() + update_hparams_for_tpu(hparams) + return hparams + + +@registry.register_ranged_hparams +def transformer_tiny_tpu_range(rhp): + """Small range of hyperparameters.""" + rhp.set_float("learning_rate", 0.3, 3.0, scale=rhp.LOG_SCALE) + rhp.set_float("weight_decay", 0.0, 2.0) + + +@registry.register_ranged_hparams +def transformer_tpu_range(rhp): + """Small range of hyperparameters.""" + # After starting from base, set intervals for some parameters. + rhp.set_float("learning_rate", 0.3, 3.0, scale=rhp.LOG_SCALE) + rhp.set_discrete("learning_rate_warmup_steps", + [1000, 2000, 4000, 8000, 16000]) + rhp.set_float("initializer_gain", 0.5, 2.0) + rhp.set_float("optimizer_adam_beta1", 0.85, 0.95) + rhp.set_float("optimizer_adam_beta2", 0.97, 0.99) + rhp.set_float("weight_decay", 0.0, 2.0) + + +@registry.register_hparams +def transformer_small_tpu(): + """TPU-friendly version of transformer_small. Returns: - A list of layers that maps vectors to vectors. + an hparams object. """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - return [ - tl.Dense(d_ff), - activation(), - _Dropout(), - tl.Dense(d_model), - ] + hparams = transformer_small() + update_hparams_for_tpu(hparams) + return hparams + + +@registry.register_hparams +def transformer_clean(): + """No dropout, label smoothing, max_length.""" + hparams = transformer_base_v2() + hparams.label_smoothing = 0.0 + hparams.layer_prepostprocess_dropout = 0.0 + hparams.attention_dropout = 0.0 + hparams.relu_dropout = 0.0 + hparams.max_length = 0 + return hparams + + +@registry.register_hparams +def transformer_clean_big(): + hparams = transformer_clean() + hparams.hidden_size = 1024 + hparams.filter_size = 4096 + return hparams + + +@registry.register_hparams +def transformer_clean_big_tpu(): + hparams = transformer_clean_big() + update_hparams_for_tpu(hparams) + return hparams + + +@registry.register_hparams +def transformer_tpu_with_conv(): + """Cut down on the number of heads, and use convs instead.""" + hparams = transformer_tpu() + hparams.num_heads = 4 # Heads are expensive on TPUs. + hparams.ffn_layer = "conv_relu_conv" + return hparams + + +@registry.register_hparams +def transformer_lm_tpu_0(): + """HParams for training languagemodel_lm1b8k on tpu. 92M Params.""" + hparams = transformer_clean_big() + update_hparams_for_tpu(hparams) + hparams.num_heads = 4 # Heads are expensive on TPUs. + hparams.batch_size = 4096 + hparams.shared_embedding_and_softmax_weights = False + hparams.layer_prepostprocess_dropout = 0.1 + return hparams + + +@registry.register_hparams +def transformer_lm_tpu_1(): + """HParams for training languagemodel_lm1b8k on tpu. 335M Params.""" + hparams = transformer_lm_tpu_0() + hparams.hidden_size = 2048 + hparams.filter_size = 8192 + return hparams + + +@registry.register_hparams +def transformer_librispeech_v1(): + """HParams for training ASR model on LibriSpeech V1.""" + hparams = transformer_base() + + hparams.num_heads = 4 + hparams.filter_size = 1024 + hparams.hidden_size = 256 + hparams.num_encoder_layers = 5 + hparams.num_decoder_layers = 3 + hparams.learning_rate = 0.15 + hparams.batch_size = 6000000 + + librispeech.set_librispeech_length_hparams(hparams) + return hparams + + +@registry.register_hparams +def transformer_librispeech_v2(): + """HParams for training ASR model on LibriSpeech V2.""" + hparams = transformer_base() + + hparams.max_length = 1240000 + hparams.max_input_seq_length = 1550 + hparams.max_target_seq_length = 350 + hparams.batch_size = 16 + hparams.num_decoder_layers = 4 + hparams.num_encoder_layers = 6 + hparams.hidden_size = 384 + hparams.learning_rate = 0.15 + hparams.daisy_chain_variables = False + hparams.filter_size = 1536 + hparams.num_heads = 2 + hparams.ffn_layer = "conv_relu_conv" + hparams.conv_first_kernel = 9 + hparams.weight_decay = 0 + hparams.layer_prepostprocess_dropout = 0.2 + hparams.relu_dropout = 0.2 + + return hparams + + +@registry.register_hparams +def transformer_librispeech_tpu_v1(): + """HParams for training ASR model on Librispeech on TPU v1.""" + hparams = transformer_librispeech_v1() + update_hparams_for_tpu(hparams) + + hparams.batch_size = 16 + librispeech.set_librispeech_length_hparams(hparams) + return hparams + + +@registry.register_hparams +def transformer_librispeech_tpu_v2(): + """HParams for training ASR model on Librispeech on TPU v2.""" + hparams = transformer_librispeech_v2() + update_hparams_for_tpu(hparams) + + hparams.batch_size = 16 + librispeech.set_librispeech_length_hparams(hparams) + return hparams + + +@registry.register_hparams +def transformer_librispeech_with_area_attention(): + """HParams for training ASR model on Librispeech on TPU v2.""" + hparams = transformer_librispeech_tpu_v2() + hparams.num_area_layers = 3 # area attn on first 3 encoder and decoder layers + hparams.max_area_width = 5 + hparams.area_key_mode = "concat" + hparams.area_value_mode = "sum" + return hparams + + +@registry.register_hparams +def transformer_librispeech(): + """HParams for training ASR model on Librispeech.""" + return transformer_librispeech_v2() + + +@registry.register_hparams +def transformer_librispeech_tpu(): + """HParams for training ASR model on Librispeech on TPU.""" + return transformer_librispeech_tpu_v2() + + +@registry.register_hparams +def transformer_common_voice(): + """HParams for training ASR model on Mozilla Common Voice.""" + return transformer_librispeech() + + +@registry.register_hparams +def transformer_common_voice_tpu(): + """HParams for training ASR model on Mozilla Common Voice on TPU.""" + hparams = transformer_librispeech_tpu() + hparams.batch_size = 8 + return hparams + + +@registry.register_hparams +def transformer_supervised_attention(): + """HParams for supervised attention problems.""" + hparams = transformer_base() + # Attention loss type (KL-divergence or MSE). + hparams.add_hparam("expected_attention_loss_type", "kl_divergence") + # Multiplier to the encoder-decoder expected attention loss. + hparams.add_hparam("expected_attention_loss_multiplier", 1.0) + return hparams + + +@registry.register_hparams +def transformer_tpu_1b(): + """Hparams for machine translation with ~1.1B parameters.""" + hparams = transformer_tpu() + hparams.hidden_size = 2048 + hparams.filter_size = 8192 + hparams.num_hidden_layers = 8 + # smaller batch size to avoid OOM + hparams.batch_size = 1024 + hparams.activation_dtype = "bfloat16" + hparams.weight_dtype = "bfloat16" + # maximize number of parameters relative to computation by not sharing. + hparams.shared_embedding_and_softmax_weights = False + return hparams + + +@registry.register_hparams +def transformer_wikitext103_l4k_v0(): + """HParams for training languagemodel_wikitext103_l4k.""" + hparams = transformer_big() + + # Adafactor uses less memory than Adam. + # switch to Adafactor with its recommended learning rate scheme. + hparams.optimizer = "Adafactor" + hparams.learning_rate_schedule = "rsqrt_decay" + hparams.learning_rate_warmup_steps = 10000 + + hparams.num_heads = 4 + hparams.max_length = 4096 + hparams.batch_size = 4096 + hparams.shared_embedding_and_softmax_weights = False + + hparams.num_hidden_layers = 8 + hparams.attention_dropout = 0.1 + hparams.layer_prepostprocess_dropout = 0.2 + hparams.relu_dropout = 0.1 + hparams.label_smoothing = 0.0 + + # Using noise broadcast in the dropout layers saves memory during training. + hparams.attention_dropout_broadcast_dims = "0,1" # batch, heads + hparams.relu_dropout_broadcast_dims = "1" # length + hparams.layer_prepostprocess_dropout_broadcast_dims = "1" # length + + # Avoid an expensive concat on TPU. + # >1 shards helps with faster parameter distribution on multi-GPU machines + hparams.symbol_modality_num_shards = 1 + + return hparams + + +@registry.register_hparams +def transformer_wikitext103_l4k_memory_v0(): + """HParams for training languagemodel_wikitext103_l4k with memory.""" + hparams = transformer_wikitext103_l4k_v0() + + hparams.split_targets_chunk_length = 64 + hparams.split_targets_max_chunks = 64 + hparams.split_targets_strided_training = True + hparams.add_hparam("memory_type", "transformer_xl") + + # The hparams specify batch size *before* chunking, but we want to have a + # consistent 4K batch size *after* chunking to fully utilize the hardware. + target_tokens_per_batch = 4096 + hparams.batch_size = int(target_tokens_per_batch * ( + hparams.max_length / hparams.split_targets_chunk_length)) # 262144 + + hparams.pos = None + hparams.self_attention_type = "dot_product_relative" + hparams.max_relative_position = 2 * hparams.split_targets_chunk_length + + hparams.add_hparam("unconditional", True) + hparams.add_hparam("recurrent_memory_batch_size", 0) # 0 = try to guess + # By default, cache one chunk only (like Transformer-XL) + hparams.add_hparam("num_memory_items", hparams.split_targets_chunk_length) + + return hparams + + +@registry.register_hparams +def transformer_wikitext103_l16k_memory_v0(): + """HParams for training languagemodel_wikitext103_l16k with memory.""" + hparams = transformer_wikitext103_l4k_memory_v0() + + hparams.max_length = 16384 + hparams.split_targets_chunk_length = 64 + hparams.split_targets_max_chunks = int( + hparams.max_length / hparams.split_targets_chunk_length) + + # The hparams specify batch size *before* chunking, but we want to have a + # consistent 4K batch size *after* chunking to fully utilize the hardware. + target_tokens_per_batch = 4096 + hparams.batch_size = int(target_tokens_per_batch * ( + hparams.max_length / hparams.split_targets_chunk_length)) + + hparams.max_relative_position = 2 * hparams.split_targets_chunk_length + + return hparams + + +@registry.register_hparams +def transformer_cifar10_memory_v0(): + """HParams for training image_cifar10_plain_gen_flat_rev with memory.""" + hparams = transformer_wikitext103_l4k_memory_v0() + + hparams.num_hidden_layers = 6 + + hparams.max_length = 32 * 32 * 3 + hparams.split_targets_chunk_length = 64 * 3 + hparams.split_targets_max_chunks = int( + hparams.max_length / hparams.split_targets_chunk_length) + hparams.num_memory_items = 128 * 3 + + # Since this is an image problem, batch size refers to examples (not tokens) + target_images_per_batch = 4 + hparams.batch_size = int(target_images_per_batch * ( + hparams.max_length / hparams.split_targets_chunk_length)) + + # The recurrent memory needs to know the actual batch size (in sequences) + hparams.recurrent_memory_batch_size = hparams.batch_size + + hparams.max_relative_position = ( + hparams.num_memory_items + hparams.split_targets_chunk_length) + + return hparams + + +@registry.register_hparams +def transformer_imagenet64_memory_v0(): + """HParams for training image_imagenet64_gen_flat_rev with memory.""" + hparams = transformer_cifar10_memory_v0() + + hparams.max_length = 64 * 64 * 3 + hparams.split_targets_chunk_length = 64 * 3 + hparams.split_targets_max_chunks = int( + hparams.max_length / hparams.split_targets_chunk_length) + hparams.num_memory_items = 128 * 3 + + # Since this is an image problem, batch size refers to examples (not tokens) + target_images_per_batch = 2 + hparams.batch_size = int(target_images_per_batch * ( + hparams.max_length / hparams.split_targets_chunk_length)) + + # The recurrent memory needs to know the actual batch size (in sequences) + hparams.recurrent_memory_batch_size = hparams.batch_size + + hparams.max_relative_position = 3072 + + return hparams diff --git a/trax/models/transformer_test.py b/trax/models/transformer_test.py index 017b1d4e0..96cdae359 100644 --- a/trax/models/transformer_test.py +++ b/trax/models/transformer_test.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The Trax Authors. +# Copyright 2023 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,58 +13,418 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Transformer models.""" +"""Tests for Transformer.""" -import functools - -from absl.testing import absltest -from absl.testing import parameterized +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import numpy as np -from trax import fastmath -from trax import shapes -from trax.layers import test_utils -from trax.models import transformer - - -class TransformerTest(parameterized.TestCase): - - def test_transformer_lm_forward_shape(self): - vocab_size = 16 - model = transformer.TransformerLM( - vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) - x = np.ones((3, 5)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 5, vocab_size)) - - def _test_transformer_forward_shape(self, input_vocab_size, - output_vocab_size): - model = transformer.Transformer( - input_vocab_size, output_vocab_size, d_model=32, d_ff=64, - n_encoder_layers=2, n_decoder_layers=2, n_heads=2) - xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - y, _ = model(xs) - - vocab_size = output_vocab_size or input_vocab_size - self.assertEqual(y.shape, (3, 5, vocab_size)) - - @parameterized.named_parameters( - ('same_vocab', 16, None), - ('same_size', 16, 16), - ('different_size', 16, 50)) - def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): - """Run the Transformer forward and check output shape.""" - self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) - - - def test_dot_product_causal_attention_fast_inference(self): - model_fn = functools.partial( - transformer.TransformerLM, d_model=4, d_ff=8, n_layers=2, n_heads=2 - ) - test_utils.test_eval_equals_predict_discrete(model_fn) - - -if __name__ == '__main__': - absltest.main() +from tensor2tensor.data_generators import librispeech +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.models import transformer + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +BATCH_SIZE = 3 +INPUT_LENGTH = 5 +TARGET_LENGTH = 7 +VOCAB_SIZE = 10 + + +def get_model(hparams=None, mode=tf_estimator.ModeKeys.TRAIN, + has_input=True, model_cls=transformer.Transformer): + if hparams is None: + hparams = transformer.transformer_tiny() + hparams.hidden_size = 8 + hparams.filter_size = 32 + hparams.num_heads = 1 + hparams.layer_prepostprocess_dropout = 0.0 + + if hparams.get("problem_hparams", None) is None: + p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, + VOCAB_SIZE, + hparams) + if not has_input: + del p_hparams.modality["inputs"] + hparams.problem_hparams = p_hparams + + inputs = np.random.randint( + VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) + targets = np.random.randint( + VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) + features = { + "targets": tf.constant(targets, dtype=tf.int32, name="targets"), + "target_space_id": tf.constant(1, dtype=tf.int32) + } + if has_input: + features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs") + + return model_cls(hparams, mode, p_hparams), features + + +def small_librispeech_model(param_overrides=None): + hparams = transformer.transformer_small() + hparams.hidden_size = 8 + hparams.filter_size = 32 + hparams.num_heads = 1 + hparams.layer_prepostprocess_dropout = 0.0 + p_hparams = librispeech.Librispeech().get_hparams(hparams) + p_hparams.vocab_size["targets"] = VOCAB_SIZE + hparams.problem_hparams = p_hparams + model = transformer.Transformer(hparams, problem_hparams=p_hparams) + if param_overrides is not None: # Add or Set any provided HParams + assert isinstance(param_overrides, dict) + for param_name in param_overrides: + if hasattr(hparams, param_name): + hparams.set_hparam(param_name, param_overrides[param_name]) + else: + hparams.add_hparam(param_name, param_overrides[param_name]) + inputs = np.random.rand( + BATCH_SIZE, INPUT_LENGTH, 80, 3).astype("float32") # modify for speech + targets = np.random.randint( + VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) + features = { + "inputs": tf.constant(inputs, dtype=tf.float32, name="inputs"), + "targets": tf.constant(targets, dtype=tf.int32, name="targets"), + "target_space_id": tf.constant(1, dtype=tf.int32) + } + return model, features + + +class TransformerTest(tf.test.TestCase): + + def testTransformer(self, get_model_fn=None, p=None): + if get_model_fn: + model, features = get_model_fn(param_overrides=p) + else: + model, features = get_model(transformer.transformer_small()) + logits, _ = model(features) + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE)) + + def testTransformerLibrispeech(self, params=None): + self.testTransformer(get_model_fn=small_librispeech_model, p=params) + + def testLibrispeechSlowVsFast(self, params=None): + self.testSlowVsFast(get_model_fn=small_librispeech_model, p=params) + + def testLibrispeechMultihead(self, params=None): + self.testTransformerLibrispeech({"num_heads": 2}) + + def testLibrispeechWithAreaAttention(self): + self.testTransformerLibrispeech({"max_area_width": 2, + "num_area_layers": 1, + "area_key_mode": "mean", + "area_value_mode": "sum"}) + + def testTransformerRelative(self): + model, features = get_model(transformer.transformer_relative_tiny()) + logits, _ = model(features) + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE)) + + def testSlowVsFast(self, get_model_fn=None, p=None): + if get_model_fn: + model, features = get_model_fn(param_overrides=p) + else: + model, features = get_model(transformer.transformer_small()) + + decode_length = 3 + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(100): + apply_grad.run() + + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + greedy_result = model._slow_greedy_infer( + features, decode_length)["outputs"] + greedy_result = tf.squeeze(greedy_result, axis=[2, 3]) + + fast_result = model._greedy_infer(features, decode_length)["outputs"] + + with self.test_session(): + greedy_res = greedy_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(greedy_res, fast_res) + + def testSlowVsFastNoInput(self): + model, features = get_model( + transformer.transformer_small(), has_input=False) + + decode_length = 3 + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(100): + apply_grad.run() + + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + slow_result = model._slow_greedy_infer( + features, decode_length)["outputs"] + slow_result = tf.squeeze(slow_result, axis=[2, 3]) + + fast_result = model._greedy_infer(features, decode_length)["outputs"] + + with self.test_session(): + slow_res = slow_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(slow_res.shape, (BATCH_SIZE, decode_length)) + self.assertAllClose(slow_res, fast_res) + + def testBeamDecodeWithRelativeAttention(self): + decode_length = 2 + model, features = get_model(transformer.transformer_relative_tiny()) + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + beam_result = model._beam_decode( + features, decode_length, beam_size=4, top_beams=1, + alpha=1.0)["outputs"] + + with self.test_session(): + tf.global_variables_initializer().run() + beam_result.eval() + + # TODO(petershaw): This test is flaky because the decode may hit EOS before + # getting to the expected length. + # self.assertEqual(beam_res.shape, + # (BATCH_SIZE, INPUT_LENGTH + decode_length)) + + def testBeamVsFast(self): + model, features = get_model(transformer.transformer_small()) + + decode_length = 2 + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(100): + apply_grad.run() + + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + beam_result = model._beam_decode_slow( + features, + decode_length, + beam_size=4, + top_beams=1, + alpha=1.0)["outputs"] + + fast_result = model._beam_decode( + features, + decode_length, + beam_size=4, + top_beams=1, + alpha=1.0)["outputs"] + + with self.test_session(): + beam_res = beam_result.eval() + fast_res = fast_result.eval() + + self.assertAllClose(beam_res, fast_res) + + def testTransformerWithoutProblem(self): + hparams = transformer.transformer_test() + + embedded_inputs = np.random.random_sample( + (BATCH_SIZE, INPUT_LENGTH, 1, hparams.hidden_size)) + embedded_targets = np.random.random_sample( + (BATCH_SIZE, TARGET_LENGTH, 1, hparams.hidden_size)) + + transformed_features = { + "inputs": tf.constant(embedded_inputs, dtype=tf.float32), + "targets": tf.constant(embedded_targets, dtype=tf.float32) + } + + model = transformer.Transformer(hparams) + body_out, _ = model(transformed_features) + + self.assertAllEqual( + body_out.get_shape().as_list(), + [BATCH_SIZE, TARGET_LENGTH, 1, hparams.hidden_size]) + + def testTransformerWithEncoderDecoderAttentionLoss(self): + model, features = get_model( + transformer.transformer_supervised_attention()) + expected_attention_weights = np.random.random_sample( + size=(BATCH_SIZE, TARGET_LENGTH, INPUT_LENGTH)) + features["expected_attentions"] = tf.constant( + expected_attention_weights, dtype=tf.float32) + _, extra_loss = model(features) + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + res = session.run(extra_loss["attention_loss"]) + self.assertEqual(res.shape, ()) + + def _create_greedy_infer_model(self): + """Creates model for greedy inference testing. + + Returns: + model: A t2t model. + features: An map of string to tensor. + """ + model, features = get_model(transformer.transformer_small()) + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(100): + apply_grad.run() + + model.set_mode(tf_estimator.ModeKeys.PREDICT) + + return model, features + + def testGreedySlowTPUVsNonTPU(self): + decode_length = 3 + + model, features = self._create_greedy_infer_model() + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + slow_result_non_tpu = model._slow_greedy_infer( + features, decode_length)["outputs"] + slow_result_non_tpu = tf.squeeze(slow_result_non_tpu, axis=[2, 3]) + + slow_result_tpu = model._slow_greedy_infer_tpu( + features, decode_length)["outputs"] + slow_result_tpu = tf.squeeze(slow_result_tpu, axis=[2, 3]) + + with self.test_session(): + slow_non_tpu_res = slow_result_non_tpu.eval() + slow_tpu_res = slow_result_tpu.eval() + + self.assertEqual(slow_tpu_res.shape, + (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(slow_tpu_res, slow_non_tpu_res) + + def testGreedyFastTPUVsNonTPU(self): + decode_length = 3 + + model, features = self._create_greedy_infer_model() + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + fast_result_non_tpu = model._greedy_infer( + features, decode_length, use_tpu=False)["outputs"] + + fast_result_tpu = model._greedy_infer( + features, decode_length, use_tpu=True)["outputs"] + + with self.test_session(): + fast_non_tpu_res = fast_result_non_tpu.eval() + fast_tpu_res = fast_result_tpu.eval() + + self.assertEqual(fast_tpu_res.shape, + (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(fast_tpu_res, fast_non_tpu_res) + + def testGreedyTPUSlowVsFast(self): + decode_length = 3 + + model, features = self._create_greedy_infer_model() + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + slow_result = model._slow_greedy_infer_tpu( + features, decode_length)["outputs"] + slow_result = tf.squeeze(slow_result, axis=[2, 3]) + + fast_result = model._greedy_infer( + features, decode_length, use_tpu=True)["outputs"] + + with self.test_session(): + slow_res = slow_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(fast_res.shape, + (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(fast_res, slow_res) + + +class TransformerScorerTest(tf.test.TestCase): + + def testReturnsScores(self): + model, features = get_model( + mode=tf_estimator.ModeKeys.PREDICT, + model_cls=transformer.TransformerScorer) + infer_out = model.infer(features) + self.assertTrue("outputs" in infer_out) + self.assertTrue("scores" in infer_out) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + infer_out = session.run(infer_out) + self.assertEqual((BATCH_SIZE,), infer_out["scores"].shape) + self.assertEqual((BATCH_SIZE, TARGET_LENGTH), infer_out["outputs"].shape) + + def testVarNames(self): + with tf.Graph().as_default(): + model, features = get_model( + mode=tf_estimator.ModeKeys.PREDICT, + model_cls=transformer.TransformerScorer) + _ = model.infer(features) + scorer_vars = [v.name for v in tf.global_variables()] + + with tf.Graph().as_default(): + model, features = get_model( + mode=tf_estimator.ModeKeys.EVAL, + model_cls=transformer.TransformerScorer) + _ = model(features) + scorer_eval_vars = [v.name for v in tf.global_variables()] + + with tf.Graph().as_default(): + model, features = get_model( + mode=tf_estimator.ModeKeys.EVAL, + model_cls=transformer.Transformer) + _ = model(features) + transformer_vars = [v.name for v in tf.global_variables()] + + self.assertEqual(sorted(scorer_vars), sorted(transformer_vars)) + self.assertEqual(sorted(scorer_eval_vars), sorted(transformer_vars)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/models/vanilla_gan.py b/trax/models/vanilla_gan.py new file mode 100644 index 000000000..a79a7575f --- /dev/null +++ b/trax/models/vanilla_gan.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple Generative Adversarial Model with two linear layers. + +Example of how to create a GAN in T2T. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +def lrelu(input_, leak=0.2, name="lrelu"): + return tf.maximum(input_, leak * input_, name=name) + + +def deconv2d( + input_, output_shape, k_h, k_w, d_h, d_w, stddev=0.02, name="deconv2d"): + """Deconvolution layer.""" + with tf.variable_scope(name): + w = tf.get_variable( + "w", [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], + initializer=tf.random_normal_initializer(stddev=stddev)) + deconv = tf.nn.conv2d_transpose( + input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) + biases = tf.get_variable( + "biases", [output_shape[-1]], initializer=tf.constant_initializer(0.0)) + return tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) + + +def reverse_gradient(x): + return -x + tf.stop_gradient(2 * x) + + +class AbstractGAN(t2t_model.T2TModel): + """Base class for all GANs.""" + + def discriminator(self, x, is_training, reuse=False): + """Discriminator architecture based on InfoGAN. + + Args: + x: input images, shape [bs, h, w, channels] + is_training: boolean, are we in train or eval model. + reuse: boolean, should params be re-used. + + Returns: + out_logit: the output logits (before sigmoid). + """ + hparams = self.hparams + with tf.variable_scope( + "discriminator", reuse=reuse, + initializer=tf.random_normal_initializer(stddev=0.02)): + batch_size, height, width = common_layers.shape_list(x)[:3] + # Mapping x from [bs, h, w, c] to [bs, 1] + net = tf.layers.conv2d(x, 64, (4, 4), strides=(2, 2), + padding="SAME", name="d_conv1") + # [bs, h/2, w/2, 64] + net = lrelu(net) + net = tf.layers.conv2d(net, 128, (4, 4), strides=(2, 2), + padding="SAME", name="d_conv2") + # [bs, h/4, w/4, 128] + if hparams.discriminator_batchnorm: + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="d_bn2") + net = lrelu(net) + size = height * width + net = tf.reshape(net, [batch_size, size * 8]) # [bs, h * w * 8] + net = tf.layers.dense(net, 1024, name="d_fc3") # [bs, 1024] + if hparams.discriminator_batchnorm: + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="d_bn3") + net = lrelu(net) + return net + + def generator(self, z, is_training, out_shape): + """Generator outputting image in [0, 1].""" + hparams = self.hparams + height, width, c_dim = out_shape + batch_size = hparams.batch_size + with tf.variable_scope( + "generator", + initializer=tf.random_normal_initializer(stddev=0.02)): + net = tf.layers.dense(z, 1024, name="g_fc1") + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="g_bn1") + net = lrelu(net) + net = tf.layers.dense(net, 128 * (height // 4) * (width // 4), + name="g_fc2") + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="g_bn2") + net = lrelu(net) + net = tf.reshape(net, [batch_size, height // 4, width // 4, 128]) + net = deconv2d(net, [batch_size, height // 2, width // 2, 64], + 4, 4, 2, 2, name="g_dc3") + net = tf.layers.batch_normalization(net, training=is_training, + momentum=0.999, name="g_bn3") + net = lrelu(net) + net = deconv2d(net, [batch_size, height, width, c_dim], + 4, 4, 2, 2, name="g_dc4") + out = tf.nn.sigmoid(net) + return common_layers.convert_real_to_rgb(out) + + def losses(self, inputs, generated): + """Return the losses dictionary.""" + raise NotImplementedError + + def body(self, features): + """Body of the model. + + Args: + features: a dictionary with the tensors. + + Returns: + A pair (predictions, losses) where predictions is the generated image + and losses is a dictionary of losses (that get added for the final loss). + """ + features["targets"] = features["inputs"] + is_training = self.hparams.mode == tf_estimator.ModeKeys.TRAIN + + # Input images. + inputs = tf.to_float(features["targets_raw"]) + + # Noise vector. + z = tf.random_uniform([self.hparams.batch_size, + self.hparams.bottleneck_bits], + minval=-1, maxval=1, name="z") + + # Generator output: fake images. + out_shape = common_layers.shape_list(inputs)[1:4] + g = self.generator(z, is_training, out_shape) + + losses = self.losses(inputs, g) # pylint: disable=not-callable + + summary_g_image = tf.reshape( + g[0, :], [1] + common_layers.shape_list(inputs)[1:]) + tf.summary.image("generated", summary_g_image, max_outputs=1) + + if is_training: # Returns an dummy output and the losses dictionary. + return tf.zeros_like(inputs), losses + return tf.reshape(g, tf.shape(inputs)), losses + + def top(self, body_output, features): + """Override the top function to not do anything.""" + return body_output + + +@registry.register_model +class SlicedGan(AbstractGAN): + """Sliced GAN for demonstration.""" + + def losses(self, inputs, generated): + """Losses in the sliced case.""" + is_training = self.hparams.mode == tf_estimator.ModeKeys.TRAIN + def discriminate(x): + return self.discriminator(x, is_training=is_training, reuse=False) + generator_loss = common_layers.sliced_gan_loss( + inputs, reverse_gradient(generated), discriminate, + self.hparams.num_sliced_vecs) + return {"training": - generator_loss} + + def infer(self, *args, **kwargs): # pylint: disable=arguments-differ + del args, kwargs + + try: + num_channels = self.hparams.problem.num_channels + except AttributeError: + num_channels = 1 + + with tf.variable_scope("body/vanilla_gan", reuse=tf.AUTO_REUSE): + hparams = self.hparams + z = tf.random_uniform([hparams.batch_size, hparams.bottleneck_bits], + minval=-1, maxval=1, name="z") + out_shape = (hparams.sample_height, hparams.sample_width, num_channels) + g_sample = self.generator(z, False, out_shape) + return g_sample + + +@registry.register_hparams +def sliced_gan(): + """Basic parameters for a vanilla_gan.""" + hparams = common_hparams.basic_params1() + hparams.optimizer = "adam" + hparams.learning_rate_constant = 0.0002 + hparams.learning_rate_warmup_steps = 500 + hparams.learning_rate_schedule = "constant * linear_warmup" + hparams.label_smoothing = 0.0 + hparams.batch_size = 128 + hparams.hidden_size = 128 + hparams.initializer = "uniform_unit_scaling" + hparams.initializer_gain = 1.0 + hparams.weight_decay = 1e-6 + hparams.kernel_height = 4 + hparams.kernel_width = 4 + hparams.bottleneck_bits = 128 + hparams.add_hparam("discriminator_batchnorm", True) + hparams.add_hparam("num_sliced_vecs", 4096) + return hparams diff --git a/trax/models/video/init.py b/trax/models/video/init.py new file mode 100644 index 000000000..ff174dd63 --- /dev/null +++ b/trax/models/video/init.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/trax/models/xception.py b/trax/models/xception.py new file mode 100644 index 000000000..83b6697c6 --- /dev/null +++ b/trax/models/xception.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Xception.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +from six.moves import range # pylint: disable=redefined-builtin + +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow.compat.v1 as tf + + +def residual_block(x, hparams): + """A stack of convolution blocks with residual connection.""" + k = (hparams.kernel_height, hparams.kernel_width) + dilations_and_kernels = [((1, 1), k) for _ in range(3)] + y = common_layers.subseparable_conv_block( + x, + hparams.hidden_size, + dilations_and_kernels, + padding="SAME", + separability=0, + name="residual_block") + x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm") + return tf.nn.dropout(x, 1.0 - hparams.dropout) + + +def xception_internal(inputs, hparams): + """Xception body.""" + with tf.variable_scope("xception"): + cur = inputs + + if cur.get_shape().as_list()[1] > 200: + # Large image, Xception entry flow + cur = xception_entry(cur, hparams.hidden_size) + else: + # Small image, conv + cur = common_layers.conv_block( + cur, + hparams.hidden_size, [((1, 1), (3, 3))], + first_relu=False, + padding="SAME", + force2d=True, + name="small_image_conv") + + for i in range(hparams.num_hidden_layers): + with tf.variable_scope("layer_%d" % i): + cur = residual_block(cur, hparams) + + return xception_exit(cur) + + +def xception_entry(inputs, hidden_dim): + """Xception entry flow.""" + with tf.variable_scope("xception_entry"): + + def xnet_resblock(x, filters, res_relu, name): + """Resblock.""" + with tf.variable_scope(name): + y = common_layers.separable_conv_block( + x, + filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))], + first_relu=True, + padding="SAME", + force2d=True, + name="sep_conv_block") + y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 2)) + return y + common_layers.conv_block( + x, + filters, [((1, 1), (1, 1))], + padding="SAME", + strides=(2, 2), + first_relu=res_relu, + force2d=True, + name="res_conv0") + + tf.summary.image("inputs", inputs, max_outputs=2) + x = common_layers.conv_block( + inputs, + 32, [((1, 1), (3, 3))], + first_relu=False, + padding="SAME", + strides=(2, 2), + force2d=True, + name="conv0") + x = common_layers.conv_block( + x, 64, [((1, 1), (3, 3))], padding="SAME", force2d=True, name="conv1") + x = xnet_resblock(x, min(128, hidden_dim), True, "block0") + x = xnet_resblock(x, min(256, hidden_dim), False, "block1") + return xnet_resblock(x, hidden_dim, False, "block2") + + +def xception_exit(inputs): + """Xception exit flow.""" + with tf.variable_scope("xception_exit"): + x = inputs + x_shape = x.get_shape().as_list() + if x_shape[1] is None or x_shape[2] is None: + length_float = tf.to_float(tf.shape(x)[1]) + length_float *= tf.to_float(tf.shape(x)[2]) + spatial_dim_float = tf.sqrt(length_float) + spatial_dim = tf.to_int32(spatial_dim_float) + x_depth = x_shape[3] + x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth]) + elif x_shape[1] != x_shape[2]: + spatial_dim = int(math.sqrt(float(x_shape[1] * x_shape[2]))) + if spatial_dim * spatial_dim != x_shape[1] * x_shape[2]: + raise ValueError("Assumed inputs were square-able but they were " + "not. Shape: %s" % x_shape) + x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth]) + + x = common_layers.conv_block_downsample(x, (3, 3), (2, 2), "SAME") + return tf.nn.relu(x) + + +@registry.register_model +class Xception(t2t_model.T2TModel): + + def body(self, features): + return xception_internal(features["inputs"], self._hparams) + + +@registry.register_hparams +def xception_base(): + """Set of hyperparameters.""" + hparams = common_hparams.basic_params1() + hparams.batch_size = 128 + hparams.hidden_size = 768 + hparams.dropout = 0.2 + hparams.symbol_dropout = 0.2 + hparams.label_smoothing = 0.1 + hparams.clip_grad_norm = 2.0 + hparams.num_hidden_layers = 8 + hparams.kernel_height = 3 + hparams.kernel_width = 3 + hparams.learning_rate_decay_scheme = "exp" + hparams.learning_rate = 0.05 + hparams.learning_rate_warmup_steps = 3000 + hparams.initializer_gain = 1.0 + hparams.weight_decay = 3.0 + hparams.num_sampled_classes = 0 + hparams.sampling_method = "argmax" + hparams.optimizer_adam_epsilon = 1e-6 + hparams.optimizer_adam_beta1 = 0.85 + hparams.optimizer_adam_beta2 = 0.997 + return hparams + + +@registry.register_hparams +def xception_tiny(): + hparams = xception_base() + hparams.batch_size = 2 + hparams.hidden_size = 64 + hparams.num_hidden_layers = 2 + hparams.learning_rate_decay_scheme = "none" + return hparams + + +@registry.register_hparams +def xception_tiny_tpu(): + hparams = xception_base() + hparams.batch_size = 2 + hparams.num_hidden_layers = 2 + hparams.hidden_size = 128 + hparams.optimizer = "true_adam" + return hparams diff --git a/trax/models/xception_test.py b/trax/models/xception_test.py new file mode 100644 index 000000000..36ca2d1be --- /dev/null +++ b/trax/models/xception_test.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright 2023 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Xception tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensor2tensor.data_generators import problem_hparams +from tensor2tensor.layers import modalities +from tensor2tensor.models import xception + +import tensorflow.compat.v1 as tf +from tensorflow.compat.v1 import estimator as tf_estimator + + +class XceptionTest(tf.test.TestCase): + + def _test_xception(self, img_size): + vocab_size = 9 + batch_size = 3 + x = np.random.randint( + 256, size=(batch_size, img_size, img_size, 3)) + y = np.random.randint( + 1, high=vocab_size, size=(batch_size, 1, 1, 1)) + hparams = xception.xception_tiny() + p_hparams = problem_hparams.test_problem_hparams(vocab_size, + vocab_size, + hparams) + p_hparams.modality["inputs"] = modalities.ModalityType.IMAGE + p_hparams.modality["targets"] = modalities.ModalityType.CLASS_LABEL + with self.test_session() as session: + features = { + "inputs": tf.constant(x, dtype=tf.int32), + "targets": tf.constant(y, dtype=tf.int32), + } + model = xception.Xception(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) + logits, _ = model(features) + session.run(tf.global_variables_initializer()) + res = session.run(logits) + self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size)) + + def testXceptionSmallImage(self): + self._test_xception(img_size=9) + + def testXceptionLargeImage(self): + self._test_xception(img_size=256) + + +if __name__ == "__main__": + tf.test.main()