forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiment.py
389 lines (356 loc) · 14.7 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Basic Jaxline ImageNet experiment."""
import importlib
import sys
from absl import flags
from absl import logging
import haiku as hk
import jax
import jax.numpy as jnp
from jaxline import base_config
from jaxline import experiment
from jaxline import platform
from jaxline import utils as jl_utils
from ml_collections import config_dict
import numpy as np
from nfnets import dataset
from nfnets import optim
from nfnets import utils
# pylint: disable=logging-format-interpolation
FLAGS = flags.FLAGS
# We define the experiment launch config in the same file as the experiment to
# keep things self-contained in a single file, but one might consider moving the
# config and/or sweep functions to a separate file, if necessary.
def get_config():
"""Return config object for training."""
config = base_config.get_base_config()
# Experiment config.
train_batch_size = 1024 # Global batch size.
images_per_epoch = 1281167
num_epochs = 90
steps_per_epoch = images_per_epoch / train_batch_size
config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
config.random_seed = 0
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
lr=0.1,
num_epochs=num_epochs,
label_smoothing=0.1,
model='ResNet',
image_size=224,
use_ema=False,
ema_decay=0.9999, # Quatros nuevos amigos
ema_start=0,
which_ema='tf1_ema',
augment_name=None, # 'mixup_cutmix',
augment_before_mix=True,
eval_preproc='crop_resize',
train_batch_size=train_batch_size,
eval_batch_size=50,
eval_subset='test',
num_classes=1000,
which_dataset='imagenet',
fake_data=False,
which_loss='softmax_cross_entropy', # For now, must be softmax
transpose=True, # Use the double-transpose trick?
bfloat16=False,
lr_schedule=dict(
name='WarmupCosineDecay',
kwargs=dict(
num_steps=config.training_steps,
start_val=0,
min_val=0,
warmup_steps=5 * steps_per_epoch),
),
lr_scale_by_bs=True,
optimizer=dict(
name='SGD',
kwargs={
'momentum': 0.9,
'nesterov': True,
'weight_decay': 1e-4,
},
),
model_kwargs=dict(
width=4,
which_norm='BatchNorm',
norm_kwargs=dict(
create_scale=True,
create_offset=True,
decay_rate=0.9,
), # cross_replica_axis='i'),
variant='ResNet50',
activation='relu',
drop_rate=0.0,
),
),))
# Training loop config: log and checkpoint every minute
config.log_train_data_interval = 60
config.log_tensors_interval = 60
config.save_checkpoint_interval = 60
config.eval_specific_checkpoint_dir = ''
return config
class Experiment(experiment.AbstractExperiment):
"""Imagenet experiment."""
CHECKPOINT_ATTRS = {
'_params': 'params',
'_state': 'state',
'_ema_params': 'ema_params',
'_ema_state': 'ema_state',
'_opt_state': 'opt_state',
}
def __init__(self, mode, config, init_rng):
super().__init__(mode=mode)
self.mode = mode
self.config = config
self.init_rng = init_rng
# Checkpointed experiment state.
self._params = None
self._state = None
self._ema_params = None
self._ema_state = None
self._opt_state = None
# Input pipelines.
self._train_input = None
self._eval_input = None
# Get model, loaded in from the zoo
module_prefix = 'nfnets.'
self.model_module = importlib.import_module(
(module_prefix + self.config.model.lower()))
self.net = hk.transform_with_state(self._forward_fn)
# Assign image sizes
if self.config.get('override_imsize', False):
self.train_imsize = self.config.image_size
self.test_imsize = self.config.get('eval_image_size', self.train_imsize)
else:
variant_dict = getattr(self.model_module, self.config.model).variant_dict
variant_dict = variant_dict[self.config.model_kwargs.variant]
self.train_imsize = variant_dict.get('train_imsize',
self.config.image_size)
# Test imsize defaults to model-specific value, then to config imsize
test_imsize = self.config.get('eval_image_size', self.config.image_size)
self.test_imsize = variant_dict.get('test_imsize', test_imsize)
donate_argnums = (0, 1, 2, 6, 7) if self.config.use_ema else (0, 1, 2)
self.train_fn = jax.pmap(
self._train_fn, axis_name='i', donate_argnums=donate_argnums)
self.eval_fn = jax.pmap(self._eval_fn, axis_name='i')
def _initialize_train(self):
self._train_input = self._build_train_input()
# Initialize net and EMA copy of net if no params available.
if self._params is None:
inputs = next(self._train_input)
init_net = jax.pmap(
lambda *a: self.net.init(*a, is_training=True), axis_name='i')
init_rng = jl_utils.bcast_local_devices(self.init_rng)
self._params, self._state = init_net(init_rng, inputs)
if self.config.use_ema:
self._ema_params, self._ema_state = init_net(init_rng, inputs)
num_params = hk.data_structures.tree_size(self._params)
logging.info(f'Net parameters: {num_params / jax.local_device_count()}')
self._make_opt()
def _make_opt(self):
# Separate conv params and gains/biases
def pred(mod, name, val): # pylint:disable=unused-argument
return (name in ['scale', 'offset', 'b'] or 'gain' in name or
'bias' in name)
gains_biases, weights = hk.data_structures.partition(pred, self._params)
# Lr schedule with batch-based LR scaling
if self.config.lr_scale_by_bs:
max_lr = (self.config.lr * self.config.train_batch_size) / 256
else:
max_lr = self.config.lr
lr_sched_fn = getattr(optim, self.config.lr_schedule.name)
lr_schedule = lr_sched_fn(max_val=max_lr, **self.config.lr_schedule.kwargs)
# Optimizer; no need to broadcast!
opt_kwargs = {key: val for key, val in self.config.optimizer.kwargs.items()}
opt_kwargs['lr'] = lr_schedule
opt_module = getattr(optim, self.config.optimizer.name)
self.opt = opt_module([{
'params': gains_biases,
'weight_decay': None
}, {
'params': weights
}], **opt_kwargs)
if self._opt_state is None:
self._opt_state = self.opt.states()
else:
self.opt.plugin(self._opt_state)
def _forward_fn(self, inputs, is_training):
net_kwargs = {
'num_classes': self.config.num_classes,
**self.config.model_kwargs
}
net = getattr(self.model_module, self.config.model)(**net_kwargs)
if self.config.get('transpose', False):
images = jnp.transpose(inputs['images'], (3, 0, 1, 2)) # HWCN -> NHWC
else:
images = inputs['images']
if self.config.bfloat16 and self.mode == 'train':
images = utils.to_bf16(images)
return net(images, is_training=is_training)['logits']
def _one_hot(self, value):
"""One-hot encoding potentially over a sequence of labels."""
y = jax.nn.one_hot(value, self.config.num_classes)
return y
def _loss_fn(self, params, state, inputs, rng):
logits, state = self.net.apply(params, state, rng, inputs, is_training=True)
y = self._one_hot(inputs['labels'])
if 'mix_labels' in inputs: # Handle cutmix/mixup label mixing
logging.info('Using mixup or cutmix!')
y1 = self._one_hot(inputs['mix_labels'])
y = inputs['ratio'][:, None] * y + (1. - inputs['ratio'][:, None]) * y1
if self.config.label_smoothing > 0: # get smoothy
spositives = 1. - self.config.label_smoothing
snegatives = self.config.label_smoothing / self.config.num_classes
y = spositives * y + snegatives
if self.config.bfloat16: # Cast logits to float32
logits = logits.astype(jnp.float32)
which_loss = getattr(utils, self.config.which_loss)
loss = which_loss(logits, y, reduction='mean')
metrics = utils.topk_correct(logits, inputs['labels'], prefix='train_')
# Average top-1 and top-5 correct labels
metrics = jax.tree_map(jnp.mean, metrics)
metrics['train_loss'] = loss # Metrics will be pmeaned so don't divide here
scaled_loss = loss / jax.device_count() # Grads get psummed so do divide
return scaled_loss, (metrics, state)
def _train_fn(self, params, states, opt_states, inputs, rng, global_step,
ema_params, ema_states):
"""Runs one batch forward + backward and run a single opt step."""
grad_fn = jax.grad(self._loss_fn, argnums=0, has_aux=True)
if self.config.bfloat16:
in_params, states = jax.tree_map(utils.to_bf16, (params, states))
else:
in_params = params
grads, (metrics, states) = grad_fn(in_params, states, inputs, rng)
if self.config.bfloat16:
states, metrics, grads = jax.tree_map(utils.from_bf16,
(states, metrics, grads))
# Sum gradients and average losses for pmap
grads = jax.lax.psum(grads, 'i')
metrics = jax.lax.pmean(metrics, 'i')
# Compute updates and update parameters
metrics['learning_rate'] = self.opt._hyperparameters['lr'](global_step) # pylint: disable=protected-access
params, opt_states = self.opt.step(params, grads, opt_states, global_step)
if ema_params is not None:
ema_fn = getattr(utils, self.config.get('which_ema', 'tf1_ema'))
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
ema_params = jax.tree_map(ema, ema_params, params)
ema_states = jax.tree_map(ema, ema_states, states)
return {
'params': params,
'states': states,
'opt_states': opt_states,
'ema_params': ema_params,
'ema_states': ema_states,
'metrics': metrics
}
# _ _
# | |_ _ __ __ _(_)_ __
# | __| '__/ _` | | '_ \
# | |_| | | (_| | | | | |
# \__|_| \__,_|_|_| |_|
#
def step(self, global_step, rng, *unused_args, **unused_kwargs):
if self._train_input is None:
self._initialize_train()
inputs = next(self._train_input)
out = self.train_fn(
params=self._params,
states=self._state,
opt_states=self._opt_state,
inputs=inputs,
rng=rng,
global_step=global_step,
ema_params=self._ema_params,
ema_states=self._ema_state)
self._params, self._state = out['params'], out['states']
self._opt_state = out['opt_states']
self._ema_params, self._ema_state = out['ema_params'], out['ema_states']
self.opt.plugin(self._opt_state)
return jl_utils.get_first(out['metrics'])
def _build_train_input(self):
num_devices = jax.device_count()
global_batch_size = self.config.train_batch_size
bs_per_device, ragged = divmod(global_batch_size, num_devices)
if ragged:
raise ValueError(
f'Global batch size {global_batch_size} must be divisible by '
f'num devices {num_devices}')
return dataset.load(
dataset.Split.TRAIN_AND_VALID,
is_training=True,
batch_dims=[jax.local_device_count(), bs_per_device],
transpose=self.config.get('transpose', False),
image_size=(self.train_imsize,) * 2,
augment_name=self.config.augment_name,
augment_before_mix=self.config.get('augment_before_mix', True),
name=self.config.which_dataset,
fake_data=self.config.get('fake_data', False))
# _
# _____ ____ _| |
# / _ \ \ / / _` | |
# | __/\ V / (_| | |
# \___| \_/ \__,_|_|
#
def evaluate(self, global_step, **unused_args):
metrics = self._eval_epoch(self._params, self._state)
if self.config.use_ema:
ema_metrics = self._eval_epoch(self._ema_params, self._ema_state)
metrics.update({f'ema_{key}': val for key, val in ema_metrics.items()})
logging.info(f'[Step {global_step}] Eval scalars: {metrics}')
return metrics
def _eval_epoch(self, params, state):
"""Evaluates an epoch."""
num_samples = 0.
summed_metrics = None
for inputs in self._build_eval_input():
num_samples += np.prod(inputs['labels'].shape[:2]) # Account for pmaps
metrics = self.eval_fn(params, state, inputs)
# Accumulate the sum of metrics for each step.
metrics = jax.tree_map(lambda x: jnp.sum(x[0], axis=0), metrics)
if summed_metrics is None:
summed_metrics = metrics
else:
summed_metrics = jax.tree_map(jnp.add, summed_metrics, metrics)
mean_metrics = jax.tree_map(lambda x: x / num_samples, summed_metrics)
return jax.device_get(mean_metrics)
def _eval_fn(self, params, state, inputs):
"""Evaluate a single batch and return loss and top-k acc."""
logits, _ = self.net.apply(params, state, None, inputs, is_training=False)
y = self._one_hot(inputs['labels'])
which_loss = getattr(utils, self.config.which_loss)
loss = which_loss(logits, y, reduction=None)
metrics = utils.topk_correct(logits, inputs['labels'], prefix='eval_')
metrics['eval_loss'] = loss
return jax.lax.psum(metrics, 'i')
def _build_eval_input(self):
"""Builds the evaluation input pipeline."""
bs_per_device = (self.config.eval_batch_size // jax.local_device_count())
split = dataset.Split.from_string(self.config.eval_subset)
eval_preproc = self.config.get('eval_preproc', 'crop_resize')
return dataset.load(
split,
is_training=False,
batch_dims=[jax.local_device_count(), bs_per_device],
transpose=self.config.get('transpose', False),
image_size=(self.test_imsize,) * 2,
name=self.config.which_dataset,
eval_preproc=eval_preproc,
fake_data=self.config.get('fake_data', False))
if __name__ == '__main__':
flags.mark_flag_as_required('config')
platform.main(Experiment, sys.argv[1:])