forked from SeanNaren/deepspeech.torch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Train.lua
68 lines (59 loc) · 2.61 KB
/
Train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
--[[Trains the CTC model using the AN4 audio database.]]
local Network = require 'Network'
--Training parameters
torch.setdefaulttensortype('torch.FloatTensor')
seed = 10
torch.manualSeed(seed)
cutorch.manualSeedAll(seed)
cmd = torch.CmdLine()
cmd:text()
cmd:text('Train a Deep Speech Model')
cmd:text()
cmd:text('Options')
-- generals
cmd:option('-nGPU',4,'number of GPU to use.')
cmd:option('backend', 'cudnn', 'use nn or cudnn')
cmd:option('saveModel', true, 'save model?')
cmd:option('loadModel', false, 'load model?')
-- paths: data and logs
cmd:option('-trainingSetLMDBPath', '/data1/zhirongw/LibriSpeech/train/', 'training lmdb')
cmd:option('-validationSetLMDBPath', '/data1/zhirongw/LibriSpeech/test/', 'validation lmdb')
cmd:option('-logsTrainPath', './logs/TrainingLoss/', 'loss log directory')
cmd:option('-logsValidationPath', './logs/ValidationScores/', 'scores log directory')
cmd:option('-modelTrainingPath', './models/', 'model snapshot directory')
cmd:option('-fileName', 'CTCNetwork.t7', 'model snapshot filename')
cmd:option('-dictionaryPath', './dictionary', 'dictionary path')
-- Model
cmd:option('-feature', 'spect', 'input feature of the sound wave')
cmd:option('-dataHeight', 161, 'feature dimension')
cmd:option('-dictSize', 29, 'language dictionary size')
cmd:option('-modelName', 'DeepSpeechModelSpect', 'model architecture')
cmd:option('-hidden_size', 400, 'hidden memory size')
cmd:option('-num_layers', 5, 'number of rnn layers')
cmd:option('-rnn_type', 'RNN_RELU', 'RNN_RELU, GRU, or LSTM?')
-- configs
cmd:option('-batchSize', 200, 'training batch size')
cmd:option('-epochs', 70, 'training epochs')
cmd:option('-validationBatchSize', 24, 'testing batch size')
cmd:option('-saveModelIterations', 10, 'every several epochs for snapshot')
-- optim
cmd:option('-optim','sgd','optimization algorithm')
cmd:option('-learning_rate',1e-1,'learning rate')
cmd:option('-learning_rate_decay',1e-9,'learning rate decay')
cmd:option('-learning_rate_decay_after',20,'in number of epochs, when to start decaying the learning rate')
cmd:option('-learning_rate_decay_every',20,'decrease learning rate every')
cmd:option('-beta1',0.8,'beta1 for adam')
cmd:option('-beta2',0.95,'beta2 for adam')
cmd:option('-alpha',0.8,'alpha for rmsprop')
cmd:option('-weight_decay',0,'weight decay')
cmd:option('-decay_rate',0.95,'decay rate for rmsprop')
cmd:option('-grad_clip',1,'clip gradients at this value') -- not used
cmd:text()
local opts = cmd:parse(arg)
print (opts)
--Create and train the network based on the parameters and training data.
Network:init(opts)
Network:trainNetwork()
--Creates the loss plot.
Network:createLossGraph()
print("finished")