Skip to content

Commit

Permalink
tritonsrc/attn_torch_function: fix block size settings
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyazhang committed Apr 21, 2024
1 parent 205d5a0 commit d6c6dcc
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions tritonsrc/attn_torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,21 +274,19 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax
else:
BIAS_TYPE = 1

if is_meff:
BLOCK_M = 1
BLOCK_N = 1
use_small_block = dropout_p > 0.0 or BIAS_TYPE != 0 or return_encoded_softmax
use_medium_block = False # reserved
if use_small_block:
BLOCK_M = 64
BLOCK_N = 32
elif use_medium_block:
BLOCK_M = 64
BLOCK_N = 64
else:
use_small_block = dropout_p > 0.0 or BIAS_TYPE != 0
use_medium_block = False # reserved
if use_small_block:
BLOCK_M = 64
BLOCK_N = 32
elif use_medium_block:
BLOCK_M = 64
BLOCK_N = 64
else:
BLOCK_M = 128
BLOCK_N = 64
BLOCK_M = 128
BLOCK_N = 64
if dtype == torch.float32:
BLOCK_M //= 2

if autotune:
# assert False, "No time to test autotune for now"
Expand All @@ -314,12 +312,6 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax
BIAS_TYPE=BIAS_TYPE,
)
else:
# BLOCK_M = min(MAX_BLOCK_M, q.shape[2], k.shape[2])
# BLOCK_N = min(MAX_BLOCK_N, q.shape[2], k.shape[2])
BLOCK_M = MAX_BLOCK_M
BLOCK_N = MAX_BLOCK_N
# BLOCK_M = 32
# BLOCK_N = 32
RETURN_ENCODED_SOFTMAX=encoded_softmax is not None
print(f'{BLOCK_M=} {BLOCK_N=} {RETURN_ENCODED_SOFTMAX=} seqlen_q={q.shape[2]} seqlen_k={k.shape[2]}',
flush=True)
Expand Down Expand Up @@ -474,6 +466,8 @@ def backward(ctx, do, _, fwd_tuning_result):
else:
BLOCK_M = 128
BLOCK_N = 64
if q.dtype == torch.float32:
BLOCK_M //= 2
# debug_mask = torch.zeros((q.shape[0], q.shape[1], max_seqlens_q, max_seqlens_k), device=q.device, dtype=ctx.encoded_softmax.dtype)
grid_dk_dv = lambda META: (
triton.cdiv(max_seqlens_k, META['BLOCK_N']),
Expand Down

0 comments on commit d6c6dcc

Please sign in to comment.