-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathoptim.py
74 lines (69 loc) · 2.85 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import mxnet as mx
def get_optimizer_params(optimizer=None, learning_rate=None, momentum=None,
weight_decay=None, lr_scheduler=None, ctx=None, logger=None):
if optimizer.lower() == 'rmsprop':
opt = 'rmsprop'
logger.info('you chose RMSProp, decreasing lr by a factor of 10')
optimizer_params = {'learning_rate': learning_rate / 10.0,
'wd': weight_decay,
'lr_scheduler': lr_scheduler,
'clip_gradient': None,
'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0}
elif optimizer.lower() == 'sgd':
opt = 'sgd'
optimizer_params = {'learning_rate': learning_rate,
'momentum': momentum,
'wd': weight_decay,
'lr_scheduler': lr_scheduler,
'clip_gradient': None,
'rescale_grad': 1.0}
elif optimizer.lower() == 'adadelta':
opt = 'adadelta'
optimizer_params = {}
elif optimizer.lower() == 'adam':
opt = 'adam'
optimizer_params = {'learning_rate': learning_rate,
'lr_scheduler': lr_scheduler,
'clip_gradient': None,
'rescale_grad': 1.0}
return opt, optimizer_params
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
num_example, batch_size, begin_epoch):
"""
Compute learning rate and refactor scheduler
Parameters:
---------
learning_rate : float
original learning rate
lr_refactor_step : comma separated str
epochs to change learning rate
lr_refactor_ratio : float
lr *= ratio at certain steps
num_example : int
number of training images, used to estimate the iterations given epochs
batch_size : int
training batch size
begin_epoch : int
starting epoch
Returns:
---------
(learning_rate, mx.lr_scheduler) as tuple
"""
assert lr_refactor_ratio > 0
iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
if lr_refactor_ratio >= 1:
return (learning_rate, None)
else:
lr = learning_rate
epoch_size = num_example // batch_size
for s in iter_refactor:
if begin_epoch >= s:
lr *= lr_refactor_ratio
if lr != learning_rate:
pass
# logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
if not steps:
return (lr, None)
lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)
return (lr, lr_scheduler)