diff --git a/.gitattributes b/.gitattributes
index a587138..abd8569 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -3,3 +3,4 @@
*.data* filter=lfs diff=lfs merge=lfs -text
*.index filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index c19ccf6..46802cd 100644
--- a/README.md
+++ b/README.md
@@ -269,7 +269,7 @@
Ethos U |
- Wav2letter Pruned INT8 |
+ Wav2letter Pruned INT8 * |
INT8 |
TensorFlow Lite |
:heavy_check_mark: |
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/README.md b/models/speech_recognition/wav2letter/tflite_pruned_int8/README.md
index 5bef923..0cd3255 100644
--- a/models/speech_recognition/wav2letter/tflite_pruned_int8/README.md
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/README.md
@@ -10,12 +10,15 @@ Wav2letter is a convolutional speech recognition neural network. This implementa
### Class Labels
The class labels associated with this model can be downloaded by running the script `get_class_labels.sh`.
+### Model Recreation Code
+Code to recreate this model can be found [here](recreate_model/).
+
## Network Information
| Network Information | Value |
|---------------------|----------------|
| Framework | TensorFlow Lite |
-| SHA-1 Hash | e389797705f5f8a7973c3280954dd5cdf54284a1 |
-| Size (Bytes) | 23815520 |
+| SHA-1 Hash | 1771d122ba1ed9354188491e6efbcbd31cc8ba69 |
+| Size (Bytes) | 23766192 |
| Provenance | https://github.com/ARM-software/ML-zoo/tree/master/models/speech_recognition/wav2letter/tflite_pruned_int8 |
| Paper | https://arxiv.org/abs/1609.03193 |
@@ -36,7 +39,7 @@ Dataset: LibriSpeech
| Metric | Value |
|--------|-------|
-| LER | 0.07981431 |
+| LER | 0.07831 |
## Optimizations
| Optimization | Value |
@@ -52,7 +55,7 @@ Dataset: LibriSpeech
Description |
- input_2_int8 |
+ input_4 |
(1, 296, 39) |
Speech converted to MFCCs and quantized to INT8 |
@@ -66,7 +69,7 @@ Dataset: LibriSpeech
Description |
- Identity_int8 |
+ Identity |
(1, 1, 148, 29) |
A tensor of (batch, time, class probabilities) that represents the probability of each class at each timestep. Should be passed to a decoder e.g. ctc_beam_search_decoder. |
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/definition.yaml b/models/speech_recognition/wav2letter/tflite_pruned_int8/definition.yaml
index c2abcde..a2f0ce0 100644
--- a/models/speech_recognition/wav2letter/tflite_pruned_int8/definition.yaml
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/definition.yaml
@@ -1,25 +1,25 @@
benchmark:
LibriSpeech:
- LER: 0.07981431
+ LER: 0.07831443101167679
description: Wav2letter is a convolutional speech recognition neural network. This
implementation was created by Arm, pruned to 50% sparisty, fine-tuned and quantized
using the TensorFlow Model Optimization Toolkit.
license:
- Apache-2.0
network:
- file_size_bytes: 23815520
+ file_size_bytes: 23766192
filename: wav2letter_pruned_int8.tflite
framework: TensorFlow Lite
hash:
algorithm: sha1
- value: e389797705f5f8a7973c3280954dd5cdf54284a1
+ value: 1771d122ba1ed9354188491e6efbcbd31cc8ba69
provenance: https://github.com/ARM-software/ML-zoo/tree/master/models/speech_recognition/wav2letter/tflite_pruned_int8
network_parameters:
input_nodes:
- description: Speech converted to MFCCs and quantized to INT8
example_input:
- path: models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_2_int8
- name: input_2_int8
+ path: models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_4
+ name: input_4
shape:
- 1
- 296
@@ -29,13 +29,13 @@ network_parameters:
- description: A tensor of (batch, time, class probabilities) that represents the
probability of each class at each timestep. Should be passed to a decoder e.g.
ctc_beam_search_decoder.
- name: Identity_int8
+ name: Identity
shape:
- 1
- 1
- 148
- 29
- test_output_path: models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity_int8
+ test_output_path: models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity
operators:
TensorFlow Lite:
- CONV_2D
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/README.md b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/README.md
new file mode 100644
index 0000000..50e0d64
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/README.md
@@ -0,0 +1,11 @@
+# Wav2letter Pruned int8 Model Re-Creation
+This folder contains a script that allows for the model to be re-created from scratch.
+
+## Requirements
+The script in this folder requires that the following must be installed:
+- Python 3.6
+- Cuda 11.2
+- Sox
+
+## Running The Script
+To run the script, run the following in a terminal: `./recreate_model.sh`
\ No newline at end of file
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/librispeech_mfcc.py b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/librispeech_mfcc.py
new file mode 100644
index 0000000..406c2ce
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/librispeech_mfcc.py
@@ -0,0 +1,217 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pandas
+import os
+import librosa
+import tensorflow as tf
+
+
+def normalize(values):
+ """ Normalize values to mean 0 and std 1. """
+ return (values - np.mean(values)) / np.std(values)
+
+
+def overlap(batch_x, n_context=296, n_input=39):
+ """
+ Due to the requirement of static shapes(see fix_batch_size()),
+ we need to stack the dynamic data to form a static input shape.
+ Using the n_context of 296 (1 second of mfcc)
+ """
+ window_width = n_context
+ num_channels = n_input
+
+ batch_x = tf.expand_dims(batch_x, axis=0)
+ # Create a constant convolution filter using an identity matrix, so that the
+ # convolution returns patches of the input tensor as is, and we can create
+ # overlapping windows over the MFCCs.
+ eye_filter = tf.constant(
+ np.eye(window_width * num_channels).reshape(
+ window_width, num_channels, window_width * num_channels
+ ),
+ tf.float32,
+ )
+ # Create overlapping windows
+ batch_x = tf.nn.conv1d(batch_x, eye_filter, stride=1, padding="SAME")
+ # Remove dummy depth dimension and reshape into
+ # [n_windows, n_input]
+ batch_x = tf.reshape(batch_x, [-1, num_channels])
+
+ return batch_x
+
+
+def label_from_string(str_to_label: dict, string: str) -> int:
+ try:
+ return str_to_label[string]
+ except KeyError as e:
+ raise KeyError(
+ f"ERROR: String: {string} in transcripts not occur in alphabet!"
+ ).with_traceback(e.__traceback__)
+
+
+def text_to_int_array_wrapper(alphabet_dict: dict):
+ def text_to_int_array(original: str):
+ r"""
+ Given a Python string ``original``, map characters
+ to integers and return a numpy array representing the processed string.
+ """
+
+ return np.asarray([label_from_string(alphabet_dict, c) for c in original])
+ return text_to_int_array
+
+
+class LibriSpeechMfcc:
+ def __init__(self, data_dir: str):
+ """
+ Args:
+ data_dir: Absolute path to librispeech data folder
+ """
+ self.overlap = False
+ self.data_dir = data_dir
+ self.seed = 0
+ self.train = False
+ self.batch_size = 32
+ self.num_samples = 0
+ self.input_files = []
+
+ def training_set(self, overlap=False, batch_size=32):
+ """
+ Args:
+ overlap: boolean to create overlapping windows
+ batch_size: batch size required for the set
+ """
+ self.input_files = [
+ "librivox-train-clean-100.csv",
+ "librivox-train-clean-360.csv",
+ "librivox-train-other-500.csv",
+ ]
+
+ self.train = True
+ self.overlap = overlap
+ self.batch_size = batch_size
+ self.num_samples = 281241
+ return self.create_dataset()
+
+ def evaluation_set(self, overlap=False, batch_size=32):
+ """
+ Args:
+ overlap: boolean to create overlapping windows
+ batch_size: batch size required for the set
+ """
+ self.input_files = ["librivox-test-clean.csv"]
+
+ self.train = False
+ self.overlap = overlap
+ self.batch_size = batch_size
+ self.num_samples = 2620
+ return self.create_dataset()
+
+ def validation_set(self, overlap=False, batch_size=32):
+ """
+ Args:
+ overlap: boolean to create overlapping windows
+ batch_size: batch size required for the set
+ """
+ self.input_files = ["librivox-dev-clean.csv"]
+
+ self.train = False
+ self.overlap = overlap
+ self.batch_size = batch_size
+ self.num_samples = 2703
+ return self.create_dataset()
+
+ def create_dataset(self) -> tf.data.Dataset:
+ """Create dataset generator for use in fit and evaluation functions."""
+ df = self.read_input_files()
+ df.sort_values(by="wav_filesize", inplace=True)
+
+ # Convert to character index arrays
+ alphabet = "abcdefghijklmnopqrstuvwxyz' @"
+ alphabet_dict = {c: ind for (ind, c) in enumerate(alphabet)}
+ df["transcript"] = df["transcript"].apply(text_to_int_array_wrapper(alphabet_dict))
+
+ def generate_values():
+ for _, row in df.iterrows():
+ yield row.wav_filename, row.transcript
+
+ dataset = tf.data.Dataset.from_generator(
+ generate_values, output_types=(tf.string, tf.int32)
+ )
+ # librosa.feature.mfcc takes a long time to run when shuffling
+ # so lets shuffle the data before performing our mapping function
+ if self.train:
+ dataset = dataset.shuffle(
+ buffer_size=max(self.batch_size * 2, 1024), seed=self.seed
+ )
+
+ dataset = dataset.map(
+ lambda filename, transcript: tf.py_function(
+ self.load_data_mfcc,
+ inp=[filename, transcript],
+ Tout=[tf.float32, tf.int32],
+ )
+ )
+ dataset = dataset.padded_batch(
+ self.batch_size,
+ padded_shapes=(tf.TensorShape([None, 39]), tf.TensorShape([None])),
+ padding_values=(0.0, 28), drop_remainder=True
+ )
+ # Indication that shuffling is executed before mapping function
+ return dataset
+
+ def read_input_files(self) -> pandas.DataFrame:
+ """Read the input files required for a particular set."""
+ source_data = None
+ for csv in self.input_files:
+ file = pandas.read_csv(os.path.join(self.data_dir, csv), encoding="utf-8", na_filter=False)
+ csv_dir = os.path.dirname(os.path.abspath(csv))
+ file["wav_filename"] = file["wav_filename"].str.replace(
+ r"(^[^/])", lambda m: os.path.join(csv_dir, m.group(1))
+ ) # pylint: disable=cell-var-from-loop
+ if source_data is None:
+ source_data = file
+ else:
+ source_data = source_data.append(file)
+ return source_data
+
+ def num_steps(self, batch):
+ """
+ Get the number of steps based on the given batch size and the number
+ of samples.
+ """
+ return int(np.math.ceil(self.num_samples / batch))
+
+ def load_data_mfcc(self, filename, transcript):
+ """ Calculate mfcc from the given raw audio data for Wav2Letter. """
+ audio_data, samplerate = librosa.load(filename.numpy(), sr=16000)
+ mfcc = librosa.feature.mfcc(
+ audio_data, sr=samplerate, n_mfcc=13, n_fft=512, hop_length=160
+ )
+ mfcc_delta = librosa.feature.delta(mfcc)
+ mfcc_delta2 = librosa.feature.delta(mfcc, order=2)
+ mfcc = np.concatenate(
+ (normalize(mfcc), normalize(mfcc_delta), normalize(mfcc_delta2)), axis=0
+ )
+
+ seq_length = mfcc.shape[1] // 2
+ sequences = np.concatenate([[seq_length], transcript]).astype(np.int32)
+ mfcc_out = (
+ overlap(mfcc.T.astype(np.float32))
+ if self.overlap
+ else mfcc.T.astype(np.float32)
+ )
+ return mfcc_out, sequences
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/prune_quantize_model.py b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/prune_quantize_model.py
new file mode 100644
index 0000000..6504202
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/prune_quantize_model.py
@@ -0,0 +1,334 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Wav2letter optimization and evaluation script"""
+import argparse
+import datetime
+import multiprocessing
+import os
+import pathlib
+
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tqdm import tqdm
+import numpy as np
+
+from wav2letter import create_wav2letter, get_metrics
+from librispeech_mfcc import LibriSpeechMfcc
+
+
+def log(std):
+ """Log the given string to the standard output."""
+ print("******* {}".format(std), flush=True)
+
+
+def create_directories(paths):
+ """Directory creation"""
+ for path in paths:
+ os.makedirs(path, exist_ok=True)
+
+
+def get_data(dataset_type, overlap=False):
+ """Returns particular training, validation and evaluation dataset."""
+ dataset = LibriSpeechMfcc(args.data_dir)
+
+ return {"train": [dataset.training_set(batch_size=args.batch_size, overlap=overlap),
+ dataset.num_steps(batch=args.batch_size)],
+ "val": [dataset.validation_set(batch_size=args.batch_size, overlap=overlap),
+ dataset.num_steps(batch=args.batch_size)],
+ "eval": [dataset.evaluation_set(batch_size=1, overlap=overlap),
+ dataset.num_steps(batch=1)]
+ }[dataset_type]
+
+
+def setup_callbacks(checkpoint_path, log_dir):
+ """Returns callbacks for baseline training and optimization fine-tuning."""
+ callbacks = [
+ tf.keras.callbacks.TerminateOnNaN(),
+ tf.keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_path,
+ verbose=1,
+ save_weights_only=True,
+ period=1, # save every epoch
+ ),
+ tf.keras.callbacks.TensorBoard(
+ log_dir=log_dir,
+ histogram_freq=1, # update every epoch
+ update_freq=100, # update every 100 batch
+ ),
+ ]
+ return callbacks
+
+
+def get_lr_schedule(steps_per_epoch, learning_rate=1e-5, lr_schedule_config=[[1.0, 0.1, 0.01, 0.001]]):
+ """Returns learn rate schedule for baseline training and optimization fine-tuning."""
+ initial_learning_rate = learning_rate
+ lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
+ boundaries=list(p[1] * steps_per_epoch for p in lr_schedule_config),
+ values=[initial_learning_rate] + list(p[0] * initial_learning_rate for p in lr_schedule_config))
+ return lr_schedule
+
+
+def prune_model(model):
+ """Performs pruning, fine-tuning and returns stripped pruned model"""
+ log("Pruning model to {} sparsity".format(args.sparsity))
+ (training_data, training_num_steps) = get_data("train")
+ (validation_data, validation_num_steps) = get_data("val")
+ (evaluation_data, eval_num_steps) = get_data("eval")
+ log_dir = "logs/pruned" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ checkpoint_path = "checkpoints_pruned"
+ export_checkpoint_path = "checkpoints_export"
+
+ create_directories([log_dir, checkpoint_path, export_checkpoint_path])
+
+ callbacks = setup_callbacks(os.path.join(checkpoint_path, "pruned-{epoch:04d}.h5"), log_dir)
+
+ pruning_params = {
+ "pruning_schedule": tfmot.sparsity.keras.ConstantSparsity(
+ args.sparsity, begin_step=0, end_step=int(training_num_steps * 0.7), frequency=10
+ )
+ }
+
+ callbacks.append(tfmot.sparsity.keras.UpdatePruningStep())
+ pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
+
+ opt = tf.keras.optimizers.Adam(learning_rate=get_lr_schedule(steps_per_epoch=training_num_steps))
+ pruned_model.compile(
+ loss=get_metrics("loss"), metrics=[get_metrics("ler")], optimizer=opt)
+
+ pruned_model.fit(
+ training_data,
+ epochs=args.finetuning_epochs,
+ steps_per_epoch=training_num_steps,
+ verbose=1,
+ callbacks=callbacks,
+ validation_data=validation_data,
+ validation_steps=validation_num_steps,
+ )
+
+ log("Evaluating {}".format(model.name))
+ pruned_model.evaluate(x=evaluation_data, steps=eval_num_steps)
+
+ stripped_model = tfmot.sparsity.keras.strip_pruning(model)
+
+ stripped_model.save_weights(os.path.join(export_checkpoint_path, "pruned-{}.h5".format(str(args.finetuning_epochs))))
+
+ return stripped_model
+
+
+def prepare_model_for_inference(model):
+
+ layer_input = tf.keras.layers.Input((296, 39), batch_size=1)
+ static_shaped_model = tf.keras.models.Model(
+ inputs=[layer_input], outputs=[model.call(layer_input)]
+ )
+ return static_shaped_model
+
+
+def tflite_conversion(model, tflite_path, conversion_type="fp32"):
+ """Performs tflite conversion (fp32, int8)."""
+ # Prepare model for inference
+ model = prepare_model_for_inference(model)
+
+ create_directories([os.path.dirname(tflite_path)])
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+ def representative_dataset_gen(input_dim):
+ calib_data = []
+ for data in tqdm(training_data.take(1000), desc="model calibration"):
+ input_data = data[0]
+ for i in range(input_data.shape[1] // input_dim):
+ input_chunks = [
+ input_data[:, i * input_dim: (i + 1) * input_dim, :, ]
+ ]
+ for chunk in input_chunks:
+ calib_data.append([chunk])
+
+ return lambda: [
+ (yield data) for data in tqdm(calib_data, desc="model calibration")
+ ]
+
+ if conversion_type == "int8":
+ log("Quantizing Model")
+ (training_data, training_num_steps) = get_data("train", overlap=True)
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ converter.representative_dataset = representative_dataset_gen(model.input_shape[1])
+
+ tflite_model = converter.convert()
+ open(tflite_path, "wb").write(tflite_model)
+
+
+def evaluate_tflite(tflite_path):
+ """Evaluates tflite (fp32, int8)."""
+ results = []
+ (evaluation_data, eval_num_steps) = get_data("eval")
+ tflite_path = tflite_path
+
+ log("Setting number of used threads to {}".format(multiprocessing.cpu_count()))
+ interpreter = tf.lite.Interpreter(
+ model_path=tflite_path, num_threads=multiprocessing.cpu_count()
+ )
+ interpreter.allocate_tensors()
+ input_chunk = interpreter.get_input_details()[0]
+ output_details = interpreter.get_output_details()[0]
+
+ input_shape = input_chunk["shape"]
+ log("eval_model() - input_shape: {}".format(input_shape))
+ input_dtype = input_chunk["dtype"]
+ output_dtype = output_details["dtype"]
+
+ # Check if the input/output type is quantized,
+ # set scale and zero-point accordingly
+ if input_dtype != tf.float32:
+ input_scale, input_zero_point = input_chunk["quantization"]
+ else:
+ input_scale, input_zero_point = 1, 0
+
+ if output_dtype != tf.float32:
+ output_scale, output_zero_point = output_details["quantization"]
+ else:
+ output_scale, output_zero_point = 1, 0
+
+ log("Running {} iterations".format(eval_num_steps))
+ for i_iter, (data, label) in enumerate(
+ tqdm(evaluation_data, total=eval_num_steps)
+ ):
+ data = data / input_scale + input_zero_point
+ # Round the data up if dtype is int8, uint8 or int16
+ if input_dtype is not np.float32:
+ data = np.round(data)
+
+ while data.shape[1] < 296:
+ data = np.append(data, data[:, -2:-1, :], axis=1)
+ # Zero-pad any odd-length inputs
+ if data.shape[1] % 2 == 1:
+ log('Input length is odd, zero-padding to even (first layer has stride 2)')
+ data = np.concatenate([data, np.zeros((1, 1, data.shape[2]), dtype=input_dtype)], axis=1)
+
+ context = 24 + 2 * (7 * 3 + 16) # = 98 - theoretical max receptive field on each side
+ size = input_chunk['shape'][1]
+ inner = size - 2 * context
+ data_end = data.shape[1]
+
+ # Initialize variables for the sliding window loop
+ data_pos = 0
+ outputs = []
+
+ while data_pos < data_end:
+ if data_pos == 0:
+ # Align inputs from the first window to the start of the data and include the intial context in the output
+ start = data_pos
+ end = start + size
+ y_start = 0
+ y_end = y_start + (size - context) // 2
+ data_pos = end - context
+ elif data_pos + inner + context >= data_end:
+ # Shift left to align final window to the end of the data and include the final context in the output
+ shift = (data_pos + inner + context) - data_end
+ start = data_pos - context - shift
+ end = start + size
+ assert start >= 0
+ y_start = (shift + context) // 2 # Will be even because we assert it above
+ y_end = size // 2
+ data_pos = data_end
+ else:
+ # Capture only the inner region from mid-input inferences, excluding output from both context regions
+ start = data_pos - context
+ end = start + size
+ y_start = context // 2
+ y_end = y_start + inner // 2
+ data_pos = end - context
+
+ interpreter.set_tensor(
+ input_chunk["index"], tf.cast(data[:, start:end, :], input_dtype))
+ interpreter.invoke()
+ cur_output_data = interpreter.get_tensor(output_details["index"])[:, :, y_start:y_end, :]
+ cur_output_data = output_scale * (
+ cur_output_data.astype(np.float32) - output_zero_point
+ )
+ outputs.append(cur_output_data)
+
+ complete = np.concatenate(outputs, axis=2)
+ results.append(get_metrics("ler")(label, complete))
+
+ log("Avg LER: {}".format(np.mean(results) * 100))
+
+
+def main(args):
+ """Main execution function"""
+ # Model creation
+ model = create_wav2letter(
+ batch_size=args.batch_size, no_stride_count=args.no_stride_count
+ )
+ # load baseline wav2letter weights
+ model.load_weights("weights/wav2letter.h5")
+
+ # Prune model
+ pruned_model = prune_model(model)
+
+ # Pruned int8 wav2letter
+ output_directory = pathlib.Path(os.path.dirname(os.path.abspath(__file__))).parent.as_posix()
+ wav2letter_pruned_int8 = os.path.join(output_directory, "wav2letter_pruned_int8.tflite")
+ tflite_conversion(pruned_model, wav2letter_pruned_int8, "int8")
+ evaluate_tflite(wav2letter_pruned_int8)
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(allow_abbrev=False)
+ parser.add_argument(
+ "--batch_size",
+ dest="batch_size",
+ type=int,
+ required=False,
+ default=32,
+ help="batch size wanted when creating model",
+ )
+ parser.add_argument(
+ "--finetuning_epochs",
+ dest="finetuning_epochs",
+ type=int,
+ required=False,
+ default=1,
+ help="Amount of epochs for baseline training",
+ )
+ parser.add_argument(
+ "--no_stride_count",
+ dest="no_stride_count",
+ type=int,
+ required=False,
+ default=7,
+ help="Number of Convolution2D layers without striding",
+ )
+ parser.add_argument(
+ "--sparsity",
+ dest="sparsity",
+ type=float,
+ required=False,
+ default=0.5,
+ help="Level of sparsity required",
+ )
+ parser.add_argument(
+ "--data_dir",
+ dest="data_dir",
+ type=str,
+ required=True,
+ help="Path to dataset directory",
+ )
+ args = parser.parse_args()
+ main(args)
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/recreate_model.sh b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/recreate_model.sh
new file mode 100755
index 0000000..2fa61d9
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/recreate_model.sh
@@ -0,0 +1,32 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#!/usr/bin/env bash
+
+python3.6 -m venv python_env
+
+source python_env/bin/activate
+
+pip install --upgrade pip
+pip install -r requirements.txt
+
+# Download and build dataset
+if [ ! -d "${HOME}/DeepSpeech" ] ; then
+ git clone https://github.com/mozilla/DeepSpeech.git ${HOME}/DeepSpeech
+fi
+PYTHONPATH=${HOME}/DeepSpeech/training python ${HOME}/DeepSpeech/bin/import_librivox.py ${HOME}/librispeech
+
+python prune_quantize_model.py --data_dir ${HOME}/librispeech
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/requirements.txt b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/requirements.txt
new file mode 100644
index 0000000..abb0bdc
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/requirements.txt
@@ -0,0 +1,7 @@
+tensorflow==2.5
+pandas==1.0.5
+librosa
+tensorflow_model_optimization
+tqdm
+progressbar2
+sox==1.3.7
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/wav2letter.py b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/wav2letter.py
new file mode 100644
index 0000000..4d200ba
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/wav2letter.py
@@ -0,0 +1,123 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model definition for Wav2Letter."""
+import tensorflow as tf
+from tensorflow.python.ops import ctc_ops
+
+
+def get_metrics(metric):
+ """Get metrics needed to compile wav2letter."""
+ def ctc_preparation(tensor, y_predict):
+ if len(y_predict.shape) == 4:
+ y_predict = tf.squeeze(y_predict, axis=1)
+ y_predict = tf.transpose(y_predict, (1, 0, 2))
+ sequence_lengths, labels = tensor[:, 0], tensor[:, 1:]
+ idx = tf.where(tf.not_equal(labels, 28))
+ sparse_labels = tf.SparseTensor(
+ idx, tf.gather_nd(labels, idx), tf.shape(labels, out_type=tf.int64)
+ )
+ return sparse_labels, sequence_lengths, y_predict
+
+ def get_loss():
+ """Calculate CTC loss."""
+ def ctc_loss(y_true, y_predict):
+ sparse_labels, logit_length, y_predict = ctc_preparation(y_true, y_predict)
+ return tf.reduce_mean(
+ ctc_ops.ctc_loss_v2(
+ labels=sparse_labels,
+ logits=y_predict,
+ label_length=None,
+ logit_length=logit_length,
+ blank_index=-1,
+ )
+ )
+ return ctc_loss
+
+ def get_ler():
+ """Calculate CTC LER (Letter Error Rate)."""
+ def ctc_ler(y_true, y_predict):
+ sparse_labels, logit_length, y_predict = ctc_preparation(y_true, y_predict)
+ decoded, log_probabilities = tf.nn.ctc_greedy_decoder(
+ y_predict, tf.cast(logit_length, tf.int32), merge_repeated=True
+ )
+ return tf.reduce_mean(
+ tf.edit_distance(
+ tf.cast(decoded[0], tf.int32), tf.cast(sparse_labels, tf.int32)
+ )
+ )
+ return ctc_ler
+ return {"loss": get_loss(), "ler": get_ler()}[metric]
+
+
+def create_wav2letter(batch_size=1, no_stride_count=7) -> tf.keras.models.Model:
+ """Create and return Wav2Letter model"""
+ layer = tf.keras.layers
+ leaky_relu = layer.LeakyReLU([0.20000000298023224])
+ input = layer.Input(shape=[None, 39], batch_size=batch_size)
+ # Reshape to prepare input for first layer
+ x = layer.Reshape([1, -1, 39])(input)
+ # One striding layer of output size [batch_size, max_time / 2, 250]
+ x = layer.Conv2D(
+ filters=250,
+ kernel_size=[1, 48],
+ padding="same",
+ activation=None,
+ strides=[1, 2],
+ )(x)
+ # Add non-linearity
+ x = leaky_relu(x)
+ # layers without striding of output size [batch_size, max_time / 2, 250]
+ for i in range(0, no_stride_count):
+ x = layer.Conv2D(
+ filters=250,
+ kernel_size=[1, 7],
+ padding="same",
+ activation=None,
+ strides=[1, 1],
+ )(x)
+ # Add non-linearity
+ x = leaky_relu(x)
+ # 1 layer with high kernel width and output size [batch_size, max_time / 2, 2000]
+ x = layer.Conv2D(
+ filters=2000,
+ kernel_size=[1, 32],
+ padding="same",
+ activation=None,
+ strides=[1, 1],
+ )(x)
+ # Add non-linearity
+ x = leaky_relu(x)
+ # 1 layer of output size [batch_size, max_time / 2, 2000]
+ x = layer.Conv2D(
+ filters=2000,
+ kernel_size=[1, 1],
+ padding="same",
+ activation=None,
+ strides=[1, 1],
+ )(x)
+ # Add non-linearity
+ x = leaky_relu(x)
+ # 1 layer of output size [batch_size, max_time / 2, num_classes]
+ # We must not apply a non linearity in this last layer
+ x = layer.Conv2D(
+ filters=29,
+ kernel_size=[1, 1],
+ padding="same",
+ activation=None,
+ strides=[1, 1],
+ )(x)
+ return tf.keras.models.Model(inputs=[input], outputs=[x])
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/weights/wav2letter.h5 b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/weights/wav2letter.h5
new file mode 100644
index 0000000..c54cb4a
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/recreate_model/weights/wav2letter.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2cfc6ed8c9a1b97b2b9c52c08e0ffad5eb811d0548216276d1d45c169e81262
+size 94420016
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_2_int8/0.npy b/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_2_int8/0.npy
deleted file mode 100644
index 529966a..0000000
--- a/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_2_int8/0.npy
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:84e23e7199b96820ed3f69c08a1832aa353976506ae0ab3333c975eb916e84ad
-size 11672
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_4/0.npy b/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_4/0.npy
new file mode 100644
index 0000000..352f85a
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_4/0.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:61ad67cf696943bf6779d520979a657a785fed2322863ba3f693f616dc49917d
+size 11672
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity/0.npy b/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity/0.npy
new file mode 100644
index 0000000..a030327
--- /dev/null
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity/0.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42a1ed1500282b9cc1ebb88f0658af5a5e3e6189985c97c2ad7521ed9789a458
+size 4420
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity_int8/0.npy b/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity_int8/0.npy
deleted file mode 100644
index 9f4d9e4..0000000
--- a/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity_int8/0.npy
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:f813c6c406b355a10bb51974c660622ea2dfec0b5e6993c6ff02f30889ae1a9e
-size 4420
diff --git a/models/speech_recognition/wav2letter/tflite_pruned_int8/wav2letter_pruned_int8.tflite b/models/speech_recognition/wav2letter/tflite_pruned_int8/wav2letter_pruned_int8.tflite
index 0f045ac..0821808 100644
--- a/models/speech_recognition/wav2letter/tflite_pruned_int8/wav2letter_pruned_int8.tflite
+++ b/models/speech_recognition/wav2letter/tflite_pruned_int8/wav2letter_pruned_int8.tflite
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:e0814a0e586f2f2881f7fe96e6bc0b8e4cfecc8d64ced56d8f7ca85b9b7ab257
-size 23815520
+oid sha256:1acd36de9271fbb1ca5b1c3c5c36379da8a77563923a4d6dee77a7dac6ffc2a8
+size 23766192