-
Notifications
You must be signed in to change notification settings - Fork 35
/
utils.py
39 lines (33 loc) · 1.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/faceformer.py
import torch
import math
# Temporal Bias
def init_biased_mask(n_head, max_seq_len, period):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
:n - closest_power_of_2]
slopes = torch.Tensor(get_slopes(n_head))
bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1, period).view(-1) // (period)
bias = - torch.flip(bias, dims=[0])
alibi = torch.zeros(max_seq_len, max_seq_len)
for i in range(max_seq_len):
alibi[i, :i + 1] = bias[-(i + 1):]
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
mask = mask.unsqueeze(0) + alibi
return mask
# Alignment Bias
def enc_dec_mask(device, T, S):
mask = torch.ones(T, S).to(device)
for i in range(T):
mask[i, i] = 0
return (mask == 1).to(device=device)