-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
237 lines (154 loc) · 4.94 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import math
from beartype.typing import List, Union
from functools import wraps
import torch
import torch.nn.functional as F
from torch import nn
from torch.special import expm1
from einops import rearrange, repeat, reduce, pack, unpack
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def divisible_by(numer, denom):
return (numer % denom) == 0
def first(arr, d=None):
if len(arr) == 0:
return d
return arr[0]
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length=None):
if isinstance(val, list):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(output) == length
return output
def compact(input_dict):
return {key: value for key, value in input_dict.items() if exists(value)}
def maybe_transform_dict_key(input_dict, key, fn):
if key not in input_dict:
return input_dict
copied_dict = input_dict.copy()
copied_dict[key] = fn(copied_dict[key])
return copied_dict
def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
return images / 255
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def pad_tuple_to_length(t, length, fillvalue=None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# helper classes
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
# tensor helpers
def log(t, eps: float = 1e-12):
return torch.log(t.clamp(min=eps))
def l2norm(t):
return F.normalize(t, dim=-1)
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def masked_mean(t, *, dim, mask=None):
if not exists(mask):
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, 'b n -> b n 1')
masked_t = t.masked_fill(~mask, 0.)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
def resize_image_to(
image,
target_image_size,
clamp_range=None,
mode='nearest'
):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
out = F.interpolate(image, target_image_size, mode=mode)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
def calc_all_frame_dims(
downsample_factors: List[int],
frames
):
if not exists(frames):
return (tuple(),) * len(downsample_factors)
all_frame_dims = []
for divisor in downsample_factors:
assert divisible_by(frames, divisor)
all_frame_dims.append((frames // divisor,))
return all_frame_dims
def safe_get_tuple_index(tup, index, default=None):
if len(tup) <= index:
return default
return tup[index]
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
def normalize_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
# gaussian diffusion with continuous time helper functions and classes
# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
@torch.jit.script
def beta_linear_log_snr(t):
return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1,
eps=1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
def log_snr_to_alpha_sigma(log_snr):
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))