forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
126 lines (99 loc) · 3.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils."""
import dill
import jax
import jax.numpy as jnp
import tree
def reduce_fn(x, mode):
"""Reduce fn for various losses."""
if mode == 'none' or mode is None:
return jnp.asarray(x)
elif mode == 'sum':
return jnp.sum(x)
elif mode == 'mean':
return jnp.mean(x)
else:
raise ValueError('Unsupported reduction option.')
def softmax_cross_entropy(logits, labels, reduction='sum'):
"""Computes softmax cross entropy given logits and one-hot class labels.
Args:
logits: Logit output values.
labels: Ground truth one-hot-encoded labels.
reduction: Type of reduction to apply to loss.
Returns:
Loss value. If `reduction` is `none`, this has the same shape as `labels`;
otherwise, it is scalar.
Raises:
ValueError: If the type of `reduction` is unsupported.
"""
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
return reduce_fn(loss, reduction)
def topk_correct(logits, labels, mask=None, prefix='', topk=(1, 5)):
"""Calculate top-k error for multiple k values."""
metrics = {}
argsorted_logits = jnp.argsort(logits)
for k in topk:
pred_labels = argsorted_logits[..., -k:]
# Get the number of examples where the label is in the top-k predictions
correct = any_in(pred_labels, labels).any(axis=-1).astype(jnp.float32)
if mask is not None:
correct *= mask
metrics[f'{prefix}top_{k}_acc'] = correct
return metrics
@jax.vmap
def any_in(prediction, target):
"""For each row in a and b, checks if any element of a is in b."""
return jnp.isin(prediction, target)
def tf1_ema(ema_value, current_value, decay, step):
"""Implements EMA with TF1-style decay warmup."""
decay = jnp.minimum(decay, (1.0 + step) / (10.0 + step))
return ema_value * decay + current_value * (1 - decay)
def ema(ema_value, current_value, decay, step):
"""Implements EMA without any warmup."""
del step
return ema_value * decay + current_value * (1 - decay)
to_bf16 = lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x
from_bf16 = lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x
def _replicate(x, devices=None):
"""Replicate an object on each device."""
x = jax.numpy.array(x)
if devices is None:
devices = jax.local_devices()
return jax.device_put_sharded(len(devices) * [x], devices)
def broadcast(obj):
"""Broadcasts an object to all devices."""
if obj is not None and not isinstance(obj, bool):
return _replicate(obj)
else:
return obj
def split_tree(tuple_tree, base_tree, n):
"""Splits tuple_tree with n-tuple leaves into n trees."""
return [tree.map_structure_up_to(base_tree, lambda x: x[i], tuple_tree) # pylint: disable=cell-var-from-loop
for i in range(n)]
def load_haiku_file(filename):
"""Loads a haiku parameter tree, using dill."""
with open(filename, 'rb') as in_file:
output = dill.load(in_file)
return output
def flatten_haiku_tree(haiku_dict):
"""Flattens a haiku parameter tree into a flat dictionary."""
out = {}
for module in haiku_dict.keys():
out_module = module.replace('/~/', '.').replace('/', '.')
for key in haiku_dict[module]:
out_key = f'{out_module}.{key}'
out[out_key] = haiku_dict[module][key]
return out