forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiment_nfnets.py
138 lines (123 loc) · 5.06 KB
/
experiment_nfnets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""ImageNet experiment with NFNets."""
import sys
from absl import flags
import haiku as hk
from jaxline import platform
from ml_collections import config_dict
from nfnets import experiment
from nfnets import optim
FLAGS = flags.FLAGS
def get_config():
"""Return config object for training."""
config = experiment.get_config()
# Experiment config.
train_batch_size = 4096 # Global batch size.
images_per_epoch = 1281167
num_epochs = 360
steps_per_epoch = images_per_epoch / train_batch_size
config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
config.random_seed = 0
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
lr=0.1,
num_epochs=num_epochs,
label_smoothing=0.1,
model='NFNet',
image_size=224,
use_ema=True,
ema_decay=0.99999,
ema_start=0,
augment_name=None,
augment_before_mix=False,
eval_preproc='resize_crop_32',
train_batch_size=train_batch_size,
eval_batch_size=50,
eval_subset='test',
num_classes=1000,
which_dataset='imagenet',
which_loss='softmax_cross_entropy', # One of softmax or sigmoid
bfloat16=True,
lr_schedule=dict(
name='WarmupCosineDecay',
kwargs=dict(num_steps=config.training_steps,
start_val=0,
min_val=0.0,
warmup_steps=5*steps_per_epoch),
),
lr_scale_by_bs=True,
optimizer=dict(
name='SGD_AGC',
kwargs={'momentum': 0.9, 'nesterov': True,
'weight_decay': 2e-5,
'clipping': 0.01, 'eps': 1e-3},
),
model_kwargs=dict(
variant='F0',
width=1.0,
se_ratio=0.5,
alpha=0.2,
stochdepth_rate=0.25,
drop_rate=None, # Use native drop-rate
activation='gelu',
final_conv_mult=2,
final_conv_ch=None,
use_two_convs=True,
),
)))
# Unlike NF-RegNets, use the same weight decay for all, but vary RA levels
variant = config.experiment_kwargs.config.model_kwargs.variant
# RandAugment levels (e.g. 405 = 4 layers, magnitude 5, 205 = 2 layers, mag 5)
augment = {'F0': '405', 'F1': '410', 'F2': '410', 'F3': '415',
'F4': '415', 'F5': '415', 'F6': '415', 'F7': '415'}[variant]
aug_base_name = 'cutmix_mixup_randaugment'
config.experiment_kwargs.config.augment_name = f'{aug_base_name}_{augment}'
return config
class Experiment(experiment.Experiment):
"""Experiment with correct parameter filtering for applying AGC."""
def _make_opt(self):
# Separate conv params and gains/biases
def pred_gb(mod, name, val):
del mod, val
return (name in ['scale', 'offset', 'b']
or 'gain' in name or 'bias' in name)
gains_biases, weights = hk.data_structures.partition(pred_gb, self._params)
def pred_fc(mod, name, val):
del name, val
return 'linear' in mod and 'squeeze_excite' not in mod
fc_weights, weights = hk.data_structures.partition(pred_fc, weights)
# Lr schedule with batch-based LR scaling
if self.config.lr_scale_by_bs:
max_lr = (self.config.lr * self.config.train_batch_size) / 256
else:
max_lr = self.config.lr
lr_sched_fn = getattr(optim, self.config.lr_schedule.name)
lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs)
# Optimizer; no need to broadcast!
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
opt_kwargs['lr'] = lr_schedule
opt_module = getattr(optim, self.config.optimizer.name)
self.opt = opt_module([{'params': gains_biases, 'weight_decay': None,},
{'params': fc_weights, 'clipping': None},
{'params': weights}], **opt_kwargs)
if self._opt_state is None:
self._opt_state = self.opt.states()
else:
self.opt.plugin(self._opt_state)
if __name__ == '__main__':
flags.mark_flag_as_required('config')
platform.main(Experiment, sys.argv[1:])