-
Notifications
You must be signed in to change notification settings - Fork 0
/
classify-train
executable file
·140 lines (125 loc) · 6.5 KB
/
classify-train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/python
# Copyright 2014 Douglas Bagnall <[email protected]> LGPL
import os, sys
import random
import argparse
DEFAULT_LEARN_RATE = 3e-5
DEFAULT_LEARN_RATE_DECAY = 0
DEFAULT_LEARN_RATE_MIN = 1e-8
DEFAULT_MOMENTUM = 0.93
DEFAULT_LEARNING_STYLE = 0
DEFAULT_MOMENTUM_SOFT_START = 5000
DEFAULT_LEARN_RATE_TIME_OFFSET = 2000.0
from classify import gst_init
from classify import Trainer, lr_sqrt_exp, lr_inverse_time
from classify import add_args_from_classifier, range_arg
from classify import add_common_args, process_common_args
def main():
gst_init()
parser = argparse.ArgumentParser()
prop_names = add_common_args(parser)
group = parser.add_argument_group('classify-train specific arguments')
add_args_from_classifier(group, (['-H', '--hidden-size'],
['-B', '--bottom-layer'],
['--pgm-dump'],
['-l', '--learn-rate'],
['--learning-style'],
['-m', '--momentum'],
['--momentum-soft-start'],
['--momentum-weight'],
['--top-learn-rate-scale'],
['--bottom-learn-rate-scale'],
['-r', '--random-alignment'],
['-E', '--error-weight'],
['--bptt-depth'],
['--mfccs'],
['--weight-noise'],
['--weight-init-scale'],
['--presynaptic-noise'],
['--activation'],
['--delta-features'],
['--intensity-feature'],
['-w', '--window-size'],
['--focus-frequency'],
['--min-frequency'],
['--max-frequency'],
['--knee-frequency'],
['--lag'],
['--balanced-training'],
))
group.add_argument('--learn-rate-decay', type=range_arg(0, 1),
default=DEFAULT_LEARN_RATE_DECAY,
help="learning rate decay")
group.add_argument('--learn-rate-time-offset', type=float,
help="learning rate time offset (implies inverse-time schedule)")
group.add_argument('--learn-rate-schedule', default='sqrt-exponential',
help='"sqrt-exponential", "inverse-time", or "flat"')
group.add_argument('--learn-rate-min', type=range_arg(0, 10),
default=DEFAULT_LEARN_RATE_MIN,
help="learning rate decay stops here")
group.add_argument('-N', '--no-save-net', action='store_true',
help="don't save the net, periodically or otherwise")
group.add_argument('-C', '--channels', default=12, type=int,
help="how many channels to use")
group.add_argument('--activity-bias', type=int, default=0,
help="Train more on examples with changing classes")
group.add_argument('-P', '--prioritise', action='store_true',
help="Do not renice downwards")
group.add_argument('--test-interval', type=int, default=0,
help="Test after this many training cycles")
group.add_argument('--log-file', default="auto",
help="log to this file (default: based on net basename)")
group.add_argument('--random-seed', default=1,
help="use this random seed")
advanced_group = parser.add_argument_group('advanced classify-train specific '
'arguments')
add_args_from_classifier(advanced_group, (['--weight-init-method'],
['--weight-fan-in-sum'],
['--weight-fan-in-kurtosis'],
['--lawn-mower'],
['--confirmation-lag'],
['--adagrad-ballast']))
args = parser.parse_args()
if args.learn_rate_schedule == "flat":
lr = lr_inverse_time(args.learn_rate, args.learn_rate)
elif args.learn_rate_schedule == "inverse-time" or args.learn_rate_time_offset:
lr = lr_inverse_time(args.learn_rate,
min(args.learn_rate_min, args.learn_rate),
offset=(args.learn_rate_time_offset or
DEFAULT_LEARN_RATE_TIME_OFFSET))
elif args.learn_rate_decay == 0.0:
lr = lr_inverse_time(args.learn_rate, args.learn_rate)
else:
if args.learn_rate_schedule != "sqrt-exponential":
print >> sys.stderr, "assuming sqrt-exponential schedule"
lr = lr_sqrt_exp(args.learn_rate, args.learn_rate_decay,
min(args.learn_rate_min, args.learn_rate))
if not args.prioritise and hasattr(os, 'nice'):
os.nice(10)
n_channels = args.channels
c = Trainer(channels=n_channels, filetype=args.filetype)
c.no_save_net = args.no_save_net
c.maybe_setp('random-alignment', args.random_alignment)
if args.bptt_depth:
c.maybe_setp('bptt-depth', args.bptt_depth)
if args.test_interval > 0:
c.test_interval = args.test_interval
c.maybe_setp('weight-noise', args.weight_noise)
c.maybe_setp('weight-init-scale', args.weight_init_scale)
timed_files = process_common_args(c, args, prop_names, random_seed=args.random_seed)
validate_streams = [timed_files[:n_channels]]
training_streams = [timed_files[n_channels:]]
random.shuffle(training_streams[0])
#--activity_bias=n: add n streams of tracks that actually have class changes,
# but drop one file on each round to ensure they cycle out of sync
for i in range(args.activity_bias):
s = [x for x in training_streams[0] if len(x.timings) > 1]
if len(s) == 0: # < n_channels?
break
random.shuffle(s)
training_streams.append(s[i:])
c.train(training_streams, validate_streams,
iterations=args.iterations,
learn_rate_fn=lr,
log_file=args.log_file)
main()