-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use time-domain stretching in KWS20 data loader; add higher performan…
…ce KWS20 v3 network (#53) Use PyTSMod for time-scale stretching for augmentation of the KWS dataset. This replaces the librosa.effects-based function which brings an undesired echo effect due to phase mismatch in the Fourier domain.
- Loading branch information
1 parent
c7a1a4e
commit ac57996
Showing
6 changed files
with
122 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
################################################################################################### | ||
# | ||
# Copyright (C) 2020 Maxim Integrated Products, Inc. All Rights Reserved. | ||
# | ||
# Maxim Integrated Products, Inc. Default Copyright Notice: | ||
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html | ||
# | ||
################################################################################################### | ||
""" | ||
Keyword spotting network for AI85/AI86 | ||
""" | ||
import torch.nn as nn | ||
|
||
import ai8x | ||
|
||
|
||
class AI85KWS20Netv3(nn.Module): | ||
""" | ||
Compound KWS20 v3 Audio net, all with Conv1Ds | ||
""" | ||
|
||
# num_classes = n keywords + 1 unknown | ||
def __init__( | ||
self, | ||
num_classes=21, | ||
num_channels=128, | ||
dimensions=(128, 1), # pylint: disable=unused-argument | ||
bias=False, | ||
**kwargs | ||
|
||
): | ||
super().__init__() | ||
self.drop = nn.Dropout(p=0.2) | ||
# Time: 128 Feature :128 | ||
self.voice_conv1 = ai8x.FusedConv1dReLU(num_channels, 100, 1, stride=1, padding=0, | ||
bias=bias, **kwargs) | ||
# T: 128 F: 100 | ||
self.voice_conv2 = ai8x.FusedConv1dReLU(100, 96, 3, stride=1, padding=0, | ||
bias=bias, **kwargs) | ||
# T: 126 F : 96 | ||
self.voice_conv3 = ai8x.FusedMaxPoolConv1dReLU(96, 64, 3, stride=1, padding=1, | ||
bias=bias, **kwargs) | ||
# T: 62 F : 64 | ||
self.voice_conv4 = ai8x.FusedConv1dReLU(64, 48, 3, stride=1, padding=0, | ||
bias=bias, **kwargs) | ||
# T : 60 F : 48 | ||
self.kws_conv1 = ai8x.FusedMaxPoolConv1dReLU(48, 64, 3, stride=1, padding=1, | ||
bias=bias, **kwargs) | ||
# T: 30 F : 64 | ||
self.kws_conv2 = ai8x.FusedConv1dReLU(64, 96, 3, stride=1, padding=0, | ||
bias=bias, **kwargs) | ||
# T: 28 F : 96 | ||
self.kws_conv3 = ai8x.FusedAvgPoolConv1dReLU(96, 100, 3, stride=1, padding=1, | ||
bias=bias, **kwargs) | ||
# T : 14 F: 100 | ||
self.kws_conv4 = ai8x.FusedMaxPoolConv1dReLU(100, 64, 6, stride=1, padding=1, | ||
bias=bias, **kwargs) | ||
# T : 2 F: 128 | ||
self.fc = ai8x.Linear(256, num_classes, bias=bias, wide=True, **kwargs) | ||
|
||
def forward(self, x): # pylint: disable=arguments-differ | ||
"""Forward prop""" | ||
# Run CNN | ||
x = self.voice_conv1(x) | ||
x = self.voice_conv2(x) | ||
x = self.drop(x) | ||
x = self.voice_conv3(x) | ||
x = self.voice_conv4(x) | ||
x = self.drop(x) | ||
x = self.kws_conv1(x) | ||
x = self.kws_conv2(x) | ||
x = self.drop(x) | ||
x = self.kws_conv3(x) | ||
x = self.kws_conv4(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
|
||
return x | ||
|
||
|
||
def ai85kws20netv3(pretrained=False, **kwargs): | ||
""" | ||
Constructs a AI85KWS20Net model. | ||
rn AI85KWS20Net(**kwargs) | ||
""" | ||
assert not pretrained | ||
return AI85KWS20Netv3(**kwargs) | ||
|
||
|
||
models = [ | ||
{ | ||
'name': 'ai85kws20netv3', | ||
'min_input': 1, | ||
'dim': 1, | ||
}, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,5 @@ torchvision==0.8.2 | |
tensorboard==2.4.0 | ||
numba<0.50.0 | ||
opencv-python>=4.4.0 | ||
pytsmod>=0.3.3 | ||
-e distiller |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/sh | ||
./train.py --model ai85kws20netv3 --dataset KWS_20 --confusion --evaluate --exp-load-weights-from ../ai8x-synthesis/trained/ai85-kws20_v3-qat8-q.pth.tar -8 --device MAX78000 "$@" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/sh | ||
./train.py --epochs 200 --optimizer Adam --lr 0.001 --deterministic --compress schedule_kws20.yaml --model ai85kws20netv3 --dataset KWS_20 --confusion --device MAX78000 "$@" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters