Running cuda...
Cuda time: 11700.296us
Running torch...
Torch time: 66341.925us
================================= RESULT =================================
Maximum error value = tensor(1.1921e-07, device='cuda:0')
CUDA Speedup = 5.67x
Kernel test passed.
==========================================================================
Running cuda...
Cuda time: 16315.532us
Running torch...
Torch time: 74056.602us
================================= RESULT =================================
Maximum error value = tensor(0.0017, device='cuda:0')
CUDA Speedup = 4.54x
Kernel test passed.
==========================================================================
Running cuda...
Cuda time: 4012.561us
Running torch...
Torch time: 18420.815us
================================= RESULT =================================
Maximum error value = tensor(0.0023, device='cuda:0')
CUDA Speedup = 4.59x
Kernel test passed.
==========================================================================
主要就是在 kernel/fused_grouped_tril_mask_softmax_kernel.cu 中实现了逻辑Mask,用计算代替访存,具体是在 GroupedTrilMaskLoad
和 GroupedTrilMaskStore
中。
# 在cutlass机器上([email protected])
cd /data1/jerryao/code/Attention-Mask-Optimization
conda activate masked_opt
# 正确性
python fused_grouped_tril_mask_softmax_test.py
# 性能
python fused_grouped_tril_mask_softmax_bench.py
fused_grouped_tril_mask_softmax_test.py 中计算了最大误差,在 1.1921e-07左右
import torch
device = torch.device("cpu")
dtype = torch.float32
B = 3
S = 5 # 填充长度
H = 2
M = 2
SEQ_LENS = [3, 4, 5] # 有效长度,每个batch都不同
attn = torch.randn(B, M, S, S, device=device, dtype=dtype)
# fused_grouped_tril_mask_softmax
for b in range(B):
i, j = attn.shape[-2:]
causal_mask = torch.ones((i, j), dtype=torch.bool, device=attn.device)
seq_len = SEQ_LENS[b]
for row in range(i):
for col in range(j):
invalid = row >= seq_len or col >= seq_len
need_load = not invalid and row >= col
if need_load:
causal_mask[row][col] = False
attn[b] = attn[b].masked_fill(causal_mask, -torch.finfo(attn.dtype).max)
attn[b] = attn[b].softmax(-1)
# print(causal_mask)
attn[attn == 1 / S] = 0.0
# print(attn)
# fused_grouped_tril_mask_softmax
上面代码运行的打印内容为:
tensor([[False, True, True, True, True],
[False, False, True, True, True],
[False, False, False, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True]])
tensor([[False, True, True, True, True],
[False, False, True, True, True],
[False, False, False, True, True],
[False, False, False, False, True],
[ True, True, True, True, True]])
tensor([[False, True, True, True, True],
[False, False, True, True, True],
[False, False, False, True, True],
[False, False, False, False, True],
[False, False, False, False, False]])
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5596, 0.4404, 0.0000, 0.0000, 0.0000],
[0.0146, 0.9730, 0.0125, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1611, 0.8389, 0.0000, 0.0000, 0.0000],
[0.1570, 0.3879, 0.4552, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2416, 0.7584, 0.0000, 0.0000, 0.0000],
[0.8261, 0.1336, 0.0404, 0.0000, 0.0000],
[0.2205, 0.0954, 0.2523, 0.4318, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4338, 0.5662, 0.0000, 0.0000, 0.0000],
[0.1283, 0.0794, 0.7923, 0.0000, 0.0000],
[0.3714, 0.3939, 0.1283, 0.1064, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5161, 0.4839, 0.0000, 0.0000, 0.0000],
[0.6362, 0.1044, 0.2594, 0.0000, 0.0000],
[0.2404, 0.1837, 0.3760, 0.1999, 0.0000],
[0.0307, 0.0916, 0.5037, 0.1961, 0.1780]],
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1792, 0.8208, 0.0000, 0.0000, 0.0000],
[0.6914, 0.2153, 0.0933, 0.0000, 0.0000],
[0.3500, 0.2563, 0.2432, 0.1505, 0.0000],
[0.4088, 0.3471, 0.0822, 0.0416, 0.1203]]]])