diff --git a/tutorials/transformer_tutorials/README.md b/tutorials/transformer_tutorials/README.md new file mode 100644 index 0000000..dbc448e --- /dev/null +++ b/tutorials/transformer_tutorials/README.md @@ -0,0 +1,10 @@ +# Transformer Tutorials + +Jupiter notebooks showing how to quantise and compress toy transformer encoder and encoder-decoder models. + +# Tutorials + +* ViT_PCQAT.ipynb - Shows PCQAT using TFMOT for transformer encoder models. +* ViT_2x4-PQAT.ipynb - Shows new 2x4 pruning and QAT with TFMOT for transformer encoder models. +* translation.ipynb - Shows how to do QAT with TFMOT for encoder-decoder models. +* translation_PQAT.ipynb - Shows how to apply pruning & QAT with TFMOT for encoder-decoder models. diff --git a/tutorials/transformer_tutorials/ViT_2x4-PQAT.ipynb b/tutorials/transformer_tutorials/ViT_2x4-PQAT.ipynb new file mode 100755 index 0000000..80d983a --- /dev/null +++ b/tutorials/transformer_tutorials/ViT_2x4-PQAT.ipynb @@ -0,0 +1,1005 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "49d3aa7b", + "metadata": {}, + "source": [ + "# 2:4 structured PQAT on Vision Transformer using TFMOT\n", + "\n", + "Example notebook to demonstrate how TFMOT can be used for applying 2:4 structured pruning and QAT on a ViT model." + ] + }, + { + "cell_type": "markdown", + "id": "69f4b77f", + "metadata": {}, + "source": [ + "## Background\n", + "\n", + "The [Vision Transformer (ViT)](https://arxiv.org/pdf/2010.11929.pdf) architecture uses stacked transformer encoder blocks to process images for certain tasks. The encoder blocks are architecturally similar to the popular [NLP transformers](https://arxiv.org/pdf/1706.03762.pdf). The inputs to the transformer encoders are embeddings of patches extracted from the image. For a classification task, an additional feed forward network is added to the end.\n", + "\n", + "\"Vision\n", + "\n", + "In this notebook:\n", + "1. Firstly a ViT model is created and trained from scratch on the MNIST dataset. In practice, pre-trained weights can also be loaded.\n", + "2. Afterwards, 2:4 structured weight pruning and quantisation aware training (QAT) techniques are applied sequentially using the collaborative optimisation features of the [TensorFlow Model Optimization Toolkit (TFMOT)](https://www.tensorflow.org/model_optimization).\n", + "3. Finally, an integer-only TFLite model is generated and tested.\n", + "\n", + "### 2:4 structured pruning\n", + "\n", + "In [2:4 structured pruning](https://arxiv.org/pdf/2104.08378.pdf), a constraint is set on the weights during training to ensure at least 2 parameters are zero in each block of 4. Pruning weights with a 2:4 structure leads to a more compact representation and enables hardware acceleration during inference.\n", + "\n", + "\"A" + ] + }, + { + "cell_type": "markdown", + "id": "e1c87fa4", + "metadata": {}, + "source": [ + "## TFMOT limitations\n", + "- Subclassed models are not supported. Only sequential and functional model definitions are supported. (Pruning, Clustering & QAT)\n", + "- Custom subclassed layers are not supported. (Clustering & QAT)\n", + " - Clustering will only work with subclassed layers if the weight variables you have to cluster are not nested within another layer (e.g. MHA).\n", + " - QAT works correctly if the subclassed layer performs only 1 operation.\n", + "- Low-level tensorflow operators such as `tf.linalg.matmul` are not supported. (Only for QAT)\n", + " - QAT expects all quantised layers to be a subclass of `tf.keras.layers.Layer`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "405d90d4", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow_model_optimization as tfmot\n", + "\n", + "tf.random.set_seed(0)\n", + "\n", + "print('TensorFlow version: {}'.format(tf.__version__))\n", + "print('TFMOT version: {}'.format(tfmot.__version__))" + ] + }, + { + "cell_type": "markdown", + "id": "7bee1a89", + "metadata": {}, + "source": [ + "## Model definition\n", + "\n", + "Due to the above-mentioned limitations, custom Keras layers must be defined for all of the low-level TensorFlow operators in order to perform QAT (each layer must only contain a single operation).\n", + "\n", + "Since none of these will have any prunable weights, first we create a base prunable layer class to extend, instead of `tf.keras.layers.Layer`. If any of the weights in the custom layers should be pruned, a list of the weights should be provided in the `get_prunable_weights` method. Refer to the TFMOT documentation for more details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef0735bd", + "metadata": {}, + "outputs": [], + "source": [ + "class PrunableLayer(tf.keras.layers.Layer,\n", + " tfmot.sparsity.keras.PrunableLayer):\n", + " def get_prunable_weights(self): return []" + ] + }, + { + "cell_type": "markdown", + "id": "3b617997", + "metadata": {}, + "source": [ + "### 1. Define each of the TensorFlow operations ViT uses as a Keras subclassed layer:\n", + "\n", + "Note that some of these layers have trainable weights defined using the `add_weight` method. These weights will not be pruned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89a90471", + "metadata": {}, + "outputs": [], + "source": [ + "class MatMul(PrunableLayer):\n", + " def __init__(self, transpose_b=False, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.transpose_b = transpose_b\n", + "\n", + " def call(self, inputs):\n", + " return tf.linalg.matmul(*inputs, transpose_b=self.transpose_b)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'transpose_b': self.transpose_b})\n", + " return config\n", + "\n", + "class Multiply(PrunableLayer):\n", + " def call(self, inputs):\n", + " return tf.multiply(*inputs)\n", + "\n", + "# Calling Multiply with a scalar input will lead to an error.\n", + "# Use the following ScalarMultiply class instead.\n", + "class ScalarMultiply(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar\n", + "\n", + " def call(self, x):\n", + " return tf.math.multiply(x, self.scalar)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "class Add(PrunableLayer):\n", + " def call(self, inputs):\n", + " return tf.math.add(*inputs)\n", + "\n", + "# Calling Add with a scalar input will lead to an error.\n", + "# Use the following ScalarAdd class instead.\n", + "class ScalarAdd(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar\n", + "\n", + " def call(self, x):\n", + " return tf.math.add(x, self.scalar)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "class Slice(PrunableLayer):\n", + " def __init__(self, seq_idx, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.seq_idx = seq_idx\n", + "\n", + " def call(self, x):\n", + " return x[:, self.seq_idx, ...]\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'seq_idx': self.seq_idx})\n", + " return config\n", + "\n", + "class Mean(PrunableLayer):\n", + " def __init__(self, axes=None, keepdims=True, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes=axes\n", + " self.keepdims = keepdims\n", + "\n", + " def call(self, x):\n", + " return tf.math.reduce_mean(x, axis=self.axes, keepdims=self.keepdims)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes,\n", + " 'keepdims': self.keepdims})\n", + " return config\n", + "\n", + "class Subtract(PrunableLayer):\n", + " def call(self, inputs):\n", + " return tf.math.subtract(*inputs)\n", + "\n", + "class StopGradient(PrunableLayer):\n", + " def call(self, x):\n", + " return tf.stop_gradient(x)\n", + "\n", + "class RSqrt(PrunableLayer):\n", + " def call(self, x):\n", + " return tf.math.rsqrt(x)\n", + "\n", + "class ClipMin(PrunableLayer):\n", + " def __init__(self, min_val=0, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.min_val = min_val\n", + "\n", + " def call(self, x):\n", + " return tf.math.maximum(x, self.min_val)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'min_val': self.min_val})\n", + " return config\n", + "\n", + "class BroadcastToken(PrunableLayer):\n", + " \"\"\"Layer to broadcast the class token\"\"\"\n", + " def __init__(self, embedding_dim, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(1, 1, self.embedding_dim), initializer='zeros', \n", + " trainable=True, name='token')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " batch_size = tf.shape(x)[0]\n", + " return tf.broadcast_to(self.w, [batch_size, 1, self.embedding_dim])\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim})\n", + " return config\n", + "\n", + "class AddPositionalEmbedding(PrunableLayer):\n", + " \"\"\"Layer to add positional embeddings to the tokens\"\"\"\n", + " def __init__(self, seq_len, embedding_dim, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + " self.seq_len = seq_len\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(1, self.seq_len, self.embedding_dim), initializer=None,\n", + " trainable=True, name='pos_emb')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return x + self.w\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim, 'seq_len': self.seq_len})\n", + " return config\n", + "\n", + "class Scale(PrunableLayer):\n", + " \"\"\"Multiply with gamma (LayerNorm)\"\"\"\n", + " def __init__(self, axes, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes = axes\n", + "\n", + " def build(self, input_shape):\n", + " param_shape = [input_shape[dim] for dim in self.axes]\n", + " self.w = self.add_weight(name='gamma', shape=param_shape,\n", + " trainable=True, initializer='ones')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return tf.multiply(x, self.w)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes})\n", + " return config\n", + "\n", + "class Centre(PrunableLayer):\n", + " \"\"\"Add beta (LayerNorm)\"\"\"\n", + " def __init__(self, axes, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes = axes\n", + "\n", + " def build(self, input_shape):\n", + " param_shape = [input_shape[dim] for dim in self.axes]\n", + " self.w = self.add_weight(name='beta', shape=param_shape,\n", + " trainable=True, initializer='zeros')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return tf.math.add(x, self.w)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes})\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "id": "9c9a97c7", + "metadata": {}, + "source": [ + "### 2. Now that these low-level operators are defined as Keras layers, we can start writing ViT layers such as multi-head attention or layer normalisation functionally:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a79da268", + "metadata": {}, + "outputs": [], + "source": [ + "Tanh = tf.keras.layers.Activation('tanh')\n", + "\n", + "def patch_encoder(inp, patch_size, num_patches, embedding_dim):\n", + " \"\"\"\n", + " Patch encoder layer, extracts patches from the image, flattens them \n", + " and adds the class token and positional embedding vectors.\n", + " \"\"\"\n", + " x = tf.keras.layers.Conv2D(filters=embedding_dim, kernel_size=patch_size,\n", + " strides=patch_size, name='patch_encoder/conv2d')(inp)\n", + " x = tf.keras.layers.Reshape((num_patches, embedding_dim))(x)\n", + "\n", + " # add the class token\n", + " cls_token = BroadcastToken(embedding_dim=embedding_dim, name='patch_encoder/cls_token')(inp)\n", + " x = tf.keras.layers.Concatenate(axis=1)([cls_token, x])\n", + "\n", + " x = AddPositionalEmbedding(seq_len=(num_patches + 1), # +1 for the class token\n", + " embedding_dim=embedding_dim,\n", + " name='patch_encoder/add_pos_emb')(x)\n", + " return x\n", + "\n", + "def self_attention(x, n_heads, dim, name='mha'):\n", + " \"\"\"Multi-head attention layer\"\"\"\n", + " depth = dim // n_heads\n", + "\n", + " q = tf.keras.layers.Dense(units=dim, name=f'{name}/query')(x)\n", + " k = tf.keras.layers.Dense(units=dim, name=f'{name}/key')(x)\n", + " v = tf.keras.layers.Dense(units=dim, name=f'{name}/value')(x)\n", + "\n", + " q = tf.keras.layers.Reshape((-1, n_heads, depth))(q)\n", + " q = tf.keras.layers.Permute((2, 1, 3))(q)\n", + " k = tf.keras.layers.Reshape((-1, n_heads, depth))(k)\n", + " k = tf.keras.layers.Permute((2, 1, 3))(k)\n", + " v = tf.keras.layers.Reshape((-1, n_heads, depth))(v)\n", + " v = tf.keras.layers.Permute((2, 1, 3))(v)\n", + "\n", + " qk = ScalarMultiply(depth ** -0.5)(MatMul(transpose_b=True)([q, k]))\n", + " attn_weights = tf.keras.layers.Softmax(axis=-1)(qk)\n", + "\n", + " attn_out = MatMul()([attn_weights, v]) \n", + " attn_out = tf.keras.layers.Permute((2, 1, 3))(attn_out)\n", + " attn_out = tf.keras.layers.Reshape((-1, dim))(attn_out)\n", + " out = tf.keras.layers.Dense(dim, name=f'{name}/output_dense')(attn_out)\n", + "\n", + " return out\n", + "\n", + "def layer_norm(x, axes=2, epsilon=0.001, name='layer_norm', trainable=True):\n", + " \"\"\"LayerNormalization\"\"\"\n", + " if isinstance(axes, int): axes = [axes]\n", + "\n", + " mean = Mean(axes=axes)(x)\n", + " ## This block can be replaced with a squared_difference layer ##\n", + " diff = Subtract()([x, StopGradient()(mean)]) ##\n", + " sq_diff = Multiply()([diff, diff]) ##\n", + " ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##\n", + " variance = Mean(axes=axes, name=f'{name}/variance')(sq_diff)\n", + " if not trainable:\n", + " inv = RSqrt()(variance)\n", + " x = Multiply()([diff, inv])\n", + " else:\n", + " inv = RSqrt()(ClipMin(min_val=epsilon)(variance)) # ClipMin prevents division by 0.\n", + " x = Subtract(name=f'{name}/grad_subtract')([x, mean]) # This layer is removed for inference so it is named.\n", + " x = Multiply()([x, inv])\n", + "\n", + " x = Scale(axes=axes)(x)\n", + " x = Centre(axes=axes)(x)\n", + "\n", + " return x\n", + "\n", + "def gelu(x):\n", + " \"\"\"Functional definition of approximate GELU with Keras layers\"\"\"\n", + " res = Add()([x, ScalarMultiply(0.044715)(Multiply()([x, Multiply()([x, x])]))])\n", + " res = ScalarAdd(1.0)(Tanh(ScalarMultiply(math.sqrt(2 / math.pi))(res)))\n", + " res = ScalarMultiply(0.5)(res)\n", + " res = Multiply()([x, res])\n", + " return res\n", + "\n", + "def mlp(x, hidden_dim, out_dim):\n", + " \"\"\"Multi-layer perceptron block\"\"\"\n", + " x = tf.keras.layers.Dense(units=hidden_dim)(x)\n", + " x = gelu(x)\n", + " x = tf.keras.layers.Dense(units=out_dim)(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "3c1db26f", + "metadata": {}, + "source": [ + "### 3. Full functional model definition:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50e42892", + "metadata": {}, + "outputs": [], + "source": [ + "def get_vision_transformer(input_shape,\n", + " n_classes,\n", + " patch_size,\n", + " embedding_dim,\n", + " n_layers,\n", + " n_attention_heads,\n", + " mlp_hidden_dim,\n", + " trainable=True):\n", + " \"\"\"\n", + " Args:\n", + " input_shape (tuple): Shape of the inputs, including the batch size.\n", + " n_classes (int): Number of classes in the dataset.\n", + " patch_size (int / tuple of ints): Size of the patches to extract from the images.\n", + " embedding_dim (int): Size of the embedded patch vectors.\n", + " n_layers (int): Number of transformer encoder layers.\n", + " n_attention_heads (int): Number of attention heads.\n", + " mlp_hidden_dim (int): Hidden layer size for the intermediate MLPs.\n", + "\n", + " Returns:\n", + " model (tf.keras.Model): The Keras model.\n", + " \"\"\"\n", + "\n", + " if isinstance(patch_size, int): patch_size = (patch_size, patch_size)\n", + "\n", + " # Calculate the number of patches\n", + " num_patches = (input_shape[1] * input_shape[2]) // (patch_size[0] * patch_size[1])\n", + "\n", + " inp = tf.keras.layers.Input(shape=input_shape[1:], batch_size=input_shape[0], name='image')\n", + "\n", + " # Patch encoder layer\n", + " x = patch_encoder(inp, patch_size, num_patches, embedding_dim)\n", + "\n", + " for block in range(n_layers):\n", + " # Attention block\n", + " x1 = layer_norm(x, name=(f'layer_norm_{2 * block}' if block != 0 else 'layer_norm'), trainable=trainable)\n", + " x1 = self_attention(x1, n_attention_heads, embedding_dim, name=(f'mha_{block}' if block != 0 else 'mha'))\n", + " x1 = tf.keras.layers.Add()([x1, x])\n", + "\n", + " # MLP block\n", + " x2 = layer_norm(x1, name=f'layer_norm_{2 * block + 1}', trainable=trainable)\n", + " x2 = mlp(x2, mlp_hidden_dim, embedding_dim)\n", + " x = tf.keras.layers.Add()([x2, x1])\n", + "\n", + " x = layer_norm(x, name=f'layer_norm_{2 * block + 2}', trainable=trainable)\n", + "\n", + " ## ~ Classification head ~ ##\n", + " cls_head = Slice(0)(x)\n", + " out = tf.keras.layers.Dense(n_classes, kernel_initializer='zeros', name='cls_head')(cls_head)\n", + " ## ~~~~~~~~~~~~~~~~~~~~~~~ ##\n", + "\n", + " model = tf.keras.Model(inputs=inp, outputs=out)\n", + "\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "id": "ed35c8ee", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3efed1b", + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZE = 32\n", + "\n", + "# Load the MNIST dataset\n", + "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()\n", + "X_train, X_test = (X_train[..., np.newaxis] / 255.0), (X_test[..., np.newaxis] / 255.0)\n", + "train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(1000) \\\n", + " .batch(BATCH_SIZE, drop_remainder=True) \\\n", + " .prefetch(tf.data.AUTOTUNE)\n", + "test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(BATCH_SIZE, drop_remainder=True) \\\n", + " .prefetch(tf.data.AUTOTUNE)\n", + "\n", + "def compile_and_fit(model, **kwargs):\n", + " model.compile(optimizer=\"adam\",\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=[\"accuracy\"])\n", + " model.fit(train_ds, validation_data=test_ds, **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05a2cd31", + "metadata": {}, + "outputs": [], + "source": [ + "model = get_vision_transformer(input_shape=(BATCH_SIZE, 28, 28, 1), # (batch_size, height, width, channels)\n", + " n_classes=10,\n", + " patch_size=(4, 4),\n", + " embedding_dim=16,\n", + " n_layers=2,\n", + " n_attention_heads=2,\n", + " mlp_hidden_dim=16)\n", + "\n", + "compile_and_fit(model, epochs=3)" + ] + }, + { + "cell_type": "markdown", + "id": "85113104", + "metadata": {}, + "source": [ + "## 2:4 structured pruning & QAT" + ] + }, + { + "cell_type": "markdown", + "id": "cccafcb2", + "metadata": {}, + "source": [ + "### 1. Apply the pruning API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06982c22", + "metadata": {}, + "outputs": [], + "source": [ + "prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude\n", + "strip_pruning = tfmot.sparsity.keras.strip_pruning\n", + "\n", + "N_EPOCHS = 2\n", + "pruning_params = {\n", + " 'sparsity_m_by_n': (2, 4)\n", + "}\n", + "pruned_model = prune_low_magnitude(model, **pruning_params)\n", + "# Fine-tune with pruning\n", + "compile_and_fit(pruned_model, epochs=N_EPOCHS, callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])\n", + "stripped_pruned_model = strip_pruning(pruned_model)\n", + "print('Success')" + ] + }, + { + "cell_type": "markdown", + "id": "3d516f37", + "metadata": {}, + "source": [ + "**Warning: The original model is modified after calling [`prune_low_magnitude`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude).**" + ] + }, + { + "cell_type": "markdown", + "id": "3339a1f8", + "metadata": {}, + "source": [ + "#### 1.1. Check that the weights are 2:4 pruned" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38f6ed82", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import is_pruned_m_by_n\n", + "\n", + "def print_2x4_sparsity(model):\n", + " for w in model.weights:\n", + " if (w.shape.rank == 2 or w.shape.rank == 4) and is_pruned_m_by_n(w, m_by_n=(2, 4)):\n", + " print(' ', w.name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59c38ae8", + "metadata": {}, + "outputs": [], + "source": [ + "print('2:4 pruned weights:')\n", + "print_2x4_sparsity(stripped_pruned_model)" + ] + }, + { + "cell_type": "markdown", + "id": "2e8dff73", + "metadata": {}, + "source": [ + "### 2. Quantisation-aware training API\n", + "#### 2.1. To use the custom Keras layers we defined, we need to pass a [`QuantizeConfig`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/QuantizeConfig) for each of these layers.\n", + "\n", + "For Keras layers which are already supported in TFMOT, a default `QuantizeConfig` class is assigned to each one. However custom `QuantizeConfig` instances could also be created for these layers to give more control over how they are quantised." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "829dd7be", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow_model_optimization.quantization.keras import QuantizeConfig, quantizers\n", + "\n", + "LastValueQuantizer = quantizers.LastValueQuantizer\n", + "MovingAverageQuantizer = quantizers.MovingAverageQuantizer\n", + "AllValuesQuantizer = quantizers.AllValuesQuantizer\n", + "\n", + "class NoOpQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which does not quantize any part of the layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "class OutputQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which only quantizes the output of a layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "class WeightQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which quantizes the custom weights in the patch encoder and layer normalisation layers.\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.weight_quantizer = LastValueQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=True, narrow_range=True)\n", + " self.activation_quantizer = MovingAverageQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=False, narrow_range=False)\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return [(layer.w, self.weight_quantizer)]\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " layer.w = quantize_weights[0]\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [self.activation_quantizer]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "class VarianceQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig for the variance calculation in the layer normalisation layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}" + ] + }, + { + "cell_type": "markdown", + "id": "587cf644", + "metadata": {}, + "source": [ + "Since custom layers and `QuantizeConfig`s are used, the whole model cannot directly be wrapped with QAT wrappers.
\n", + "So first we write a function to wrap the individual layers with QAT wrappers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d5a7153", + "metadata": {}, + "outputs": [], + "source": [ + "def apply_wrapper(wrapper_function, layer_param_dict):\n", + "\n", + " def wrap_layer(layer):\n", + " if layer.name in layer_param_dict.keys():\n", + " return wrapper_function(layer, **layer_param_dict[layer.name])\n", + " return layer\n", + "\n", + " return wrap_layer\n", + "\n", + "def layer_wrapper(model, wrapper_function, layer_param_dict):\n", + " return tf.keras.models.clone_model(model, clone_function=apply_wrapper(wrapper_function, layer_param_dict))" + ] + }, + { + "cell_type": "markdown", + "id": "1a12706a", + "metadata": {}, + "source": [ + "The custom layers should be quantized with the following `QuantizeConfig` classes:\n", + "\n", + "| Custom Layer | QuantizeConfig |\n", + "| :- | :-: |\n", + "| ClipMin | NoOpQuantizeConfig |\n", + "| Slice | NoOpQuantizeConfig |\n", + "| StopGradient | NoOpQuantizeConfig |\n", + "| MatMul | OutputQuantizeConfig |\n", + "| Multiply | OutputQuantizeConfig |\n", + "| ScalarMultiply | OutputQuantizeConfig |\n", + "| Add | OutputQuantizeConfig |\n", + "| ScalarAdd | OutputQuantizeConfig |\n", + "| Subtract | OutputQuantizeConfig |\n", + "| RSqrt | OutputQuantizeConfig |\n", + "| Mean
Mean (variance) | OutputQuantizeConfig
VarianceQuantizeConfig |\n", + "| BroadcastToken | WeightQuantizeConfig |\n", + "| AddPositionalEmbedding | WeightQuantizeConfig |\n", + "| Scale | WeightQuantizeConfig |\n", + "| Centre | WeightQuantizeConfig |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4838d0e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_quant_configs(model):\n", + " layer_param_dict = {} # stores {Layer_Name: QuantizeConfig} pairs\n", + " scope = {} # stores all custom objects\n", + "\n", + " for layer in model.layers:\n", + "\n", + " if any([x in layer.name for x in ['clip', 'slice', 'stop_gradient']]):\n", + " layer_param_dict[layer.name] = {'quantize_config': NoOpQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " elif any([x in layer.name for x in ['mat_mul', 'multiply', 'scalar_multiply', 'add', \\\n", + " 'scalar_add', 'mean', 'subtract', 'r_sqrt']]):\n", + " layer_param_dict[layer.name] = {'quantize_config': OutputQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " elif any([x in layer.name for x in ['patch_encoder/cls_token', 'patch_encoder/add_pos_emb', \\\n", + " 'scale', 'centre']]):\n", + " layer_param_dict[layer.name] = {'quantize_config': WeightQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " elif 'variance' in layer.name:\n", + " layer_param_dict[layer.name] = {'quantize_config': VarianceQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " scope['NoOpQuantizeConfig'] = NoOpQuantizeConfig\n", + " scope['OutputQuantizeConfig'] = OutputQuantizeConfig\n", + " scope['WeightQuantizeConfig'] = WeightQuantizeConfig\n", + " scope['VarianceQuantizeConfig'] = VarianceQuantizeConfig\n", + "\n", + " return layer_param_dict, scope" + ] + }, + { + "cell_type": "markdown", + "id": "87983caa", + "metadata": {}, + "source": [ + "#### 2.2 Load the necessary API classes/functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1618c7d", + "metadata": {}, + "outputs": [], + "source": [ + "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n", + "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n", + "quantize_apply = tfmot.quantization.keras.quantize_apply\n", + "quantize_scope = tfmot.quantization.keras.quantize_scope\n", + "Default8BitPrunePreserveQuantizeScheme = tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme" + ] + }, + { + "cell_type": "markdown", + "id": "33a6326b", + "metadata": {}, + "source": [ + "#### 2.3 Apply QAT\n", + "\n", + "When calling the `quantize_apply` function, if an unsupported layer is missing from `layer_param_dict` or the `scope`, TFMOT will throw an error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ddd2d94", + "metadata": {}, + "outputs": [], + "source": [ + "layer_param_dict, scope = get_quant_configs(stripped_pruned_model)\n", + "\n", + "# Wrap each custom layer with the corresponding QuantizeConfig:\n", + "pqat_model = layer_wrapper(stripped_pruned_model, quantize_annotate_layer, layer_param_dict)\n", + "# Quantize the rest of the model with the API defaults:\n", + "pqat_model = quantize_annotate_model(pqat_model)\n", + "\n", + "with quantize_scope(scope):\n", + " pqat_model = quantize_apply(pqat_model, scheme=Default8BitPrunePreserveQuantizeScheme())\n", + "\n", + "compile_and_fit(pqat_model, epochs=2)\n", + "\n", + "WEIGHTS_PATH = './ViT_2x4-PQAT.h5'\n", + "pqat_model.save_weights(WEIGHTS_PATH)\n", + "print('Success')" + ] + }, + { + "cell_type": "markdown", + "id": "d1344356", + "metadata": {}, + "source": [ + "#### 2.4 Check that the weights are still 2:4 pruned" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5005701", + "metadata": {}, + "outputs": [], + "source": [ + "print('2:4 pruned weights:')\n", + "print_2x4_sparsity(pqat_model)" + ] + }, + { + "cell_type": "markdown", + "id": "ea69d3f0", + "metadata": {}, + "source": [ + "### 3. Generate an int8 TFLite file\n", + "\n", + "If we attempt to directly generate a TFLite file using the fine-tuned model above:\n", + "1. It will not have a correct batch size of 1.\n", + "2. It will have operators which are unnecessary during inference. Precisely, the extra `Subtract` operators and `ClipMin` operator in the layer normalisation blocks, which were used during training and fine-tuning, should be removed from the graph before creating the TFLite file.\n", + "\n", + "Therefore the network should be redefined with a batch size of 1 and with the redundant operators removed. The weights of the fine-tuned optimised model can then be loaded into this new model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4465f71d", + "metadata": {}, + "outputs": [], + "source": [ + "tf.keras.backend.clear_session() # reset layer name counters\n", + "\n", + "net = get_vision_transformer(input_shape=(1, 28, 28, 1), # (batch_size, height, width, channels)\n", + " n_classes=10,\n", + " patch_size=(4, 4),\n", + " embedding_dim=16,\n", + " n_layers=2,\n", + " n_attention_heads=2,\n", + " mlp_hidden_dim=16,\n", + " trainable=False)\n", + "layer_param_dict, scope = get_quant_configs(net)\n", + "net = quantize_annotate_model(layer_wrapper(net, quantize_annotate_layer, layer_param_dict))\n", + "with quantize_scope(scope):\n", + " net = quantize_apply(net)\n", + "\n", + "net.load_weights(WEIGHTS_PATH, by_name=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "837dab4f", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PATH = './ViT_2x4-PQAT_int8.tflite'\n", + "\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(net)\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "converter.inference_input_type = tf.int8\n", + "converter.inference_output_type = tf.int8\n", + "\n", + "# Experimental flag which improves efficiency for some devices\n", + "converter._experimental_disable_batchmatmul_unfold = True\n", + "\n", + "tflite_model = converter.convert()\n", + "with open(MODEL_PATH, \"wb+\") as tflite_file:\n", + " tflite_file.write(tflite_model)" + ] + }, + { + "cell_type": "markdown", + "id": "eb024aec", + "metadata": {}, + "source": [ + "### 4. Evaluate the TFLite model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2c3d910", + "metadata": {}, + "outputs": [], + "source": [ + "interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)\n", + "interpreter.allocate_tensors()\n", + "\n", + "input_details = interpreter.get_input_details()\n", + "output_details = interpreter.get_output_details()\n", + "input_scale, input_zero_point = input_details[0]['quantization']\n", + "output_scale, output_zero_point = output_details[0]['quantization']\n", + "\n", + "int8_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='int8_accuracy')\n", + "progbar = tf.keras.utils.Progbar(len(X_test), stateful_metrics=['accuracy'])\n", + "for step, (img, lbl) in enumerate(zip(X_test, y_test)):\n", + " # Set input tensor\n", + " img = img[np.newaxis, ...] / input_scale + input_zero_point\n", + " interpreter.set_tensor(input_details[0]['index'], tf.cast(img, input_details[0]['dtype']))\n", + " interpreter.invoke()\n", + "\n", + " # Get output tensor\n", + " output_data = interpreter.get_tensor(output_details[0]['index'])\n", + " output_data = output_scale * (output_data.astype(np.float32) - output_zero_point)\n", + "\n", + " # Update accuracy\n", + " int8_accuracy.update_state(lbl, output_data)\n", + " progbar.update(step + 1, values=[('accuracy', int8_accuracy.result().numpy())])\n", + "\n", + "print('Accuracy:', int8_accuracy.result().numpy())" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "d7b4b41e591e77f7c9707eacea36c84f7efa2eea6b6b57ad0b82061db904f16d" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/transformer_tutorials/ViT_PCQAT.ipynb b/tutorials/transformer_tutorials/ViT_PCQAT.ipynb new file mode 100755 index 0000000..01c91d9 --- /dev/null +++ b/tutorials/transformer_tutorials/ViT_PCQAT.ipynb @@ -0,0 +1,1072 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "820bff2b", + "metadata": {}, + "source": [ + "# Vision Transformer optimisation using TFMOT\n", + "\n", + "Example notebook to demonstrate how TFMOT can be used for optimising complex transformer models such as ViT." + ] + }, + { + "cell_type": "markdown", + "id": "b22944d9", + "metadata": {}, + "source": [ + "## Background\n", + "\n", + "The [Vision Transformer (ViT)](https://arxiv.org/pdf/2010.11929.pdf) architecture uses stacked transformer encoder blocks to process images for certain tasks. The encoder blocks are architecturally similar to the popular [NLP transformers](https://arxiv.org/pdf/1706.03762.pdf). The inputs to the transformer encoders are embeddings of patches extracted from the image. For a classification task, an additional feed forward network is added to the end.\n", + "\n", + "\"Vision\n", + "\n", + "In this notebook:\n", + "1. Firstly a ViT model is created and trained from scratch on the MNIST dataset. In practice, pre-trained weights can also be loaded.\n", + "2. Afterwards, unstructured weight pruning, clustering and quantisation aware training (QAT) techniques are applied sequentially using the collaborative optimisation features of the [TensorFlow Model Optimization Toolkit (TFMOT)](https://www.tensorflow.org/model_optimization).\n", + "3. Finally, an integer-only TFLite model is generated and tested." + ] + }, + { + "cell_type": "markdown", + "id": "ae5a58ae", + "metadata": {}, + "source": [ + "## TFMOT limitations\n", + "- Subclassed models are not supported. Only sequential and functional model definitions are supported. (Pruning, Clustering & QAT)\n", + "- Custom subclassed layers are not supported. (Clustering & QAT)\n", + " - Clustering will only work with subclassed layers if the weight variables you have to cluster are not nested within another layer (e.g. MHA).\n", + " - QAT works correctly if the subclassed layer performs only 1 operation.\n", + "- Low-level tensorflow operators such as `tf.linalg.matmul` are not supported. (Only for QAT)\n", + " - QAT expects all quantised layers to be a subclass of `tf.keras.layers.Layer`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f09a1f7", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow_model_optimization as tfmot\n", + "\n", + "tf.random.set_seed(0)\n", + "\n", + "print('TensorFlow version: {}'.format(tf.__version__))\n", + "print('TFMOT version: {}'.format(tfmot.__version__))" + ] + }, + { + "cell_type": "markdown", + "id": "ea4a90ee", + "metadata": {}, + "source": [ + "## Model definition\n", + "\n", + "Due to the above-mentioned limitations, custom Keras layers must be defined for all of the low-level TensorFlow operators in order to perform QAT (each layer must only contain a single operation).\n", + "\n", + "Since none of these will have any prunable/clusterable weights, first we create a base prunable clusterable layer class to extend, instead of `tf.keras.layers.Layer`. If any of the weights in the custom layers should be pruned or clustered, a list of the weights should be provided in the respective method. Refer to the TFMOT documentation for more details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b022e09f", + "metadata": {}, + "outputs": [], + "source": [ + "class PrunableClusterableLayer(tf.keras.layers.Layer,\n", + " tfmot.sparsity.keras.PrunableLayer,\n", + " tfmot.clustering.keras.ClusterableLayer):\n", + " def get_prunable_weights(self): return []\n", + " def get_clusterable_weights(self): return []" + ] + }, + { + "cell_type": "markdown", + "id": "bd2b1152", + "metadata": {}, + "source": [ + "### 1. Define each of the TensorFlow operations ViT uses as a Keras subclassed layer:\n", + "\n", + "Note that some of these layers have trainable weights defined using the `add_weight` method. These weights will not be pruned or clustered." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51101c26", + "metadata": {}, + "outputs": [], + "source": [ + "class MatMul(PrunableClusterableLayer):\n", + " def __init__(self, transpose_b=False, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.transpose_b = transpose_b\n", + "\n", + " def call(self, inputs):\n", + " return tf.linalg.matmul(*inputs, transpose_b=self.transpose_b)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'transpose_b': self.transpose_b})\n", + " return config\n", + "\n", + "class Multiply(PrunableClusterableLayer):\n", + " def call(self, inputs):\n", + " return tf.multiply(*inputs)\n", + "\n", + "# Calling Multiply with a scalar input will lead to an error.\n", + "# Use the following ScalarMultiply class instead.\n", + "class ScalarMultiply(PrunableClusterableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar\n", + "\n", + " def call(self, x):\n", + " return tf.math.multiply(x, self.scalar)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "class Add(PrunableClusterableLayer):\n", + " def call(self, inputs):\n", + " return tf.math.add(*inputs)\n", + "\n", + "# Calling Add with a scalar input will lead to an error.\n", + "# Use the following ScalarAdd class instead.\n", + "class ScalarAdd(PrunableClusterableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar\n", + "\n", + " def call(self, x):\n", + " return tf.math.add(x, self.scalar)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "class Slice(PrunableClusterableLayer):\n", + " def __init__(self, seq_idx, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.seq_idx = seq_idx\n", + "\n", + " def call(self, x):\n", + " return x[:, self.seq_idx, ...]\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'seq_idx': self.seq_idx})\n", + " return config\n", + "\n", + "class Mean(PrunableClusterableLayer):\n", + " def __init__(self, axes=None, keepdims=True, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes=axes\n", + " self.keepdims = keepdims\n", + "\n", + " def call(self, x):\n", + " return tf.math.reduce_mean(x, axis=self.axes, keepdims=self.keepdims)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes,\n", + " 'keepdims': self.keepdims})\n", + " return config\n", + "\n", + "class Subtract(PrunableClusterableLayer):\n", + " def call(self, inputs):\n", + " return tf.math.subtract(*inputs)\n", + "\n", + "class StopGradient(PrunableClusterableLayer):\n", + " def call(self, x):\n", + " return tf.stop_gradient(x)\n", + "\n", + "class RSqrt(PrunableClusterableLayer):\n", + " def call(self, x):\n", + " return tf.math.rsqrt(x)\n", + "\n", + "class ClipMin(PrunableClusterableLayer):\n", + " def __init__(self, min_val=0, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.min_val = min_val\n", + "\n", + " def call(self, x):\n", + " return tf.math.maximum(x, self.min_val)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'min_val': self.min_val})\n", + " return config\n", + "\n", + "class BroadcastToken(PrunableClusterableLayer):\n", + " \"\"\"Layer to broadcast the class token\"\"\"\n", + " def __init__(self, embedding_dim, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(1, 1, self.embedding_dim), initializer='zeros', \n", + " trainable=True, name='token')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " batch_size = tf.shape(x)[0]\n", + " return tf.broadcast_to(self.w, [batch_size, 1, self.embedding_dim])\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim})\n", + " return config\n", + "\n", + "class AddPositionalEmbedding(PrunableClusterableLayer):\n", + " \"\"\"Layer to add positional embeddings to the tokens\"\"\"\n", + " def __init__(self, seq_len, embedding_dim, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + " self.seq_len = seq_len\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(1, self.seq_len, self.embedding_dim), initializer=None,\n", + " trainable=True, name='pos_emb')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return x + self.w\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim, 'seq_len': self.seq_len})\n", + " return config\n", + "\n", + "class Scale(PrunableClusterableLayer):\n", + " \"\"\"Multiply with gamma (LayerNorm)\"\"\"\n", + " def __init__(self, axes, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes = axes\n", + "\n", + " def build(self, input_shape):\n", + " param_shape = [input_shape[dim] for dim in self.axes]\n", + " self.w = self.add_weight(name='gamma', shape=param_shape,\n", + " trainable=True, initializer='ones')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return tf.multiply(x, self.w)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes})\n", + " return config\n", + "\n", + "class Centre(PrunableClusterableLayer):\n", + " \"\"\"Add beta (LayerNorm)\"\"\"\n", + " def __init__(self, axes, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes = axes\n", + "\n", + " def build(self, input_shape):\n", + " param_shape = [input_shape[dim] for dim in self.axes]\n", + " self.w = self.add_weight(name='beta', shape=param_shape,\n", + " trainable=True, initializer='zeros')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return tf.math.add(x, self.w)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes})\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "id": "5c0abcb4", + "metadata": {}, + "source": [ + "### 2. Now that these low-level operators are defined as Keras layers, we can start writing ViT layers such as multi-head attention or layer normalisation functionally:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2fd1ae8", + "metadata": {}, + "outputs": [], + "source": [ + "Tanh = tf.keras.layers.Activation('tanh')\n", + "\n", + "def patch_encoder(inp, patch_size, num_patches, embedding_dim):\n", + " \"\"\"\n", + " Patch encoder layer, extracts patches from the image, flattens them \n", + " and adds the class token and positional embedding vectors.\n", + " \"\"\"\n", + " x = tf.keras.layers.Conv2D(filters=embedding_dim, kernel_size=patch_size,\n", + " strides=patch_size, name='patch_encoder/conv2d')(inp)\n", + " x = tf.keras.layers.Reshape((num_patches, embedding_dim))(x)\n", + "\n", + " # add the class token\n", + " cls_token = BroadcastToken(embedding_dim=embedding_dim, name='patch_encoder/cls_token')(inp)\n", + " x = tf.keras.layers.Concatenate(axis=1)([cls_token, x])\n", + "\n", + " x = AddPositionalEmbedding(seq_len=(num_patches + 1), # +1 for the class token\n", + " embedding_dim=embedding_dim,\n", + " name='patch_encoder/add_pos_emb')(x)\n", + " return x\n", + "\n", + "def self_attention(x, n_heads, dim, name='mha'):\n", + " \"\"\"Multi-head attention layer\"\"\"\n", + " depth = dim // n_heads\n", + "\n", + " q = tf.keras.layers.Dense(units=dim, name=f'{name}/query')(x)\n", + " k = tf.keras.layers.Dense(units=dim, name=f'{name}/key')(x)\n", + " v = tf.keras.layers.Dense(units=dim, name=f'{name}/value')(x)\n", + "\n", + " q = tf.keras.layers.Reshape((-1, n_heads, depth))(q)\n", + " q = tf.keras.layers.Permute((2, 1, 3))(q)\n", + " k = tf.keras.layers.Reshape((-1, n_heads, depth))(k)\n", + " k = tf.keras.layers.Permute((2, 1, 3))(k)\n", + " v = tf.keras.layers.Reshape((-1, n_heads, depth))(v)\n", + " v = tf.keras.layers.Permute((2, 1, 3))(v)\n", + "\n", + " qk = ScalarMultiply(depth ** -0.5)(MatMul(transpose_b=True)([q, k]))\n", + " attn_weights = tf.keras.layers.Softmax(axis=-1)(qk)\n", + "\n", + " attn_out = MatMul()([attn_weights, v]) \n", + " attn_out = tf.keras.layers.Permute((2, 1, 3))(attn_out)\n", + " attn_out = tf.keras.layers.Reshape((-1, dim))(attn_out)\n", + " out = tf.keras.layers.Dense(dim, name=f'{name}/output_dense')(attn_out)\n", + "\n", + " return out\n", + "\n", + "def layer_norm(x, axes=2, epsilon=0.001, name='layer_norm', trainable=True):\n", + " \"\"\"LayerNormalization\"\"\"\n", + " if isinstance(axes, int): axes = [axes]\n", + "\n", + " mean = Mean(axes=axes)(x)\n", + " ## This block can be replaced with a squared_difference layer ##\n", + " diff = Subtract()([x, StopGradient()(mean)]) ##\n", + " sq_diff = Multiply()([diff, diff]) ##\n", + " ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##\n", + " variance = Mean(axes=axes, name=f'{name}/variance')(sq_diff)\n", + " if not trainable:\n", + " inv = RSqrt()(variance)\n", + " x = Multiply()([diff, inv])\n", + " else:\n", + " inv = RSqrt()(ClipMin(min_val=epsilon)(variance)) # ClipMin prevents division by 0.\n", + " x = Subtract(name=f'{name}/grad_subtract')([x, mean]) # This layer is removed for inference so it is named.\n", + " x = Multiply()([x, inv])\n", + "\n", + " x = Scale(axes=axes)(x)\n", + " x = Centre(axes=axes)(x)\n", + "\n", + " return x\n", + "\n", + "def gelu(x):\n", + " \"\"\"Functional definition of approximate GELU with Keras layers\"\"\"\n", + " res = Add()([x, ScalarMultiply(0.044715)(Multiply()([x, Multiply()([x, x])]))])\n", + " res = ScalarAdd(1.0)(Tanh(ScalarMultiply(math.sqrt(2 / math.pi))(res)))\n", + " res = ScalarMultiply(0.5)(res)\n", + " res = Multiply()([x, res])\n", + " return res\n", + "\n", + "def mlp(x, hidden_dim, out_dim):\n", + " \"\"\"Multi-layer perceptron block\"\"\"\n", + " x = tf.keras.layers.Dense(units=hidden_dim)(x)\n", + " x = gelu(x)\n", + " x = tf.keras.layers.Dense(units=out_dim)(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "0b2464d0", + "metadata": {}, + "source": [ + "### 3. Full functional model definition:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5ce1116", + "metadata": {}, + "outputs": [], + "source": [ + "def get_vision_transformer(input_shape,\n", + " n_classes,\n", + " patch_size,\n", + " embedding_dim,\n", + " n_layers,\n", + " n_attention_heads,\n", + " mlp_hidden_dim,\n", + " trainable=True):\n", + " \"\"\"\n", + " Args:\n", + " input_shape (tuple): Shape of the inputs, including the batch size.\n", + " n_classes (int): Number of classes in the dataset.\n", + " patch_size (int / tuple of ints): Size of the patches to extract from the images.\n", + " embedding_dim (int): Size of the embedded patch vectors.\n", + " n_layers (int): Number of transformer encoder layers.\n", + " n_attention_heads (int): Number of attention heads.\n", + " mlp_hidden_dim (int): Hidden layer size for the intermediate MLPs.\n", + "\n", + " Returns:\n", + " model (tf.keras.Model): The Keras model.\n", + " \"\"\"\n", + "\n", + " if isinstance(patch_size, int): patch_size = (patch_size, patch_size)\n", + "\n", + " # Calculate the number of patches\n", + " num_patches = (input_shape[1] * input_shape[2]) // (patch_size[0] * patch_size[1])\n", + "\n", + " inp = tf.keras.layers.Input(shape=input_shape[1:], batch_size=input_shape[0], name='image')\n", + "\n", + " # Patch encoder layer\n", + " x = patch_encoder(inp, patch_size, num_patches, embedding_dim)\n", + "\n", + " for block in range(n_layers):\n", + " # Attention block\n", + " x1 = layer_norm(x, name=(f'layer_norm_{2 * block}' if block != 0 else 'layer_norm'), trainable=trainable)\n", + " x1 = self_attention(x1, n_attention_heads, embedding_dim, name=(f'mha_{block}' if block != 0 else 'mha'))\n", + " x1 = tf.keras.layers.Add()([x1, x])\n", + "\n", + " # MLP block\n", + " x2 = layer_norm(x1, name=f'layer_norm_{2 * block + 1}', trainable=trainable)\n", + " x2 = mlp(x2, mlp_hidden_dim, embedding_dim)\n", + " x = tf.keras.layers.Add()([x2, x1])\n", + "\n", + " x = layer_norm(x, name=f'layer_norm_{2 * block + 2}', trainable=trainable)\n", + "\n", + " ## ~ Classification head ~ ##\n", + " cls_head = Slice(0)(x)\n", + " out = tf.keras.layers.Dense(n_classes, kernel_initializer='zeros', name='cls_head')(cls_head)\n", + " ## ~~~~~~~~~~~~~~~~~~~~~~~ ##\n", + "\n", + " model = tf.keras.Model(inputs=inp, outputs=out)\n", + "\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "id": "9a9060af", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13f55e4d", + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZE = 32\n", + "\n", + "# Load the MNIST dataset\n", + "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()\n", + "X_train, X_test = (X_train[..., np.newaxis] / 255.0), (X_test[..., np.newaxis] / 255.0)\n", + "train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(1000) \\\n", + " .batch(BATCH_SIZE, drop_remainder=True) \\\n", + " .prefetch(tf.data.AUTOTUNE)\n", + "test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(BATCH_SIZE, drop_remainder=True) \\\n", + " .prefetch(tf.data.AUTOTUNE)\n", + "\n", + "def compile_and_fit(model, **kwargs):\n", + " model.compile(optimizer=\"adam\",\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=[\"accuracy\"])\n", + " model.fit(train_ds, validation_data=test_ds, **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96da5e80", + "metadata": {}, + "outputs": [], + "source": [ + "model = get_vision_transformer(input_shape=(BATCH_SIZE, 28, 28, 1), # (batch_size, height, width, channels)\n", + " n_classes=10,\n", + " patch_size=(4, 4),\n", + " embedding_dim=16,\n", + " n_layers=2,\n", + " n_attention_heads=2,\n", + " mlp_hidden_dim=16)\n", + "\n", + "compile_and_fit(model, epochs=3)" + ] + }, + { + "cell_type": "markdown", + "id": "edf7f692", + "metadata": {}, + "source": [ + "## Pruning, Clustering & QAT" + ] + }, + { + "cell_type": "markdown", + "id": "9d49a31f", + "metadata": {}, + "source": [ + "### 1. Apply the pruning API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e3048bc", + "metadata": {}, + "outputs": [], + "source": [ + "prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude\n", + "strip_pruning = tfmot.sparsity.keras.strip_pruning\n", + "\n", + "N_EPOCHS = 1\n", + "pruning_params = {\n", + " 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.1, final_sparsity=0.5,\n", + " begin_step=0, end_step=int(len(train_ds)*N_EPOCHS*0.7))\n", + "}\n", + "pruned_model = prune_low_magnitude(model, **pruning_params)\n", + "# Fine-tune with pruning\n", + "compile_and_fit(pruned_model, epochs=N_EPOCHS, callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])\n", + "stripped_pruned_model = strip_pruning(pruned_model)\n", + "print('Success')" + ] + }, + { + "cell_type": "markdown", + "id": "bcc2c089", + "metadata": {}, + "source": [ + "#### 1.1. Check that the weights are pruned" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "490e2923", + "metadata": {}, + "outputs": [], + "source": [ + "def print_sparsity(model):\n", + " for w in model.weights:\n", + " n_weights = w.numpy().size\n", + " n_zeros = np.count_nonzero(w == 0)\n", + " sparsity = n_zeros / n_weights * 100.0\n", + " if sparsity > 0:\n", + " print(' {} - {:.1f}% sparsity'.format(w.name, sparsity))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6542138", + "metadata": {}, + "outputs": [], + "source": [ + "print('Sparse weights:')\n", + "print_sparsity(stripped_pruned_model)" + ] + }, + { + "cell_type": "markdown", + "id": "7a084e9d", + "metadata": {}, + "source": [ + "### 2. Apply the clustering API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a24393d1", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster\n", + "\n", + "cluster_weights = cluster.cluster_weights\n", + "CentroidInitialization = tfmot.clustering.keras.CentroidInitialization\n", + "strip_clustering = tfmot.clustering.keras.strip_clustering\n", + "\n", + "# Add sparsity-preserving clustering wrappers\n", + "pruned_clustered_model = cluster_weights(stripped_pruned_model,\n", + " number_of_clusters=4,\n", + " cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,\n", + " preserve_sparsity=True)\n", + "# Fine-tune with clustering\n", + "compile_and_fit(pruned_clustered_model, epochs=1)\n", + "stripped_pruned_clustered_model = strip_clustering(pruned_clustered_model)\n", + "print('Success')" + ] + }, + { + "cell_type": "markdown", + "id": "0180d56c", + "metadata": {}, + "source": [ + "#### 2.1. Check that the weights are pruned and clustered" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6169783c", + "metadata": {}, + "outputs": [], + "source": [ + "def print_clusters(model):\n", + " for w in model.weights:\n", + " n_weights = w.numpy().size\n", + " n_unique = len(np.unique(w))\n", + " if n_unique < n_weights:\n", + " print(' {} - {} unique weights'.format(w.name, n_unique))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b1d8480", + "metadata": {}, + "outputs": [], + "source": [ + "print('Sparse weights:')\n", + "print_sparsity(stripped_pruned_clustered_model)\n", + "print('Clustered weights:')\n", + "print_clusters(stripped_pruned_clustered_model)" + ] + }, + { + "cell_type": "markdown", + "id": "5ec4c7b4", + "metadata": {}, + "source": [ + "**Warning: The original model is modified after calling [`prune_low_magnitude`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude) or [`cluster_weights`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/clustering/keras/cluster_weights).**" + ] + }, + { + "cell_type": "markdown", + "id": "209c2d0a", + "metadata": {}, + "source": [ + "### 3. Quantisation-aware training API\n", + "#### 3.1. To use the custom Keras layers we defined, we need to pass a [`QuantizeConfig`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/QuantizeConfig) for each of these layers.\n", + "\n", + "For Keras layers which are already supported in TFMOT, a default `QuantizeConfig` class is assigned to each one. However custom `QuantizeConfig` instances could also be created for these layers to give more control over how they are quantised." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40c11c58", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow_model_optimization.quantization.keras import QuantizeConfig, quantizers\n", + "\n", + "LastValueQuantizer = quantizers.LastValueQuantizer\n", + "MovingAverageQuantizer = quantizers.MovingAverageQuantizer\n", + "AllValuesQuantizer = quantizers.AllValuesQuantizer\n", + "\n", + "class NoOpQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which does not quantize any part of the layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "class OutputQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which only quantizes the output of a layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "class WeightQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which quantizes the custom weights in the patch encoder and layer normalisation layers.\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.weight_quantizer = LastValueQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=True, narrow_range=True)\n", + " self.activation_quantizer = MovingAverageQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=False, narrow_range=False)\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return [(layer.w, self.weight_quantizer)]\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " layer.w = quantize_weights[0]\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [self.activation_quantizer]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "class VarianceQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig for the variance calculation in the layer normalisation layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}" + ] + }, + { + "cell_type": "markdown", + "id": "a928d97d", + "metadata": {}, + "source": [ + "Since custom layers and `QuantizeConfig`s are used, the whole model cannot directly be wrapped with QAT wrappers.
\n", + "So first we write a function to wrap the individual layers with QAT wrappers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43837c2e", + "metadata": {}, + "outputs": [], + "source": [ + "def apply_wrapper(wrapper_function, layer_param_dict):\n", + "\n", + " def wrap_layer(layer):\n", + " if layer.name in layer_param_dict.keys():\n", + " return wrapper_function(layer, **layer_param_dict[layer.name])\n", + " return layer\n", + "\n", + " return wrap_layer\n", + "\n", + "def layer_wrapper(model, wrapper_function, layer_param_dict):\n", + " return tf.keras.models.clone_model(model, clone_function=apply_wrapper(wrapper_function, layer_param_dict))" + ] + }, + { + "cell_type": "markdown", + "id": "26030a8c", + "metadata": {}, + "source": [ + "The custom layers should be quantized with the following `QuantizeConfig` classes:\n", + "\n", + "| Custom Layer | QuantizeConfig |\n", + "| :- | :-: |\n", + "| ClipMin | NoOpQuantizeConfig |\n", + "| Slice | NoOpQuantizeConfig |\n", + "| StopGradient | NoOpQuantizeConfig |\n", + "| MatMul | OutputQuantizeConfig |\n", + "| Multiply | OutputQuantizeConfig |\n", + "| ScalarMultiply | OutputQuantizeConfig |\n", + "| Add | OutputQuantizeConfig |\n", + "| ScalarAdd | OutputQuantizeConfig |\n", + "| Subtract | OutputQuantizeConfig |\n", + "| RSqrt | OutputQuantizeConfig |\n", + "| Mean
Mean (variance) | OutputQuantizeConfig
VarianceQuantizeConfig |\n", + "| BroadcastToken | WeightQuantizeConfig |\n", + "| AddPositionalEmbedding | WeightQuantizeConfig |\n", + "| Scale | WeightQuantizeConfig |\n", + "| Centre | WeightQuantizeConfig |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ed06321", + "metadata": {}, + "outputs": [], + "source": [ + "def get_quant_configs(model):\n", + " layer_param_dict = {} # stores {Layer_Name: QuantizeConfig} pairs\n", + " scope = {} # stores all custom objects\n", + "\n", + " for layer in model.layers:\n", + "\n", + " if any([x in layer.name for x in ['clip', 'slice', 'stop_gradient']]):\n", + " layer_param_dict[layer.name] = {'quantize_config': NoOpQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " elif any([x in layer.name for x in ['mat_mul', 'multiply', 'scalar_multiply', 'add', \\\n", + " 'scalar_add', 'mean', 'subtract', 'r_sqrt']]):\n", + " layer_param_dict[layer.name] = {'quantize_config': OutputQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " elif any([x in layer.name for x in ['patch_encoder/cls_token', 'patch_encoder/add_pos_emb', \\\n", + " 'scale', 'centre']]):\n", + " layer_param_dict[layer.name] = {'quantize_config': WeightQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " elif 'variance' in layer.name:\n", + " layer_param_dict[layer.name] = {'quantize_config': VarianceQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " scope['NoOpQuantizeConfig'] = NoOpQuantizeConfig\n", + " scope['OutputQuantizeConfig'] = OutputQuantizeConfig\n", + " scope['WeightQuantizeConfig'] = WeightQuantizeConfig\n", + " scope['VarianceQuantizeConfig'] = VarianceQuantizeConfig\n", + "\n", + " return layer_param_dict, scope" + ] + }, + { + "cell_type": "markdown", + "id": "55225a25", + "metadata": {}, + "source": [ + "#### 3.2 Load the necessary API classes/functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff279a49", + "metadata": {}, + "outputs": [], + "source": [ + "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n", + "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n", + "quantize_apply = tfmot.quantization.keras.quantize_apply\n", + "quantize_scope = tfmot.quantization.keras.quantize_scope\n", + "Default8BitClusterPreserveQuantizeScheme = tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme\n", + "strip_clustering_cqat = tfmot.experimental.combine.strip_clustering_cqat" + ] + }, + { + "cell_type": "markdown", + "id": "1f6598c0", + "metadata": {}, + "source": [ + "#### 3.3 Apply QAT\n", + "\n", + "When calling the `quantize_apply` function, if an unsupported layer is missing from `layer_param_dict` or the `scope`, TFMOT will throw an error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "881a8ee9", + "metadata": {}, + "outputs": [], + "source": [ + "layer_param_dict, scope = get_quant_configs(stripped_pruned_clustered_model)\n", + "\n", + "# Wrap each custom layer with the corresponding QuantizeConfig:\n", + "pcqat_model = layer_wrapper(stripped_pruned_clustered_model, quantize_annotate_layer, layer_param_dict)\n", + "# Quantize the rest of the model with the API defaults:\n", + "pcqat_model = quantize_annotate_model(pcqat_model)\n", + "\n", + "with quantize_scope(scope):\n", + " pcqat_model = quantize_apply(pcqat_model, scheme=Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))\n", + "\n", + "compile_and_fit(pcqat_model, epochs=2)\n", + "pcqat_model = strip_clustering_cqat(pcqat_model) # strip clustering variables\n", + "\n", + "WEIGHTS_PATH = './ViT_PCQAT.h5'\n", + "pcqat_model.save_weights(WEIGHTS_PATH)\n", + "print('Success')" + ] + }, + { + "cell_type": "markdown", + "id": "ce333f92", + "metadata": {}, + "source": [ + "#### 3.4. Check that the weights are still pruned and clustered" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0058a007", + "metadata": {}, + "outputs": [], + "source": [ + "print('Sparse weights:')\n", + "print_sparsity(pcqat_model)\n", + "print('Clustered weights:')\n", + "print_clusters(pcqat_model)" + ] + }, + { + "cell_type": "markdown", + "id": "26864702", + "metadata": {}, + "source": [ + "### 4. Generate an int8 TFLite file\n", + "\n", + "If we attempt to directly generate a TFLite file using the fine-tuned model above:\n", + "1. It will not have a correct batch size of 1.\n", + "2. It will have operators which are unnecessary during inference. Precisely, the extra `Subtract` operators and `ClipMin` operator in the layer normalisation blocks, which were used during training and fine-tuning, should be removed from the graph before creating the TFLite file.\n", + "\n", + "Therefore the network should be redefined with a batch size of 1 and with the redundant operators removed. The weights of the fine-tuned optimised model can then be loaded into this new model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86b93732", + "metadata": {}, + "outputs": [], + "source": [ + "tf.keras.backend.clear_session() # reset layer name counters\n", + "\n", + "net = get_vision_transformer(input_shape=(1, 28, 28, 1), # (batch_size, height, width, channels)\n", + " n_classes=10,\n", + " patch_size=(4, 4),\n", + " embedding_dim=16,\n", + " n_layers=2,\n", + " n_attention_heads=2,\n", + " mlp_hidden_dim=16,\n", + " trainable=False)\n", + "layer_param_dict, scope = get_quant_configs(net)\n", + "net = quantize_annotate_model(layer_wrapper(net, quantize_annotate_layer, layer_param_dict))\n", + "with quantize_scope(scope):\n", + " net = quantize_apply(net)\n", + "\n", + "net.load_weights(WEIGHTS_PATH, by_name=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d00bbff7", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PATH = './ViT_PCQAT_int8.tflite'\n", + "\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(net)\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "converter.inference_input_type = tf.int8\n", + "converter.inference_output_type = tf.int8\n", + "\n", + "# Experimental flag which improves efficiency for some devices\n", + "converter._experimental_disable_batchmatmul_unfold = True\n", + "\n", + "tflite_model = converter.convert()\n", + "with open(MODEL_PATH, \"wb+\") as tflite_file:\n", + " tflite_file.write(tflite_model)" + ] + }, + { + "cell_type": "markdown", + "id": "9c7df6b2", + "metadata": {}, + "source": [ + "### 5. Evaluate the TFLite model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a5d2a75", + "metadata": {}, + "outputs": [], + "source": [ + "interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)\n", + "interpreter.allocate_tensors()\n", + "\n", + "input_details = interpreter.get_input_details()\n", + "output_details = interpreter.get_output_details()\n", + "input_scale, input_zero_point = input_details[0]['quantization']\n", + "output_scale, output_zero_point = output_details[0]['quantization']\n", + "\n", + "int8_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='int8_accuracy')\n", + "progbar = tf.keras.utils.Progbar(len(X_test), stateful_metrics=['accuracy'])\n", + "for step, (img, lbl) in enumerate(zip(X_test, y_test)):\n", + " # Set input tensor\n", + " img = img[np.newaxis, ...] / input_scale + input_zero_point\n", + " interpreter.set_tensor(input_details[0]['index'], tf.cast(img, input_details[0]['dtype']))\n", + " interpreter.invoke()\n", + "\n", + " # Get output tensor\n", + " output_data = interpreter.get_tensor(output_details[0]['index'])\n", + " output_data = output_scale * (output_data.astype(np.float32) - output_zero_point)\n", + "\n", + " # Update accuracy\n", + " int8_accuracy.update_state(lbl, output_data)\n", + " progbar.update(step + 1, values=[('accuracy', int8_accuracy.result().numpy())])\n", + "\n", + "print('Accuracy:', int8_accuracy.result().numpy())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/transformer_tutorials/install.sh b/tutorials/transformer_tutorials/install.sh new file mode 100755 index 0000000..8d0fbb7 --- /dev/null +++ b/tutorials/transformer_tutorials/install.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +#Creates virtual environment and installs dependencies + +python3 -m venv ./venv +source ./venv/bin/activate +pip install --upgrade pip +pip install -r requirements.txt diff --git a/tutorials/transformer_tutorials/requirements.txt b/tutorials/transformer_tutorials/requirements.txt new file mode 100644 index 0000000..1a14560 --- /dev/null +++ b/tutorials/transformer_tutorials/requirements.txt @@ -0,0 +1,5 @@ +notebook +tensorflow>=2.6,<2.8 +tensorflow_model_optimization==0.7.0 +nltk==3.6.5 +flatbuffers==1.12.0 diff --git a/tutorials/transformer_tutorials/translation.ipynb b/tutorials/transformer_tutorials/translation.ipynb new file mode 100644 index 0000000..b424077 --- /dev/null +++ b/tutorials/transformer_tutorials/translation.ipynb @@ -0,0 +1,2155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sequence to Sequence Transformer model optimisation using TFMOT (Quantization Aware Training)\n", + "\n", + "Example notebook to demonstrate how TFMOT can be used for optimising complex sequence to sequence transformer models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Background\n", + "\n", + "The sequence to sequence transformer is one of the initial transformer model architectures. The core idea behind the Transformer model is self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence. The paper called [\"Attention Is All You Need\"](https://arxiv.org/pdf/1706.03762.pdf) might give a deeper insight into transformer model and their self-attention mechanism.\n", + "\n", + "\"Sequence\n", + "\n", + "[1] The above image was taken from [here](https://deepfrench.gitlab.io/deep-learning-project/)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### In this notebook:\n", + "\n", + "* The aim of this tutorial is to first train the Transformer model from [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)\n", + "* Re-write the above model as a funtional FP32 model\n", + "* Perform Quantized Aware Training (QAT) on the FP32 model\n", + "* Create and test the tflite model generated from the FP32 model after performing QAT on it. \n", + "\n", + "Note: This tutorial has re-used some code and explanation from the original [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### TFMOT limitations\n", + "- Subclassed models are not supported. Only sequential and functional model definitions are supported. (Pruning, Clustering & QAT)\n", + "- Custom subclassed layers are not supported. (Clustering & QAT)\n", + " - Clustering will only work with subclassed layers if the weight variables you have to cluster are not nested within another layer (e.g. MHA).\n", + " - QAT works correctly if the subclassed layer performs only 1 operation.\n", + "- Low-level tensorflow operators such as `tf.linalg.matmul` are not supported. (Only for QAT)\n", + " - QAT expects all quantised layers to be a subclass of `tf.keras.layers.Layer`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pathlib\n", + "import random\n", + "import tempfile\n", + "import zipfile\n", + "import re\n", + "import os\n", + "import nltk\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow.keras import layers\n", + "from tensorflow.keras.layers import TextVectorization\n", + "import tensorflow_model_optimization as tfmot\n", + "from collections import defaultdict\n", + "\n", + "def reset_random_seeds():\n", + " os.environ['PYTHONHASHSEED']=str(2)\n", + " tf.random.set_seed(2)\n", + " np.random.seed(2)\n", + " random.seed(2)\n", + "\n", + "reset_random_seeds()\n", + "\n", + "print('TensorFlow version: {}'.format(tf.__version__))\n", + "print('TFMOT version: {}'.format(tfmot.__version__))\n", + "print(\"NLTK verison: {}\".format(nltk.__version__))\n", + "print(\"Numpy version: {}\".format(np.__version__))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Downloading the data\n", + "\n", + "The dataset used here is English to Spanish translation dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text_file = keras.utils.get_file(\n", + " fname=\"spa-eng.zip\",\n", + " origin=\"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\",\n", + " extract=True,\n", + ")\n", + "text_file = pathlib.Path(text_file).parent / \"spa-eng\" / \"spa.txt\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Parsing the data\n", + "\n", + "Each target sentence (which is in Spanish) has `[start]` and `[end]` token prepended and appended, respectively, at this stage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(text_file) as f:\n", + " lines = f.read().split(\"\\n\")[:-1]\n", + "text_pairs = []\n", + "for line in lines:\n", + " eng, spa = line.split(\"\\t\")\n", + " spa = \"[start] \" + spa + \" [end]\"\n", + " text_pairs.append((eng, spa))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for _ in range(5):\n", + " print(random.choice(text_pairs))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Split the dataset into train, test and validation set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random.shuffle(text_pairs)\n", + "num_val_samples = int(0.15 * len(text_pairs))\n", + "num_train_samples = len(text_pairs) - 2 * num_val_samples\n", + "train_pairs = text_pairs[:num_train_samples]\n", + "val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]\n", + "test_pairs = text_pairs[num_train_samples + num_val_samples :]\n", + "\n", + "print(f\"{len(text_pairs)} total pairs\")\n", + "print(f\"{len(train_pairs)} training pairs\")\n", + "print(f\"{len(val_pairs)} validation pairs\")\n", + "print(f\"{len(test_pairs)} test pairs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Vectorizing the text data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Vectorization refers to the preprocessing step where text features are mapped to integer sequences where each integer represents the index of a word in a vocubulary. For this, [`tf.keras.layers.TextVecorization`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer is used.\n", + "\n", + "In our case, vectorization for English sequences is a little different from that of Spanish sequences:\n", + "\n", + "- For English string sequences, default standardization is used which strips all punctuation characters\n", + "- For Spanish string sequences, custom standardization is used which strips all characters which are not in `{` a-z.?!,¿[]`}`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vocab_size = 15000\n", + "seq_len = 20\n", + "batch_size = 64\n", + "embed_dim = 256\n", + "latent_dim = 2048\n", + "num_heads = 8\n", + "\n", + "\n", + "def custom_standardization(input_string):\n", + " lowercase = tf.strings.lower(input_string)\n", + " # The following regex replaces a character with \"\"\n", + " # which is not one of the following:\n", + " # 1. Lower case alphabet\n", + " # 2. Space\n", + " # 3. Is on of these characters: \".\", \"?\", \"!\", \",\", \"¿\", \"[\", \"]\"\n", + " return tf.strings.regex_replace(lowercase, '[^ a-z.?!,¿\\[\\]]', \"\")\n", + "\n", + "\n", + "eng_vectorization = TextVectorization(\n", + " max_tokens=vocab_size, output_mode=\"int\", output_sequence_length=seq_len,\n", + ")\n", + "spa_vectorization = TextVectorization(\n", + " max_tokens=vocab_size,\n", + " output_mode=\"int\",\n", + " output_sequence_length=seq_len + 1,\n", + " standardize=custom_standardization,\n", + ")\n", + "\n", + "train_eng_texts = [pair[0] for pair in train_pairs]\n", + "train_spa_texts = [pair[1] for pair in train_pairs]\n", + "eng_vectorization.adapt(train_eng_texts)\n", + "spa_vectorization.adapt(train_spa_texts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At each training step, the model will seek to predict target words N+1 (and beyond) using the source (or the input) sentence and the target words 0 to N. For this reason, we need (`inputs`, `targets`)\n", + "\n", + "- `inputs`:\n", + "\n", + " After vectorization, our dataset is formatted to include the following four in the `inputs` (`inputs` is essentially a list of four inputs):\n", + "\n", + " * encoder_inputs : which contains the vectorized english sentence data\n", + " * decoder_inputs : which contains the vectorized spanish (target) sentence data, i.e. target_sentence[:, :-1]. It is also the target sentence \"so far\", that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence. \n", + " * encoder_masks : which contains the corresponding mask data for encoder_inputs\n", + " * decoder_masks : which contains the corresponding mask data for decoder_inputs\n", + "\n", + " Please note that the two mask inputs are only required for the custom FP32 functional model as the original keras model is able to generate it's own mask. Therefore, the original model tends to ignore the two mask inputs (user doesn't need to worry about this).

\n", + " \n", + "- `targets`:\n", + "\n", + " After vectorization, our dataset is formatted to assign the target sentence offset by one (i.e. target_sentence[:, 1:]) as the `targets`. In other words this is what model will try to predict." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def format_dataset(eng, spa):\n", + " eng = eng_vectorization(eng)\n", + " spa = spa_vectorization(spa)\n", + "\n", + " # Create input masks\n", + " encoder_masks=tf.cast(tf.not_equal(np.int64(0),eng),tf.float32)\n", + " decoder_masks=tf.cast(tf.not_equal(np.int64(0),spa[:, :-1]),tf.float32)\n", + " \n", + " return ({\"encoder_inputs\": eng, \"encoder_masks\": encoder_masks, \"decoder_inputs\": spa[:, :-1], \"decoder_masks\": decoder_masks}, spa[:, 1:])\n", + "\n", + "def make_dataset(pairs, batch_size=64):\n", + " eng_texts, spa_texts = zip(*pairs)\n", + " eng_texts = list(eng_texts)\n", + " spa_texts = list(spa_texts)\n", + " dataset = tf.data.Dataset.from_tensor_slices((eng_texts, spa_texts))\n", + " dataset = dataset.batch(batch_size, drop_remainder=True)\n", + " dataset = dataset.map(format_dataset)\n", + " \n", + " return dataset.shuffle(2048).prefetch(16).cache()\n", + "\n", + "\n", + "train_ds = make_dataset(train_pairs, batch_size)\n", + "val_ds = make_dataset(val_pairs, batch_size)\n", + "test_ds = make_dataset(test_pairs, batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for inputs, targets in train_ds.take(1):\n", + " print(f'inputs[\"encoder_inputs\"].shape: {inputs[\"encoder_inputs\"].shape}')\n", + " print(f'inputs[\"decoder_inputs\"].shape: {inputs[\"decoder_inputs\"].shape}')\n", + " print(f'inputs[\"encoder_masks\"].shape: {inputs[\"encoder_masks\"].shape}')\n", + " print(f'inputs[\"decoder_masks\"].shape: {inputs[\"decoder_masks\"].shape}')\n", + " print(f\"targets.shape: {targets.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Utility functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Typically, BLEU score is used to measure the quality of a translation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def bleu_score(real_text, predicted_text):\n", + " '''Get BLEU score'''\n", + " return (nltk.translate.bleu_score.corpus_bleu(real_text,predicted_text))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For decoding (or in other words translating a source sentence to a target sentenceg), we provide a vectorized source sentence as `encoder_inputs` and a vecotrized `[start]` token (ofcourse, padded to match the right sequence length) as the `decoder_inputs`, then we repeatedly generated the next token, until we hit the token `[end]`.\n", + "\n", + "A key thing to note is that in the custom FP32 functional model used in this notebook `encoder_masks` and `decoder_masks` are also fed into the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_text_result(model, num_samples_to_eval =200, no_input_masks=False):\n", + " '''\n", + " Function to calculate BLEU score on test set\n", + "\n", + " num_samples_to_eval: Represents the total number of test sentences to\n", + " consider during evaluation. If you want the entire \n", + " test set to be used for evaluation then set \n", + " max_sample = -1\n", + " \n", + " no_input_masks: Set as True for the original transformer model from\n", + " keras example\n", + " '''\n", + "\n", + " spa_vocab = spa_vectorization.get_vocabulary()\n", + " spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))\n", + " max_decoded_sentence_length = 20\n", + "\n", + " def decode_sequence_func(input_sentence):\n", + "\n", + " tokenized_input_sentence = eng_vectorization([input_sentence])\n", + " encoder_mask = tf.cast(tf.not_equal(np.int64(0),tokenized_input_sentence), tf.float32)\n", + "\n", + " decoded_sentence = \"[start]\"\n", + " for i in range(max_decoded_sentence_length):\n", + " tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]\n", + " decoder_mask=tf.cast(tf.not_equal(np.int64(0),tokenized_target_sentence), tf.float32)\n", + " if no_input_masks:\n", + " predictions = model([tokenized_input_sentence, tokenized_target_sentence])\n", + " else:\n", + " predictions = model([tokenized_input_sentence, encoder_mask, tokenized_target_sentence,decoder_mask])\n", + " sampled_token_index = np.argmax(predictions[0, i, :])\n", + " sampled_token = spa_index_lookup[sampled_token_index]\n", + " decoded_sentence += \" \" + sampled_token\n", + "\n", + " if sampled_token == \"[end]\":\n", + " break\n", + "\n", + " return decoded_sentence\n", + "\n", + "\n", + " hypothesis= []\n", + " references = []\n", + " test_sample_count = sum(1 for e in test_pairs) \n", + " progbar = tf.keras.utils.Progbar(test_sample_count if num_samples_to_eval == -1 else num_samples_to_eval)\n", + "\n", + " for step, (inp, target) in enumerate(test_pairs[:num_samples_to_eval]):\n", + " translated = decode_sequence_func(inp)\n", + " target=target.lower()\n", + " target=re.sub('[^ a-z.?!,¿\\[\\]]', \"\",target)\n", + " hypothesis.append(translated.split()[1:-1])\n", + " references.append([target.split()[1:-1]])\n", + " progbar.update(step + 1)\n", + "\n", + " print(str(\"Bleu Score: \") + str(bleu_score(references[:], hypothesis[:])))\n", + "\n", + " # Print first 10 actual and predicted spanish translation for sanity check\n", + " for i in range(10):\n", + " print(references[i][0])\n", + " print(hypothesis[i])\n", + " print(\"-----------------------/n\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Suggestion: While trying to run inference on a tflite file please make sure that the scale, zero_point and data type are correct for the inputs and outputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_text_result_tflite(model_path, input_type = 'int8/32', output_type = 'int8', num_samples_to_eval = 200):\n", + " '''\n", + " Function to calculate BLEU score for a given tflite file on the test set\n", + "\n", + " model_path: Path to the tflite file\n", + "\n", + " input_type: Could be float32 or int8/32. If the inputs in tflite graph\n", + " are float32 set this value to 'float32' but if inputs are\n", + " int8 (mask inputs) and int32 (non-maks inputs) set this\n", + " value to 'int8/32'.\n", + "\n", + " output_type: Could be float32 or int8. If the outputs in tflite graph\n", + " are float32 set this value to 'float32' but if output\n", + " are int8 set this value to 'int8'.\n", + " \n", + " num_samples_to_eval: Evaluation of entire test set will take a lot\n", + " of time therefore, only first 200 samples are \n", + " evaluated. To evaluate the entire test-set, \n", + " set the value below to a negative value, e.g.\n", + " -1\n", + " '''\n", + " assert(input_type in ['float32', 'int8/32']), \"input_type not supported\"\n", + " assert(output_type in ['float32', 'int8']), \"output_type not supported\"\n", + "\n", + " print('Performing BLEU evaluation for tflite file at {}'.format(model_path))\n", + "\n", + " spa_vocab = spa_vectorization.get_vocabulary()\n", + " spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))\n", + " max_decoded_sentence_length = seq_len\n", + "\n", + " interpreter = tf.lite.Interpreter(model_path=model_path)\n", + "\n", + " input_details = interpreter.get_input_details()\n", + " output_details = interpreter.get_output_details()\n", + "\n", + " input_scale_1, input_zero_point_1 = input_details[0]['quantization']\n", + " input_scale_2, input_zero_point_2 = input_details[1]['quantization']\n", + " input_scale_3, input_zero_point_3 = input_details[2]['quantization']\n", + " input_scale_4, input_zero_point_4 = input_details[3]['quantization']\n", + " output_scale, output_zero_point = output_details[0]['quantization']\n", + "\n", + " interpreter.allocate_tensors()\n", + "\n", + " def decode_sequence_func(input_sentence):\n", + "\n", + " input_1 = eng_vectorization([input_sentence])\n", + " input_2 = tf.cast(tf.not_equal(np.int64(0),input_1), tf.float32)\n", + " if input_type == 'int8/32':\n", + " input_2 = input_2/ input_scale_2 + input_zero_point_2\n", + "\n", + " decoded_sentence = \"[start]\"\n", + "\n", + " for i in range(max_decoded_sentence_length):\n", + " input_3 = spa_vectorization([decoded_sentence])[:, :-1]\n", + " input_4=tf.cast(tf.not_equal(np.int64(0),input_3), tf.float32)\n", + "\n", + " # Set input tensor\n", + " interpreter.set_tensor(input_details[0]['index'], tf.cast(input_1, input_details[0]['dtype']))\n", + "\n", + " # Set input tensor\n", + " interpreter.set_tensor(input_details[1]['index'], tf.cast(input_2, input_details[1]['dtype']))\n", + "\n", + " # Set input tensor\n", + " interpreter.set_tensor(input_details[2]['index'], tf.cast(input_3, input_details[2]['dtype']))\n", + "\n", + " # Set input tensor\n", + " if input_type == 'int8/32':\n", + " input_4 = input_4/ input_scale_4 + input_zero_point_4\n", + " interpreter.set_tensor(input_details[3]['index'], tf.cast(input_4, input_details[3]['dtype']))\n", + "\n", + " interpreter.invoke()\n", + " \n", + " # Get output tensor\n", + " output_data = interpreter.get_tensor(output_details[0]['index'])\n", + " predictions = output_data.astype(np.float32)\n", + " if output_type == 'int8':\n", + " predictions = output_scale * (predictions - output_zero_point)\n", + " \n", + " sampled_token_index = np.argmax(predictions[0, i, :])\n", + " sampled_token = spa_index_lookup[sampled_token_index]\n", + " decoded_sentence += \" \" + sampled_token\n", + "\n", + " if sampled_token == \"[end]\":\n", + " break\n", + "\n", + " return decoded_sentence\n", + "\n", + "\n", + " hypothesis= []\n", + " references = []\n", + " test_sample_count = sum(1 for e in test_pairs) \n", + " progbar = tf.keras.utils.Progbar(test_sample_count if num_samples_to_eval == -1 else num_samples_to_eval)\n", + "\n", + " for step, (inp, target) in enumerate(test_pairs[:num_samples_to_eval]):\n", + " translated = decode_sequence_func(inp)\n", + " target=target.lower()\n", + " target=re.sub('[^ a-z.?!,¿\\[\\]]', \"\",target)\n", + " hypothesis.append(translated.split()[1:-1])\n", + " references.append([target.split()[1:-1]])\n", + " progbar.update(step + 1)\n", + "\n", + " print(str(\"Bleu Score: \") + str(bleu_score(references[:], hypothesis[:])))\n", + " for i in range(10):\n", + " print(references[i][0])\n", + " print(hypothesis[i])\n", + " print(\"-----------------------/n\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_tflite_accuracy(model_path, input_type = 'int8/32', output_type = 'int8', num_samples_to_eval = 200):\n", + " '''\n", + " Function to calculate accuracy of a given tflite file on the test set\n", + "\n", + " model_path: Path to the tflite file\n", + "\n", + " input_type: Could be float32 or int8/32. If the inputs in tflite graph\n", + " are float32 set this value to 'float32' but if inputs are\n", + " int8 (mask inputs) and int32 (non-maks inputs) set this\n", + " value to 'int8/64'.\n", + "\n", + " output_type: Could be float32 or int8. If the outputs in tflite graph\n", + " are float32 set this value to 'float32' but if output\n", + " are int8 set this value to 'int8'.\n", + " \n", + " num_samples_to_eval: Evaluation of entire test set will take a lot\n", + " of time therefore, only first 200 samples are \n", + " evaluated. To evaluate the entire test-set, \n", + " set the value below to a negative value, e.g.\n", + " -1\n", + " '''\n", + " assert(input_type in ['float32', 'int8/32']), \"input_type not supported\"\n", + " assert(output_type in ['float32', 'int8']), \"output_type not supported\"\n", + "\n", + " print('Performing accuracy evaluation for tflite file at {}'.format(model_path))\n", + "\n", + " interpreter = tf.lite.Interpreter(model_path=model_path)\n", + "\n", + " input_details = interpreter.get_input_details()\n", + " output_details = interpreter.get_output_details()\n", + "\n", + " input_scale_1, input_zero_point_1 = input_details[0]['quantization']\n", + " input_scale_2, input_zero_point_2 = input_details[1]['quantization']\n", + " input_scale_3, input_zero_point_3 = input_details[2]['quantization']\n", + " input_scale_4, input_zero_point_4 = input_details[3]['quantization']\n", + " output_scale, output_zero_point = output_details[0]['quantization']\n", + " interpreter.allocate_tensors()\n", + "\n", + " test_ds_tflite = make_dataset(test_pairs, 1)\n", + " accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')\n", + " progbar = tf.keras.utils.Progbar(sum(1 for e in test_ds_tflite) if num_samples_to_eval == -1 else num_samples_to_eval, stateful_metrics=['accuracy'])\n", + "\n", + " for step, (input, target) in enumerate(test_ds_tflite):\n", + "\n", + " # Set input tensor\n", + " input_1 = input['encoder_inputs']\n", + " interpreter.set_tensor(input_details[0]['index'], tf.cast(input_1, input_details[0]['dtype']))\n", + "\n", + " # Set input tensorprogress bars for loopp python\n", + " input_2=input['encoder_masks']\n", + " if input_type == 'int8/32':\n", + " input_2 = tf.cast(input_2, tf.float32)\n", + " input_2 = input_2/ input_scale_2 + input_zero_point_2\n", + " interpreter.set_tensor(input_details[1]['index'], tf.cast(input_2, input_details[1]['dtype']))\n", + "\n", + " # Set input tensor\n", + " input_3 = input['decoder_inputs']\n", + " interpreter.set_tensor(input_details[2]['index'], tf.cast(input_3, input_details[2]['dtype']))\n", + "\n", + " # Set input tensor\n", + " input_4=input['decoder_masks']\n", + " if input_type == 'int8/32':\n", + " input_4 = tf.cast(input_4, tf.float32)\n", + " input_4 = input_4/ input_scale_4 + input_zero_point_4\n", + " interpreter.set_tensor(input_details[3]['index'], tf.cast(input_4, input_details[3]['dtype']))\n", + " interpreter.invoke()\n", + " \n", + " # Get output tensor\n", + " output_data = interpreter.get_tensor(output_details[0]['index'])\n", + " output_data = output_data.astype(np.float32)\n", + " if output_type == 'int8':\n", + " output_data = output_scale * (output_data - output_zero_point)\n", + " \n", + " # Update accuracy\n", + " mask = input['decoder_inputs']\n", + " accuracy.update_state(target, output_data, mask)\n", + " progbar.update(step + 1, values=[('accuracy', accuracy.result().numpy())])\n", + " \n", + " if step == num_samples_to_eval:\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the following function to get the size of the tflite file when zipped" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_gzipped_model_size(file):\n", + " '''Returns the size of a gzipped tflite file in kilobytes'''\n", + "\n", + " _, zipped_file = tempfile.mkstemp('.zip')\n", + " with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n", + " f.write(file)\n", + "\n", + " return os.path.getsize(zipped_file)/1000" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Functions related to Training the model\n", + "\n", + "The loss function used is Masked Sparse Categorical Crossentropy loss (which uses the `tf.keras.losses.SparseCategoricalCrossentropy` but with masks).\n", + "The loss function needs masks to be propogated correctly through the model layers down to the loss function which, the custom FP32 model wasn't able to do correctly therefore, a custom training loop was needed to calculate the loss correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 11\n", + "\n", + "def evaluate(model_to_eval, training=False):\n", + "\n", + " val_loss = tf.keras.metrics.SparseCategoricalCrossentropy()\n", + " val_acc = tf.keras.metrics.SparseCategoricalAccuracy()\n", + "\n", + " @tf.function\n", + " def eval_step(inp, y_true):\n", + " preds = model_to_eval(inp, training=training)\n", + " # masked loss\n", + " val_loss.update_state(y_true, preds,tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32)) \n", + " # masked accuracy\n", + " val_acc.update_state(y_true, preds,tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32)) \n", + "\n", + " for step, (inp, y_true) in enumerate(val_ds):\n", + " eval_step(inp, y_true)\n", + "\n", + " return {'loss': val_loss.result().numpy(), 'accuracy': val_acc.result().numpy()}\n", + "\n", + "\n", + "def train(model_to_train, save_best_weights =True, model_type='original', lr=1e-3, epochs = epochs):\n", + "\n", + " if model_type == 'original':\n", + " ckpt_path = './eng_spa_transformer_qat_tutorial_model.h5'\n", + " elif model_type == 'fp32':\n", + " ckpt_path = './eng_spa_transformer_qat_tutorial_fp32_model.h5'\n", + " elif model_type == 'qat':\n", + " ckpt_path = './eng_spa_transformer_qat_tutorial_qat_model.h5'\n", + " else:\n", + " print('Please select the correct model type!!')\n", + " return None\n", + " \n", + " print('Training (save_best_weights={}, model_type={})'.format(save_best_weights, model_type))\n", + "\n", + " loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()\n", + " optimiser = tf.keras.optimizers.Adam(learning_rate=lr)\n", + " train_acc = tf.keras.metrics.SparseCategoricalAccuracy()\n", + " model_to_train.optimizer = optimiser\n", + "\n", + " @tf.function\n", + " def train_step(inp, y_true):\n", + " mask =tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32)\n", + " preds=None\n", + " loss=None\n", + " \n", + " with tf.GradientTape() as tape:\n", + " preds = model_to_train(inp, training=True)\n", + " # Masked loss\n", + " loss = loss_fn(y_true, preds, mask)\n", + " grads = tape.gradient(loss, model_to_train.trainable_weights)\n", + " optimiser.apply_gradients(zip(grads, model_to_train.trainable_weights))\n", + "\n", + " # Masked accuracy \n", + " train_acc.update_state(y_true, preds, mask)\n", + " return loss\n", + "\n", + " max_val = float('-inf')\n", + "\n", + " for epoch in range(epochs):\n", + " print('Epoch {}/{}'.format(epoch + 1, epochs), flush=True)\n", + " # Train\n", + " progbar = tf.keras.utils.Progbar(len(train_ds), interval=.5,\n", + " stateful_metrics=['acc']) \n", + "\n", + " for step, (inp, y_true) in enumerate(train_ds):\n", + " loss = train_step(inp, y_true)\n", + " progbar.update(step + 1, values=[('loss', loss),\n", + " ('acc', train_acc.result())])\n", + "\n", + " # Evaluate\n", + " val_results = evaluate(model_to_train)\n", + "\n", + " validation_accuracy = val_results['accuracy']\n", + " print('Validation accuracy: {}'.format(validation_accuracy))\n", + "\n", + " if save_best_weights and validation_accuracy > max_val:\n", + " \n", + " print('Best validation accuracy so far, saving weights')\n", + " model_to_train.save_weights(ckpt_path)\n", + " max_val = validation_accuracy\n", + "\n", + " train_acc.reset_states() \n", + "\n", + " if not save_best_weights:\n", + " model_to_train.save_weights(ckpt_path)\n", + " # Load weights\n", + " model_to_train.load_weights(ckpt_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 7. Building the original Transformer Keras model mentioned in the [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Define the custom layers for the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerEncoder(layers.Layer):\n", + " def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):\n", + " super(TransformerEncoder, self).__init__(**kwargs)\n", + " self.embed_dim = embed_dim\n", + " self.dense_dim = dense_dim\n", + " self.num_heads = num_heads\n", + " self.attention = layers.MultiHeadAttention(\n", + " num_heads=num_heads, key_dim=embed_dim\n", + " )\n", + " self.dense_proj = keras.Sequential(\n", + " [layers.Dense(dense_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", + " )\n", + " self.layernorm_1 = layers.LayerNormalization()\n", + " self.layernorm_2 = layers.LayerNormalization()\n", + " self.supports_masking = True\n", + "\n", + " def call(self, inputs, mask=None):\n", + " if mask is not None:\n", + " padding_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=\"int32\")\n", + " attention_output = self.attention(\n", + " query=inputs, value=inputs, key=inputs, attention_mask=padding_mask\n", + " )\n", + " proj_input = self.layernorm_1(inputs + attention_output)\n", + " proj_output = self.dense_proj(proj_input)\n", + " return self.layernorm_2(proj_input + proj_output)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embed_dim': self.embed_dim,\n", + " 'dense_dim': self.dense_dim,\n", + " 'num_heads': self.num_heads})\n", + " return config\n", + "\n", + "\n", + "class PositionalEmbedding(layers.Layer):\n", + " def __init__(self, seq_len, vocab_size, embed_dim, **kwargs):\n", + " super(PositionalEmbedding, self).__init__(**kwargs)\n", + " self.token_embeddings = layers.Embedding(\n", + " input_dim=vocab_size, output_dim=embed_dim\n", + " )\n", + " self.position_embeddings = layers.Embedding(\n", + " input_dim=seq_len, output_dim=embed_dim\n", + " )\n", + " self.seq_len = seq_len\n", + " self.vocab_size = vocab_size\n", + " self.embed_dim = embed_dim\n", + "\n", + " def call(self, inputs):\n", + " length = tf.shape(inputs)[-1]\n", + " positions = tf.range(start=0, limit=length, delta=1)\n", + " embedded_tokens = self.token_embeddings(inputs)\n", + " embedded_positions = self.position_embeddings(positions)\n", + " return embedded_tokens + embedded_positions\n", + "\n", + " def compute_mask(self, inputs, mask=None):\n", + " return tf.math.not_equal(inputs, 0)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embed_dim': self.embed_dim,\n", + " 'vocab_size': self.vocab_size,\n", + " 'seq_len': self.seq_len})\n", + " return config\n", + "\n", + "\n", + "class TransformerDecoder(layers.Layer):\n", + " def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):\n", + " super(TransformerDecoder, self).__init__(**kwargs)\n", + " self.embed_dim = embed_dim\n", + " self.latent_dim = latent_dim\n", + " self.num_heads = num_heads\n", + " self.attention_1 = layers.MultiHeadAttention(\n", + " num_heads=num_heads, key_dim=embed_dim\n", + " )\n", + " self.attention_2 = layers.MultiHeadAttention(\n", + " num_heads=num_heads, key_dim=embed_dim\n", + " )\n", + " self.dense_proj = keras.Sequential(\n", + " [layers.Dense(latent_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", + " )\n", + " self.layernorm_1 = layers.LayerNormalization()\n", + " self.layernorm_2 = layers.LayerNormalization()\n", + " self.layernorm_3 = layers.LayerNormalization()\n", + " self.supports_masking = True\n", + "\n", + " def call(self, inputs, encoder_outputs, mask=None):\n", + " causal_mask = self.get_causal_attention_mask(inputs)\n", + " if mask is not None:\n", + " padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=\"int32\")\n", + " padding_mask = tf.minimum(padding_mask, causal_mask)\n", + "\n", + " attention_output_1 = self.attention_1(\n", + " query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n", + " )\n", + " out_1 = self.layernorm_1(inputs + attention_output_1)\n", + "\n", + " attention_output_2 = self.attention_2(\n", + " query=out_1,\n", + " value=encoder_outputs,\n", + " key=encoder_outputs,\n", + " attention_mask=padding_mask,\n", + " )\n", + " out_2 = self.layernorm_2(out_1 + attention_output_2)\n", + "\n", + " proj_output = self.dense_proj(out_2)\n", + " return self.layernorm_3(out_2 + proj_output)\n", + "\n", + " def get_causal_attention_mask(self, inputs):\n", + " input_shape = tf.shape(inputs)\n", + " batch_size, seq_len = input_shape[0], input_shape[1]\n", + " i = tf.range(seq_len)[:, tf.newaxis]\n", + " j = tf.range(seq_len)\n", + " mask = tf.cast(i >= j, dtype=\"int32\")\n", + " mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))\n", + " mult = tf.concat(\n", + " [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],\n", + " axis=0,\n", + " )\n", + " return tf.tile(mask, mult)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embed_dim': self.embed_dim,\n", + " 'latent_dim': self.latent_dim,\n", + " 'num_heads': self.num_heads})\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Build the end-to-end model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_encoder_decoder_model():\n", + " encoder_inputs = keras.Input(shape=(20,), dtype=\"int64\", name=\"encoder_inputs\")\n", + " x = PositionalEmbedding(seq_len, vocab_size, embed_dim)(encoder_inputs)\n", + " encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)\n", + " decoder_inputs = keras.Input(shape=(20,), dtype=\"int64\", name=\"decoder_inputs\")\n", + " encoded_seq_inputs = encoder_outputs\n", + " x = PositionalEmbedding(seq_len, vocab_size, embed_dim)(decoder_inputs)\n", + " x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)\n", + " x = layers.Dropout(0.5)(x)\n", + " decoder_outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n", + " \n", + " transformer = keras.Model(\n", + " [encoder_inputs, decoder_inputs], decoder_outputs, name=\"transformer\"\n", + " )\n", + "\n", + " return transformer\n", + "\n", + "transformer = get_encoder_decoder_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Training the original Transformer model from Keras example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transformer.summary()\n", + "train(transformer, model_type='original')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(d) Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get BLEU score on test set for original transformer model\n", + "get_text_result(transformer, no_input_masks=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get accuracy on test set for the original transformer model from Keras example\n", + "evaluate(transformer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 8. Create FP32 Function Model for the Transformer model\n", + "\n", + "Custom Keras layers must be defined for all of the low-level TensorFlow operators (each must only contain a single operation for QAT).\n", + "\n", + "Since none of these will have any prunable weights, first we create a base prunable layer class to extend, instead of `tf.keras.layers.Layer`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Create a base prunable layer class " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class PrunableLayer(tf.keras.layers.Layer, tfmot.sparsity.keras.PrunableLayer):\n", + " def get_prunable_weights(self): return []" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Define low level TensorFlow operations as Keras subclassed layers\n", + "\n", + "Note that some of these layers have trainable weights defined using the `add_weight` method. These weights will not be pruned or clustered." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Tanh(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + "\n", + " def call(self, x):\n", + " return tf.math.tanh(x)\n", + "\n", + "\n", + "class Relu(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, x):\n", + " return tf.maximum(0., x)\n", + "\n", + " \n", + "class MatMul(PrunableLayer):\n", + " def __init__(self, transpose_b=False, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.transpose_b = transpose_b \n", + " \n", + " def call(self, inputs):\n", + " return tf.linalg.matmul(*inputs, transpose_b=self.transpose_b)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'transpose_b': self.transpose_b})\n", + " return config\n", + "\n", + "\n", + "class Multiply(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, inputs):\n", + " return tf.multiply(*inputs)\n", + "\n", + "\n", + "# Calling Multiply with a scalar input will lead to an error.\n", + "# Use the following ScalarMultiply class instead.\n", + "class ScalarMultiply(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar \n", + " \n", + " def call(self, x):\n", + " return tf.math.multiply(x, self.scalar)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class Add(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, inputs):\n", + " return tf.math.add(*inputs)\n", + "\n", + "\n", + "# Calling Add with a scalar input will lead to an error.\n", + "# Use the following ScalarAdd class instead.\n", + "class ScalarAdd(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar \n", + " \n", + " def call(self, x):\n", + " return tf.math.add(x, self.scalar)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class Slice(PrunableLayer):\n", + " def __init__(self, seq_idx, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.seq_idx = seq_idx \n", + " \n", + " def call(self, x):\n", + " return x[:, self.seq_idx, ...]\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'seq_idx': self.seq_idx})\n", + " return config\n", + "\n", + "\n", + "class Mean(PrunableLayer):\n", + " def __init__(self, axes=None, keepdims=True, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes=axes\n", + " self.keepdims = keepdims \n", + " \n", + " def call(self, x):\n", + " return tf.math.reduce_mean(x, axis=self.axes, keepdims=self.keepdims)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes,\n", + " 'keepdims': self.keepdims})\n", + " return config\n", + "\n", + "\n", + "class Subtract(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs) \n", + " \n", + " def call(self, inputs):\n", + " return tf.math.subtract(*inputs)\n", + "\n", + "\n", + "class ScalarSubtract(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar \n", + " \n", + " def call(self, x):\n", + " return tf.math.subtract(self.scalar,x)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class SquaredDiffrence(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs) \n", + " \n", + " def call(self,inputs):\n", + " return tf.math.squared_difference(*inputs)\n", + "\n", + "\n", + "class StopGradient(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, x):\n", + " return tf.stop_gradient(x)\n", + "\n", + "\n", + "class RSqrt(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, x):\n", + " return tf.math.rsqrt(x)\n", + "\n", + "\n", + "class Clip(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, x):\n", + " return tf.clip_by_value(x, 0.001, 255.0)\n", + "\n", + "\n", + "class BroadcastToken(PrunableLayer):\n", + " \"\"\"Layer to broadcast the class token\"\"\"\n", + " def __init__(self, embedding_dim, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(1, 1, self.embedding_dim), initializer='zeros', \n", + " trainable=True, name='token')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " batch_size = tf.shape(x)[0]\n", + " return tf.broadcast_to(self.w, [batch_size, 1, self.embedding_dim])\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim})\n", + " return config\n", + "\n", + "\n", + "class AddPositionalEmbedding(PrunableLayer):\n", + " \"\"\"Layer to add positional embeddings to the tokens\"\"\"\n", + " def __init__(self, seq_len, embedding_dim, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + " self.seq_len = seq_len\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(self.seq_len, self.embedding_dim), initializer= 'uniform',\n", + " trainable=True, name='pos_emb')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return x + self.w\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim, 'seq_len': self.seq_len})\n", + " return config\n", + "\n", + "\n", + "class AddTokenEmbedding(PrunableLayer): \n", + " \"\"\"Layer to add token embeddings to the tokens\"\"\"\n", + " def __init__(self, vocab_size, embedding_dim, train = True, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + " self.vocab_size = vocab_size\n", + " self.train = train\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(self.vocab_size, self.embedding_dim), initializer= 'uniform',\n", + " trainable=self.train, name='token_emb')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return tf.gather(self.w,x)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim, 'vocab_size': self.vocab_size, 'train': self.train})\n", + " return config\n", + "\n", + " def compute_output_shape(self, input_shape):\n", + " return(input_shape[-1], self.embedding_dim)\n", + "\n", + "\n", + "class Scale(PrunableLayer):\n", + " \"\"\"Multiply with gamma (LayerNorm)\"\"\"\n", + " def __init__(self, axes, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes = axes \n", + " \n", + " def build(self, input_shape):\n", + " param_shape = [input_shape[dim] for dim in self.axes]\n", + " self.w = self.add_weight(name='gamma', shape=param_shape,\n", + " trainable=True, initializer='ones')\n", + " super().build(input_shape)\n", + " \n", + " def call(self, x):\n", + " return tf.multiply(x, self.w)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes})\n", + " return config\n", + "\n", + " \n", + "class Centre(PrunableLayer):\n", + " \"\"\"Add beta (LayerNorm)\"\"\"\n", + " def __init__(self, axes, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes = axes \n", + " \n", + " def build(self, input_shape):\n", + " param_shape = [input_shape[dim] for dim in self.axes]\n", + " self.w = self.add_weight(name='beta', shape=param_shape,\n", + " trainable=True, initializer='zeros')\n", + " super().build(input_shape)\n", + " \n", + " def call(self, x):\n", + " return tf.math.add(x, self.w)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes})\n", + " return config\n", + "\n", + "\n", + "class Minimum(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs) \n", + " \n", + " def call(self,inputs):\n", + " return tf.minimum(*inputs)\n", + "\n", + "\n", + "class MinimumScalar(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs) \n", + " self.scalar=scalar\n", + "\n", + " def call(self,inputs):\n", + " return tf.minimum(inputs, self.scalar)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class MaximumScalar(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs) \n", + " self.scalar=scalar\n", + "\n", + " def call(self,inputs):\n", + " return tf.maximum(inputs, self.scalar)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class Cast(PrunableLayer):\n", + " def __init__(self, type = tf.int32, **kwargs):\n", + " super().__init__(**kwargs) \n", + " self.type=type \n", + "\n", + " def call(self,inputs):\n", + " return tf.cast(inputs, self.type)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'type': self.type})\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Define Transormer layers like multiheaded-attention, layer-norm, etc. functionally" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def self_attention(query, key, value, n_heads, dim, mask=None, name='mha', block_name=None, out_dim=None):\n", + " \"\"\"Multi-head attention layer\"\"\"\n", + " depth = dim // n_heads\n", + " if out_dim is None: out_dim = query.shape[-1]\n", + " q = tf.keras.layers.Dense(units=dim, name=f'{name}/query')(query)\n", + " k = tf.keras.layers.Dense(units=dim, name=f'{name}/key')(key)\n", + " v = tf.keras.layers.Dense(units=dim, name=f'{name}/value')(value)\n", + "\n", + " q = tf.keras.layers.Reshape((-1, n_heads, depth))(q)\n", + " q = tf.keras.layers.Permute((2, 1, 3))(q)\n", + " k = tf.keras.layers.Reshape((-1, n_heads, depth))(k)\n", + " k = tf.keras.layers.Permute((2, 1, 3))(k)\n", + " v = tf.keras.layers.Reshape((-1, n_heads, depth))(v)\n", + " v = tf.keras.layers.Permute((2, 1, 3))(v)\n", + " qk = ScalarMultiply(depth ** -0.5)(MatMul(transpose_b=True)([q, k]))\n", + "\n", + " if mask is not None:\n", + " if isinstance(mask, tf.Tensor):\n", + " qk = ScalarMultiply(mask)(qk)\n", + " mask=1. - mask\n", + " mask = mask * -10\n", + " qk = ScalarAdd(mask)(qk)\n", + " \n", + " else:\n", + " qk = Multiply()([qk, mask])\n", + " mask = ScalarSubtract(1.)(mask)\n", + " mask = ScalarMultiply(-10)(mask)\n", + " qk = Add(name=f'add/{name}')([qk, (mask)])\n", + " \n", + " attn_weights = tf.keras.layers.Softmax(axis=-1)(qk)\n", + " attn_out = MatMul()([attn_weights, v]) \n", + " attn_out = tf.keras.layers.Permute((2, 1, 3))(attn_out)\n", + " attn_out = tf.keras.layers.Reshape((-1, dim))(attn_out)\n", + " out = tf.keras.layers.Dense(out_dim, name=f'{name}/output_dense', dtype=\"float32\")(attn_out)\n", + " \n", + " return out, attn_weights\n", + "\n", + "def AddPositionalEmbeddingForEncoderDecoder(inputs, seq_len, vocab_size, embed_dim, block_name, freeze):\n", + " x = AddTokenEmbedding(vocab_size, embed_dim, train = not freeze, name= ('token_embedding/' + block_name))(inputs)\n", + " x = AddPositionalEmbedding(seq_len, embed_dim, name= ('positional_embedding/' + block_name))(x)\n", + " return x\n", + " \n", + "def enc_padding_mask(inputs):\n", + " computed_mask=tf.keras.layers.Reshape((1, 1, -1))(inputs)\n", + " return computed_mask \n", + "\n", + "def causal_mask(inputs):\n", + " seq_len=inputs.shape[1]\n", + " causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)\n", + " return causal_mask\n", + "\n", + "def dec_padding_mask(inputs, cau_mask):\n", + " padding_mask= enc_padding_mask(inputs)\n", + " padding_mask = MinimumScalar(scalar=cau_mask)(padding_mask)\n", + " return padding_mask\n", + "\n", + "def layer_norm(x, axes=2, epsilon=0.001, name='layer_norm', trainable = True):\n", + " \"\"\"LayerNormalization\"\"\"\n", + " if isinstance(axes, int): axes = [axes]\n", + " \n", + " mean = Mean(axes=axes, dtype=x.dtype)(x)\n", + " ## This block can be replaced with a squared_difference layer ##\n", + " diff = Subtract()([x, StopGradient()(mean)]) ##\n", + " sq_diff = Multiply()([diff, diff]) ##\n", + " ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##\n", + " variance = Mean(axes=axes,dtype=x.dtype ,name=f'{name}/variance')(sq_diff)\n", + " if not trainable:\n", + " inv = RSqrt()(variance)\n", + " x = Multiply()([diff, inv])\n", + " else:\n", + " # MaximumScalar prevents division by 0.\n", + " inv = RSqrt()(MaximumScalar(epsilon)(variance))\n", + " # This layer is removed for inference so it is named.\n", + " x = Subtract(name=f'{name}/grad_subtract')([x, mean]) \n", + " x = Multiply()([x, inv])\n", + "\n", + " x = Scale(axes=axes)(x)\n", + " x = Centre(axes=axes)(x)\n", + " \n", + " return x\n", + "\n", + "def mlp(x, hidden_dim, out_dim=None):\n", + " \"\"\"Multi-layer perceptron block\"\"\"\n", + " if out_dim is None: out_dim = x.shape[-1]\n", + "\n", + " x = tf.keras.layers.Dense(units=hidden_dim)(x)\n", + " x = Relu()(x)\n", + " x = tf.keras.layers.Dense(units=out_dim)(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(d) Build end-to-end model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "def get_translation_model(input_shape, batch_size=batch_size, seq_len=seq_len, vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, freeze= False, trainable=True):\n", + " \n", + " aux_output=defaultdict(list)\n", + " ## Encoder\n", + " \n", + " # Input to encoder\n", + " enc_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"encoder_inputs\")\n", + " encoder_inputs=Cast()(enc_inputs)\n", + " \n", + " x = AddPositionalEmbeddingForEncoderDecoder(encoder_inputs, seq_len, vocab_size, embed_dim, 'encoder', freeze)\n", + " encoder_padding_mask_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"encoder_masks\")\n", + " encoder_padding_mask = enc_padding_mask(encoder_padding_mask_inputs)\n", + "\n", + " # Encoder Attention block\n", + " attention_output, attention_weights = self_attention(x, x, x, num_heads, embed_dim*num_heads, mask=encoder_padding_mask, name=(f'mha'), block_name=(f'encoder'))\n", + " proj_input = tf.keras.layers.Add()([x, attention_output])\n", + " proj_input = layer_norm(proj_input, name=(f'layer_norm'), trainable=trainable)\n", + "\n", + " # MLP block\n", + " proj_output = mlp(proj_input, latent_dim, embed_dim)\n", + " x = tf.keras.layers.Add()([proj_input, proj_output])\n", + " encoder_outputs = layer_norm(x, name=(f'layer_norm_1'), trainable=trainable)\n", + " \n", + " ## Decoder\n", + " \n", + " # Input to decoder\n", + " dec_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"decoder_inputs\")\n", + " decoder_inputs=Cast()(dec_inputs)\n", + "\n", + " x = AddPositionalEmbeddingForEncoderDecoder(decoder_inputs, seq_len, vocab_size, embed_dim, 'decoder', freeze)\n", + " decoder_causal_mask = causal_mask(decoder_inputs)\n", + " decoder_padding_mask_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"decoder_masks\")\n", + " decoder_padding_mask = dec_padding_mask(decoder_padding_mask_inputs, decoder_causal_mask)\n", + " \n", + " \n", + " # Decoder Attention Block 1\n", + " attention_output_1, attention_weights_1 = self_attention(x, x, x, num_heads, embed_dim*num_heads, mask=decoder_causal_mask, name=(f'mha_1'), block_name=(f'decoder_1'))\n", + " x1 = tf.keras.layers.Add()([x, attention_output_1])\n", + " out_1 = layer_norm(x1, name=(f'layer_norm_2'), trainable=trainable)\n", + " \n", + " # Decoder Attention Block 2\n", + " attention_output_2, attention_weights_2 = self_attention(out_1, encoder_outputs, encoder_outputs, num_heads, embed_dim*num_heads, mask=decoder_padding_mask, name=(f'mha_2'), block_name=(f'decoder_2'))\n", + " x2 = tf.keras.layers.Add()([out_1, attention_output_2])\n", + " out_2 = layer_norm(x2, name=(f'layer_norm_3'), trainable=trainable)\n", + " \n", + " # MLP Block\n", + " proj_output = mlp(out_2, latent_dim, embed_dim)\n", + " x3 = tf.keras.layers.Add()([out_2, proj_output])\n", + " x3 = layer_norm(x3, name=(f'layer_norm_4'), trainable=trainable)\n", + " \n", + "\n", + " x3 = tf.keras.layers.Dropout(0.5)(x3)\n", + " x3 = tf.keras.layers.Dense(units=vocab_size, name=\"dense_last\", activation='softmax')(x3)\n", + "\n", + " transformer = keras.Model(\n", + " [enc_inputs,encoder_padding_mask_inputs, dec_inputs, decoder_padding_mask_inputs], x3, name=\"transformer\"\n", + " )\n", + " \n", + " return transformer\n", + "\n", + "tf.keras.backend.clear_session() # reset layer name counters\n", + "\n", + "transform = get_translation_model(input_shape = (seq_len,), batch_size = batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(e) Train the FP32 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transform.summary()\n", + "train(transform, model_type='fp32')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(f) Evaluate Performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get BLEU score on test set for FP32 transformer model\n", + "get_text_result(transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get accuracy on test set for the FP32 transformer model from Keras example\n", + "evaluate(transform)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 9. Convert FP32 model to FP32 tflite model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Generate a non-optimized tflite (float32 operations) file for FP32 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "i = tf.keras.Input(shape=(20,), batch_size=1)\n", + "j = tf.keras.Input(shape=(20,), batch_size=1)\n", + "k = tf.keras.Input(shape=(20,), batch_size=1)\n", + "l = tf.keras.Input(shape=(20,), batch_size=1)\n", + "net = tf.keras.Model(inputs=[i, j,k,l,], outputs=transform.call([i,j,k,l]))\n", + "\n", + "MODEL_PATH = './encoder_decoder_fp32.tflite'\n", + "\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(net)\n", + "tflite_model = converter.convert()\n", + "with open(MODEL_PATH, \"wb+\") as tflite_file:\n", + " tflite_file.write(tflite_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_tflite_accuracy(MODEL_PATH, input_type='float32', output_type='float32')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_text_result_tflite(MODEL_PATH, input_type='float32', output_type='float32')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Model size: \", get_gzipped_model_size(MODEL_PATH), ' KB')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10. Perform QAT on FP32 model with TFMOT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) To use the custom Keras layers we defined, we need to pass a [`QuantizeConfig`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/QuantizeConfig) for each of these layers.\n", + "\n", + "For Keras layers which are already supported in TFMOT, a default QuantizeConfig class is assigned to each one. However, custom QuantizeConfig instances could also be created for these layers to give more control over how they are quantised." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow_model_optimization.quantization.keras import QuantizeConfig, quantizers\n", + "\n", + "LastValueQuantizer = quantizers.LastValueQuantizer\n", + "MovingAverageQuantizer = quantizers.MovingAverageQuantizer\n", + "AllValuesQuantizer = quantizers.AllValuesQuantizer\n", + "\n", + "class NoOpQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which does not quantize any part of the layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return []\n", + " \n", + " def get_config(self):\n", + " return {}\n", + "\n", + "\n", + "class TFOpQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which only quantizes the output of a layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "\n", + "class MaskOpQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which only quantizes the output of a layer and is meant for the input masks.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + " \n", + "class VarianceQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig for the variance calculation in the layer normalisation layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + " \n", + "\n", + "class WeightQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which quantizes the custom weights in the patch encoder and layer normalisation layers.\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.weight_quantizer = LastValueQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=True, narrow_range=True)\n", + " self.activation_quantizer = MovingAverageQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=False, narrow_range=False)\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return [(layer.w, self.weight_quantizer)]\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " layer.w = quantize_weights[0]\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [self.activation_quantizer]\n", + "\n", + " def get_config(self):\n", + " return {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Define wrapper function\n", + "\n", + "Since custom layers and QuantizeConfigs are used, the whole model cannot directly be wrapped with QAT wrappers.\n", + "So first we write a function to wrap the individual layers with QAT wrappers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def apply_wrapper(wrapper_function, layer_param_dict):\n", + " \n", + " def wrap_layer(layer):\n", + " if layer.name in layer_param_dict.keys():\n", + " return wrapper_function(layer, **layer_param_dict[layer.name])\n", + " return layer\n", + "\n", + " return wrap_layer\n", + "\n", + "def layer_wrapper(model, wrapper_function, layer_param_dict):\n", + " return tf.keras.models.clone_model(model, clone_function=apply_wrapper(wrapper_function, layer_param_dict))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Assign QuantizeConfigs to custom layers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_quantize_config(model):\n", + " layer_param_dict = {} # stores {Layer_Name: QuantizeConfig} pairs\n", + " scope = {} # stores all custom objects\n", + "\n", + " for layer in model.layers:\n", + " \n", + " if layer.name.startswith(('clip', 'minimum', 'minimum_scalar', 'maximum_scalar', 'cast', 'stop_gradient')):\n", + " layer_param_dict[layer.name] = {'quantize_config': NoOpQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + " \n", + " elif 'grad_subtract' in layer.name or layer.name.startswith(('mat_mul', 'multiply', 'scalar_multiply', 'add',\n", + " 'scalar_add', 'slice', 'mean', 'subtract',\n", + " 'scalar_subtract', 'r_sqrt', 'relu')):\n", + " layer_param_dict[layer.name] = {'quantize_config': TFOpQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + " \n", + " elif layer.name.startswith(( 'scale', 'centre', 'positional_embedding', 'token_embedding' )):\n", + " layer_param_dict[layer.name] = {'quantize_config': WeightQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " # Make sure to quantise the encoder and decoder mask input layers so that they can be quantized to INT8\n", + " \n", + " elif layer.name.startswith(('encoder_masks', 'decoder_masks' )):\n", + " layer_param_dict[layer.name] = {'quantize_config': MaskOpQuantizeConfig()}\n", + "\n", + " elif 'variance' in layer.name:\n", + " layer_param_dict[layer.name] = {'quantize_config': VarianceQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + " \n", + " scope['NoOpQuantizeConfig'] = NoOpQuantizeConfig\n", + " scope['TFOpQuantizeConfig'] = TFOpQuantizeConfig\n", + " scope['WeightQuantizeConfig'] = WeightQuantizeConfig\n", + " scope['VarianceQuantizeConfig'] = VarianceQuantizeConfig\n", + " scope['MaskOpQuantizeConfig'] = MaskOpQuantizeConfig\n", + "\n", + " return layer_param_dict, scope\n", + "\n", + "layer_param_dict, scope = get_quantize_config(transform)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Few layers like `cast`, `encoder_inputs` and `decoder_inputs` musn't be annontated with any QuantizeConfig as this will result into a `quantize` node being added after the inputs in tflite graph, which would pass down an int8 value to the `tfl.gather` operation.
And since, the `tfl.gather` operation expects only int32 and int64 as the indices, an int8 value in the `tfl.gather` operation will result into error ([Please refer TF Lite Ops Page](https://www.tensorflow.org/mlir/tfl_ops#tflgather_mlirtflgatherop))." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def remove_unwanted_layers(model, layer_param_dict):\n", + " # All the layers that don't need quantization can be added along side 'cast', 'encoder_inputs' and 'decoder_inputs'\n", + " layers_to_not_quantise = [x.name for x in model.layers if not any([y in x.name for y in ['cast', 'encoder_inputs', 'decoder_inputs'\n", + " ]])]\n", + " layer_param_dict = {k: v for k, v in layer_param_dict.items() if k in layers_to_not_quantise}\n", + " for k in layers_to_not_quantise:\n", + " if k not in layer_param_dict:\n", + " layer_param_dict[k] = {'quantize_config': None}\n", + "\n", + " return layer_param_dict\n", + "\n", + "layer_param_dict = remove_unwanted_layers(transform, layer_param_dict) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(d) Load the necessary API classes/functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n", + "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n", + "quantize_apply = tfmot.quantization.keras.quantize_apply\n", + "quantize_scope = tfmot.quantization.keras.quantize_scope" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(e) Annotate individual layers\n", + "\n", + "When calling the quantize_apply function, if an unsupported layer is missing from the scope, TFMOT will throw an error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Wrap each custom layer with the corresponding QuantizeConfig:\n", + "\n", + "qat_model = layer_wrapper(transform, quantize_annotate_layer, layer_param_dict)\n", + "\n", + "with quantize_scope(scope):\n", + " qat_model = quantize_apply(qat_model)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(f) Perform QAT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qat_model.summary()\n", + "train(qat_model, model_type='qat', epochs=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(g) Evaluate Performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_text_result(qat_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "evaluate(qat_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 11. Create INT8 tflite file for QAT FP32 model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we attempt to directly generate a TFLite file using the fine-tuned model above:\n", + "\n", + "- It will not have a correct batch size of 1.\n", + "- It will have operators which are unnecessary during inference. Precisely, the extra `Subtract` operators and `MaximumScalar` operator in the layer normalisation blocks, which were used during training and fine-tuning, should be removed from the graph before creating the TFLite file.\n", + "\n", + "Therefore the network should be redefined with a batch size of 1 and with the redundant operators removed. The weights of the fine-tuned optimised model can then be loaded into this new model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Remove layers which are not required" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tf.keras.backend.clear_session() # reset layer name counters\n", + "\n", + "new_qat_model = get_translation_model(input_shape = (seq_len,), batch_size = batch_size, trainable = False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Annotate individual layers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the QuantizeConfig and Scope which would be used to annotate the layers\n", + "layer_param_dict, scope = get_quantize_config(new_qat_model)\n", + "# Remove unwanted QuantizeConfigs\n", + "layer_param_dict = remove_unwanted_layers(new_qat_model, layer_param_dict) \n", + "\n", + "new_qat_model = layer_wrapper(new_qat_model, quantize_annotate_layer, layer_param_dict)\n", + "\n", + "with quantize_scope(scope):\n", + " new_qat_model = quantize_apply(new_qat_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Load weights into the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_qat_model.load_weights('./eng_spa_transformer_qat_tutorial_qat_model.h5', by_name=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sanity check to see if weights are loaded correctly\n", + "evaluate(new_qat_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(d) Create tflite file (int8 ops)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "i = tf.keras.Input(shape=(20,), batch_size=1, dtype = tf.int32)\n", + "j = tf.keras.Input(shape=(20,), batch_size=1)\n", + "k = tf.keras.Input(shape=(20,), batch_size=1, dtype = tf.int32)\n", + "l = tf.keras.Input(shape=(20,), batch_size=1)\n", + "\n", + "# The following is done to ensure that the batch size of input in\n", + "# tflite graph is 1\n", + "net = tf.keras.Model(inputs=[i, j,k,l,], outputs=new_qat_model.call([i,j,k,l]))\n", + "\n", + "MODEL_PATH = './encoder_decoder_qat.tflite'\n", + "\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(net)\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "\n", + "# The following two lines ensure that the mask inputs\n", + "# and the output are int8\n", + "converter.inference_input_type = tf.int8\n", + "converter.inference_output_type = tf.int8\n", + "\n", + "# Toggle this option to fold/unfold batchmatmul\n", + "converter._experimental_disable_batchmatmul_unfold = True\n", + "\n", + "tflite_model = converter.convert()\n", + "with open(MODEL_PATH, \"wb+\") as tflite_file:\n", + " tflite_file.write(tflite_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(e) Evaluate Performance\n", + "\n", + "NOTE: These steps are slow to execute therefore, the number of samples on which evaluation is performed is set to 200 by default (but definitely can be modified by the user)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_tflite_accuracy(MODEL_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "get_text_result_tflite(MODEL_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Model size: \", get_gzipped_model_size(MODEL_PATH), ' KB')" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "cb8a88e82453314166ba9bb471422eb5142b42682e25f5c554fbcb3447334d71" + }, + "kernelspec": { + "display_name": "Python 3.6.9 64-bit ('venv': venv)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/transformer_tutorials/translation_PQAT.ipynb b/tutorials/transformer_tutorials/translation_PQAT.ipynb new file mode 100644 index 0000000..8c6a472 --- /dev/null +++ b/tutorials/transformer_tutorials/translation_PQAT.ipynb @@ -0,0 +1,2397 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sequence to Sequence Transformer model optimisation using TFMOT (Pruning preserving Quantization Aware Training)\n", + "\n", + "Example notebook to demonstrate how TFMOT can be used for optimising complex sequence to sequence transformer models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Background\n", + "\n", + "The sequence to sequence transformer is one of the initial transformer model architectures. The core idea behind the Transformer model is self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence. The paper called [\"Attention Is All You Need\"](https://arxiv.org/pdf/1706.03762.pdf) might give a deeper insight into transformer model and their self-attention mechanism.\n", + "\n", + "\"Sequence\n", + "\n", + "[1] The above image was taken from [here](https://deepfrench.gitlab.io/deep-learning-project/)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### In this notebook:\n", + "\n", + "* The aim of this tutorial is to first train the Transformer model from [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)\n", + "* Re-write the above model as a funtional FP32 model\n", + "* Perform Structured 2X4 pruning on the FP32 model\n", + "* Perform Quantised Aware Training (QAT) on the pruned FP32 model\n", + "* Create and test the tflite model generated from the FP32 model after performing PQAT on it. \n", + "\n", + "Note: This tutorial has re-used some code and explanation from the original [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### TFMOT limitations\n", + "- Subclassed models are not supported. Only sequential and functional model definitions are supported. (Pruning, Clustering & QAT)\n", + "- Custom subclassed layers are not supported. (Clustering & QAT)\n", + " - Clustering will only work with subclassed layers if the weight variables you have to cluster are not nested within another layer (e.g. MHA).\n", + " - QAT works correctly if the subclassed layer performs only 1 operation.\n", + "- Low-level tensorflow operators such as `tf.linalg.matmul` are not supported. (Only for QAT)\n", + " - QAT expects all quantised layers to be a subclass of `tf.keras.layers.Layer`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pathlib\n", + "import random\n", + "import string\n", + "import tempfile\n", + "import zipfile\n", + "import re\n", + "import os\n", + "import nltk\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow.keras import layers\n", + "from tensorflow.keras.layers import TextVectorization\n", + "import tensorflow_model_optimization as tfmot\n", + "from collections import defaultdict\n", + "\n", + "def reset_random_seeds():\n", + " os.environ['PYTHONHASHSEED']=str(2)\n", + " tf.random.set_seed(2)\n", + " np.random.seed(2)\n", + " random.seed(2)\n", + "\n", + "reset_random_seeds()\n", + "\n", + "print('TensorFlow version: {}'.format(tf.__version__))\n", + "print('TFMOT version: {}'.format(tfmot.__version__))\n", + "print(\"NLTK verison: {}\".format(nltk.__version__))\n", + "print(\"Numpy version: {}\".format(np.__version__))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Downloading the data\n", + "\n", + "The dataset used here is English to Spanish translation dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text_file = keras.utils.get_file(\n", + " fname=\"spa-eng.zip\",\n", + " origin=\"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\",\n", + " extract=True,\n", + ")\n", + "text_file = pathlib.Path(text_file).parent / \"spa-eng\" / \"spa.txt\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Parsing the data\n", + "\n", + "Each target sentence (which is in Spanish) has `[start]` and `[end]` token prepended and appended, respectively, at this stage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(text_file) as f:\n", + " lines = f.read().split(\"\\n\")[:-1]\n", + "text_pairs = []\n", + "for line in lines:\n", + " eng, spa = line.split(\"\\t\")\n", + " spa = \"[start] \" + spa + \" [end]\"\n", + " text_pairs.append((eng, spa))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for _ in range(5):\n", + " print(random.choice(text_pairs))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Split the dataset into train, test and validation set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random.shuffle(text_pairs)\n", + "num_val_samples = int(0.15 * len(text_pairs))\n", + "num_train_samples = len(text_pairs) - 2 * num_val_samples\n", + "train_pairs = text_pairs[:num_train_samples]\n", + "val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]\n", + "test_pairs = text_pairs[num_train_samples + num_val_samples :]\n", + "\n", + "print(f\"{len(text_pairs)} total pairs\")\n", + "print(f\"{len(train_pairs)} training pairs\")\n", + "print(f\"{len(val_pairs)} validation pairs\")\n", + "print(f\"{len(test_pairs)} test pairs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Vectorizing the text data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Vectorization refers to the preprocessing step where text features are mapped to integer sequences where each integer represents the index of a word in a vocubulary. For this, [`tf.keras.layers.TextVecorization`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer is used.\n", + "\n", + "In our case, vectorization for English sequences is a little different from that of Spanish sequences:\n", + "\n", + "- For English string sequences, default standardization is used which strips all punctuation characters\n", + "- For Spanish string sequences, custom standardization is used which strips all characters which are not in `{` a-z.?!,¿[]`}`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vocab_size = 15000\n", + "seq_len = 20\n", + "batch_size = 64\n", + "embed_dim = 256\n", + "latent_dim = 2048\n", + "num_heads = 8\n", + "\n", + "\n", + "def custom_standardization(input_string):\n", + " lowercase = tf.strings.lower(input_string)\n", + " # The following regex replaces a character with \"\"\n", + " # which is not one of the following:\n", + " # 1. Lower case alphabet\n", + " # 2. Space\n", + " # 3. Is on of these characters: \".\", \"?\", \"!\", \",\", \"¿\", \"[\", \"]\"\n", + " return tf.strings.regex_replace(lowercase, '[^ a-z.?!,¿\\[\\]]', \"\")\n", + "\n", + "\n", + "eng_vectorization = TextVectorization(\n", + " max_tokens=vocab_size, output_mode=\"int\", output_sequence_length=seq_len,\n", + ")\n", + "spa_vectorization = TextVectorization(\n", + " max_tokens=vocab_size,\n", + " output_mode=\"int\",\n", + " output_sequence_length=seq_len + 1,\n", + " standardize=custom_standardization,\n", + ")\n", + "\n", + "train_eng_texts = [pair[0] for pair in train_pairs]\n", + "train_spa_texts = [pair[1] for pair in train_pairs]\n", + "eng_vectorization.adapt(train_eng_texts)\n", + "spa_vectorization.adapt(train_spa_texts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At each training step, the model will seek to predict target words N+1 (and beyond) using the source (or the input) sentence and the target words 0 to N. For this reason, we need (`inputs`, `targets`)\n", + "\n", + "- `inputs`:\n", + "\n", + " After vectorization, our dataset is formatted to include the following four in the `inputs` (`inputs` is essentially a list of four inputs):\n", + "\n", + " * encoder_inputs : which contains the vectorized english sentence data\n", + " * decoder_inputs : which contains the vectorized spanish (target) sentence data, i.e. target_sentence[:, :-1]. It is also the target sentence \"so far\", that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence. \n", + " * encoder_masks : which contains the corresponding mask data for encoder_inputs\n", + " * decoder_masks : which contains the corresponding mask data for decoder_inputs\n", + "\n", + " Please note that the two mask inputs are only required for the custom FP32 functional model as the original keras model is able to generate it's own mask. Therefore, the original model tends to ignore the two mask inputs (user doesn't need to worry about this).

\n", + " \n", + "- `targets`:\n", + "\n", + " After vectorization, our dataset is formatted to assign the target sentence offset by one (i.e. target_sentence[:, 1:]) as the `targets`. In other words this is what model will try to predict." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def format_dataset(eng, spa):\n", + " eng = eng_vectorization(eng)\n", + " spa = spa_vectorization(spa)\n", + "\n", + " # Create input masks\n", + " encoder_masks=tf.cast(tf.not_equal(np.int64(0),eng),tf.float32)\n", + " decoder_masks=tf.cast(tf.not_equal(np.int64(0),spa[:, :-1]),tf.float32)\n", + " \n", + " return ({\"encoder_inputs\": eng, \"encoder_masks\": encoder_masks, \"decoder_inputs\": spa[:, :-1], \"decoder_masks\":decoder_masks}, spa[:, 1:])\n", + "\n", + "def make_dataset(pairs, batch_size=batch_size):\n", + " eng_texts, spa_texts = zip(*pairs)\n", + " eng_texts = list(eng_texts)\n", + " spa_texts = list(spa_texts)\n", + " dataset = tf.data.Dataset.from_tensor_slices((eng_texts, spa_texts))\n", + " dataset = dataset.batch(batch_size, drop_remainder=True)\n", + " dataset = dataset.map(format_dataset)\n", + " \n", + " return dataset.shuffle(2048).prefetch(16).cache()\n", + "\n", + "\n", + "train_ds = make_dataset(train_pairs)\n", + "val_ds = make_dataset(val_pairs)\n", + "test_ds = make_dataset(test_pairs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for inputs, targets in train_ds.take(1):\n", + " print(f'inputs[\"encoder_inputs\"].shape: {inputs[\"encoder_inputs\"].shape}')\n", + " print(f'inputs[\"decoder_inputs\"].shape: {inputs[\"decoder_inputs\"].shape}')\n", + " print(f'inputs[\"encoder_masks\"].shape: {inputs[\"encoder_masks\"].shape}')\n", + " print(f'inputs[\"decoder_masks\"].shape: {inputs[\"decoder_masks\"].shape}')\n", + " print(f\"targets.shape: {targets.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Utility functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Typically, BLEU score is used to measure the quality of a translation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def bleu_score(real_text, predicted_text):\n", + " '''Get BLEU score'''\n", + " return (nltk.translate.bleu_score.corpus_bleu(real_text,predicted_text))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For decoding (or in other words translating a source sentence to a target sentenceg), we provide a vectorized source sentence as `encoder_inputs` and a vecotrized `[start]` token (ofcourse, padded to match the right sequence length) as the `decoder_inputs`, then we repeatedly generated the next token, until we hit the token `[end]`.\n", + "\n", + "A key thing to note is that in the custom FP32 functional model used in this notebook `encoder_masks` and `decoder_masks` are also fed into the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_text_result(model, num_samples_to_eval =200, no_input_masks=False):\n", + " '''\n", + " Function to calculate BLEU score on test set\n", + "\n", + " num_samples_to_eval: Represents the total number of test sentences to\n", + " consider during evaluation. If you want the entire \n", + " test set to be used for evaluation then set \n", + " max_sample = -1\n", + " \n", + " no_input_masks: Set as True for the original transformer model from\n", + " keras example\n", + " '''\n", + "\n", + " spa_vocab = spa_vectorization.get_vocabulary()\n", + " spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))\n", + " max_decoded_sentence_length = 20\n", + "\n", + " def decode_sequence_func(input_sentence):\n", + "\n", + " tokenized_input_sentence = eng_vectorization([input_sentence])\n", + " encoder_mask = tf.cast(tf.not_equal(np.int64(0),tokenized_input_sentence), tf.float32)\n", + "\n", + " decoded_sentence = \"[start]\"\n", + " for i in range(max_decoded_sentence_length):\n", + " tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]\n", + " decoder_mask=tf.cast(tf.not_equal(np.int64(0),tokenized_target_sentence), tf.float32)\n", + " if no_input_masks:\n", + " predictions = model([tokenized_input_sentence, tokenized_target_sentence])\n", + " else:\n", + " predictions = model([tokenized_input_sentence, encoder_mask, tokenized_target_sentence,decoder_mask])\n", + " sampled_token_index = np.argmax(predictions[0, i, :])\n", + " sampled_token = spa_index_lookup[sampled_token_index]\n", + " decoded_sentence += \" \" + sampled_token\n", + "\n", + " if sampled_token == \"[end]\":\n", + " break\n", + "\n", + " return decoded_sentence\n", + "\n", + "\n", + " hypothesis= []\n", + " references = []\n", + " test_sample_count = sum(1 for e in test_pairs) \n", + " progbar = tf.keras.utils.Progbar(test_sample_count if num_samples_to_eval == -1 else num_samples_to_eval)\n", + "\n", + " for step, (inp, target) in enumerate(test_pairs[:num_samples_to_eval]):\n", + " translated = decode_sequence_func(inp)\n", + " target=target.lower()\n", + " target=re.sub('[^ a-z.?!,¿\\[\\]]', \"\",target)\n", + " hypothesis.append(translated.split()[1:-1])\n", + " references.append([target.split()[1:-1]])\n", + " progbar.update(step + 1)\n", + "\n", + " print(str(\"Bleu Score: \") + str(bleu_score(references[:], hypothesis[:])))\n", + "\n", + " # Print first 10 actual and predicted spanish translation for sanity check\n", + " for i in range(10):\n", + " print(references[i][0])\n", + " print(hypothesis[i])\n", + " print(\"-----------------------/n\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Suggestion: While trying to run inference on a tflite file please make sure that the scale, zero_point and data type are correct for the inputs and outputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_text_result_tflite(model_path, input_type = 'int8/32', output_type = 'int8', num_samples_to_eval = 200):\n", + " '''\n", + " Function to calculate BLEU score for a given tflite file on the test set\n", + "\n", + " model_path: Path to the tflite file\n", + "\n", + " input_type: Could be float32 or int8/32. If the inputs in tflite graph\n", + " are float32 set this value to 'float32' but if inputs are\n", + " int8 (mask inputs) and int32 (non-maks inputs) set this\n", + " value to 'int8/32'.\n", + "\n", + " output_type: Could be float32 or int8. If the outputs in tflite graph\n", + " are float32 set this value to 'float32' but if output\n", + " are int8 set this value to 'int8'.\n", + " \n", + " num_samples_to_eval: Evaluation of entire test set will take a lot\n", + " of time therefore, only first 200 samples are \n", + " evaluated. To evaluate the entire test-set, \n", + " set the value below to a negative value, e.g.\n", + " -1\n", + " '''\n", + " assert(input_type in ['float32', 'int8/32']), \"input_type not supported\"\n", + " assert(output_type in ['float32', 'int8']), \"output_type not supported\"\n", + "\n", + " print('Performing BLEU evaluation for tflite file at {}'.format(model_path))\n", + "\n", + " spa_vocab = spa_vectorization.get_vocabulary()\n", + " spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))\n", + " max_decoded_sentence_length = seq_len\n", + "\n", + " interpreter = tf.lite.Interpreter(model_path=model_path)\n", + "\n", + " input_details = interpreter.get_input_details()\n", + " output_details = interpreter.get_output_details()\n", + "\n", + " input_scale_1, input_zero_point_1 = input_details[0]['quantization']\n", + " input_scale_2, input_zero_point_2 = input_details[1]['quantization']\n", + " input_scale_3, input_zero_point_3 = input_details[2]['quantization']\n", + " input_scale_4, input_zero_point_4 = input_details[3]['quantization']\n", + " output_scale, output_zero_point = output_details[0]['quantization']\n", + "\n", + " interpreter.allocate_tensors()\n", + "\n", + " def decode_sequence_func(input_sentence):\n", + "\n", + " input_1 = eng_vectorization([input_sentence])\n", + " input_2 = tf.cast(tf.not_equal(np.int64(0),input_1), tf.float32)\n", + " if input_type == 'int8/32':\n", + " input_2 = input_2/ input_scale_2 + input_zero_point_2\n", + "\n", + " decoded_sentence = \"[start]\"\n", + "\n", + " for i in range(max_decoded_sentence_length):\n", + " input_3 = spa_vectorization([decoded_sentence])[:, :-1]\n", + " input_4=tf.cast(tf.not_equal(np.int64(0),input_3), tf.float32)\n", + "\n", + " # Set input tensor\n", + " interpreter.set_tensor(input_details[0]['index'], tf.cast(input_1, input_details[0]['dtype']))\n", + "\n", + " # Set input tensor\n", + " interpreter.set_tensor(input_details[1]['index'], tf.cast(input_2, input_details[1]['dtype']))\n", + "\n", + " # Set input tensor\n", + " interpreter.set_tensor(input_details[2]['index'], tf.cast(input_3, input_details[2]['dtype']))\n", + "\n", + " # Set input tensor\n", + " if input_type == 'int8/32':\n", + " input_4 = input_4/ input_scale_4 + input_zero_point_4\n", + " interpreter.set_tensor(input_details[3]['index'], tf.cast(input_4, input_details[3]['dtype']))\n", + "\n", + " interpreter.invoke()\n", + " \n", + " # Get output tensor\n", + " output_data = interpreter.get_tensor(output_details[0]['index'])\n", + " predictions = output_data.astype(np.float32)\n", + " if output_type == 'int8':\n", + " predictions = output_scale * (predictions - output_zero_point)\n", + " \n", + " sampled_token_index = np.argmax(predictions[0, i, :])\n", + " sampled_token = spa_index_lookup[sampled_token_index]\n", + " decoded_sentence += \" \" + sampled_token\n", + "\n", + " if sampled_token == \"[end]\":\n", + " break\n", + "\n", + " return decoded_sentence\n", + "\n", + "\n", + " hypothesis= []\n", + " references = []\n", + " test_sample_count = sum(1 for e in test_pairs) \n", + " progbar = tf.keras.utils.Progbar(test_sample_count if num_samples_to_eval == -1 else num_samples_to_eval)\n", + " \n", + " for step, (inp, target) in enumerate(test_pairs[:num_samples_to_eval]):\n", + " translated = decode_sequence_func(inp)\n", + " target=target.lower()\n", + " target=re.sub('[^ a-z.?!,¿\\[\\]]', \"\",target)\n", + " hypothesis.append(translated.split()[1:-1])\n", + " references.append([target.split()[1:-1]])\n", + " progbar.update(step + 1)\n", + "\n", + " print(str(\"Bleu Score: \") + str(bleu_score(references[:], hypothesis[:])))\n", + " for i in range(10):\n", + " print(references[i][0])\n", + " print(hypothesis[i])\n", + " print(\"-----------------------/n\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_tflite_accuracy(model_path, input_type = 'int8/32', output_type = 'int8', num_samples_to_eval = 200):\n", + " '''\n", + " Function to calculate accuracy of a given tflite file on the test set\n", + "\n", + " model_path: Path to the tflite file\n", + "\n", + " input_type: Could be float32 or int8/32. If the inputs in tflite graph\n", + " are float32 set this value to 'float32' but if inputs are\n", + " int8 (mask inputs) and int32 (non-maks inputs) set this\n", + " value to 'int8/64'.\n", + "\n", + " output_type: Could be float32 or int8. If the outputs in tflite graph\n", + " are float32 set this value to 'float32' but if output\n", + " are int8 set this value to 'int8'.\n", + " \n", + " num_samples_to_eval: Evaluation of entire test set will take a lot\n", + " of time therefore, only first 200 samples are \n", + " evaluated. To evaluate the entire test-set, \n", + " set the value below to a negative value, e.g.\n", + " -1\n", + " '''\n", + " assert(input_type in ['float32', 'int8/32']), \"input_type not supported\"\n", + " assert(output_type in ['float32', 'int8']), \"output_type not supported\"\n", + "\n", + " print('Performing accuracy evaluation for tflite file at {}'.format(model_path))\n", + "\n", + " interpreter = tf.lite.Interpreter(model_path=model_path)\n", + "\n", + " input_details = interpreter.get_input_details()\n", + " output_details = interpreter.get_output_details()\n", + "\n", + " input_scale_1, input_zero_point_1 = input_details[0]['quantization']\n", + " input_scale_2, input_zero_point_2 = input_details[1]['quantization']\n", + " input_scale_3, input_zero_point_3 = input_details[2]['quantization']\n", + " input_scale_4, input_zero_point_4 = input_details[3]['quantization']\n", + " output_scale, output_zero_point = output_details[0]['quantization']\n", + " interpreter.allocate_tensors()\n", + "\n", + " test_ds_tflite = make_dataset(test_pairs, 1)\n", + " accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')\n", + " progbar = tf.keras.utils.Progbar(sum(1 for e in test_ds_tflite) if num_samples_to_eval == -1 else num_samples_to_eval, stateful_metrics=['accuracy'])\n", + "\n", + " for step, (input, target) in enumerate(test_ds_tflite):\n", + "\n", + " # Set input tensor\n", + " input_1 = input['encoder_inputs']\n", + " interpreter.set_tensor(input_details[0]['index'], tf.cast(input_1, input_details[0]['dtype']))\n", + "\n", + " # Set input tensorprogress bars for loopp python\n", + " input_2=input['encoder_masks']\n", + " if input_type == 'int8/32':\n", + " input_2 = tf.cast(input_2, tf.float32)\n", + " input_2 = input_2/ input_scale_2 + input_zero_point_2\n", + " interpreter.set_tensor(input_details[1]['index'], tf.cast(input_2, input_details[1]['dtype']))\n", + "\n", + " # Set input tensor\n", + " input_3 = input['decoder_inputs']\n", + " interpreter.set_tensor(input_details[2]['index'], tf.cast(input_3, input_details[2]['dtype']))\n", + "\n", + " # Set input tensor\n", + " input_4=input['decoder_masks']\n", + " if input_type == 'int8/32':\n", + " input_4 = tf.cast(input_4, tf.float32)\n", + " input_4 = input_4/ input_scale_4 + input_zero_point_4\n", + " interpreter.set_tensor(input_details[3]['index'], tf.cast(input_4, input_details[3]['dtype']))\n", + " interpreter.invoke()\n", + " \n", + " # Get output tensor\n", + " output_data = interpreter.get_tensor(output_details[0]['index'])\n", + " output_data = output_data.astype(np.float32)\n", + " if output_type == 'int8':\n", + " output_data = output_scale * (output_data - output_zero_point)\n", + " \n", + " # Update accuracy\n", + " mask = input['decoder_inputs']\n", + " accuracy.update_state(target, output_data, mask)\n", + " progbar.update(step + 1, values=[('accuracy', accuracy.result().numpy())])\n", + " \n", + " if step == num_samples_to_eval:\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the following function to get the size of the tflite file when zipped" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_gzipped_model_size(file):\n", + " '''Returns the size of a gzipped tflite file in kilobytes'''\n", + "\n", + " _, zipped_file = tempfile.mkstemp('.zip')\n", + " with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n", + " f.write(file)\n", + "\n", + " return os.path.getsize(zipped_file)/1000" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the following function to get the model sparsity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def print_model_weights_sparsity(model):\n", + "\n", + " for layer in model.layers:\n", + " if isinstance(layer, tf.keras.layers.Wrapper):\n", + " weights = layer.trainable_weights\n", + " else:\n", + " weights = layer.weights\n", + " for weight in weights:\n", + " # Ignore auxiliary quantization weights\n", + " if \"quantize_layer\" in weight.name:\n", + " continue\n", + " weight_size = weight.numpy().size\n", + " zero_num = np.count_nonzero(weight == 0)\n", + " if zero_num/weight_size > 0:\n", + " print(\n", + " f\"{weight.name}: {zero_num/weight_size:.2%} sparsity \",\n", + " f\"({zero_num}/{weight_size})\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Functions related to Training the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The loss function used is Masked Sparse Categorical Crossentropy loss (which uses the `tf.keras.losses.SparseCategoricalCrossentropy` but with masks). The loss function needs masks to be propogated correctly through the model layers down to the loss function which, the custom FP32 model wasn't able to do correctly therefore, a custom training loop was needed to calculate the loss correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 11\n", + "\n", + "def evaluate(model_to_eval, training=False):\n", + "\n", + " val_loss = tf.keras.metrics.SparseCategoricalCrossentropy()\n", + " val_acc = tf.keras.metrics.SparseCategoricalAccuracy()\n", + "\n", + " @tf.function\n", + " def eval_step(inp, y_true):\n", + " preds = model_to_eval(inp, training=training)\n", + " # masked loss\n", + " val_loss.update_state(y_true, preds,tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32))\n", + " # masked accuracy\n", + " val_acc.update_state(y_true, preds,tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32)) \n", + "\n", + " for step, (inp, y_true) in enumerate(val_ds):\n", + " eval_step(inp, y_true)\n", + "\n", + " return {'loss': val_loss.result().numpy(), 'accuracy': val_acc.result().numpy()}\n", + "\n", + "\n", + "def train(model_to_train, save_best_weights =True, model_type='original', lr=1e-3, epochs = epochs):\n", + "\n", + " if model_type == 'original':\n", + " ckpt_path = './eng_spa_transformer_pqat_tutorial_original_model.h5'\n", + " elif model_type == 'fp32':\n", + " ckpt_path = './eng_spa_transformer_pqat_tutorial_fp32_model.h5'\n", + " elif model_type == 'qat':\n", + " ckpt_path = './eng_spa_transformer_pqat_tutorial_qat_model.h5'\n", + " elif model_type == 'prune':\n", + " ckpt_path = './eng_spa_transformer_pqat_tutorial_pruned_fp32.h5'\n", + " else:\n", + " print('Please select the right model_type!!!!')\n", + " return None\n", + " \n", + " print('Training (save_best_weights={}, model_type={})'.format(save_best_weights, model_type))\n", + "\n", + " loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()\n", + " optimiser = tf.keras.optimizers.Adam(learning_rate=lr)\n", + " train_acc = tf.keras.metrics.SparseCategoricalAccuracy()\n", + " model_to_train.optimizer = optimiser\n", + "\n", + " @tf.function\n", + " def train_step(inp, y_true):\n", + " mask =tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32)\n", + " preds=None\n", + " loss=None\n", + " \n", + " with tf.GradientTape() as tape:\n", + " preds = model_to_train(inp, training=True)\n", + " # masked loss\n", + " loss = loss_fn(y_true, preds, mask)\n", + " grads = tape.gradient(loss, model_to_train.trainable_weights)\n", + " optimiser.apply_gradients(zip(grads, model_to_train.trainable_weights))\n", + "\n", + " # masked accuracy \n", + " train_acc.update_state(y_true, preds, mask)\n", + " return loss\n", + "\n", + " max_val = float('-inf')\n", + " step_callback = None\n", + "\n", + " if model_type == 'prune':\n", + " # This callback is only necessary when trying to perform\n", + " # pruning\n", + " step_callback = tfmot.sparsity.keras.UpdatePruningStep()\n", + " step_callback.set_model(model_to_train)\n", + " step_callback.on_train_begin()\n", + "\n", + " for epoch in range(epochs):\n", + " print('Epoch {}/{}'.format(epoch + 1, epochs), flush=True)\n", + " # Train\n", + " progbar = tf.keras.utils.Progbar(len(train_ds), interval=.5,\n", + " stateful_metrics=['acc']) \n", + "\n", + " for step, (inp, y_true) in enumerate(train_ds):\n", + " loss = train_step(inp, y_true)\n", + " progbar.update(step + 1, values=[('loss', loss),\n", + " ('acc', train_acc.result())])\n", + " if model_type == 'prune':\n", + " step_callback.on_epoch_end(batch=-1)\n", + "\n", + " # Evaluate\n", + " val_results = evaluate(model_to_train)\n", + "\n", + " validation_accuracy = val_results['accuracy']\n", + " print('Validation accuracy: {}'.format(validation_accuracy))\n", + "\n", + " if save_best_weights and validation_accuracy > max_val:\n", + " \n", + " print('Best validation accuracy so far, saving weights')\n", + " model_to_train.save_weights(ckpt_path)\n", + " max_val = validation_accuracy\n", + "\n", + " train_acc.reset_states() \n", + "\n", + " if not save_best_weights:\n", + " model_to_train.save_weights(ckpt_path)\n", + " # Load weights\n", + " model_to_train.load_weights(ckpt_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 7. Building the original Transformer Keras model mentioned in the [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Define the custom layers for the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerEncoder(layers.Layer):\n", + " def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):\n", + " super(TransformerEncoder, self).__init__(**kwargs)\n", + " self.embed_dim = embed_dim\n", + " self.dense_dim = dense_dim\n", + " self.num_heads = num_heads\n", + " self.attention = layers.MultiHeadAttention(\n", + " num_heads=num_heads, key_dim=embed_dim\n", + " )\n", + " self.dense_proj = keras.Sequential(\n", + " [layers.Dense(dense_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", + " )\n", + " self.layernorm_1 = layers.LayerNormalization()\n", + " self.layernorm_2 = layers.LayerNormalization()\n", + " self.supports_masking = True\n", + "\n", + " def call(self, inputs, mask=None):\n", + " if mask is not None:\n", + " padding_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=\"int32\")\n", + " attention_output = self.attention(\n", + " query=inputs, value=inputs, key=inputs, attention_mask=padding_mask\n", + " )\n", + " proj_input = self.layernorm_1(inputs + attention_output)\n", + " proj_output = self.dense_proj(proj_input)\n", + " return self.layernorm_2(proj_input + proj_output)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embed_dim': self.embed_dim,\n", + " 'dense_dim': self.dense_dim,\n", + " 'num_heads': self.num_heads})\n", + " return config\n", + "\n", + "\n", + "class PositionalEmbedding(layers.Layer):\n", + " def __init__(self, seq_len, vocab_size, embed_dim, **kwargs):\n", + " super(PositionalEmbedding, self).__init__(**kwargs)\n", + " self.token_embeddings = layers.Embedding(\n", + " input_dim=vocab_size, output_dim=embed_dim\n", + " )\n", + " self.position_embeddings = layers.Embedding(\n", + " input_dim=seq_len, output_dim=embed_dim\n", + " )\n", + " self.seq_len = seq_len\n", + " self.vocab_size = vocab_size\n", + " self.embed_dim = embed_dim\n", + "\n", + " def call(self, inputs):\n", + " length = tf.shape(inputs)[-1]\n", + " positions = tf.range(start=0, limit=length, delta=1)\n", + " embedded_tokens = self.token_embeddings(inputs)\n", + " embedded_positions = self.position_embeddings(positions)\n", + " return embedded_tokens + embedded_positions\n", + "\n", + " def compute_mask(self, inputs, mask=None):\n", + " return tf.math.not_equal(inputs, 0)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embed_dim': self.embed_dim,\n", + " 'vocab_size': self.vocab_size,\n", + " 'seq_len': self.seq_len})\n", + " return config\n", + "\n", + "\n", + "class TransformerDecoder(layers.Layer):\n", + " def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):\n", + " super(TransformerDecoder, self).__init__(**kwargs)\n", + " self.embed_dim = embed_dim\n", + " self.latent_dim = latent_dim\n", + " self.num_heads = num_heads\n", + " self.attention_1 = layers.MultiHeadAttention(\n", + " num_heads=num_heads, key_dim=embed_dim\n", + " )\n", + " self.attention_2 = layers.MultiHeadAttention(\n", + " num_heads=num_heads, key_dim=embed_dim\n", + " )\n", + " self.dense_proj = keras.Sequential(\n", + " [layers.Dense(latent_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", + " )\n", + " self.layernorm_1 = layers.LayerNormalization()\n", + " self.layernorm_2 = layers.LayerNormalization()\n", + " self.layernorm_3 = layers.LayerNormalization()\n", + " self.supports_masking = True\n", + "\n", + " def call(self, inputs, encoder_outputs, mask=None):\n", + " causal_mask = self.get_causal_attention_mask(inputs)\n", + " if mask is not None:\n", + " padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=\"int32\")\n", + " padding_mask = tf.minimum(padding_mask, causal_mask)\n", + "\n", + " attention_output_1 = self.attention_1(\n", + " query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n", + " )\n", + " out_1 = self.layernorm_1(inputs + attention_output_1)\n", + "\n", + " attention_output_2 = self.attention_2(\n", + " query=out_1,\n", + " value=encoder_outputs,\n", + " key=encoder_outputs,\n", + " attention_mask=padding_mask,\n", + " )\n", + " out_2 = self.layernorm_2(out_1 + attention_output_2)\n", + "\n", + " proj_output = self.dense_proj(out_2)\n", + " return self.layernorm_3(out_2 + proj_output)\n", + "\n", + " def get_causal_attention_mask(self, inputs):\n", + " input_shape = tf.shape(inputs)\n", + " batch_size, sequence_length = input_shape[0], input_shape[1]\n", + " i = tf.range(sequence_length)[:, tf.newaxis]\n", + " j = tf.range(sequence_length)\n", + " mask = tf.cast(i >= j, dtype=\"int32\")\n", + " mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))\n", + " mult = tf.concat(\n", + " [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],\n", + " axis=0,\n", + " )\n", + " return tf.tile(mask, mult)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embed_dim': self.embed_dim,\n", + " 'latent_dim': self.latent_dim,\n", + " 'num_heads': self.num_heads})\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Build the end-to-end model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_encoder_decoder_model():\n", + " encoder_inputs = keras.Input(shape=(20,), dtype=\"int64\", name=\"encoder_inputs\")\n", + " x = PositionalEmbedding(seq_len, vocab_size, embed_dim)(encoder_inputs)\n", + " encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)\n", + " decoder_inputs = keras.Input(shape=(20,), dtype=\"int64\", name=\"decoder_inputs\")\n", + " encoded_seq_inputs = encoder_outputs\n", + " x = PositionalEmbedding(seq_len, vocab_size, embed_dim)(decoder_inputs)\n", + " x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)\n", + " x = layers.Dropout(0.5)(x)\n", + " decoder_outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n", + " \n", + " transformer = keras.Model(\n", + " [encoder_inputs, decoder_inputs], decoder_outputs, name=\"transformer\"\n", + " )\n", + "\n", + " return transformer\n", + "\n", + "transformer = get_encoder_decoder_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Training the original Transformer model from Keras example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transformer.summary()\n", + "train(transformer, model_type='original')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(d) Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get BLEU score on test set for original transformer model\n", + "get_text_result(transformer, no_input_masks=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get accuracy on test set for the original transformer model from Keras example\n", + "evaluate(transformer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 8. Create FP32 Function Model for the Transformer model\n", + "\n", + "Custom Keras layers must be defined for all of the low-level TensorFlow operators (each must only contain a single operation for QAT).\n", + "\n", + "Since none of these will have any prunable weights, first we create a base prunable layer class to extend, instead of `tf.keras.layers.Layer`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Create a base prunable layer class" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class PrunableLayer(tf.keras.layers.Layer, tfmot.sparsity.keras.PrunableLayer):\n", + " def get_prunable_weights(self): return []" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Define low level TensorFlow operations as Keras subclassed layers\n", + "\n", + "Note that some of these layers have trainable weights defined using the `add_weight` method. These weights will not be pruned or clustered." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Tanh(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + "\n", + " def call(self, x):\n", + " return tf.math.tanh(x)\n", + "\n", + "\n", + "class Relu(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, x):\n", + " return tf.maximum(0., x)\n", + "\n", + " \n", + "class MatMul(PrunableLayer):\n", + " def __init__(self, transpose_b=False, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.transpose_b = transpose_b \n", + " \n", + " def call(self, inputs):\n", + " return tf.linalg.matmul(*inputs, transpose_b=self.transpose_b)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'transpose_b': self.transpose_b})\n", + " return config\n", + "\n", + "\n", + "class Multiply(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, inputs):\n", + " return tf.multiply(*inputs)\n", + "\n", + "\n", + "# Calling Multiply with a scalar input will lead to an error.\n", + "# Use the following ScalarMultiply class instead.\n", + "class ScalarMultiply(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar \n", + " \n", + " def call(self, x):\n", + " return tf.math.multiply(x, self.scalar)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class Add(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, inputs):\n", + " return tf.math.add(*inputs)\n", + "\n", + "\n", + "# Calling Add with a scalar input will lead to an error.\n", + "# Use the following ScalarAdd class instead.\n", + "class ScalarAdd(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar \n", + " \n", + " def call(self, x):\n", + " return tf.math.add(x, self.scalar)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class Slice(PrunableLayer):\n", + " def __init__(self, seq_idx, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.seq_idx = seq_idx \n", + " \n", + " def call(self, x):\n", + " return x[:, self.seq_idx, ...]\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'seq_idx': self.seq_idx})\n", + " return config\n", + "\n", + "\n", + "class Mean(PrunableLayer):\n", + " def __init__(self, axes=None, keepdims=True, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes=axes\n", + " self.keepdims = keepdims \n", + " \n", + " def call(self, x):\n", + " return tf.math.reduce_mean(x, axis=self.axes, keepdims=self.keepdims)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes,\n", + " 'keepdims': self.keepdims})\n", + " return config\n", + "\n", + "\n", + "class Subtract(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs) \n", + " \n", + " def call(self, inputs):\n", + " return tf.math.subtract(*inputs)\n", + "\n", + "\n", + "class ScalarSubtract(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.scalar = scalar \n", + " \n", + " def call(self, x):\n", + " return tf.math.subtract(self.scalar,x)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class SquaredDiffrence(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs) \n", + " \n", + " def call(self,inputs):\n", + " return tf.math.squared_difference(*inputs)\n", + "\n", + "\n", + "class StopGradient(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, x):\n", + " return tf.stop_gradient(x)\n", + "\n", + "\n", + "class RSqrt(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, x):\n", + " return tf.math.rsqrt(x)\n", + "\n", + "\n", + "class Clip(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs)\n", + " \n", + " def call(self, x):\n", + " return tf.clip_by_value(x, 0.001, 255.0)\n", + "\n", + "\n", + "class BroadcastToken(PrunableLayer):\n", + " \"\"\"Layer to broadcast the class token\"\"\"\n", + " def __init__(self, embedding_dim, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(1, 1, self.embedding_dim), initializer='zeros', \n", + " trainable=True, name='token')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " batch_size = tf.shape(x)[0]\n", + " return tf.broadcast_to(self.w, [batch_size, 1, self.embedding_dim])\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim})\n", + " return config\n", + "\n", + "\n", + "class AddPositionalEmbedding(PrunableLayer):\n", + " \"\"\"Layer to add positional embeddings to the tokens\"\"\"\n", + " def __init__(self, seq_len, embedding_dim, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + " self.seq_len = seq_len\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(self.seq_len, self.embedding_dim), initializer= 'uniform',\n", + " trainable=True, name='pos_emb')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return x + self.w\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim, 'seq_len': self.seq_len})\n", + " return config\n", + "\n", + "\n", + "class AddTokenEmbedding(PrunableLayer): \n", + " \"\"\"Layer to add token embeddings to the tokens\"\"\"\n", + " def __init__(self, vocab_size, embedding_dim, train = True, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.embedding_dim = embedding_dim\n", + " self.vocab_size = vocab_size\n", + " self.train = train\n", + "\n", + " def build(self, input_shape):\n", + " self.w = self.add_weight(shape=(self.vocab_size, self.embedding_dim), initializer= 'uniform',\n", + " trainable=self.train, name='token_emb')\n", + " super().build(input_shape)\n", + "\n", + " def call(self, x):\n", + " return tf.gather(self.w,x)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'embedding_dim': self.embedding_dim, 'vocab_size': self.vocab_size, 'train': self.train})\n", + " return config\n", + "\n", + " def compute_output_shape(self, input_shape):\n", + " return(input_shape[-1], self.embedding_dim)\n", + "\n", + "\n", + "class Scale(PrunableLayer):\n", + " \"\"\"Multiply with gamma (LayerNorm)\"\"\"\n", + " def __init__(self, axes, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes = axes \n", + " \n", + " def build(self, input_shape):\n", + " param_shape = [input_shape[dim] for dim in self.axes]\n", + " self.w = self.add_weight(name='gamma', shape=param_shape,\n", + " trainable=True, initializer='ones')\n", + " super().build(input_shape)\n", + " \n", + " def call(self, x):\n", + " return tf.multiply(x, self.w)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes})\n", + " return config\n", + "\n", + " \n", + "class Centre(PrunableLayer):\n", + " \"\"\"Add beta (LayerNorm)\"\"\"\n", + " def __init__(self, axes, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.axes = axes \n", + " \n", + " def build(self, input_shape):\n", + " param_shape = [input_shape[dim] for dim in self.axes]\n", + " self.w = self.add_weight(name='beta', shape=param_shape,\n", + " trainable=True, initializer='zeros')\n", + " super().build(input_shape)\n", + " \n", + " def call(self, x):\n", + " return tf.math.add(x, self.w)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'axes': self.axes})\n", + " return config\n", + "\n", + "\n", + "class Minimum(PrunableLayer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__(**kwargs) \n", + " \n", + " def call(self,inputs):\n", + " return tf.minimum(*inputs)\n", + "\n", + "\n", + "class MinimumScalar(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs) \n", + " self.scalar=scalar\n", + "\n", + " def call(self,inputs):\n", + " return tf.minimum(inputs, self.scalar)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class MaximumScalar(PrunableLayer):\n", + " def __init__(self, scalar, **kwargs):\n", + " super().__init__(**kwargs) \n", + " self.scalar=scalar\n", + "\n", + " def call(self,inputs):\n", + " return tf.maximum(inputs, self.scalar)\n", + " \n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'scalar': self.scalar})\n", + " return config\n", + "\n", + "\n", + "class Cast(PrunableLayer):\n", + " def __init__(self, type = tf.int32, **kwargs):\n", + " super().__init__(**kwargs) \n", + " self.type=type \n", + "\n", + " def call(self,inputs):\n", + " return tf.cast(inputs, self.type)\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({'type': self.type})\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Define Transormer layers like multiheaded-attention, layer-norm, etc. functionally" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def self_attention(query, key, value, n_heads, dim, mask=None, name='mha', block_name=None, out_dim=None):\n", + " \"\"\"Multi-head attention layer\"\"\"\n", + " depth = dim // n_heads\n", + " if out_dim is None: out_dim = query.shape[-1]\n", + " q = tf.keras.layers.Dense(units=dim, name=f'{name}/query')(query)\n", + " k = tf.keras.layers.Dense(units=dim, name=f'{name}/key')(key)\n", + " v = tf.keras.layers.Dense(units=dim, name=f'{name}/value')(value)\n", + "\n", + " q = tf.keras.layers.Reshape((-1, n_heads, depth))(q)\n", + " q = tf.keras.layers.Permute((2, 1, 3))(q)\n", + " k = tf.keras.layers.Reshape((-1, n_heads, depth))(k)\n", + " k = tf.keras.layers.Permute((2, 1, 3))(k)\n", + " v = tf.keras.layers.Reshape((-1, n_heads, depth))(v)\n", + " v = tf.keras.layers.Permute((2, 1, 3))(v)\n", + " qk = ScalarMultiply(depth ** -0.5)(MatMul(transpose_b=True)([q, k]))\n", + "\n", + " if mask is not None:\n", + " if isinstance(mask, tf.Tensor):\n", + " qk = ScalarMultiply(mask)(qk)\n", + " mask=1. - mask\n", + " mask = mask * -10\n", + " qk = ScalarAdd(mask)(qk)\n", + " \n", + " else:\n", + " qk = Multiply()([qk, mask])\n", + " mask = ScalarSubtract(1.)(mask)\n", + " mask = ScalarMultiply(-10)(mask)\n", + " qk = Add(name=f'add/{name}')([qk, (mask)])\n", + " \n", + " attn_weights = tf.keras.layers.Softmax(axis=-1)(qk)\n", + " attn_out = MatMul()([attn_weights, v]) \n", + " attn_out = tf.keras.layers.Permute((2, 1, 3))(attn_out)\n", + " attn_out = tf.keras.layers.Reshape((-1, dim))(attn_out)\n", + " out = tf.keras.layers.Dense(out_dim, name=f'{name}/output_dense', dtype=\"float32\")(attn_out)\n", + " \n", + " return out, attn_weights\n", + "\n", + "def AddPositionalEmbeddingForEncoderDecoder(inputs, seq_len, vocab_size, embed_dim, block_name, freeze):\n", + " x = AddTokenEmbedding(vocab_size, embed_dim, train = not freeze, name= ('token_embedding/' + block_name))(inputs)\n", + " x = AddPositionalEmbedding(seq_len, embed_dim, name= ('positional_embedding/' + block_name))(x)\n", + " return x\n", + " \n", + "def enc_padding_mask(inputs):\n", + " computed_mask=tf.keras.layers.Reshape((1, 1, -1))(inputs)\n", + " return computed_mask \n", + "\n", + "def causal_mask(inputs):\n", + " seq_len=inputs.shape[1]\n", + " causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)\n", + " return causal_mask\n", + "\n", + "def dec_padding_mask(inputs, cau_mask):\n", + " padding_mask= enc_padding_mask(inputs)\n", + " padding_mask = MinimumScalar(scalar=cau_mask)(padding_mask)\n", + " return padding_mask\n", + "\n", + "def layer_norm(x, axes=2, epsilon=0.001, name='layer_norm', trainable = True):\n", + " \"\"\"LayerNormalization\"\"\"\n", + " if isinstance(axes, int): axes = [axes]\n", + " \n", + " mean = Mean(axes=axes, dtype=x.dtype)(x)\n", + " ## This block can be replaced with a squared_difference layer ##\n", + " diff = Subtract()([x, StopGradient()(mean)]) ##\n", + " sq_diff = Multiply()([diff, diff]) ##\n", + " ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##\n", + " variance = Mean(axes=axes,dtype=x.dtype ,name=f'{name}/variance')(sq_diff)\n", + " if not trainable:\n", + " inv = RSqrt()(variance)\n", + " x = Multiply()([diff, inv])\n", + " else:\n", + " # MaximumScalar prevents division by 0.\n", + " inv = RSqrt()(MaximumScalar(epsilon)(variance))\n", + " # This layer is removed for inference so it is named.\n", + " x = Subtract(name=f'{name}/grad_subtract')([x, mean]) \n", + " x = Multiply()([x, inv])\n", + "\n", + " x = Scale(axes=axes)(x)\n", + " x = Centre(axes=axes)(x)\n", + " \n", + " return x\n", + "\n", + "def mlp(x, hidden_dim, out_dim=None):\n", + " \"\"\"Multi-layer perceptron block\"\"\"\n", + " if out_dim is None: out_dim = x.shape[-1]\n", + "\n", + " x = tf.keras.layers.Dense(units=hidden_dim)(x)\n", + " x = Relu()(x)\n", + " x = tf.keras.layers.Dense(units=out_dim)(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(d) Build end-to-end model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "def get_translation_model(input_shape, batch_size, seq_len=seq_len, vocab_size=vocab_size, embed_dim=256, num_heads=8, freeze=False, trainable=True):\n", + " \n", + " aux_output=defaultdict(list)\n", + " ## Encoder\n", + " \n", + " # Input to encoder\n", + " enc_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"encoder_inputs\")\n", + " encoder_inputs=Cast()(enc_inputs)\n", + " \n", + " x = AddPositionalEmbeddingForEncoderDecoder(encoder_inputs, seq_len, vocab_size, embed_dim, 'encoder', freeze)\n", + " encoder_padding_mask_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"encoder_masks\")\n", + " encoder_padding_mask = enc_padding_mask(encoder_padding_mask_inputs)\n", + "\n", + " # Encoder Attention block\n", + " attention_output, attention_weights = self_attention(x, x, x, num_heads, embed_dim*num_heads, mask=encoder_padding_mask, name=(f'mha'), block_name=(f'encoder'))\n", + " proj_input = tf.keras.layers.Add()([x, attention_output])\n", + " proj_input = layer_norm(proj_input, name=(f'layer_norm'), trainable=trainable)\n", + "\n", + " # MLP block\n", + " proj_output = mlp(proj_input, latent_dim, embed_dim)\n", + " x = tf.keras.layers.Add()([proj_input, proj_output])\n", + " encoder_outputs = layer_norm(x, name=(f'layer_norm_1'), trainable=trainable)\n", + " \n", + " ## Decoder\n", + " \n", + " # Input to decoder\n", + " dec_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"decoder_inputs\")\n", + " decoder_inputs=Cast()(dec_inputs)\n", + "\n", + " x = AddPositionalEmbeddingForEncoderDecoder(decoder_inputs, seq_len, vocab_size, embed_dim, 'decoder', freeze)\n", + " decoder_causal_mask = causal_mask(decoder_inputs)\n", + " decoder_padding_mask_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name=\"decoder_masks\")\n", + " decoder_padding_mask = dec_padding_mask(decoder_padding_mask_inputs, decoder_causal_mask)\n", + " \n", + " \n", + " # Decoder Attention Block 1\n", + " attention_output_1, attention_weights_1 = self_attention(x, x, x, num_heads, embed_dim*num_heads, mask=decoder_causal_mask, name=(f'mha_1'), block_name=(f'decoder_1'))\n", + " x1 = tf.keras.layers.Add()([x, attention_output_1])\n", + " out_1 = layer_norm(x1, name=(f'layer_norm_2'), trainable=trainable)\n", + " \n", + " # Decoder Attention Block 2\n", + " attention_output_2, attention_weights_2 = self_attention(out_1, encoder_outputs, encoder_outputs, num_heads, embed_dim*num_heads, mask=decoder_padding_mask, name=(f'mha_2'), block_name=(f'decoder_2'))\n", + " x2 = tf.keras.layers.Add()([out_1, attention_output_2])\n", + " out_2 = layer_norm(x2, name=(f'layer_norm_3'), trainable=trainable)\n", + " \n", + " # MLP Block\n", + " proj_output = mlp(out_2, latent_dim, embed_dim)\n", + " x3 = tf.keras.layers.Add()([out_2, proj_output])\n", + " x3 = layer_norm(x3, name=(f'layer_norm_4'), trainable=trainable)\n", + " \n", + "\n", + " x3 = tf.keras.layers.Dropout(0.5)(x3)\n", + " x3 = tf.keras.layers.Dense(units=vocab_size, name=\"dense_last\", activation='softmax')(x3)\n", + " \n", + " transformer = keras.Model(\n", + " [enc_inputs,encoder_padding_mask_inputs, dec_inputs, decoder_padding_mask_inputs], x3, name=\"transformer\"\n", + " )\n", + " \n", + " return transformer\n", + "\n", + "tf.keras.backend.clear_session() # reset layer name counters\n", + "\n", + "transform = get_translation_model(input_shape = (seq_len,), batch_size = batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(e) Train the FP32 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transform.summary()\n", + "train(transform, model_type='fp32')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(f) Evaluate Performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get BLEU score on test set for FP32 transformer model\n", + "get_text_result(transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get accuracy on test set for the FP32 transformer model from Keras example\n", + "evaluate(transform)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 9. Perform Pruning on FP32 model with TFMOT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Apply Pruning API\n", + "\n", + "Apply pruning wrapper for '2:4' pruning, i.e. for every 4 continuous weight elements/values 2 will be zero " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude\n", + "pruned_model = prune_low_magnitude(transform, sparsity_m_by_n=(2,4))\n", + "pruned_model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Perform Pruning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fine-tune with pruning\n", + "train(pruned_model, model_type='prune', epochs=3)\n", + "\n", + "# Remove pruning wrapper from the trained model\n", + "stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Check sparsity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print_model_weights_sparsity(stripped_pruned_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(d) Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "evaluate(stripped_pruned_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_text_result(stripped_pruned_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10. Perform QAT on FP32 model with TFMOT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Create the Prune Preserve Quantizer\n", + "\n", + "Ideally, there shouldn't be a need to write this quantizer as this can be directly done by using the `tfmot.quantization.keras.collaborative_optimizations.Default8BitPrunePreserveQuantizeScheme`. But in case of our translation transformer model (as you'll notice later in the section 10.[d]) QAT annotations are not done to custom `Cast` layer or in other words a QuantizeConfig is not assigned to the `Cast` layer, this makes the `Default8BitPrunePreserveQuantizeScheme` fail as it expects that everly layer has a QuantizeConfig which certainly is not the case.\n", + "\n", + "To circumvent this issue, the quantizer class used in the `Default8BitPrunePreserveQuantizeScheme` api has been used here directly to avoid the limitation of the api." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow_model_optimization.python.core.quantization.keras import quant_ops\n", + "from tensorflow_model_optimization.quantization.keras import QuantizeConfig, quantizers\n", + "\n", + "# This piece of code has been taken from the implementation of Default8BitPrunePreserveQuantizeScheme\n", + "\n", + "class PrunePreserveLastValueWeightsQuantizer(quantizers.LastValueQuantizer):\n", + " \"\"\"Quantize weights while preserve sparsity.\"\"\"\n", + "\n", + " def __init__(self, num_bits, per_axis, symmetric, narrow_range):\n", + " \"\"\"Initializes PrunePreserveDefaultWeightsQuantizer.\n", + " Args:\n", + " num_bits: Number of bits for quantization\n", + " per_axis: Whether to apply per_axis quantization. The last dimension is\n", + " used as the axis.\n", + " symmetric: If true, use symmetric quantization limits instead of training\n", + " the minimum and maximum of each quantization range separately.\n", + " narrow_range: In case of 8 bits, narrow_range nudges the quantized range\n", + " to be [-127, 127] instead of [-128, 127]. This ensures symmetric range\n", + " has 0 as the centre.\n", + " \"\"\"\n", + " quantizers.LastValueQuantizer.__init__(self, num_bits, per_axis, symmetric,\n", + " narrow_range)\n", + "\n", + " def _build_sparsity_mask(self, name, layer):\n", + " weights = getattr(layer.layer, name)\n", + " sparsity_mask = tf.math.divide_no_nan(weights, weights)\n", + " return {'sparsity_mask': sparsity_mask}\n", + "\n", + " def build(self, tensor_shape, name, layer):\n", + " \"\"\"Constructs mask to preserve weights sparsity.\n", + " Args:\n", + " tensor_shape: Shape of weights which needs to be quantized.\n", + " name: Name of weights in layer.\n", + " layer: quantization wrapped keras layer.\n", + " Returns:\n", + " Dictionary of constructed sparsity mask and\n", + " quantization params, the dictionary will be passed\n", + " to __call__ function.\n", + " \"\"\"\n", + " result = self._build_sparsity_mask(name, layer)\n", + " result.update(\n", + " super(PrunePreserveLastValueWeightsQuantizer,\n", + " self).build(tensor_shape, name, layer))\n", + " return result\n", + "\n", + " def __call__(self, inputs, training, weights, **kwargs):\n", + " \"\"\"Applies sparsity preserved quantization to the input tensor.\n", + " Args:\n", + " inputs: Input tensor (layer's weights) to be quantized.\n", + " training: Whether the graph is currently training.\n", + " weights: Dictionary of weights (params) the quantizer can use to\n", + " quantize the tensor (layer's weights). This contains the weights\n", + " created in the `build` function.\n", + " **kwargs: Additional variables which may be passed to the quantizer.\n", + " Returns:\n", + " quantized tensor.\n", + " \"\"\"\n", + "\n", + " prune_preserve_inputs = tf.multiply(inputs, weights['sparsity_mask'])\n", + "\n", + " return quant_ops.LastValueQuantize(\n", + " prune_preserve_inputs,\n", + " weights['min_var'],\n", + " weights['max_var'],\n", + " is_training=training,\n", + " num_bits=self.num_bits,\n", + " per_channel=self.per_axis,\n", + " symmetric=self.symmetric,\n", + " narrow_range=self.narrow_range,\n", + " )\n", + "\n", + "\n", + "class PrunePreserveWeightsQuantizer(\n", + " PrunePreserveLastValueWeightsQuantizer):\n", + " \"\"\"PrunePreserveWeightsQuantizer for default 8bit weights.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super(PrunePreserveWeightsQuantizer,\n", + " self).__init__(num_bits=8,\n", + " per_axis=False,\n", + " symmetric=True,\n", + " narrow_range=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) To use the custom Keras layers we defined, we need to pass a [`QuantizeConfig`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/QuantizeConfig) for each of these layers.\n", + "\n", + "For Keras layers which are already supported in TFMOT, a default QuantizeConfig class is assigned to each one. However, custom QuantizeConfig instances could also be created for these layers to give more control over how they are quantised." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "LastValueQuantizer = quantizers.LastValueQuantizer\n", + "MovingAverageQuantizer = quantizers.MovingAverageQuantizer\n", + "AllValuesQuantizer = quantizers.AllValuesQuantizer\n", + "PrunePreserveQuantizer = PrunePreserveWeightsQuantizer\n", + "\n", + "class NoOpQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which does not quantize any part of the layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return []\n", + " \n", + " def get_config(self):\n", + " return {}\n", + "\n", + "\n", + "class TFOpQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which only quantizes the output of a layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "\n", + "class MaskOpQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which only quantizes the output of a layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + " \n", + "\n", + "class VarianceQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig for the variance calculation in the layer normalisation layer.\"\"\"\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " pass\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "\n", + "class WeightQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig which quantizes the custom weights in the patch encoder and layer normalisation layers.\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.weight_quantizer = LastValueQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=True, narrow_range=True)\n", + " self.activation_quantizer = MovingAverageQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=False, narrow_range=False)\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return [(layer.w, self.weight_quantizer)]\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return []\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " layer.w = quantize_weights[0]\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " pass\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return [self.activation_quantizer]\n", + "\n", + " def get_config(self):\n", + " return {}\n", + "\n", + "\n", + "class DenseQuantizeConfig(QuantizeConfig):\n", + " \"\"\"QuantizeConfig for dense layers that need sparsity to be preserved while performing QAT\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.weight_attrs = ['kernel']\n", + " self.activation_attrs = ['activation']\n", + " self.weight_quantizer = PrunePreserveWeightsQuantizer()\n", + " self.activation_quantizer = MovingAverageQuantizer(num_bits=8, per_axis=False,\n", + " symmetric=False, narrow_range=False)\n", + "\n", + " def get_weights_and_quantizers(self, layer):\n", + " return [(getattr(layer, weight_attr), self.weight_quantizer)\n", + " for weight_attr in self.weight_attrs]\n", + "\n", + " def get_activations_and_quantizers(self, layer):\n", + " return [(getattr(layer, activation_attr), self.activation_quantizer)\n", + " for activation_attr in self.activation_attrs]\n", + "\n", + " def set_quantize_weights(self, layer, quantize_weights):\n", + " if len(self.weight_attrs) != len(quantize_weights):\n", + " raise ValueError(\n", + " '`set_quantize_weights` called on layer {} with {} '\n", + " 'weight parameters, but layer expects {} values.'.format(\n", + " layer.name, len(quantize_weights), len(self.weight_attrs)))\n", + "\n", + " for weight_attr, weight in zip(self.weight_attrs, quantize_weights):\n", + " current_weight = getattr(layer, weight_attr)\n", + " if current_weight.shape != weight.shape:\n", + " raise ValueError('Existing layer weight shape {} is incompatible with'\n", + " 'provided weight shape {}'.format(\n", + " current_weight.shape, weight.shape))\n", + "\n", + " setattr(layer, weight_attr, weight)\n", + "\n", + " def set_quantize_activations(self, layer, quantize_activations):\n", + " if len(self.activation_attrs) != len(quantize_activations):\n", + " raise ValueError(\n", + " '`set_quantize_activations` called on layer {} with {} '\n", + " 'activation parameters, but layer expects {} values.'.format(\n", + " layer.name, len(quantize_activations),\n", + " len(self.activation_attrs)))\n", + "\n", + " for activation_attr, activation in \\\n", + " zip(self.activation_attrs, quantize_activations):\n", + " setattr(layer, activation_attr, activation)\n", + "\n", + " def get_output_quantizers(self, layer):\n", + " return []\n", + "\n", + " def get_config(self):\n", + " return {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Define wrapper function\n", + "\n", + "Since custom layers and QuantizeConfigs are used, the whole model cannot directly be wrapped with QAT wrappers.\n", + "So first we write a function to wrap the individual layers with QAT wrappers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def apply_wrapper(wrapper_function, layer_param_dict):\n", + " \n", + " def wrap_layer(layer):\n", + " if layer.name in layer_param_dict.keys():\n", + " return wrapper_function(layer, **layer_param_dict[layer.name])\n", + " return layer\n", + "\n", + " return wrap_layer\n", + "\n", + "def layer_wrapper(model, wrapper_function, layer_param_dict):\n", + " return tf.keras.models.clone_model(model, clone_function=apply_wrapper(wrapper_function, layer_param_dict))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(d) Assign QuantizeConfigs to custom layers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_quantize_config(model):\n", + " layer_param_dict = {} # stores {Layer_Name: QuantizeConfig} pairs\n", + " scope = {} # stores all custom objects\n", + "\n", + " for layer in model.layers:\n", + " \n", + " if layer.name.startswith(('clip', 'minimum', 'minimum_scalar', 'maximum_scalar', 'cast', 'stop_gradient')):\n", + " layer_param_dict[layer.name] = {'quantize_config': NoOpQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + " \n", + " elif 'grad_subtract' in layer.name or layer.name.startswith(('mat_mul', 'multiply', 'scalar_multiply', 'add',\n", + " 'scalar_add', 'slice', 'mean', 'subtract',\n", + " 'scalar_subtract', 'r_sqrt', 'relu')):\n", + " layer_param_dict[layer.name] = {'quantize_config': TFOpQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + " \n", + " elif layer.name.startswith(( 'scale', 'centre', 'positional_embedding', 'token_embedding' )):\n", + " layer_param_dict[layer.name] = {'quantize_config': WeightQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + " # Make sure to quantise the encoder and decoder mask input layers so that they can be quantized to INT8\n", + " \n", + " elif layer.name.startswith(('encoder_masks', 'decoder_masks' )):\n", + " layer_param_dict[layer.name] = {'quantize_config': MaskOpQuantizeConfig()}\n", + " \n", + " elif isinstance(layer, tf.keras.layers.Dense):\n", + " layer_param_dict[layer.name] = {'quantize_config':DenseQuantizeConfig()}\n", + "\n", + " elif 'variance' in layer.name:\n", + " layer_param_dict[layer.name] = {'quantize_config': VarianceQuantizeConfig()}\n", + " scope[layer.__class__.__name__] = layer.__class__\n", + "\n", + "\n", + " scope['DenseQuantizeConfig'] = DenseQuantizeConfig\n", + " scope['NoOpQuantizeConfig'] = NoOpQuantizeConfig\n", + " scope['TFOpQuantizeConfig'] = TFOpQuantizeConfig\n", + " scope['WeightQuantizeConfig'] = WeightQuantizeConfig\n", + " scope['VarianceQuantizeConfig'] = VarianceQuantizeConfig\n", + " scope['MaskOpQuantizeConfig'] = MaskOpQuantizeConfig\n", + "\n", + " return layer_param_dict, scope\n", + "\n", + "layer_param_dict, scope = get_quantize_config(stripped_pruned_model)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Few layers like `cast`, `encoder_inputs` and `decoder_inputs` musn't be annontated with any QuantizeConfig as this will result into a `quantize` node being added after the inputs in tflite graph, which would pass down an int8 value to the `tfl.gather` operation.
And since, the `tfl.gather` operation expects only int32 and int64 as the indices, an int8 value in the `tfl.gather` operation will result into error ([Please refer TF Lite Ops Page](https://www.tensorflow.org/mlir/tfl_ops#tflgather_mlirtflgatherop))." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def remove_unwanted_layers(model, layer_param_dict):\n", + " # All the custom layers that don't need quantization can be added along side 'cast'\n", + "\n", + " layers_to_not_quantise = [x.name for x in model.layers if not any([y in x.name for y in ['cast', 'encoder_inputs', 'decoder_inputs' \n", + " ]])]\n", + " layer_param_dict = {k: v for k, v in layer_param_dict.items() if k in layers_to_not_quantise}\n", + " for k in layers_to_not_quantise:\n", + " if k not in layer_param_dict:\n", + " layer_param_dict[k] = {'quantize_config': None}\n", + "\n", + " return layer_param_dict\n", + "\n", + "layer_param_dict = remove_unwanted_layers(stripped_pruned_model, layer_param_dict)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(e) Load the necessary API classes/functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n", + "quantize_apply = tfmot.quantization.keras.quantize_apply\n", + "quantize_scope = tfmot.quantization.keras.quantize_scope" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(f) Annotate individual layers\n", + "\n", + "When calling the quantize_apply function, if an unsupported layer is missing from the scope, TFMOT will throw an error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qat_model = layer_wrapper(stripped_pruned_model, quantize_annotate_layer, layer_param_dict)\n", + "\n", + "with quantize_scope(scope):\n", + " qat_model = quantize_apply(qat_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(g) Perform QAT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "qat_model.summary()\n", + "train(qat_model, model_type='qat', epochs=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(h) Make sure weights are still Pruned" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print_model_weights_sparsity(qat_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(i) Evaluate Performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_text_result(qat_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "evaluate(qat_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 11. Create INT8 tflite file for PQAT FP32 model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we attempt to directly generate a TFLite file using the fine-tuned model above:\n", + "\n", + "- It will not have a correct batch size of 1.\n", + "- It will have operators which are unnecessary during inference. Precisely, the extra `Subtract` operators and `MaximumScalar` operator in the layer normalisation blocks, which were used during training and fine-tuning, should be removed from the graph before creating the TFLite file.\n", + "\n", + "Therefore the network should be redefined with a batch size of 1 and with the redundant operators removed. The weights of the fine-tuned optimised model can then be loaded into this new model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Remove layers which are not required" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tf.keras.backend.clear_session() # reset layer name counters\n", + "\n", + "new_qat_model = get_translation_model(input_shape = (seq_len,), batch_size = batch_size, trainable = False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Annotate individual layers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the QuantizeConfig and Scope which would be used to annotate the layers\n", + "layer_param_dict, scope = get_quantize_config(new_qat_model)\n", + "# Remove unwanted QuantizeConfigs\n", + "layer_param_dict = remove_unwanted_layers(new_qat_model, layer_param_dict) \n", + "\n", + "new_qat_model = layer_wrapper(new_qat_model, quantize_annotate_layer, layer_param_dict)\n", + "\n", + "with quantize_scope(scope):\n", + " new_qat_model = quantize_apply(new_qat_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(c) Load weights into the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_qat_model.load_weights('./eng_spa_transformer_pqat_tutorial_qat_model.h5', by_name=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sanity check to see if weights are loaded correctly\n", + "evaluate(new_qat_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(a) Create tflite file (int8 ops)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PATH = './encoder_decoder_pqat.tflite'\n", + "\n", + "i = tf.keras.Input(shape=(20,), batch_size=1, dtype = tf.int32)\n", + "j = tf.keras.Input(shape=(20,), batch_size=1)\n", + "k = tf.keras.Input(shape=(20,), batch_size=1, dtype = tf.int32)\n", + "l = tf.keras.Input(shape=(20,), batch_size=1)\n", + "\n", + "# The following is done to ensure that the batch size of input in\n", + "# tflite graph is 1\n", + "net = tf.keras.Model(inputs=[i, j,k,l,], outputs=new_qat_model.call([i,j,k,l]))\n", + "\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(net)\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "\n", + "# The following two lines ensure that the mask inputs\n", + "# and the output are int8\n", + "converter.inference_input_type = tf.int8\n", + "converter.inference_output_type = tf.int8\n", + "\n", + "# Toggle this option to fold/unfold batchmatmul\n", + "converter._experimental_disable_batchmatmul_unfold = True\n", + "\n", + "tflite_model = converter.convert()\n", + "with open(MODEL_PATH, \"wb+\") as tflite_file:\n", + " tflite_file.write(tflite_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(b) Evaluate Performance\n", + "\n", + "NOTE: These steps are slow to execute therefore, the number of samples on which evaluation is performed is set to 200 by default (but definitely can be modified by the user)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_tflite_accuracy(MODEL_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_text_result_tflite(MODEL_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Model size: \", get_gzipped_model_size(MODEL_PATH), ' KB')" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "cb8a88e82453314166ba9bb471422eb5142b42682e25f5c554fbcb3447334d71" + }, + "kernelspec": { + "display_name": "Python 3.6.9 64-bit ('venv': venv)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}