Skip to content

Commit

Permalink
Use time-domain stretching in KWS20 data loader; add higher performan…
Browse files Browse the repository at this point in the history
…ce KWS20 v3 network (#53)

Use PyTSMod for time-scale stretching for augmentation of the KWS dataset. This replaces the librosa.effects-based function which brings an undesired echo effect due to phase mismatch in the Fourier domain.
  • Loading branch information
batuhan-gundogdu authored Dec 30, 2020
1 parent c7a1a4e commit ac57996
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 5 deletions.
23 changes: 18 additions & 5 deletions datasets/kws20.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torchvision import transforms

import librosa
import pytsmod as tsm
from six.moves import urllib

import ai8x
Expand Down Expand Up @@ -118,7 +119,7 @@ def __parse_quantization(self, quantization_scheme):
if 'compand' not in self.quantization:
self.quantization['compand'] = False
elif 'mu' not in self.quantization:
self.quantization['mu'] = 10
self.quantization['mu'] = 255
else:
print('No define quantization schema!, ',
'Number of bits set to 8.')
Expand Down Expand Up @@ -338,7 +339,7 @@ def augment(self, audio, fs, verbose=False):
random_strech_coeff = np.random.uniform(self.augmentation['strech']['min'],
self.augmentation['strech']['max'])

aug_audio = self.stretch(audio, random_strech_coeff)
aug_audio = tsm.wsola(audio, random_strech_coeff)
aug_audio = self.shift(aug_audio, random_shift_time, fs)
aug_audio = self.add_white_noise(aug_audio, random_noise_var_coeff)
if verbose:
Expand All @@ -361,8 +362,14 @@ def compand(data, mu=255):
return data

@staticmethod
def quantize_audio(data, num_bits=8, compand=False, mu=10):
"""quantize audio
def expand(data, mu=255):
"""Undo the companding"""
data = np.sign(data) * (1 / mu) * (np.power((1 + mu), np.abs(data)) - 1)
return data

@staticmethod
def quantize_audio(data, num_bits=8, compand=False, mu=255):
"""Quantize audio
"""
if compand:
data = KWS.compand(data, mu)
Expand All @@ -371,6 +378,12 @@ def quantize_audio(data, num_bits=8, compand=False, mu=10):
max_val = 2**(num_bits) - 1
q_data = np.round((data - (-1.0)) / step_size)
q_data = np.clip(q_data, 0, max_val)

if compand:
data_ex = (q_data - 2**(num_bits - 1)) / 2**(num_bits - 1)
data_ex = KWS.expand(data_ex)
q_data = np.round((data_ex - (-1.0)) / step_size)
q_data = np.clip(q_data, 0, max_val)
return np.uint8(q_data)

def __gen_datasets(self, exp_len=16384, row_len=128, overlap_ratio=0):
Expand Down Expand Up @@ -543,7 +556,7 @@ def KWS_20_get_datasets(data, load_train=True, load_test=True):
The dataset is loaded from the archive file, so the file is required for this version.
The dataset originally includes 30 keywords. A dataset is formed with 21 classes which includes
The dataset originally includes 35 keywords. A dataset is formed with 21 classes which includes
20 of the original keywords and the rest of the dataset is used to form the last class, i.e.,
class of the others.
The dataset is split into training+validation and test sets. 90:10 training+validation:test
Expand Down
96 changes: 96 additions & 0 deletions models/ai85net-kws20-v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
###################################################################################################
#
# Copyright (C) 2020 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
"""
Keyword spotting network for AI85/AI86
"""
import torch.nn as nn

import ai8x


class AI85KWS20Netv3(nn.Module):
"""
Compound KWS20 v3 Audio net, all with Conv1Ds
"""

# num_classes = n keywords + 1 unknown
def __init__(
self,
num_classes=21,
num_channels=128,
dimensions=(128, 1), # pylint: disable=unused-argument
bias=False,
**kwargs

):
super().__init__()
self.drop = nn.Dropout(p=0.2)
# Time: 128 Feature :128
self.voice_conv1 = ai8x.FusedConv1dReLU(num_channels, 100, 1, stride=1, padding=0,
bias=bias, **kwargs)
# T: 128 F: 100
self.voice_conv2 = ai8x.FusedConv1dReLU(100, 96, 3, stride=1, padding=0,
bias=bias, **kwargs)
# T: 126 F : 96
self.voice_conv3 = ai8x.FusedMaxPoolConv1dReLU(96, 64, 3, stride=1, padding=1,
bias=bias, **kwargs)
# T: 62 F : 64
self.voice_conv4 = ai8x.FusedConv1dReLU(64, 48, 3, stride=1, padding=0,
bias=bias, **kwargs)
# T : 60 F : 48
self.kws_conv1 = ai8x.FusedMaxPoolConv1dReLU(48, 64, 3, stride=1, padding=1,
bias=bias, **kwargs)
# T: 30 F : 64
self.kws_conv2 = ai8x.FusedConv1dReLU(64, 96, 3, stride=1, padding=0,
bias=bias, **kwargs)
# T: 28 F : 96
self.kws_conv3 = ai8x.FusedAvgPoolConv1dReLU(96, 100, 3, stride=1, padding=1,
bias=bias, **kwargs)
# T : 14 F: 100
self.kws_conv4 = ai8x.FusedMaxPoolConv1dReLU(100, 64, 6, stride=1, padding=1,
bias=bias, **kwargs)
# T : 2 F: 128
self.fc = ai8x.Linear(256, num_classes, bias=bias, wide=True, **kwargs)

def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
# Run CNN
x = self.voice_conv1(x)
x = self.voice_conv2(x)
x = self.drop(x)
x = self.voice_conv3(x)
x = self.voice_conv4(x)
x = self.drop(x)
x = self.kws_conv1(x)
x = self.kws_conv2(x)
x = self.drop(x)
x = self.kws_conv3(x)
x = self.kws_conv4(x)
x = x.view(x.size(0), -1)
x = self.fc(x)

return x


def ai85kws20netv3(pretrained=False, **kwargs):
"""
Constructs a AI85KWS20Net model.
rn AI85KWS20Net(**kwargs)
"""
assert not pretrained
return AI85KWS20Netv3(**kwargs)


models = [
{
'name': 'ai85kws20netv3',
'min_input': 1,
'dim': 1,
},
]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ torchvision==0.8.2
tensorboard==2.4.0
numba<0.50.0
opencv-python>=4.4.0
pytsmod>=0.3.3
-e distiller
2 changes: 2 additions & 0 deletions scripts/evaluate_kws20_v3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/sh
./train.py --model ai85kws20netv3 --dataset KWS_20 --confusion --evaluate --exp-load-weights-from ../ai8x-synthesis/trained/ai85-kws20_v3-qat8-q.pth.tar -8 --device MAX78000 "$@"
2 changes: 2 additions & 0 deletions scripts/train_kws20_v3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/sh
./train.py --epochs 200 --optimizer Adam --lr 0.001 --deterministic --compress schedule_kws20.yaml --model ai85kws20netv3 --dataset KWS_20 --confusion --device MAX78000 "$@"
3 changes: 3 additions & 0 deletions train_all_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ echo "-----------------------------"
echo "Training kws20_v2 model"
scripts/train_kws20_v2.sh
echo "-----------------------------"
echo "Training kws20_v3 model"
scripts/train_kws20_v3.sh
echo "-----------------------------"
echo "Training faceid model"
scripts/train_faceid.sh

0 comments on commit ac57996

Please sign in to comment.