diff --git a/trax/examples/Terraformer_from_scratch.ipynb b/trax/examples/Terraformer_from_scratch.ipynb new file mode 100644 index 000000000..9e3eaea6c --- /dev/null +++ b/trax/examples/Terraformer_from_scratch.ipynb @@ -0,0 +1,2587 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Vzsxj2EV3lfL" + }, + "source": [ + "# Scaling Transformers - Sparse Is Enough\n", + "\n", + "Licensed under the Apache License, Version 2.0", + "\n", + "This colab contains all relevant code for the paper \"Sparse is Enough in Scaling Transformers\". We depend on the Trax library and the experiments in the paper were not run with the colab but in a distributed setup with the attached config files -- but with the code below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SMmztiOqenFD" + }, + "outputs": [], + "source": [ + "# Imports.\n", + "!pip install --upgrade -q trax==1.3.9\n", + "\n", + "import functools\n", + "import os\n", + "import random\n", + "import time\n", + "import numpy as np\n", + "\n", + "import jax\n", + "import trax\n", + "from trax import layers as tl\n", + "from trax import fastmath\n", + "from trax.fastmath import numpy as jnp\n", + "from trax.supervised import training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fi6zzlt15l-d" + }, + "source": [ + "## Main sparse layers\n", + "\n", + "This cell contains the implementation of our main sparse layers:\n", + "* sparse QKV layers\n", + "* sparse feed-forward blocks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kbTJBQ_fBz8d" + }, + "outputs": [], + "source": [ + "def SplitLastAxis(num_splits):\n", + " return tl.Fn(f'SplitLastAxis_{num_splits}',\n", + " lambda x: jnp.reshape(x, tuple(x.shape)[:-1] + (num_splits, -1)))\n", + "\n", + "\n", + "def MergeLastTwoAxes():\n", + " return tl.Fn('MergeLastTwoAxes',\n", + " lambda x: jnp.reshape(x, tuple(x.shape)[:-2] + (-1,)))\n", + "\n", + "\n", + "def LocallyConnectedDense(n_modules, n_units, kernel_size=1,\n", + " kernel_initializer=tl.GlorotUniformInitializer(),\n", + " bias_initializer=tl.RandomNormalInitializer(1e-6),\n", + " use_bias=True):\n", + " \"\"\"Layer using LocallyConnected1d for approximation of Dense layer.\n", + "\n", + " The layer splits the last axis of a tensor into `n_modules`, then runs\n", + " LocallyConnected1d (grouped convolution) on all those modules, and\n", + " concatenates their results. It is essentially a locally-sensitive\n", + " approximation of Dense layer, with number of parameters smaller by the factor\n", + " of `n_modules / kernel_size`.\n", + "\n", + " Args:\n", + " n_modules: Indicates how many modules (pixels) should be input and output\n", + " split into for processing.\n", + " n_units: how many outputs (filters) should each module generate.\n", + " kernel_size: The size of the kernel to be used.\n", + " kernel_initializer: Function that creates a matrix of (random) initial\n", + " connection weights `W` for the layer.\n", + " bias_initializer: Function that creates a vector of (random) initial\n", + " bias weights `b` for the layer.\n", + " use_bias: If `True`, compute an affine map `y = Wx + b`; else compute\n", + " a linear map `y = Wx`.\n", + "\n", + " Returns:\n", + " LocallyConnectedDense tl.Layer.\n", + " \"\"\"\n", + " if n_modules == 1:\n", + " return tl.Dense(n_units, kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer, use_bias=use_bias)\n", + " return tl.Serial(\n", + " SplitLastAxis(n_modules),\n", + " tl.LocallyConnected1d(\n", + " n_units, kernel_size, kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer, use_bias=use_bias, padding='WRAP'),\n", + " MergeLastTwoAxes())\n", + "\n", + "\n", + "class _RememberPad(tl.Layer):\n", + " \"\"\"Layer which remembers last N elements in predict mode.\"\"\"\n", + "\n", + " def __init__(self, n_items_to_remember, mode):\n", + " \"\"\"Returns a layer which remembers last N elements in predict mode.\n", + "\n", + " For predict mode, the layer remembers last N elements and pads with them.\n", + " For other modes, it pads with zeros. The layer pads/remembers elements from\n", + " the second axis.\n", + "\n", + " Args:\n", + " n_items_to_remember: Number of items to remember/pad with.\n", + " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", + " \"\"\"\n", + " super().__init__(name='_RememberPad')\n", + " self._n_items_to_remember = n_items_to_remember\n", + " self._mode = mode\n", + " self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none\n", + "\n", + " def monkey_patched_mask(self):\n", + " # This is necessary for Terraformer model. See comments there.\n", + " # The mask will only be used in Terraformer in predict mode.\n", + " return None\n", + "\n", + " def forward(self, x):\n", + " if self._n_items_to_remember == 0:\n", + " return x\n", + " if self._mode == 'predict':\n", + " x = jnp.concatenate([self.state[0], x], axis=1)\n", + " if self._portal_mask is not None and 'init' in self.state[1]:\n", + " assert x.shape[0] == 1\n", + " mask = self._portal_mask.get_value()\n", + " count_padding = jnp.sum(mask == 0, dtype=jnp.int32)\n", + " self.state = (fastmath.dynamic_slice_in_dim(\n", + " x, x.shape[1] - (self._n_items_to_remember + count_padding),\n", + " self._n_items_to_remember, axis=1), {'forward': ()})\n", + " else:\n", + " self.state = (x[:, -self._n_items_to_remember:, ...], {'forward': ()})\n", + " else:\n", + " pad_widths = [[0, 0] for _ in range(len(x.shape))]\n", + " pad_widths[1][0] = self._n_items_to_remember\n", + " x = jnp.pad(x, pad_width=pad_widths, mode='constant')\n", + " return x\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Initializes this layer's weights.\"\"\"\n", + " if isinstance(input_signature, (list, tuple)):\n", + " input_signature = input_signature[0]\n", + " self.weights = ()\n", + " if self._mode == 'predict':\n", + " shape = list(input_signature.shape)\n", + " shape[1] = self._n_items_to_remember\n", + " self.state = (jnp.zeros(shape, dtype=jnp.float32), {'init': ()})\n", + " else:\n", + " self.state = ()\n", + "\n", + "\n", + "def LocallyConvDense(n_modules, n_units, mode, kernel_size=1,\n", + " length_kernel_size=1):\n", + " \"\"\"Layer using local convolutions for approximation of Dense layer.\n", + "\n", + " The layer splits the last axis of a tensor into `n_modules`, then runs\n", + " a convolution on all those modules, and concatenates their results.\n", + " It is similar to LocallyConnectedDense above, but shares weights.\n", + "\n", + " Args:\n", + " n_modules: Indicates how many modules (pixels) should be input and output\n", + " split into for processing.\n", + " n_units: how many outputs (filters) should each module generate.\n", + " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", + " kernel_size: The size of the kernel to be used.\n", + " length_kernel_size: If \u003e 1, also do causal convolution on the previous axis,\n", + " which is often the sentence length in sequence models.\n", + "\n", + " Returns:\n", + " LocallyConvDense tl.Layer.\n", + " \"\"\"\n", + " if n_modules == 1:\n", + " return tl.Dense(n_units)\n", + " if kernel_size % 2 != 1:\n", + " raise ValueError('Currently we only handle odd kernel sizes.')\n", + " half = (kernel_size - 1) // 2\n", + " pad_widths = [[0, 0], [0, 0], [half, half], [0, 0]]\n", + " return tl.Serial(\n", + " SplitLastAxis(n_modules),\n", + " tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths, mode='constant')),\n", + " _RememberPad(length_kernel_size-1, mode=mode),\n", + " tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)),\n", + " MergeLastTwoAxes()\n", + " )\n", + "\n", + "\n", + "def RandomLayer(layer_a, layer_b, prob_a):\n", + " \"\"\"Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.\"\"\"\n", + " condition = tl.Serial(\n", + " tl.RandomUniform(),\n", + " tl.Fn('SmallerThan', lambda x: x \u003c prob_a)\n", + " )\n", + " return tl.Cond(condition, layer_a, layer_b)\n", + "\n", + "\n", + "def SparseDenseWithOptions(n_units, d_input=None, sparsity_type=None,\n", + " sparsity=0, d_lowrank=None, prob_sparse=None,\n", + " mode=None, use_bias=True, use_bfloat16=False):\n", + " \"\"\"Configurable sparse version of Dense layer.\"\"\"\n", + " if prob_sparse is not None:\n", + " if mode is not None and mode != 'train':\n", + " # For non-training modes, we want to use a sparse variant.\n", + " # This is different than simply prob_sparse being None, as the weights of\n", + " # the model are different.\n", + " prob_sparse = 1.0\n", + " return RandomLayer(\n", + " SparseDenseWithOptions(n_units, d_input, sparsity_type, sparsity,\n", + " d_lowrank, use_bias=use_bias,\n", + " use_bfloat16=use_bfloat16),\n", + " tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16),\n", + " prob_sparse)\n", + "\n", + " if sparsity_type is None or sparsity_type == 'None' or sparsity == 0:\n", + " return tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16)\n", + " if sparsity_type == 'mult':\n", + " return FactoredDense(sparsity, d_input, n_units, use_bias=use_bias,\n", + " use_bfloat16=use_bfloat16)\n", + "\n", + " assert not use_bfloat16 # use_bfloat16 is unsupported for other variants\n", + " if sparsity_type == 'local':\n", + " assert use_bias # use_bias = False is unsupported\n", + " assert n_units % sparsity == 0\n", + " return LocallyConnectedDense(sparsity, n_units/sparsity)\n", + " if sparsity_type == 'local3':\n", + " assert use_bias # use_bias = False is unsupported\n", + " assert n_units % sparsity == 0\n", + " return LocallyConnectedDense(sparsity, n_units/sparsity, kernel_size=3)\n", + "\n", + " raise ValueError('Unknown sparsity type: {}'.format(sparsity_type))\n", + "\n", + "\n", + "def FactoredDense(n_modules, d_in, d_out, use_bias=True, use_bfloat16=False):\n", + " r\"\"\"Returns a Dense-like layer, internally factored to use fewer parameters.\n", + "\n", + " This layer treats an activation vector as if divided into :math:`M`\n", + " subvectors (``n_modules`` 'modules'). It uses this factored view to compute\n", + " a :py:class:`Dense`-like mapping with high mixing/connectivity, but using\n", + " approximately :math:`1/M` the number of weights of a similarly dimensioned\n", + " :py:class:`Dense` layer.\n", + "\n", + " More specifically, each activation vector of dimensionality ``n_in`` is\n", + " multiplied element-wise (a generalized form of gating) with ``n_modules``\n", + " vectors also of dimensionality ``n_in``. The resulting vectors are projected\n", + " to the subvector/module dimensionality ``d_out / n_modules`` via a matrix\n", + " multiply, and finally reshaped back to a single vector of dimensionality\n", + " ``d_out``. Optionally, a bias vector of dimensionality ``d_out`` is added at\n", + " the end. All the above-mentioned non-input objects -- gating vectors,\n", + " projection matrix, and optional bias -- are trainable weights.\n", + "\n", + " Args:\n", + " n_modules: Number by which an activation vector is divided into subvectors\n", + " (modules) for the factored computation.\n", + " d_in: Last/innermost dimension of input array.\n", + " d_out: Last/innermost dimension of output array.\n", + " use_bias: If True, add bias vectors at the end of the layer; else end the\n", + " layer with the matrix multiply.\n", + " use_bfloat16: If True, use bfloat16 weights; else use float32 weights.\n", + " \"\"\"\n", + " if d_out % n_modules != 0:\n", + " raise ValueError(f'Value d_out ({d_out}) must be a multiple of arg '\n", + " f'n_modules ({n_modules}).')\n", + " d_module = d_out // n_modules\n", + "\n", + " def GatingVectors():\n", + " return tl.Weights(tl.RandomNormalInitializer(stddev=0.5),\n", + " shape=[n_modules, d_in],\n", + " use_bfloat16=use_bfloat16)\n", + "\n", + " def ProjectionMatrix():\n", + " return tl.Weights(tl.GlorotUniformInitializer(),\n", + " shape=[d_in, d_module],\n", + " use_bfloat16=use_bfloat16),\n", + "\n", + " def Bias():\n", + " return tl.Weights(tl.RandomNormalInitializer(1e-6),\n", + " shape=[d_out],\n", + " use_bfloat16=use_bfloat16),\n", + "\n", + " layers = [\n", + " GatingVectors(),\n", + " ProjectionMatrix(),\n", + " _GateAndProject(),\n", + " MergeLastTwoAxes(),\n", + " ]\n", + " if use_bias:\n", + " layers += [Bias(), tl.Add()]\n", + "\n", + " return tl.Serial(layers)\n", + "\n", + "\n", + "def _GateAndProject():\n", + " \"\"\"Returns a combined gating+projection layer that saves on memory.\"\"\"\n", + "\n", + " def f(projection, gating, x):\n", + " # Args arrive in reverse order because of how they were put on the stack.\n", + " # Einsum indices: d (d_in), n (n_modules), m (d_module = d_out/n_modules)\n", + " return jnp.einsum('...d,nd,dm-\u003e...nm', x, gating, projection)\n", + "\n", + " return tl.Fn('_GateAndProject', f)\n", + "\n", + "\n", + "def MultiplicativeConvCausalAttention(\n", + " d_feature, n_heads=1, sparsity=None, length_kernel_size=3, dropout=0.0,\n", + " force_no_dropout=False, max_inference_length=2048, share_qk=False,\n", + " output_layer_type='none', v_concat_type='none', mode='train'):\n", + " \"\"\"Returns a layer that maps activations to activations, with causal masking.\n", + "\n", + " Like `CausalAttention`, this layer type represents one pass of multi-head\n", + " self-attention with causal masking rather than padding-based masking. However,\n", + " for computing Q/K/V instead of a Dense layer it combines\n", + " FactoredDense layer with LocallyConvLayer.\n", + "\n", + " Args:\n", + " d_feature: Depth/dimensionality of feature embedding.\n", + " n_heads: Number of attention heads.\n", + " sparsity: The sparsity of the layer; usually it should be equal to n_heads.\n", + " length_kernel_size: Size of convolution kernel on the length dimension.\n", + " dropout: Probababilistic rate for internal dropout applied to attention\n", + " activations (based on query-key pairs) before dotting them with values.\n", + " force_no_dropout: If True, force dropout to be 0.0 independent of the above\n", + " value; used to override some configurations.\n", + " max_inference_length: maximum length for inference.\n", + " share_qk: if True, average Q and K embeddings and share for both Q and K.\n", + " output_layer_type: Which sparse layers to use for processing output from the\n", + " attention mechanism. One of `'none'`, `'mult'`, `'conv'`,\n", + " or `'multconv'`.\n", + " v_concat_type: What kind of concatenation to use when computing V tensor.\n", + " One of `'original'`, `'fixed'`, or `'none'`. `'none'` means using just\n", + " output from mutliplicative layer shared by Q, K, V. `'fixed'` means\n", + " using output from multiplicative layer concatenated, for each module,\n", + " with the layer input. `'original'` means using concatenation without\n", + " properly taking modules into account; this method was used in\n", + " experiments previously, so it is included for backwards-compatibility.\n", + " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", + " \"\"\"\n", + " assert output_layer_type in ['none', 'mult', 'conv', 'multconv']\n", + " assert v_concat_type in ['original', 'fixed', 'none']\n", + "\n", + " dropout = 0.0 if force_no_dropout else dropout\n", + " sparsity = n_heads if sparsity is None else sparsity\n", + " d_module = d_feature // sparsity\n", + "\n", + " output_layers = []\n", + " if 'mult' in output_layer_type:\n", + " output_layers.append(FactoredDense(\n", + " sparsity, d_feature, d_feature))\n", + " if 'conv' in output_layer_type:\n", + " output_layers.append(LocallyConvDense(\n", + " sparsity, d_module, mode=mode, kernel_size=3,\n", + " length_kernel_size=length_kernel_size))\n", + "\n", + " if v_concat_type == 'original':\n", + " # 'original'` uses concatenation without properly taking modules into\n", + " # account; this method was used in experiments previously, so it is included\n", + " # for backwards-compatibility.\n", + " concat_layers = [tl.Concatenate()] # use permuted and original for v\n", + " elif v_concat_type == 'fixed':\n", + " # `'fixed'` uses the output from multiplicative layer concatenated, for each\n", + " # module, with the layer input. This means that every module in Conv layer\n", + " # has access both to parts of embeddings which were used to compute Q/K of\n", + " # this particular module, and it ha access to parts of the embedding which\n", + " # will be modified by this module.\n", + " concat_layers = [\n", + " tl.Parallel(\n", + " tl.Fn('Reshape1', lambda x: jnp.reshape( # pylint: disable=g-long-lambda\n", + " x, (x.shape[0], x.shape[1], sparsity, d_module))),\n", + " tl.Fn('Reshape2', lambda x: jnp.reshape( # pylint: disable=g-long-lambda\n", + " x, (x.shape[0], x.shape[1], sparsity, d_module)))),\n", + " tl.Concatenate(),\n", + " tl.Fn('Reshape3',\n", + " lambda x: jnp.reshape(x, (x.shape[0], x.shape[1], 2*d_feature))),\n", + " ]\n", + " elif v_concat_type == 'none':\n", + " # `'none'` doesn't use concatenation: we throw away the original layer\n", + " # input and pass to Conv only output of shared Multiplicative layer.\n", + " concat_layers = [tl.Select([0], n_in=2)]\n", + "\n", + " if share_qk:\n", + " return tl.Serial(\n", + " tl.Select([0, 0]), # pre-qkv, pre-v-for-concat\n", + " FactoredDense(sparsity, d_feature, d_feature), # shared q k\n", + " tl.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat\n", + " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads),\n", + " tl.Select([0, 0]), # use for q and k\n", + " tl.Parallel(\n", + " [],\n", + " [],\n", + " [concat_layers,\n", + " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads)],\n", + " ),\n", + " tl.DotProductCausalAttention(\n", + " dropout=dropout, max_inference_length=max_inference_length,\n", + " mode=mode),\n", + " tl.MergeHeads(n_heads),\n", + " output_layers,\n", + " )\n", + " return tl.Serial(\n", + " tl.Select([0, 0]), # duplicate activations\n", + " FactoredDense(sparsity, d_feature, d_feature), # shared q, k\n", + " tl.Select([0, 0, 0]), # use for q, k, v\n", + " tl.Parallel(\n", + " [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads)],\n", + " [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads)],\n", + " [concat_layers,\n", + " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads)],\n", + " ),\n", + " tl.DotProductCausalAttention(\n", + " dropout=dropout, max_inference_length=max_inference_length,\n", + " mode=mode),\n", + " tl.MergeHeads(n_heads),\n", + " output_layers,\n", + " )\n", + "\n", + "\n", + "class DotProductCausalAttention(tl.Layer):\n", + " \"\"\"Layer that computes attention strengths by masking out the \"future\".\n", + "\n", + " Causal attention uses masking to prevent a given sequence position from\n", + " attending to positions greater than / following it. This is used, for\n", + " example, when training autoregressive sequence models, or when decoding a\n", + " sequence symbol by symbol.\n", + "\n", + " This layer performs the core per-head attention calculation. The layer\n", + " assumes that any splitting into attention heads precedes it, and that any\n", + " merging of attention heads will follow it.\n", + " \"\"\"\n", + "\n", + " def __init__(self, dropout=0.0, max_inference_length=2048, mode='train'):\n", + " \"\"\"Creates a :py:class:`DotProductCausalAttention` instance.\n", + "\n", + " Args:\n", + " dropout: Probababilistic rate for attention dropout, which overrides\n", + " (sets to zero) some attention strengths derived from query-key\n", + " matching. As a result, on a given forward pass, some value vectors\n", + " don't contribute to the output, analogous to how regular dropout can\n", + " cause some node activations to be ignored. Applies only if layer is\n", + " created in ``'train'`` mode.\n", + " max_inference_length: Maximum sequence length allowed in non-training\n", + " modes.\n", + " mode: One of ``'train'``, ``'eval'``, or ``'predict'``.\n", + " \"\"\"\n", + " super().__init__(n_in=3, n_out=1)\n", + " self._dropout = dropout\n", + " self._mode = mode\n", + " self._max_len = max_inference_length\n", + " self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none\n", + "\n", + " def monkey_patched_mask(self):\n", + " # This is necessary for Terraformer model. See comments there.\n", + " # The mask will only be used in Terraformer in predict mode.\n", + " return None\n", + "\n", + " def forward(self, inputs):\n", + " \"\"\"Returns attention-computed activations.\n", + "\n", + " Args:\n", + " inputs: A (queries, keys, values) tuple.\n", + " \"\"\"\n", + " q, k, v = inputs\n", + "\n", + " if self._portal_mask is not None:\n", + " mask_for_predict = self._portal_mask.get_value()\n", + " else:\n", + " mask_for_predict = None\n", + "\n", + " if self._mode == 'predict':\n", + " self.state, mask = _fast_inference_update_state(\n", + " inputs, self.state,\n", + " mask_for_predict=mask_for_predict)\n", + " if self._portal_mask is not None:\n", + " (_, k, v, _) = self.state\n", + " else:\n", + " (k, v, _) = self.state\n", + " else:\n", + " sequence_length = q.shape[-2]\n", + " mask = _causal_mask(sequence_length)\n", + "\n", + " activations, attn_strengths = _per_head_attention(\n", + " q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng)\n", + " if self._mode == 'viz':\n", + " self.state = attn_strengths\n", + " return activations\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Initializes this layer for fast inference, if in ``'predict'`` mode.\"\"\"\n", + " if self._mode == 'predict':\n", + " self.state = _fast_inference_init_state(\n", + " input_signature, self._max_len,\n", + " predict_mask=self._portal_mask)\n", + " \n", + "def _fast_inference_init_state(input_signature, buffer_length,\n", + " predict_mask=None):\n", + " \"\"\"Returns an initial state for causal attention layer fast inference.\"\"\"\n", + " def zeros_for(batch_size, shape_dtype):\n", + " shape, dtype = shape_dtype.as_tuple()\n", + " d_feature = shape[-1]\n", + " return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype)\n", + "\n", + " batch_size = input_signature[0].shape[0]\n", + " k = zeros_for(batch_size, input_signature[1])\n", + " v = zeros_for(batch_size, input_signature[2])\n", + " if predict_mask is not None:\n", + " mask_for_predict = jnp.zeros((buffer_length,)) != 0\n", + " return (mask_for_predict, k, v, jnp.array(0))\n", + " else:\n", + " return (k, v, jnp.array(0))\n", + "\n", + "\n", + "def _fast_inference_update_state(inputs, state, mask_for_predict=None):\n", + " \"\"\"Updates state of a causal attention layer for fast inference.\n", + "\n", + " The layer state stores arrays with cached values of keys and values,\n", + " as well as an index. To make shapes static, keys and values in the state are\n", + " long, and the index indicates where the new keys and values from inputs need\n", + " to be appended.\n", + "\n", + " During update, we append new_keys and new_values to keys and values at\n", + " position given by index. And we increment index by length of new keys.\n", + " We also create a mask to be 1 at appropriate positions (causal mask).\n", + "\n", + " Args:\n", + " inputs: a triple (new_queries, new_keys, new_values)\n", + " state: layer state with (keys, values, index)\n", + " mask_for_predict: mask used for predict mode. This is used only in\n", + " Terraformer.\n", + "\n", + " Returns:\n", + " Updated state and mask to be used.\n", + " \"\"\"\n", + " # Fast inference: run step-by-step, storing the sequence\n", + " # of keys and values calculated so far in state.\n", + " (_, new_k, new_v) = inputs\n", + " if mask_for_predict is not None:\n", + " (state_mask_for_predict, ks, vs, idx) = state\n", + " else:\n", + " (ks, vs, idx) = state\n", + " length = new_k.shape[1]\n", + " ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)\n", + " vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)\n", + " k_length = ks.shape[1]\n", + "\n", + " # Mask is of shape [1, q_length, k_length].\n", + " # Mask should be true for every pair of (query_token, key_token) such that\n", + " # index of query_token is equal or larger to index of key_token.\n", + " mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length))\n", + " \u003c= jnp.reshape(jnp.arange(length) + idx, (1, length, 1)))\n", + " if mask_for_predict is None:\n", + " return (ks, vs, idx + length), mask\n", + " else:\n", + " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", + " state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0,\n", + " axis=0)\n", + "\n", + " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", + " state_mask_for_predict != 0, jnp.ones((1,)) != 0,\n", + " jnp.sum(mask_for_predict, dtype=jnp.int32), axis=0)\n", + "\n", + " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", + " state_mask_for_predict != 0, jnp.ones((1,)) != 0, idx, axis=0)\n", + " placeholder = jnp.reshape(state_mask_for_predict != 0,\n", + " (1, 1, mask.shape[2],))\n", + " mask = mask * placeholder\n", + "\n", + " return (state_mask_for_predict, ks, vs, idx + length), mask\n", + "\n", + "\n", + "def _causal_mask(length):\n", + " # Not all backends define jnp.tril. However, using np.tril is inefficient\n", + " # in that it creates a large global constant.\n", + " if fastmath.is_backend(fastmath.Backend.JAX):\n", + " return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0)\n", + " else:\n", + " return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)\n", + "\n", + "\n", + "def _per_head_attention(queries, keys, values, mask, dropout, mode, rng):\n", + " \"\"\"Computes new per-head activations via scaled dot-product attention.\n", + "\n", + " This function is the core of the attention mechanism. Given per-head\n", + " ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it:\n", + "\n", + " - computes the scaled dot product of each Q-K pair;\n", + " - applies ``mask`` to screen out positions that come from padding tokens\n", + " (indicated by 0 value);\n", + " - [in ``'train'`` mode] applies dropout to Q-K dot products;\n", + " - computes Q-K attention strengths using a per-query softmax of the Q-K dot\n", + " products; and\n", + " - for each query position, combines V vectors according to the Q-K\n", + " attention strengths.\n", + "\n", + " Args:\n", + " queries: Per-head activations representing attention queries.\n", + " keys: Per-head activations representing attention keys.\n", + " values: Per-head activations to be combined by computed attention strengths.\n", + " mask: Mask that distinguishes positions with real content vs. padding.\n", + " dropout: Probababilistic rate for attention dropout, which overrides\n", + " (sets to zero) some attention strengths derived from query-key\n", + " matching. As a result, on a given forward pass, some value vectors\n", + " don't contribute to the output, analogous to how regular dropout can\n", + " cause some node activations to be ignored. Applies only in ``'train'``\n", + " mode.\n", + " mode: One of ``'train'``, ``'eval'``, or ``'predict'``.\n", + " rng: Single-use random number generator (JAX PRNG key).\n", + "\n", + " Returns:\n", + " Tuple of (activations, attn_strengths), where activations are new per-head\n", + " activation vectors and attn_strengths is a matrix of per-head attention\n", + " strengths.\n", + " \"\"\"\n", + " if dropout \u003e= 1.0:\n", + " raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.')\n", + "\n", + " d_feature = queries.shape[-1]\n", + "\n", + " dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)\n", + " if mask is not None:\n", + " dots = jnp.where(mask,\n", + " dots,\n", + " jnp.full_like(dots, -1e9))\n", + " attn_strengths = (\n", + " jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)))\n", + " if dropout is not None and dropout \u003e 0.0 and mode == 'train':\n", + " keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape)\n", + " attn_strengths = jnp.where(keep,\n", + " attn_strengths / (1.0 - dropout),\n", + " jnp.zeros_like(attn_strengths))\n", + " activations = jnp.matmul(attn_strengths, values).astype(jnp.float32)\n", + " attn_strengths = attn_strengths.astype(jnp.float32)\n", + " return activations, attn_strengths\n", + "\n", + "\n", + "class _RememberInReverse(tl.Layer):\n", + " \"\"\"Layer remembering the input in forward pass. For reversible models.\"\"\"\n", + "\n", + " def __init__(self, output=True):\n", + " \"\"\"Layer remembering the input in forward pass. For reversible models.\n", + "\n", + " During the first pass through the model this layer saves the input as\n", + " state, and returns the input unmodified. During the second pass through the\n", + " model the layer outputs the input from the first pass. This is used to\n", + " combat numerical stability problems in Terraformer. It doesn't do anything\n", + " in non-reversible models.\n", + "\n", + " Args:\n", + " output: Whether to pass the input or not.\n", + " \"\"\"\n", + " n_out = 1 if output else 0\n", + " self._output = output\n", + " super().__init__(name='_RememberInReverse', n_out=n_out)\n", + "\n", + " def forward(self, x):\n", + " if 'running_second_time_yes' in self.state[1]:\n", + " result = self.state[0]\n", + " else:\n", + " result = x\n", + " self.state = (x, {'running_second_time': ()})\n", + "\n", + " if self._output:\n", + " return result\n", + " else:\n", + " return tuple()\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Initializes this layer's weights.\"\"\"\n", + " if isinstance(input_signature, (list, tuple)):\n", + " input_signature = input_signature[0]\n", + " self.weights = ()\n", + " self.state = (jnp.zeros(input_signature.shape, dtype=jnp.int32),\n", + " {'running_second_time': ()})\n", + "\n", + "\n", + "class _RecallQuantMaskInReverse(tl.Layer):\n", + " \"\"\"Layer recalling quant mask from specific _RememberInReverse.\n", + "\n", + " This layer is needed for memory-efficient training of reversible model with\n", + " ff chunking. During forward pass it simply returns minus ones, which are\n", + " ignored in the controller. During reverse_and_grad it returns a quant_mask\n", + " which was memorized (saved to state) by a RememberInReverse layer.\n", + "\n", + " This enable us to save quant_mask right after chunking, and load it again\n", + " (when reversing) right before chunking.\n", + " \"\"\"\n", + "\n", + " def __init__(self, remember_layer, elements):\n", + " self._remember_layer = remember_layer\n", + " self._elements = elements\n", + " super().__init__(name='_RecallQuantMaskInReverse', n_in=1, n_out=2)\n", + "\n", + " def forward(self, x):\n", + " if (self._remember_layer.state and\n", + " 'running_second_time_yes' in self._remember_layer.state[1]):\n", + " # It's reverse_and_grad, so we pull the quant_mask from remembering layer.\n", + " result = self._remember_layer.state[0]\n", + " else:\n", + " result = -jnp.ones((x.shape[0], self._elements), dtype=jnp.int32)\n", + " return (x, result)\n", + "\n", + "\n", + "class _SparseFFController(tl.Layer):\n", + " \"\"\"The controller part of Sparse Feed-Forward layer.\"\"\"\n", + "\n", + " def __init__(self, d_ff, n_elements_in_block, d_lowrank, temperature,\n", + " use_bfloat16, mode, kernel_initializer, bias_initializer,\n", + " also_return_nondiscrete_output):\n", + " \"\"\"Returns a sparse feed-forward block.\"\"\"\n", + " n_out = 2 if also_return_nondiscrete_output else 1\n", + " super().__init__(name=f'_SparseFFController_{d_ff}', n_in=2, n_out=n_out)\n", + " self._use_bfloat16 = use_bfloat16\n", + " self._d_ff = d_ff\n", + " self._d_lowrank = d_lowrank\n", + " # Q: what temperature is actually most useful in training?\n", + " self._temperature = temperature if mode == 'train' else 0.0\n", + " self._mode = mode\n", + " self._n_elements_in_block = n_elements_in_block\n", + " self._kernel_initializer = kernel_initializer\n", + " self._bias_initializer = bias_initializer\n", + " # Helper numbers as d_ff will be divided by n_elements_in_block.\n", + " assert self._d_ff % self._n_elements_in_block == 0\n", + " self._d1 = self._d_ff // self._n_elements_in_block\n", + " self._d2 = self._n_elements_in_block\n", + " self._also_return_nondiscrete_output = also_return_nondiscrete_output\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input.\n", + " \"\"\"\n", + " x, recalled_quant_mask = x\n", + " m1, m2, mb = self.weights\n", + "\n", + " x_shape = x.shape\n", + " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", + "\n", + " # Q: should we add bias and/or put relu after the low-rank m1 dot?\n", + " # Replacing multiplication and reshape by this einsum brings training speed\n", + " # improvement (see also reshape in initialization).\n", + " mask_logits = jnp.einsum('bd,dl,lxy-\u003ebxy', x, m1, m2) + mb\n", + "\n", + " if self._also_return_nondiscrete_output:\n", + " # Softmax.\n", + " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", + " log_mask = mask_logits - mask_logsumexp\n", + " mask = jnp.exp(log_mask)\n", + " # Gumbel-softmax with straight-through discretization.\n", + " if self._temperature == 0.0:\n", + " quant_mask = jnp.argmax(log_mask, axis=-1)\n", + " else:\n", + " u = fastmath.random.uniform(self.rng, mask.shape, jnp.float32, 1e-6,\n", + " 1.0 - 1e-6)\n", + " g = -jnp.log(-jnp.log(u))\n", + " quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", + " else:\n", + " quant_mask = jnp.argmax(mask_logits, axis=-1)\n", + "\n", + " if self._mode == 'train':\n", + " # We use recalled_quant_mask if it's different than -1; otherwise\n", + " # we use a quant_mask which we have just computed.\n", + " quant_mask = jnp.where(recalled_quant_mask == -1,\n", + " quant_mask, recalled_quant_mask)\n", + "\n", + " if self._also_return_nondiscrete_output:\n", + " return quant_mask, mask\n", + " else:\n", + " return quant_mask\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " x_input_signature = input_signature[0]\n", + " d_model = x_input_signature.shape[-1]\n", + " shape_m1 = (d_model, self._d_lowrank)\n", + " shape_m2 = (self._d_lowrank, self._d_ff)\n", + " shape_mb = (self._d_ff,)\n", + "\n", + " rng_m1, rng_m2, rng_mb = fastmath.random.split(self.rng, 3)\n", + " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", + " m2 = self._kernel_initializer(shape_m2, rng_m2)\n", + " mb = self._bias_initializer(shape_mb, rng_mb)\n", + " if self._use_bfloat16:\n", + " m1 = m1.astype(jnp.bfloat16)\n", + " m2 = m2.astype(jnp.bfloat16)\n", + " mb = mb.astype(jnp.bfloat16)\n", + "\n", + " # Reshapes below, with einsum in feedforward, improve the training speed.\n", + " m2 = jnp.reshape(m2, [self._d_lowrank, self._d1, self._d2])\n", + " mb = jnp.reshape(mb, [self._d1, self._d2])\n", + "\n", + " self.weights = (m1, m2, mb)\n", + "\n", + "\n", + "class _SparseFFMain(tl.Layer):\n", + " \"\"\"The main (non-controller) part of Sparse Feed-Forward layer.\"\"\"\n", + "\n", + " def __init__(self, d_ff, n_elements_in_block, d_lowrank, quant_prob,\n", + " use_bfloat16, big_weights_in_bfloat16, mode, kernel_initializer,\n", + " bias_initializer, multiply_by_controller_output, kernel_scaling):\n", + " \"\"\"Returns a sparse feed-forward block.\"\"\"\n", + " n_in = 3 if mode == 'train' or multiply_by_controller_output else 2\n", + " super().__init__(name=f'_SparseFFMain_{d_ff}', n_in=n_in, n_out=2)\n", + " self._mode = mode\n", + " self._use_bfloat16 = use_bfloat16\n", + " self._big_weights_in_bfloat16 = big_weights_in_bfloat16\n", + " self._d_ff = d_ff\n", + " self._d_lowrank = d_lowrank\n", + " self._quant_prob = quant_prob\n", + " self._n_elements_in_block = n_elements_in_block\n", + " self._kernel_initializer = kernel_initializer\n", + " self._bias_initializer = bias_initializer\n", + " # Helper numbers as d_ff will be divided by n_elements_in_block.\n", + " assert self._d_ff % self._n_elements_in_block == 0\n", + " self._d1 = self._d_ff // self._n_elements_in_block\n", + " self._d2 = self._n_elements_in_block\n", + " self._multiply_by_controller_output = multiply_by_controller_output\n", + " self._kernel_scaling = kernel_scaling\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input.\n", + " \"\"\"\n", + " if self._mode == 'train' or self._multiply_by_controller_output:\n", + " quant_mask, mask, x = x\n", + " else:\n", + " quant_mask, x = x\n", + " original_quant_mask = quant_mask\n", + "\n", + " w1, w2, b2 = self.weights\n", + "\n", + " if self._mode == 'predict':\n", + " w1 = jnp.transpose(w1, (1, 2, 0)) # dm, d1, d2 -\u003e d1, d2, dm\n", + " w2 = jnp.transpose(w2, (1, 0, 2)) # d2, d1, dm -\u003e d1, d2, dm\n", + " x_shape = x.shape\n", + " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", + "\n", + " if self._mode == 'train':\n", + " # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797\n", + " quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)\n", + " quant_mask = fastmath.stop_gradient(quant_mask)\n", + " quant_mask += mask - fastmath.stop_gradient(mask) # straight-through\n", + " # We will sometimes (quant_prob of the batches) use the soft-mask instead\n", + " # of the quantized mask to improve training stability (see paper above).\n", + " select = fastmath.random.uniform(self.rng, (), jnp.float32, 0.0, 1.0)\n", + " quant_mask = jnp.where(select \u003c self._quant_prob, quant_mask, mask)\n", + "\n", + " # In training, run full matmul to get benefits from the above tricks.\n", + " mid = jnp.einsum('bd,dxy-\u003ebxy', x, w1) * quant_mask\n", + " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", + " if self._multiply_by_controller_output:\n", + " # We multiply only for quantized decisions, since for non-quantized\n", + " # decisions we've already multiplied the output.\n", + " mask_mult = jnp.where(select \u003c self._quant_prob,\n", + " mask, jnp.ones_like(mask))\n", + " # Stop-gradient is here, because we already have a pass-through gradient\n", + " # (for quantized decisions).\n", + " mask_mult = fastmath.stop_gradient(mask_mult)\n", + " relu = relu * mask_mult\n", + " res = jnp.einsum('bxy,yxd-\u003ebd', relu, w2) + b2\n", + " elif self._mode == 'predict':\n", + " # This implementation mimicks inference. It's not efficient for large\n", + " # size of joint_batch, but at inference that will be 1 most of the time.\n", + " # Shapes:\n", + " # quant_mask is [joint_batch, self._d1]\n", + " # w1 is [d_model, self._d1, self._d2]\n", + " # we'll index w1 with advanced numpy indexing, first range over\n", + " # self._d1 times the batch size, second range being quant_mask\n", + " batch_size = quant_mask.shape[0]\n", + " idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)\n", + " # flatten indices and select from w1\n", + " idx1 = jnp.reshape(idx1, [-1])\n", + " idx2 = jnp.reshape(quant_mask, [-1])\n", + " w = w1[idx1, idx2, :] # now we have per-element weights with batch dim\n", + " w = jnp.reshape(w, [batch_size, self._d1, -1])\n", + " mid = jnp.einsum('ai,aji-\u003eaj', x, w)\n", + " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", + " if self._multiply_by_controller_output:\n", + " mask_mult = jnp.take_along_axis(mask, quant_mask[..., None], -1)[..., 0]\n", + " relu = relu * mask_mult\n", + " # w2 is [self._d1, self._d2, d_model]\n", + " v = w2[idx1, idx2, :]\n", + " v = jnp.reshape(v, [batch_size, self._d1, -1])\n", + " res = jnp.einsum('ai,aij-\u003eaj', relu, v) + b2\n", + " else:\n", + " quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)\n", + " mid = jnp.einsum('bd,dxy-\u003ebxy', x, w1) * quant_mask\n", + " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", + " if self._multiply_by_controller_output:\n", + " relu = relu * mask\n", + " res = jnp.einsum('bxy,yxd-\u003ebd', relu, w2) + b2\n", + "\n", + " return original_quant_mask, jnp.reshape(res, x_shape)\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " d_model = input_signature[-1].shape[-1]\n", + " shape_w1 = (d_model, self._d_ff)\n", + " shape_w2 = (self._d_ff, d_model)\n", + " shape_b2 = (d_model,)\n", + "\n", + " rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3)\n", + " if tl.N_WEIGHTS_SHARDS \u003e 1:\n", + " # In sharded-weights mode, put the weights on CPU on init\n", + " # as they will be sharded later.\n", + " w1 = tl.on_cpu(self._kernel_initializer(shape_w1, rng_w1))\n", + " w2 = tl.on_cpu(self._kernel_initializer(shape_w2, rng_w2))\n", + " else:\n", + " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", + " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", + "\n", + " b2 = self._bias_initializer(shape_b2, rng_b2)\n", + " if self._use_bfloat16:\n", + " b2 = b2.astype(jnp.bfloat16)\n", + " if self._use_bfloat16 or self._big_weights_in_bfloat16:\n", + " w1 = w1.astype(jnp.bfloat16)\n", + " w2 = w2.astype(jnp.bfloat16)\n", + "\n", + " w1 = jnp.reshape(w1, (-1, self._d1, self._d2))\n", + " w2 = jnp.reshape(w2, (self._d2, self._d1, -1))\n", + "\n", + " if self._kernel_scaling:\n", + " # This keeps expected variance of the output regardless of N.\n", + " w2 = w2 * (self._n_elements_in_block ** 0.5)\n", + "\n", + " self.weights = (w1, w2, b2)\n", + "\n", + "\n", + "def SparseFF(\n", + " d_ff, n_elements_in_block=32, d_lowrank=64, temperature=0.1, quant_prob=0.3,\n", + " use_bfloat16=False, big_weights_in_bfloat16=False, mode='train',\n", + " kernel_initializer=tl.GlorotUniformInitializer(),\n", + " bias_initializer=tl.RandomNormalInitializer(1e-6),\n", + " dropout_rate=0.0, dropout_shared_axes=None, ff_chunk_size=0,\n", + " multiply_by_controller_output=False, kernel_scaling=False):\n", + " \"\"\"Returns Feed-forward block with sparsity.\n", + "\n", + " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", + " that takes an input, makes it of size d_ff (usually larger than it was) and\n", + " then brings it back to the original size after Relu. It is commonly used in\n", + " Transformer models where it often accounts for most of the trainable weights.\n", + "\n", + " The original block can be slow in decoding due to the need to fetch a lot of\n", + " weights from memory. This sparse block only allows one non-zero element\n", + " in a block of a specified size. This is trained with straight-through Gumbel\n", + " softmax trick.\n", + "\n", + " Args:\n", + " d_ff: Depth/dimensionality of FeedForward layer.\n", + " n_elements_in_block: The sparsity level. The layer is divided into blocks of\n", + " this size, and each block has only a single element active.\n", + " d_lowrank: The dimensionality of low-rank controller.\n", + " temperature: The temperature of the controller during training.\n", + " quant_prob: During training this proportion of blocks will have quantized\n", + " mask (i.e. a single element active). The rest will use a soft mask.\n", + " use_bfloat16: Whether to use bfloat16 for weights.\n", + " big_weights_in_bfloat16: : Whether to use bfloat16 for main weights of the\n", + " FeedForward layer.\n", + " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", + " kernel_initializer: Function that creates a matrix of (random) initial\n", + " connection weights `W` for the layer.\n", + " bias_initializer: Function that creates a vector of (random) initial\n", + " bias weights `b` for the layer.\n", + " dropout_rate: Probability for dropping an activation value.\n", + " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", + " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", + " way to save memory and apply consistent masks to activation vectors at\n", + " different sequence positions.\n", + " ff_chunk_size: int; if \u003e 0, chunk feed-forward into this-sized chunks.\n", + " multiply_by_controller_output: whether to multiply the middle activation\n", + " layer of FF by controller output (i.e. softmax).\n", + " kernel_scaling: Whether to scale the kernel matrix (during init) to keep the\n", + " variance of the layer output regardless of n_elements_in_block.\n", + " \"\"\"\n", + "\n", + " if mode == 'train' or multiply_by_controller_output:\n", + " also_return_nondiscrete_output = True\n", + " else:\n", + " also_return_nondiscrete_output = False\n", + " controller = _SparseFFController(\n", + " d_ff=d_ff, n_elements_in_block=n_elements_in_block,\n", + " d_lowrank=d_lowrank, temperature=temperature,\n", + " use_bfloat16=use_bfloat16, mode=mode,\n", + " kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer,\n", + " also_return_nondiscrete_output=also_return_nondiscrete_output)\n", + "\n", + " main = [\n", + " _SparseFFMain(\n", + " d_ff=d_ff, n_elements_in_block=n_elements_in_block,\n", + " d_lowrank=d_lowrank, quant_prob=quant_prob, use_bfloat16=use_bfloat16,\n", + " big_weights_in_bfloat16=big_weights_in_bfloat16, mode=mode,\n", + " kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer,\n", + " multiply_by_controller_output=multiply_by_controller_output,\n", + " kernel_scaling=kernel_scaling),\n", + " # quant_mask, emb\n", + " tl.Select([1, 0]),\n", + " # emb, quant_mask\n", + " tl.Dropout(rate=dropout_rate, shared_axes=dropout_shared_axes, mode=mode),\n", + " tl.Select([1, 0]),\n", + " # quant_mask, emb\n", + " ]\n", + "\n", + " # We will \"remember\" quant_mask _after_ chunking, and \"recall\" this same\n", + " # quant_mask during reverse_and_grad _before_ chunking.\n", + " remembering = _RememberInReverse(output=False)\n", + " recalling = _RecallQuantMaskInReverse(\n", + " remember_layer=remembering, elements=d_ff//n_elements_in_block)\n", + "\n", + " return tl.BatchLeadingAxes(tl.Serial(\n", + " recalling, # emb, quant_mask\n", + " tl.Chunk(chunk_size=ff_chunk_size, layer=tl.Serial(\n", + " # emb, quant_mask\n", + " tl.Select((0, 1, 0)), # emb, quant_mask, emb\n", + " controller, # quant_mask, mask, emb\n", + " main, # quant_mask, emb/output\n", + " )),\n", + " remembering, # emb/output\n", + " ))\n", + "\n", + "\n", + "class BlockSparseFF(tl.Layer):\n", + " \"\"\"Feed-forward block with block sparsity.\n", + "\n", + " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", + " that takes an input, makes it of size d_ff (usually larger than it was) and\n", + " then brings it back to the original size after Relu. It is commonly used in\n", + " Transformer models where it often accounts for most of the trainable weights.\n", + "\n", + " This block sparse layer mimics mixture of experts architecture.\n", + " It divides the dimension of d_ff in each weight matrix to # of blocks equal to\n", + " n_experts and activates only one non-zero block from the weights matrix.\n", + " This is trained with straight-through Gumbel softmax trick.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " d_ff,\n", + " n_experts=64,\n", + " temperature=0.7,\n", + " mode='train',\n", + " kernel_initializer=tl.GlorotUniformInitializer(),\n", + " bias_initializer=tl.RandomNormalInitializer(1e-6)):\n", + " \"\"\"Returns a block sparse feed-forward block.\"\"\"\n", + " super().__init__(name=f'BlockSparseFF_{d_ff}')\n", + " self._mode = mode\n", + " self._d_ff = d_ff\n", + " self._n_experts = n_experts\n", + " self._temperature = temperature if mode == 'train' else 0.0\n", + " self._n_elements_in_block = d_ff // n_experts\n", + " self._kernel_initializer = kernel_initializer\n", + " self._bias_initializer = bias_initializer\n", + " assert self._d_ff % self._n_experts == 0\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input.\n", + " \"\"\"\n", + " m1, w1, w2, b2 = self.weights\n", + " x_shape = x.shape\n", + " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", + "\n", + " # Q: check if we need bias and/or put relu after the m1 dot?\n", + " mask_logits = jnp.dot(x, m1)\n", + " # Softmax.\n", + " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", + " log_mask = mask_logits - mask_logsumexp\n", + " mask = jnp.exp(log_mask)\n", + " # Gumbel-softmax with straight-through discretization.\n", + " rng1, rng2 = fastmath.random.split(self.rng, 2)\n", + " u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)\n", + " g = -jnp.log(-jnp.log(u))\n", + " selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", + " if self._mode == 'train':\n", + " # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797\n", + " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", + " quant_mask = fastmath.stop_gradient(quant_mask)\n", + " quant_mask += mask - fastmath.stop_gradient(mask) # straight-through\n", + " # We will sometimes (50% of the batches) use the soft-mask instead of\n", + " # the quantized mask to improve training stability (see the paper above).\n", + " # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?\n", + " select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)\n", + " quant_mask = jnp.where(select \u003e 0.0, quant_mask, mask)\n", + " else:\n", + " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", + " quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1])\n", + " batch_size = quant_mask.shape[0]\n", + "\n", + " if self._mode == 'predict' and batch_size == 1:\n", + " # This implementation mimicks inference for batch_size 1.\n", + " start_idx = selected_experts[0] * self._n_elements_in_block\n", + " # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]\n", + " w = fastmath.dynamic_slice(w1, [0, start_idx],\n", + " [w1.shape[0], self._n_elements_in_block])\n", + " mid = jnp.dot(x, w)\n", + " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", + " # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]\n", + " v = fastmath.dynamic_slice(w2, [start_idx, 0],\n", + " [self._n_elements_in_block, w2.shape[-1]])\n", + " v = jnp.reshape(v, [self._n_elements_in_block, -1])\n", + " res = jnp.dot(relu, v) + b2\n", + " else:\n", + " expanded_mask = jnp.broadcast_to(\n", + " quant_mask,\n", + " (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block))\n", + " expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))\n", + " mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff]\n", + " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", + " res = jnp.dot(relu, w2) + b2\n", + "\n", + " return jnp.reshape(res, x_shape) # un-flatten if needed\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " d_model = input_signature.shape[-1]\n", + " shape_m1 = (d_model, self._n_experts)\n", + " shape_w1 = (d_model, self._d_ff)\n", + " shape_w2 = (self._d_ff, d_model)\n", + " shape_b2 = (d_model,)\n", + "\n", + " rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4)\n", + " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", + " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", + " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", + " b2 = self._bias_initializer(shape_b2, rng_b2)\n", + "\n", + " self.weights = (m1, w1, w2, b2)\n", + "\n", + "\n", + "class SwitchSparseFF(tl.Layer):\n", + " \"\"\"Feed-forward block with switch-style block sparsity.\n", + "\n", + " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", + " that takes an input, makes it of size d_ff (usually larger than it was) and\n", + " then brings it back to the original size after Relu. It is commonly used in\n", + " Transformer models where it often accounts for most of the trainable weights.\n", + "\n", + " This block sparse layer mimics mixture of experts architecture.\n", + " It divides the dimension of d_ff in each weight matrix to # of blocks equal to\n", + " n_experts and activates only one non-zero block from the weights matrix.\n", + " This is trained with methods following the Switch Transformer.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " d_ff,\n", + " n_experts=64,\n", + " temperature=0.1,\n", + " mode='train',\n", + " kernel_initializer=tl.GlorotUniformInitializer(),\n", + " bias_initializer=tl.RandomNormalInitializer(1e-6)):\n", + " \"\"\"Returns a switch-style training block sparse feed-forward block.\"\"\"\n", + " super().__init__(name=f'SwitchSparseFF_{d_ff}')\n", + " self._mode = mode\n", + " self._d_ff = d_ff\n", + " self._n_experts = n_experts\n", + " self._temperature = temperature if mode == 'train' else 0.0\n", + " self._n_elements_in_block = d_ff // n_experts\n", + " self._kernel_initializer = kernel_initializer\n", + " self._bias_initializer = bias_initializer\n", + " assert self._d_ff % self._n_experts == 0\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input.\n", + " \"\"\"\n", + " m1, w1, w2, b2 = self.weights\n", + " x_shape = x.shape\n", + " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", + "\n", + " # Q: check if we need bias and/or put relu after the m1 dot?\n", + " mask_logits = jnp.dot(x, m1)\n", + " # Softmax.\n", + " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", + " log_mask = mask_logits - mask_logsumexp\n", + " mask = jnp.exp(log_mask)\n", + " # Gumbel noise to allow sampling from the softmax.\n", + " rng1, _ = fastmath.random.split(self.rng, 2)\n", + " u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)\n", + " g = -jnp.log(-jnp.log(u))\n", + " selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", + " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", + " quant_mask = fastmath.stop_gradient(quant_mask)\n", + " quant_mask *= mask # go to just the selected expert\n", + " quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1])\n", + " batch_size = quant_mask.shape[0]\n", + "\n", + " if self._mode == 'predict' and batch_size == 1:\n", + " mask_flat = jnp.reshape(mask, [-1, self._n_experts])\n", + " selected_flat = jnp.reshape(selected_experts, [-1])\n", + " selected_mask_flat = mask_flat[np.arange(selected_flat.size),\n", + " selected_flat]\n", + " # This implementation mimicks inference for batch_size 1.\n", + " start_idx = selected_experts[0] * self._n_elements_in_block\n", + " # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]\n", + " w = fastmath.dynamic_slice(w1, [0, start_idx],\n", + " [w1.shape[0], self._n_elements_in_block])\n", + " mid = jnp.dot(x, w)\n", + " mid *= jnp.reshape(selected_mask_flat, mid.shape[:-1])[..., None]\n", + " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", + " # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]\n", + " v = fastmath.dynamic_slice(w2, [start_idx, 0],\n", + " [self._n_elements_in_block, w2.shape[-1]])\n", + " v = jnp.reshape(v, [self._n_elements_in_block, -1])\n", + " res = jnp.dot(relu, v) + b2\n", + " else:\n", + " expanded_mask = jnp.broadcast_to(\n", + " quant_mask,\n", + " (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block))\n", + " expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))\n", + " mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff]\n", + " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", + " res = jnp.dot(relu, w2) + b2\n", + "\n", + " return jnp.reshape(res, x_shape) # un-flatten if needed\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " d_model = input_signature.shape[-1]\n", + " shape_m1 = (d_model, self._n_experts)\n", + " shape_w1 = (d_model, self._d_ff)\n", + " shape_w2 = (self._d_ff, d_model)\n", + " shape_b2 = (d_model,)\n", + "\n", + " rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4)\n", + " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", + " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", + " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", + " b2 = self._bias_initializer(shape_b2, rng_b2)\n", + "\n", + " self.weights = (m1, w1, w2, b2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4-3_EPyP4c7K" + }, + "outputs": [], + "source": [ + "# SRU needs to be changed in order for concatenated encoder-decoder attention\n", + "# to work in predict mode.\n", + "\n", + "def MakeZeroState(depth_multiplier=1):\n", + " \"\"\"Makes zeros of shape like x but removing the length (axis 1).\"\"\"\n", + " def f(x): # pylint: disable=invalid-name\n", + " if len(x.shape) != 3:\n", + " raise ValueError(f'Layer input should be a rank 3 tensor representing'\n", + " f' (batch_size, sequence_length, feature_depth); '\n", + " f'instead got shape {x.shape}.')\n", + " return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]),\n", + " dtype=jnp.float32)\n", + " return tl.Fn('MakeZeroState', f)\n", + "\n", + "def InnerSRUCell():\n", + " \"\"\"The inner (non-parallel) computation of an SRU.\"\"\"\n", + " def f(cur_x_times_one_minus_f, cur_f, cur_state): # pylint: disable=invalid-name\n", + " res = cur_f * cur_state + cur_x_times_one_minus_f\n", + " return res, res\n", + " return tl.Fn('InnerSRUCell', f, n_out=2)\n", + "\n", + "\n", + "def ScanSRUCell(mode, monkey_patched_mask=None):\n", + " \"\"\"The inner (non-parallel) computation of an SRU.\"\"\"\n", + " if monkey_patched_mask is None:\n", + " return tl.Scan(InnerSRUCell(), axis=1, mode=mode)\n", + "\n", + " # This is necessary for Terraformer model. See comments there.\n", + " # The mask will only be used in Terraformer in predict mode.\n", + " assert mode == 'predict'\n", + "\n", + " def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name\n", + " initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32)\n", + " if initial.shape[1] \u003e 1:\n", + " updated_mask = fastmath.dynamic_update_slice_in_dim(\n", + " initial != 0, mask != 0, 1, axis=1)\n", + " else:\n", + " updated_mask = initial\n", + " return updated_mask, x_times_one_minus_f\n", + "\n", + " def masked_inner_sru_cell(cur_mask, cur_x_times_one_minus_f, cur_f, # pylint: disable=invalid-name\n", + " cur_state):\n", + " res = ((cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask\n", + " + (1 - cur_mask) * cur_state)\n", + " return res, res\n", + "\n", + " return tl.Serial(\n", + " monkey_patched_mask.get_layer(),\n", + " tl.Fn('update_mask', update_mask, n_out=2),\n", + " tl.Scan(tl.Fn('MaskedInnerSRUCell', masked_inner_sru_cell, n_out=2),\n", + " axis=1, mode=mode),\n", + " )\n", + "\n", + "\n", + "def SRU(n_units, activation=None, mode='train'):\n", + " r\"\"\"SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.\n", + "\n", + " As defined in the paper:\n", + "\n", + " .. math::\n", + " y_t \u0026= W x_t + B \\quad \\hbox{(include $B$ optionally)} \\\\\n", + " f_t \u0026= \\sigma(Wf x_t + bf) \\\\\n", + " r_t \u0026= \\sigma(Wr x_t + br) \\\\\n", + " c_t \u0026= f_t \\times c_{t-1} + (1 - f_t) \\times y_t \\\\\n", + " h_t \u0026= r_t \\times \\hbox{activation}(c_t) + (1 - r_t) \\times x_t\n", + "\n", + " We assume the input is of shape [batch, length, depth] and recurrence\n", + " happens on the length dimension. This returns a single layer. It's best\n", + " to use at least 2, they say in the paper, except inside a Transformer.\n", + "\n", + " Args:\n", + " n_units: output depth of the SRU layer.\n", + " activation: Optional activation function.\n", + " mode: if 'predict' then we save the previous state for one-by-one inference\n", + "\n", + " Returns:\n", + " The SRU layer.\n", + " \"\"\"\n", + " sigmoid_activation = tl.Sigmoid()\n", + " return tl.Serial( # x\n", + " tl.Branch(tl.Dense(3 * n_units), []), # r_f_y, x\n", + " tl.Split(n_items=3), # r, f, y, x\n", + " tl.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x\n", + " tl.Fn('',\n", + " lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x\n", + " n_out=3),\n", + " tl.Parallel([], [], tl.Branch(MakeZeroState(), [])),\n", + " ScanSRUCell(mode=mode),\n", + " tl.Select([0], n_in=2), # act(c), r, x\n", + " activation if activation is not None else [],\n", + " tl.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)),\n", + " # Set the name to SRU and don't print sublayers.\n", + " name=f'SRU_{n_units}', sublayers_to_print=[]\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cyf_7nTr55gU" + }, + "source": [ + "## Terraformer\n", + "\n", + "The cells below contain the implementation of the Terraformer architecture:\n", + "* feed-forward and positional encoding blocks\n", + "* encoder and decoder blocks\n", + "* concatenation and stripping to combine the encoder and decoder\n", + "* the final Terraformer model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3eEe0xnOvG_X" + }, + "outputs": [], + "source": [ + "def _FeedForward(d_model, d_ff, dropout, activation, act_dropout,\n", + " use_bfloat16, mode):\n", + " \"\"\"Feed-forward block with layer normalization at start.\"\"\"\n", + " if act_dropout is None:\n", + " act_dropout = dropout\n", + " return [\n", + " tl.Dense(d_ff, use_bfloat16=use_bfloat16),\n", + " tl.Dropout(rate=act_dropout, shared_axes=[-2], mode=mode),\n", + " activation(),\n", + " tl.Dense(d_model, use_bfloat16=use_bfloat16),\n", + " ]\n", + "\n", + "\n", + "def FeedForwardWithOptions(d_model,\n", + " d_ff,\n", + " dropout,\n", + " dropout_shared_axes,\n", + " ff_activation,\n", + " ff_dropout,\n", + " ff_chunk_size,\n", + " ff_use_sru,\n", + " ff_sparsity,\n", + " center_layernorm,\n", + " mode,\n", + " use_bfloat16=False,\n", + " ff_sparsity_type='1inN'):\n", + " \"\"\"Feed-Forward block with all the options.\n", + "\n", + " Args:\n", + " d_model: Final dimension of tensors at most points in the model, including\n", + " the initial embedding output.\n", + " d_ff: Size of special dense layer in the feed-forward part of each block.\n", + " dropout: Stochastic rate (probability) for dropping an activation value when\n", + " applying dropout within a block.\n", + " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", + " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", + " way to save memory and apply consistent masks to activation vectors at\n", + " different sequence positions.\n", + " ff_activation: Type of activation function at the end of each block; must be\n", + " an activation-type subclass of `Layer`.\n", + " ff_dropout: Stochastic rate (probability) for dropping an activation value\n", + " when applying dropout after the FF dense layer.\n", + " ff_chunk_size: int; if \u003e 0, chunk feed-forward into this-sized chunks\n", + " ff_use_sru: int or pair of ints; if \u003e 0, we use this many SRU layers\n", + " in addition to the feed-forward block (second int specifies sru size)\n", + " ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block\n", + " with this sparsity\n", + " center_layernorm: whether to use centering in LayerNorm (default) or if\n", + " to skip it, which is known as RMS normalization.\n", + " mode: If `'train'`, each block will include dropout; else, it will pass all\n", + " values through unaltered.\n", + " use_bfloat16: whether to use bfloat16 for weights (default: False).\n", + " ff_sparsity_type: string, if ff_sparsity \u003e0,\n", + " use SparseFF if ff_sparsity_type=`'1inN'` and\n", + " use BlockSparseFF if ff_sparsity_type=`'Block'`\n", + " use SwitchSparseFF if ff_sparsity_type=`'Switch'`\n", + "\n", + " Returns:\n", + " A list of layers which maps vectors to vectors.\n", + " \"\"\"\n", + " if ff_sparsity and ff_sparsity_type == '1inN':\n", + " temperature, quant_prob = 0.1, 0.3\n", + " if isinstance(ff_sparsity, str):\n", + " # This is hacky but used to pass ff_sparsity in yaml sweep files.\n", + " ff_sparsity = [(float(x) if '.' in x else int(x))\n", + " for x in ff_sparsity.split()]\n", + " if isinstance(ff_sparsity, (list, tuple)):\n", + " if len(ff_sparsity) == 2:\n", + " n_elements_in_block, d_lowrank = ff_sparsity\n", + " else:\n", + " n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity\n", + " else:\n", + " assert isinstance(ff_sparsity, int)\n", + " n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity\n", + " ff = SparseFF(\n", + " d_ff,\n", + " n_elements_in_block=n_elements_in_block,\n", + " d_lowrank=d_lowrank,\n", + " temperature=temperature,\n", + " quant_prob=quant_prob,\n", + " use_bfloat16=use_bfloat16,\n", + " mode=mode,\n", + " dropout_rate=dropout,\n", + " dropout_shared_axes=dropout_shared_axes,\n", + " ff_chunk_size=ff_chunk_size)\n", + " elif ff_sparsity and ff_sparsity_type == 'Block':\n", + " ff = BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)\n", + " elif ff_sparsity and ff_sparsity_type == 'Switch':\n", + " ff = SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)\n", + " else:\n", + " ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout,\n", + " use_bfloat16, mode)\n", + " res = [tl.LayerNorm(center=center_layernorm), ff]\n", + " if ff_sparsity_type != '1inN' or ff_sparsity == 0:\n", + " # SparseFF has Dropout and BatchLeadingAxes built-in.\n", + " res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes,\n", + " mode=mode))\n", + " if ff_chunk_size \u003e 0:\n", + " res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size))\n", + " if ff_use_sru:\n", + " if isinstance(ff_use_sru, (list, tuple)):\n", + " sru_n_layers, sru_n_units = ff_use_sru\n", + " else:\n", + " sru_n_layers, sru_n_units = ff_use_sru, 32\n", + " sru = [SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)]\n", + " block = [tl.LayerNorm(center=center_layernorm), tl.Dense(sru_n_units)\n", + " ] + sru + [tl.Dense(d_model)]\n", + " res = tl.Residual(block, shortcut=res)\n", + " return [res]\n", + "\n", + "\n", + "def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal,\n", + " masked, attention_dropout, output_dropout,\n", + " attention_chunk_size, mode):\n", + " \"\"\"Runs the supplied attention layer.\"\"\"\n", + " try:\n", + " attention = attention_type(\n", + " n_heads=n_heads,\n", + " d_qk=d_qk,\n", + " d_v=d_v,\n", + " causal=causal,\n", + " masked=masked,\n", + " output_dropout=output_dropout,\n", + " attention_dropout=attention_dropout,\n", + " mode=mode)\n", + " except TypeError: # No d_qk arguments in less advanced layers.\n", + " attention = attention_type(\n", + " d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode)\n", + " return tl.Chunk(attention, attention_chunk_size)\n", + "\n", + "\n", + "def PositionalEncoder(mode,\n", + " dropout=None,\n", + " max_len=None,\n", + " pos_type=None,\n", + " pos_axial_shape=None,\n", + " pos_d_axial_embs=None,\n", + " pos_start_from_zero_prob=1.0,\n", + " pos_max_offset_to_add=0,\n", + " use_bfloat16=False):\n", + " \"\"\"Returns the positional encoding layer depending on the arguments.\n", + "\n", + " Args:\n", + " mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder\n", + " block will include dropout; else, it will pass all values through\n", + " unaltered.\n", + " dropout: Stochastic rate (probability) for dropping an activation\n", + " value when applying dropout after the embedding block.\n", + " max_len: Maximum symbol length for positional encoding.\n", + " pos_type: string, the type of positional embeddings to use.\n", + " pos_axial_shape: tuple of ints: input shape to use for the axial position\n", + " encoding. If unset, axial position encoding is disabled.\n", + " pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.\n", + " Tuple length must match pos_axial_shape, and values must sum to d_model.\n", + " pos_start_from_zero_prob: how often to start from 0 during training,\n", + " (if 1.0, we always start from position 0, if less, we randomize).\n", + " pos_max_offset_to_add: maximum offset to add to positions during training\n", + " when randomizing; this offset plus input length must still be less than\n", + " max_len for all training examples.\n", + " use_bfloat16: If `True`, use bfloat16 weights instead of the default\n", + " float32; this can save memory but may (rarely) lead to numerical issues.\n", + "\n", + " Returns:\n", + " A layer that will do the positional encoding.\n", + " \"\"\"\n", + " if not pos_type:\n", + " positional_encoding = tl.PositionalEncoding(\n", + " max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16,\n", + " start_from_zero_prob=pos_start_from_zero_prob,\n", + " max_offset_to_add=pos_max_offset_to_add, mode=mode)\n", + " elif pos_type == 'sin-cos':\n", + " positional_encoding = tl.SinCosPositionalEncoding(mode=mode)\n", + " elif pos_type == 'fixed-base':\n", + " positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)\n", + " elif pos_type == 'infinite':\n", + " positional_encoding = tl.InfinitePositionalEncoding(affine=False)\n", + " elif pos_type == 'infinite-affine':\n", + " positional_encoding = tl.InfinitePositionalEncoding()\n", + " elif pos_type == 'time-bin':\n", + " positional_encoding = tl.TimeBinPositionalEncoding()\n", + " else:\n", + " assert pos_d_axial_embs is not None\n", + " positional_encoding = tl.AxialPositionalEncoding(\n", + " shape=pos_axial_shape, d_embs=pos_d_axial_embs,\n", + " dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)),\n", + " dropout=dropout, mode=mode)\n", + "\n", + " return positional_encoding\n", + "\n", + "\n", + "def EmbeddingAndPositionalEncodings(input_vocab_size,\n", + " d_model,\n", + " mode,\n", + " embedding_dropout,\n", + " dropout_shared_axes,\n", + " max_len,\n", + " output_vocab_size=None,\n", + " pos_type=None,\n", + " pos_axial_shape=None,\n", + " pos_d_axial_embs=None,\n", + " pos_start_from_zero_prob=1.0,\n", + " pos_max_offset_to_add=0,\n", + " use_bfloat16=False):\n", + " \"\"\"Returns the embedder and positional encoder.\n", + "\n", + " Args:\n", + " input_vocab_size: Input vocabulary size -- each element of the input tensor\n", + " should be an integer in `range(vocab_size)`. These integers typically\n", + " represent token IDs from a vocabulary-based tokenizer.\n", + " d_model: Final dimension of tensors at most points in the model, including\n", + " the initial embedding output.\n", + " mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder\n", + " block will include dropout; else, it will pass all values through\n", + " unaltered.\n", + " embedding_dropout: Stochastic rate (probability) for dropping an activation\n", + " value when applying dropout after the embedding block.\n", + " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", + " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", + " way to save memory and apply consistent masks to activation vectors at\n", + " different sequence positions.\n", + " max_len: Maximum symbol length for positional encoding.\n", + " output_vocab_size: If specified, gives the vocabulary size for the targets;\n", + " if None, then input and target integers (token IDs) are assumed to come\n", + " from the same vocabulary.\n", + " pos_type: string, the type of positional embeddings to use.\n", + " pos_axial_shape: tuple of ints: input shape to use for the axial position\n", + " encoding. If unset, axial position encoding is disabled.\n", + " pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.\n", + " Tuple length must match pos_axial_shape, and values must sum to d_model.\n", + " pos_start_from_zero_prob: how often to start from 0 during training,\n", + " (if 1.0, we always start from position 0, if less, we randomize).\n", + " pos_max_offset_to_add: maximum offset to add to positions during training\n", + " when randomizing; this offset plus input length must still be less than\n", + " max_len for all training examples.\n", + " use_bfloat16: If `True`, use bfloat16 weights instead of the default\n", + " float32; this can save memory but may (rarely) lead to numerical issues.\n", + "\n", + " Returns:\n", + " A tuple of (input encoder, output encoder, output vocab size used).\n", + " \"\"\"\n", + " # tokens --\u003e vectors\n", + " def Embedder(vocab_size, embedding_mode):\n", + " if vocab_size is not None:\n", + " embedding = tl.Embedding(vocab_size, d_model, use_bfloat16=use_bfloat16)\n", + " else:\n", + " embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16)\n", + " return [\n", + " embedding,\n", + " tl.Dropout(rate=embedding_dropout,\n", + " shared_axes=dropout_shared_axes,\n", + " mode=embedding_mode),\n", + " ]\n", + "\n", + " # NOTE: Positional encodings are not shared between encoder and decoder.\n", + "\n", + " # Since encoder doesn't run stepwise, we do not use predict mode there.\n", + " encoder_mode = 'eval' if mode == 'predict' else mode\n", + " in_embedder = Embedder(input_vocab_size, encoder_mode)\n", + " in_encoder = in_embedder + [\n", + " PositionalEncoder(encoder_mode,\n", + " dropout=embedding_dropout,\n", + " max_len=max_len,\n", + " pos_type=pos_type,\n", + " pos_axial_shape=pos_axial_shape,\n", + " pos_d_axial_embs=pos_d_axial_embs,\n", + " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", + " pos_max_offset_to_add=pos_max_offset_to_add,\n", + " use_bfloat16=use_bfloat16)\n", + " ]\n", + "\n", + " # If output_vocab_size is None, we reuse the same embedding matrix, otherwise\n", + " # we initialize one.\n", + " assert input_vocab_size or output_vocab_size\n", + " if output_vocab_size is None:\n", + " out_embedder = in_embedder\n", + " else:\n", + " out_embedder = Embedder(output_vocab_size, mode)\n", + "\n", + " out_encoder = out_embedder + [\n", + " PositionalEncoder(mode,\n", + " dropout=embedding_dropout,\n", + " max_len=max_len,\n", + " pos_type=pos_type,\n", + " pos_axial_shape=pos_axial_shape,\n", + " pos_d_axial_embs=pos_d_axial_embs,\n", + " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", + " pos_max_offset_to_add=pos_max_offset_to_add,\n", + " use_bfloat16=use_bfloat16)\n", + " ]\n", + "\n", + " # Set this to the value actually used.\n", + " if output_vocab_size is None:\n", + " output_vocab_size = input_vocab_size\n", + "\n", + " if input_vocab_size is None:\n", + " in_encoder = tl.AssertFunction('...a-\u003e...b', in_encoder)\n", + " else:\n", + " in_encoder = tl.AssertFunction('...-\u003e...d', in_encoder)\n", + " out_encoder = tl.AssertFunction('...-\u003e...d', out_encoder)\n", + "\n", + " return in_encoder, out_encoder, output_vocab_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2D3dQi9Q2bO7" + }, + "outputs": [], + "source": [ + "def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,\n", + " n_heads, attention_type, dropout, ff_activation,\n", + " ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity,\n", + " attention_chunk_size, n_attention_layers=1,\n", + " n_feedforward_layers=1, center_layernorm=True,\n", + " use_bfloat16=False, mode='train'):\n", + " \"\"\"Reversible transformer decoder layer.\n", + "\n", + " Args:\n", + " d_model: int: depth of embedding\n", + " d_ff: int: depth of feed-forward layer\n", + " d_attention_key: int: depth of key vector for each attention head\n", + " d_attention_value: int: depth of value vector for each attention head\n", + " n_heads: int: number of attention heads\n", + " attention_type: subclass of tl.BaseCausalAttention: attention class to use\n", + " dropout: float: dropout rate (how much to drop out)\n", + " ff_activation: the non-linearity in feed-forward layer\n", + " ff_dropout: the dropout rate in feed-forward layer\n", + " ff_use_sru: int; if \u003e 0, we use this many SRU layers instead of feed-forward\n", + " ff_chunk_size: int; if \u003e 0, chunk feed-forward into this-sized chunks\n", + " ff_sparsity: int, if \u003e 0 use sparse feed-forward block with this sparsity\n", + " attention_chunk_size: int, if \u003e 0 run attention chunked at this size\n", + " n_attention_layers: how many residual causal attention layers should we\n", + " have before the feed-forward block (default: 1, the standard block)\n", + " n_feedforward_layers: how many FFNN layers should we have (default 1).\n", + " center_layernorm: whether to use centering in LayerNorm (default) or if\n", + " to skip it, which is known as RMS normalization.\n", + " use_bfloat16: whether to use bfloat16 for weights (default: False).\n", + " mode: str: 'train' or 'eval'\n", + "\n", + "\n", + " Returns:\n", + " the layer.\n", + " \"\"\"\n", + " # pylint: disable=g-complex-comprehension\n", + " def _Attn():\n", + " return ApplyAttentionLayer(\n", + " attention_type, d_model, n_heads, d_attention_key,\n", + " d_attention_value, True, False, dropout, dropout,\n", + " attention_chunk_size, mode)\n", + "\n", + " def _FF():\n", + " return FeedForwardWithOptions(\n", + " d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,\n", + " ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,\n", + " mode, use_bfloat16)\n", + "\n", + " def _attention_half_residual():\n", + " return [\n", + " tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),\n", + " attention_layer=_Attn(),\n", + " name='ReversibleHalfResidualDecoderAttn'),\n", + " tl.ReversibleSwap()\n", + " ]\n", + "\n", + " def _feed_forward():\n", + " return [\n", + " tl.ReversibleHalfResidual(_FF(),\n", + " name='ReversibleHalfResidualDecoderFF'),\n", + " tl.ReversibleSwap()\n", + " ]\n", + "\n", + " return ([_attention_half_residual() for _ in range(n_attention_layers)]\n", + " + [_feed_forward() for _ in range(n_feedforward_layers)])\n", + "\n", + "\n", + "def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation,\n", + " ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0,\n", + " attention_chunk_size=0, center_layernorm=True,\n", + " use_bfloat16=False, use_two_swaps_per_block=True,\n", + " mode='train'):\n", + " \"\"\"Returns a list of layers that implements a Terraformer encoder block.\n", + "\n", + " The input to the layer is a pair, (activations, mask), where the mask was\n", + " created from the original source tokens to prevent attending to the padding\n", + " part of the input.\n", + "\n", + " Args:\n", + " d_model: int: depth of embedding\n", + " d_ff: int: depth of feed-forward layer\n", + " n_heads: int: number of attention heads\n", + " attention_type: subclass of tl.BaseCausalAttention: attention class to use\n", + " dropout: float: dropout rate (how much to drop out)\n", + " ff_activation: the non-linearity in feed-forward layer\n", + " ff_dropout: the dropout rate in feed-forward layer\n", + " ff_use_sru: int; if \u003e 0, we use this many SRU layers instead of feed-forward\n", + " ff_chunk_size: int; if \u003e 0, chunk feed-forward into this-sized chunks\n", + " ff_sparsity: int, if \u003e 0 use sparse feed-forward block with this sparsity\n", + " attention_chunk_size: int, if \u003e 0 run attention chunked at this size\n", + " center_layernorm: whether to use centering in LayerNorm (default) or if\n", + " to skip it, which is known as RMS normalization.\n", + " use_bfloat16: whether to use bfloat16 for weights (default: False)\n", + " use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder\n", + " block, otherwise use only one swap.\n", + " mode: str: 'train' or 'eval'\n", + "\n", + " Returns:\n", + " A list of layers that maps (activations, mask) to (activations, mask).\n", + " \"\"\"\n", + " if mode == 'predict':\n", + " # Mode 'predict' means that the decoder should be run one token at a time.\n", + " # The encoder only ever runs over full sequences, which is why it's switched\n", + " # to 'eval' mode instead.\n", + " mode = 'eval'\n", + "\n", + " def _Attn():\n", + " return ApplyAttentionLayer(\n", + " attention_type=attention_type, d_model=d_model, n_heads=n_heads,\n", + " d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False,\n", + " attention_dropout=dropout, output_dropout=dropout,\n", + " attention_chunk_size=attention_chunk_size, mode=mode)\n", + "\n", + " def _FF():\n", + " return FeedForwardWithOptions(\n", + " d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,\n", + " ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,\n", + " mode, use_bfloat16)\n", + "\n", + " attention = _Attn()\n", + " if attention.n_out == 2:\n", + " attention = tl.Serial(\n", + " tl.Parallel([], _InsertAxes12()),\n", + " attention,\n", + " tl.Select([0], n_in=2)\n", + " )\n", + "\n", + " def _attention_half_residual():\n", + " return [\n", + " tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),\n", + " attention_layer=attention,\n", + " name='ReversibleHalfResidualEncoderAttn'),\n", + " tl.ReversibleSwap()\n", + " ]\n", + "\n", + " def _feed_forward():\n", + " layers = [\n", + " tl.ReversibleHalfResidual(_FF(),\n", + " name='ReversibleHalfResidualEncoderFF')\n", + " ]\n", + " if use_two_swaps_per_block:\n", + " layers.append(tl.ReversibleSwap())\n", + " return layers\n", + "\n", + " return _attention_half_residual() + _feed_forward()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ITiWrbEnAZyb" + }, + "outputs": [], + "source": [ + "# Arg shapes: (B, L1, H), (B, L2, H), (B, L1).\n", + "def _ConcatWithPadding(vec_e, vec_d, mask_e):\n", + " \"\"\"Concatenate with padding: see the ConcatWithPadding layer for details.\"\"\"\n", + " # pylint: disable=invalid-name\n", + " B, L1, H = vec_e.shape\n", + " L2 = vec_d.shape[1]\n", + " # pylint: enable=invalid-name\n", + "\n", + " if vec_d.shape != (B, L2, H):\n", + " raise ValueError(f'Shape of decoder vector, {vec_d.shape}, does not'\n", + " f' equal {(B, L2, H)}.')\n", + " if mask_e.shape != (B, L1):\n", + " raise ValueError(f'Shape of encoder mask, {mask_e.shape}, does not'\n", + " f' equal {(B, L1)}.')\n", + "\n", + " def _UpdateRow(x):\n", + " # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)\n", + " row_e, row_d, row_mask_e = x\n", + " # final_row - (L1+L2, H)\n", + " final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)\n", + " # Find the last real token/vector of the encoder.\n", + " e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)\n", + " # Starting after that index, update with the decoder row.\n", + " zero = jnp.array(0, dtype=e_idx.dtype) # avoid int32/int64 mismatch\n", + " return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))\n", + "\n", + " return fastmath.map(_UpdateRow, [vec_e, vec_d, mask_e])\n", + "\n", + "\n", + "def _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d):\n", + " \"\"\"Strip concatenate with padding: see the layer below for details.\"\"\"\n", + " # pylint: disable=invalid-name\n", + " B, L, H = vec_ed.shape\n", + " L1 = tok_e.shape[1]\n", + " L2 = tok_d.shape[1]\n", + " # pylint: enable=invalid-name\n", + " if L != L1 + L2:\n", + " raise ValueError(f'Length from encoder-decoder vectors ({L}) does not'\n", + " f' equal sum of lengths from encoder ({L1}) and decoder'\n", + " f' ({L2}).')\n", + " if tok_e.shape != (B, L1):\n", + " raise ValueError(f'Shape of encoder tokens, {tok_e.shape}, does not'\n", + " f' equal {(B, L1)}.')\n", + " if tok_d.shape != (B, L2):\n", + " raise ValueError(f'Shape of decoder tokens, {tok_d.shape}, does not'\n", + " f' equal {(B, L2)}.')\n", + "\n", + " def _UpdateRow(x):\n", + " # (L, H), (L1, H) \u0026 (L2, H)\n", + " row_ed, row_e, _ = x\n", + " mask_e = row_e != 0\n", + " len_e = jnp.sum(mask_e, dtype=jnp.int32)\n", + " # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`\n", + " # and pick up (L2, H) tensor slice from there.\n", + " zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch\n", + " return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H))\n", + "\n", + " return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])\n", + "\n", + "\n", + "class StripFromConcatenateWithPadding(tl.Layer):\n", + " \"\"\"Strips out the leading encoder tokens from the concatenated array.\"\"\"\n", + "\n", + " def __init__(self, mode='train'):\n", + " super().__init__(n_in=3, n_out=1)\n", + " self._mode = mode\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Sets layer-specific internal state.\"\"\"\n", + " del input_signature\n", + " self.state = jnp.array(0, dtype=jnp.int32)\n", + "\n", + " def forward(self, inputs):\n", + " vec_ed, tok_e, tok_d = inputs\n", + "\n", + " # In training/eval mode or at the first step predict mode i.e. when\n", + " # state.shape is (), i.e. at first step, we do the actual compuration\n", + " if self._mode != 'predict' or not self.state.shape:\n", + " # Now state.shape will not evaluate to false.\n", + " self.state = self.state.reshape((1,))\n", + " return _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d)\n", + "\n", + " # In predict mode and on subsequent steps (i.e. after the first step) vec_ed\n", + " # is actually vec_d, since no concatenation happened at all.\n", + " return vec_ed\n", + "\n", + "\n", + "class ConcatWithPadding(tl.ReversibleLayer):\n", + " \"\"\"Concatenates two length padded (B, L, H) arrays (of different lenghts).\"\"\"\n", + "\n", + " def __init__(self, mode='train'):\n", + " super().__init__(n_in=5, n_out=3)\n", + " self._mode = mode\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Sets layer-specific internal state.\"\"\"\n", + " del input_signature\n", + " self.state = jnp.array(0, dtype=jnp.int32)\n", + "\n", + " def forward(self, inputs):\n", + " vec_e, vec_d, mask_e, tok_e, tok_d = inputs\n", + "\n", + " # In training/eval mode or at the first step predict mode i.e. when\n", + " # state.shape is (), i.e. at first step, we return the concatenated output.\n", + " if self._mode != 'predict' or not self.state.shape:\n", + " # Now state.shape will not evaluate to false.\n", + " self.state = self.state.reshape((1,))\n", + " return _ConcatWithPadding(vec_e, vec_d, mask_e), tok_e, tok_d\n", + "\n", + " # In predict mode and on subsequent steps (i.e. after the first step) we\n", + " # don't concatenate anymore, but just return the decoder vector.\n", + " return vec_d, tok_e, tok_d\n", + "\n", + " def reverse(self, output, weights=(), state=(), new_state=(), rng=None):\n", + " del state, new_state, rng, weights\n", + " assert self._mode != 'predict', 'cannot reverse in predict mode'\n", + " vecs_ed, toks_e, toks_d = output\n", + " vecs_d = _StripFromConcatenateWithPadding(vecs_ed, toks_e, toks_d)\n", + " mask_e = (toks_e != 0)\n", + " mask_e_float = mask_e.astype(jnp.float32)\n", + " vecs_e = vecs_ed[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", + " return vecs_e, vecs_d, mask_e, toks_e, toks_d\n", + "\n", + "\n", + "class ConcatWithPadding2(tl.ReversibleLayer):\n", + " \"\"\"Concatenate with padding operating on pairs to combine with rev-nets.\"\"\"\n", + "\n", + " def __init__(self, mode='train'):\n", + " super().__init__(n_in=6, n_out=4)\n", + " self._mode = mode\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Sets layer-specific internal state.\"\"\"\n", + " del input_signature\n", + " self.state = jnp.array(0, dtype=jnp.int32)\n", + "\n", + " def forward(self, inputs):\n", + " vecs_e1, vecs_e2, vecs_d, mask_e, toks_e, toks_d = inputs\n", + "\n", + " # In training/eval mode or at the first step predict mode i.e. when\n", + " # state.shape is (), i.e. at first step, we return the concatenated output.\n", + " if self._mode != 'predict' or not self.state.shape:\n", + " # Now state.shape will not evaluate to false.\n", + " self.state = self.state.reshape((1,))\n", + " # Calculate mask and concat_with_padding on the pairs.\n", + " vecs_ed1 = _ConcatWithPadding(vecs_e1, vecs_d, mask_e)\n", + " vecs_ed2 = _ConcatWithPadding(vecs_e2, vecs_d, mask_e)\n", + " return vecs_ed1, vecs_ed2, toks_e, toks_d\n", + "\n", + " # In predict mode and on subsequent steps (i.e. after the first step) we\n", + " # don't concatenate anymore, but just return the decoder vector.\n", + " return vecs_d, vecs_d, toks_e, toks_d\n", + "\n", + " def reverse(self, output, weights=(), state=(), new_state=(), rng=None):\n", + " del state, new_state, rng, weights\n", + " assert self._mode != 'predict', 'cannot reverse in predict mode'\n", + " vecs_ed1, vecs_ed2, toks_e, toks_d = output\n", + " vecs_d = _StripFromConcatenateWithPadding(vecs_ed1, toks_e, toks_d)\n", + " mask_e = (toks_e != 0)\n", + " mask_e_float = mask_e.astype(jnp.float32)\n", + " vecs_e1 = vecs_ed1[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", + " vecs_e2 = vecs_ed2[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", + " return vecs_e1, vecs_e2, vecs_d, mask_e, toks_e, toks_d" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4FPVnsq8Ersd" + }, + "outputs": [], + "source": [ + "def Terraformer(input_vocab_size,\n", + " output_vocab_size=None,\n", + " d_model=512,\n", + " d_ff=2048,\n", + " d_attention_key=None,\n", + " d_attention_value=None,\n", + " n_encoder_layers=6,\n", + " n_decoder_layers=6,\n", + " n_heads=8,\n", + " dropout=0.1,\n", + " max_len=2048,\n", + " encoder_attention_type=tl.SelfAttention,\n", + " encoder_decoder_attention_type=tl.SelfAttention,\n", + " pos_type='fixed-base',\n", + " pos_axial_shape=(),\n", + " pos_d_axial_embs=None,\n", + " pos_start_from_zero_prob=1.0,\n", + " pos_max_offset_to_add=0,\n", + " ff_activation=tl.Relu,\n", + " ff_use_sru=(1, 32),\n", + " ff_chunk_size=0,\n", + " ff_dropout=None,\n", + " ff_sparsity=32,\n", + " loss_sparsity_type='mult',\n", + " loss_sparsity=0,\n", + " loss_d_lowrank=0,\n", + " loss_sparsity_prob=None,\n", + " attention_chunk_size=0,\n", + " n_layers_forget=0,\n", + " forget_dense=True,\n", + " n_decoder_attention_layers=2,\n", + " use_bfloat16=False,\n", + " reversible_encoder=False,\n", + " use_two_swaps_per_encoder_block=True,\n", + " center_layernorm=True,\n", + " half_before_layer=None,\n", + " double_after_layer=None,\n", + " mode='train'):\n", + " \"\"\"Returns a highly configurable Terraformer encoder-decoder model.\n", + "\n", + " This model maps paired text sequences (source and target) to float-valued\n", + " losses. If ``input_vocab_size`` is not ``None``, the layer takes\n", + " two input sequences:\n", + "\n", + " - inputs (2):\n", + "\n", + " - source: 2-D int array representing a batch of text strings via token\n", + " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", + " where sequence_length \u003c= ``max_len``. Array elements are in\n", + " ``range(input_vocab_size)``, and 0 values mark padding positions.\n", + "\n", + " - target: 2-D int array representing a batch of text strings via token\n", + " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", + " where sequence_length \u003c= ``max_len``. Array elements are in\n", + " ``range(output_vocab_size)``, and 0 values mark padding positions.\n", + "\n", + " - output: 1-D float array of losses; shape is `(batch_size)`.\n", + "\n", + " If ``input_vocab_size`` is ``None``, the layer takes three input sequences:\n", + "\n", + " - inputs (3):\n", + "\n", + " - source: 3-D float array representing a batch of already-embedded text\n", + " strings; shape is `(batch_size, sequence_length, d_model)`, where\n", + " sequence_length \u003c= ``max_len``.\n", + "\n", + " - mask: 2-D int array representing active versus masked positions; 0\n", + " values mark masked (padding) positions.\n", + "\n", + " - target: 2-D int array representing a batch of text strings via token\n", + " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", + " where sequence_length \u003c= ``max_len``. Array elements are in\n", + " ``range(output_vocab_size)``, and 0 values mark padding positions.\n", + "\n", + " - output: 1-D float array of losses; shape is `(batch_size)`.\n", + "\n", + " Args:\n", + " input_vocab_size: Input vocabulary size -- each element of the input tensor\n", + " should be an integer in ``range(vocab_size)``. These integers typically\n", + " represent token IDs from a vocabulary-based tokenizer.\n", + " output_vocab_size: If specified, gives the vocabulary size for the targets;\n", + " if ``None``, then input and target integers (token IDs) are assumed to\n", + " come from the same vocabulary.\n", + " d_model: Last/innermost dimension of activation arrays at most points in\n", + " the model, including the initial embedding output.\n", + " d_ff: Last/innermost dimension of special (typically wider)\n", + " :py:class:`Dense` layer in the feedforward part of each encoder block.\n", + " d_attention_key: Depth of key vectors in each attention head.\n", + " d_attention_value: Depth of value vectors in each attention head.\n", + " n_encoder_layers: Number of encoder blocks.\n", + " n_decoder_layers: Number of decoder blocks.\n", + " n_heads: Number of attention heads.\n", + " dropout: Stochastic rate (probability) for dropping an activation value\n", + " when applying dropout within encoder/decoder blocks. The same rate is\n", + " also used for attention dropout in encoder/decoder blocks.\n", + " max_len: Maximum symbol length for positional encoding.\n", + " encoder_attention_type: Type of attention to use in the encoder; must be\n", + " an attention-type subclass of :py:class:`trax.layers.Layer`.\n", + " encoder_decoder_attention_type: Type of attention to use in the decoder;\n", + " must be an attention-type subclass of :py:class:`trax.layers.Layer`.\n", + " pos_type: String indicating the type of positional embeddings to use.\n", + " pos_axial_shape: Shape (tuple of ints) to use for the axial position\n", + " encoding. If unset, axial position encoding is disabled.\n", + " pos_d_axial_embs: Tuple of ints specifying the depth of position embedding\n", + " for each axis. Tuple length must match ``pos_axial_shape``, and values\n", + " must sum to ``d_model``.\n", + " pos_start_from_zero_prob: Stochastic rate (probability) for starting\n", + " positional encoding at position 0 during training. If 1.0, always start\n", + " from position 0; if \u003c 1.0, the non-zero starts will be uniformly\n", + " distributed up to ``pos_max_offset_to_add``.\n", + " pos_max_offset_to_add: Maximum offset to add to positions during training\n", + " when randomizing. This offset plus input length must be less than\n", + " ``max_len`` for all training examples.\n", + " ff_activation: Type of activation function at the end of each block; must\n", + " be an activation-type subclass of :py:class:`trax.layers.Layer`.\n", + " ff_use_sru: If \u003e 0, use this number of SRU layers in place of feedforward\n", + " layers.\n", + " ff_chunk_size: If \u003e 0, chunk each feedforward layer into chunks of this\n", + " size.\n", + " ff_dropout: Stochastic rate (probability) for dropping an activation value\n", + " at feedforward nonlinearities.\n", + " ff_sparsity: If \u003e 0, use sparse feedforward blocks with this level of\n", + " sparsity.\n", + " loss_sparsity_type: String indicating the type of sparsity to used in loss\n", + " layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``,\n", + " use no sparsity.\n", + " loss_sparsity: If \u003e 0, use this level of sparsity in the loss layer.\n", + " loss_d_lowrank: If \u003e 0, use a (low-rank) intermediate layer, with this\n", + " dimension, in the loss.\n", + " loss_sparsity_prob: Stochastic rate (probability) for using the sparse\n", + " version of the loss. If ``None``, use the sparse version exclusively.\n", + " attention_chunk_size: If \u003e 0, compute attention using chunks of this size.\n", + " n_layers_forget: How often to have a forgetting block between layers.\n", + " forget_dense: If True, use :py:class:`Dense` instances as forget layers;\n", + " else use no-ops.\n", + " n_decoder_attention_layers: Number of attention layers in a decoder block.\n", + " use_bfloat16: If True, use bfloat16 for weights; else use float32.\n", + " reversible_encoder: If True, make the encoder be reversible.\n", + " use_two_swaps_per_encoder_block: If True, ensure that there is a an even\n", + " number of swaps across the encoder.\n", + " center_layernorm: If True, use centering in :py:class:`LayerNorm` (the\n", + " default); else omit centering (which is known as RMS normalization).\n", + " half_before_layer: If not None, specifies an n'th layer such that all\n", + " layers before the n'th use half the normal values for ``d_model`` and\n", + " ``d_ff``.\n", + " double_after_layer: If not None, specifies an n'th layer such that all\n", + " layers after the n'th use double the normal values for ``d_model`` and\n", + " ``d_ff``.\n", + " mode: If ``'train'``, include dropout in each encoder/decoder block; else\n", + " dropout layers have no effect.\n", + "\n", + " Returns:\n", + " A Terraformer encoder-decoder as a layer that maps from target and source\n", + " text sequences to a scalar loss.\n", + " \"\"\"\n", + " if mode == 'predict':\n", + " portal_mask = _PortalInput()\n", + " else:\n", + " portal_mask = None\n", + "\n", + " # Set default dimensions for attention head key and value sizes.\n", + " if (d_model / 2) % n_heads != 0:\n", + " raise ValueError(f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')\n", + " if d_attention_key is None:\n", + " d_attention_key = d_model // n_heads\n", + " if d_attention_value is None:\n", + " d_attention_value = d_model // n_heads\n", + "\n", + " # Set values of d_model, d_ff and d_qkv for the first stage.\n", + " d_model1, d_ff1 = d_model, d_ff\n", + " d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value\n", + " if half_before_layer:\n", + " d_model1, d_ff1 = d_model / 2, d_ff / 2\n", + " d_attention_key1 = d_attention_key / 2\n", + " d_attention_value1 = d_attention_value / 2\n", + "\n", + " # Set values of d_model, d_ff and d_qkv for the final stage.\n", + " d_model2, d_ff2 = d_model, d_ff\n", + " d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value\n", + " if double_after_layer:\n", + " d_model2, d_ff2 = d_model * 2, d_ff * 2\n", + " d_attention_key2 = d_attention_key * 2\n", + " d_attention_value2 = d_attention_value * 2\n", + "\n", + " # Vector embeddings.\n", + " in_encoder, out_encoder, output_vocab_size = (\n", + " EmbeddingAndPositionalEncodings(\n", + " input_vocab_size,\n", + " d_model1,\n", + " mode,\n", + " dropout,\n", + " [-2], # dropout_shared_axes\n", + " max_len,\n", + " output_vocab_size=output_vocab_size,\n", + " pos_type=pos_type,\n", + " pos_axial_shape=pos_axial_shape,\n", + " pos_d_axial_embs=pos_d_axial_embs,\n", + " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", + " pos_max_offset_to_add=pos_max_offset_to_add,\n", + " use_bfloat16=use_bfloat16)\n", + " )\n", + "\n", + " def _EncoderBlock():\n", + " return EncoderBlock(\n", + " d_model1,\n", + " d_ff1,\n", + " n_heads,\n", + " encoder_attention_type,\n", + " dropout=dropout,\n", + " ff_activation=ff_activation,\n", + " ff_dropout=ff_dropout,\n", + " ff_use_sru=ff_use_sru,\n", + " ff_chunk_size=ff_chunk_size,\n", + " ff_sparsity=ff_sparsity,\n", + " attention_chunk_size=attention_chunk_size,\n", + " center_layernorm=center_layernorm,\n", + " use_bfloat16=use_bfloat16,\n", + " use_two_swaps_per_block=use_two_swaps_per_encoder_block,\n", + " mode=mode)\n", + "\n", + " def _Encoder(): # vec_e mask_e tok_e tok_d tok_d\n", + " layers = [\n", + " tl.ReversibleSelect([0, 0]),\n", + " _ReversibleSerialForget(\n", + " [_EncoderBlock() for _ in range(n_encoder_layers)],\n", + " d_model1,\n", + " n_layers_forget,\n", + " forget_dense)\n", + " ]\n", + " if not reversible_encoder:\n", + " layers += [\n", + " _XYAvg(),\n", + " tl.Dense(d_model1, use_bfloat16=use_bfloat16),\n", + " tl.LayerNorm(),\n", + " ]\n", + " if mode == 'predict':\n", + " return tl.Cache(tl.Serial(layers))\n", + " else:\n", + " return tl.Serial(layers)\n", + "\n", + " if mode == 'predict':\n", + " global DotProductCausalAttention\n", + " DotProductCausalAttention.monkey_patched_mask = (\n", + " lambda x: portal_mask)\n", + " global _RememberPad\n", + " _RememberPad.monkey_patched_mask = ( # pylint: disable=protected-access\n", + " lambda x: portal_mask)\n", + " global ScanSRUCell\n", + " originalScanSRUCell = ScanSRUCell\n", + " ScanSRUCell = functools.partial(ScanSRUCell,\n", + " monkey_patched_mask=portal_mask)\n", + "\n", + " decoder_blocks = []\n", + "\n", + " if isinstance(encoder_decoder_attention_type, (tuple, list)):\n", + " assert n_decoder_layers % len(encoder_decoder_attention_type) == 0\n", + " else:\n", + " encoder_decoder_attention_type = [encoder_decoder_attention_type]\n", + " for layer_idx in range(n_decoder_layers):\n", + " layer_attention_type = encoder_decoder_attention_type[\n", + " layer_idx % len(encoder_decoder_attention_type)]\n", + " # Grow d_model, d_ff, and d_qkv if requested.\n", + " d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1\n", + " if half_before_layer and layer_idx \u003e= half_before_layer:\n", + " d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value\n", + " if double_after_layer and layer_idx \u003e double_after_layer:\n", + " d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2\n", + " decoder_block = DecoderBlock(\n", + " d_m, d_f, d_k, d_v, n_heads,\n", + " attention_type=layer_attention_type,\n", + " dropout=dropout,\n", + " ff_activation=ff_activation,\n", + " ff_dropout=ff_dropout,\n", + " ff_use_sru=ff_use_sru,\n", + " ff_chunk_size=ff_chunk_size,\n", + " ff_sparsity=ff_sparsity,\n", + " attention_chunk_size=attention_chunk_size,\n", + " n_attention_layers=n_decoder_attention_layers,\n", + " center_layernorm=center_layernorm,\n", + " use_bfloat16=use_bfloat16,\n", + " mode=mode)\n", + " decoder_blocks.append(decoder_block)\n", + " if half_before_layer and layer_idx == half_before_layer - 1:\n", + " decoder_blocks.append(tl.ReversibleConcatenatePair())\n", + " if double_after_layer and layer_idx == double_after_layer:\n", + " decoder_blocks.append(tl.ReversibleConcatenatePair())\n", + "\n", + " if mode == 'predict':\n", + " # After initializing the decoder we can revert to original state of\n", + " # previously monkey-patched classes/functions.\n", + " DotProductCausalAttention.monkey_patched_mask = (\n", + " lambda x: None)\n", + " _RememberPad.monkey_patched_mask = (lambda x: None) # pylint: disable=protected-access\n", + " ScanSRUCell = originalScanSRUCell\n", + "\n", + " def _Loss():\n", + " return SparseDenseWithOptions(\n", + " output_vocab_size,\n", + " d_input=d_model2,\n", + " sparsity_type=loss_sparsity_type,\n", + " sparsity=loss_sparsity,\n", + " d_lowrank=loss_d_lowrank,\n", + " prob_sparse=loss_sparsity_prob,\n", + " use_bfloat16=use_bfloat16,\n", + " mode=mode)\n", + "\n", + " def _enc_dec_concat():\n", + " \"\"\"Layers to merge encoder and decoder.\"\"\"\n", + " if reversible_encoder:\n", + " return [\n", + " tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d\n", + " ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d\n", + " ]\n", + " else:\n", + " return [\n", + " tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d\n", + " ConcatWithPadding(mode=mode), # v_ed tok_e tok_d\n", + " tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d\n", + " ]\n", + "\n", + " def _inp_layers():\n", + " if input_vocab_size is not None:\n", + " return tl.AssertFunction(\n", + " 'bl,br-\u003ebld,bl,bl,br', # b: batch, l/r: enc/dec length, d: vec depth\n", + " tl.Serial( # tok_e tok_d\n", + " tl.Select([0, 0, 0, 1]),\n", + " tl.Parallel(in_encoder, [tl.PaddingMask(),\n", + " _RemoveAxes12()])\n", + " )) # vec_e mask_e tok_e tok_d\n", + " else:\n", + " # Input in this case is vec_e, mask_e, tok_d. Where all downstream\n", + " # operations expect tok_e, we give it instead mask_e, expecting that\n", + " # downstream ops only are looking for padding/not padding.\n", + " return tl.AssertFunction(\n", + " 'blf,bl,br-\u003ebld,bl,bl,br', # f: in-feature depth, d: out-vector depth\n", + " tl.Serial( # vec_e mask_e tok_d\n", + " tl.Select([0, 1, 1, 2]),\n", + " tl.Parallel(in_encoder, [], _AsTokenIDs())\n", + " )) # vec_e mask_e tok_e tok_d\n", + "\n", + " # Assemble and return the model.\n", + " return tl.Serial(\n", + " _inp_layers(), # vec_e mask_e tok_e tok_d\n", + " tl.Parallel([], portal_mask),\n", + "\n", + " tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss.\n", + "\n", + " # Embed in and out tokens; done together as weights may be shared.\n", + " tl.Parallel([], [], [], [tl.ShiftRight(mode=mode),\n", + " out_encoder]), # vec_e mask_e tok_e vec_d tok_d\n", + "\n", + " # Encode; then concat encoder and decoder, given encoder mask.\n", + " _Encoder(), # vec_e mask_e tok_e vec_d tok_d\n", + " _enc_dec_concat(),\n", + "\n", + " # Run decoder blocks.\n", + " _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,\n", + " forget_dense), # vec_ed1 vec_ed2 tok_e tok_d\n", + " _XYAvg(), # vec_ed tok_e tok_d\n", + " tl.LayerNorm(), # vec_ed tok_e tok_d\n", + "\n", + " # Separate out the encoder part from the concatenated vector,\n", + " # then compute loss.\n", + " tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d\n", + " StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d\n", + " _Loss(), # vec_d tok_d\n", + " )\n", + "\n", + "\n", + "def _InsertAxes12():\n", + " \"\"\"Returns a layer that inserts two internal size-1 axes into an array.\"\"\"\n", + " return tl.Fn('InsertAxes12',\n", + " lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])))\n", + "\n", + "\n", + "def _RemoveAxes12():\n", + " \"\"\"Returns a layer that removes two internal size-1 axes from an array.\"\"\"\n", + " return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2)))\n", + "\n", + "\n", + "def _AsTokenIDs():\n", + " \"\"\"Returns a layer that makes mask values look like token ID ints.\"\"\"\n", + " return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32))\n", + "\n", + "\n", + "def _XYAvg():\n", + " \"\"\"Returns a layer that computes the element-wise average of two arrays.\"\"\"\n", + " return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0)\n", + "\n", + "\n", + "def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True):\n", + " \"\"\"ReversibleSerial but with a forgetting block every n_layers.\"\"\"\n", + " if not n_layers or len(layers) \u003c= n_layers + 1:\n", + " return tl.ReversibleSerial(layers)\n", + " layers1, layers2 = layers[:n_layers], layers[n_layers:]\n", + "\n", + " if forget_dense:\n", + " forgetting_layer = tl.Serial(\n", + " _XYAvg(),\n", + " tl.Dense(d_model),\n", + " tl.Dup(),\n", + " )\n", + " else:\n", + " forgetting_layer = tl.Select([0, 1])\n", + "\n", + " return tl.Serial(\n", + " tl.ReversibleSerial(layers1),\n", + " forgetting_layer,\n", + " _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense)\n", + " )\n", + "\n", + "\n", + "def _ConvertToNaNsOnAnyZero():\n", + " def _convert_to_nans(x, y):\n", + " # if all values in y are non-zeros, return x; otherwise return 0s\n", + " return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y\n", + " return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2)\n", + "\n", + "\n", + "class _PortalInput(tl.Layer):\n", + " \"\"\"Portal input for monkey-patching of mask in predict mode.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__(name='_PortalInput', n_out=1, n_in=1)\n", + " self._portal_output = _PortalOutput(self)\n", + "\n", + " def forward(self, x):\n", + " if isinstance(x, (list, tuple)):\n", + " x = x[0]\n", + " self.state = (x,)\n", + " return x\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Initializes this layer's weights.\"\"\"\n", + " if isinstance(input_signature, (list, tuple)):\n", + " input_signature = input_signature[0]\n", + " self.state = (jnp.zeros(input_signature.shape),)\n", + "\n", + " def get_value(self):\n", + " return self.state[0]\n", + "\n", + " def get_layer(self):\n", + " return self._portal_output\n", + "\n", + "\n", + "class _PortalOutput(tl.Layer):\n", + " \"\"\"Portal input for monkey-patching of mask in predict mode.\"\"\"\n", + "\n", + " def __init__(self, portal_input):\n", + " super().__init__(name='_PortalOutput', n_out=1, n_in=0)\n", + " self._portal_input = portal_input\n", + "\n", + " def forward(self, x):\n", + " return self._portal_input.get_value()\n", + "\n", + " def get_value(self):\n", + " return self._portal_input.get_value()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E0Rq71ML6XZu" + }, + "source": [ + "## Example training\n", + "\n", + "Here we show how the Terraformer can be trained on example inputs. The results for the paper were obtained with identical training but for different configurations of inputs and models, which are specified in the attached config files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oI5XQcltJmeE" + }, + "outputs": [], + "source": [ + "model = Terraformer(\n", + " input_vocab_size=12,\n", + " # small model for testing\n", + " d_model=128,\n", + " d_ff=512,\n", + " n_encoder_layers=2,\n", + " n_decoder_layers=2,\n", + " # setting sparsity\n", + " ff_use_sru=(1, 32),\n", + " ff_sparsity=32,\n", + " loss_sparsity=4,\n", + " encoder_decoder_attention_type=functools.partial(\n", + " MultiplicativeConvCausalAttention, sparsity=16, length_kernel_size=3),\n", + " )\n", + "\n", + "copy_inputs = trax.data.inputs.simple_sequence_copy_inputs(\n", + " vocab_size=10, batch_size=32, train_length=32,\n", + " eval_min_length=16, eval_max_length=32)\n", + "\n", + "# Training task.\n", + "train_task = training.TrainTask(\n", + " labeled_data=copy_inputs.train_stream(1),\n", + " loss_layer=tl.WeightedCategoryCrossEntropy(),\n", + " optimizer=trax.optimizers.Adam(0.0001),\n", + " n_steps_per_checkpoint=5,\n", + ")\n", + "\n", + "# Evaluaton task.\n", + "eval_task = training.EvalTask(\n", + " labeled_data=copy_inputs.eval_stream(1),\n", + " metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],\n", + " n_eval_batches=2 # For less variance in eval numbers.\n", + ")\n", + "\n", + "# Training loop saves checkpoints to output_dir.\n", + "output_dir = os.path.expanduser('~/output_dir/')\n", + "!rm -rf {output_dir}\n", + "training_loop = training.Loop(model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir=output_dir)\n", + "\n", + "# Run 2000 steps (batches).\n", + "training_loop.run(20)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Terraformer from scratch.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1mdBTceBJGE_yff5FvRAByrisUsc88Nw7", + "timestamp": 1635190861529 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}