Skip to content

Commit

Permalink
Make linter happy
Browse files Browse the repository at this point in the history
  • Loading branch information
vgokhale committed May 16, 2024
1 parent 61109c4 commit bf51b69
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s
dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
Expand Down Expand Up @@ -880,8 +880,7 @@ def forward(ctx, q, k, v, o, metadata):
metadata.check_args(q, k, v, o)

batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, metadata)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(
q, k, v, o, metadata)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata)

# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
Expand Down Expand Up @@ -1115,9 +1114,9 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout,

# Transpose here if layout is bshd so we have same reference code for all layouts
if layout == 'bshd':
q = q.transpose(1,2).clone()
k = k.transpose(1,2).clone()
v = v.transpose(1,2).clone()
q = q.transpose(1, 2).clone()
k = k.transpose(1, 2).clone()
v = v.transpose(1, 2).clone()
# Replicate K and V if using MQA/GQA
if HQ != HK:
k = k.view(k.shape[0], k.shape[1], -1, k.shape[2],
Expand All @@ -1142,7 +1141,7 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout,
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v)
# compare
if layout == 'bshd':
ref_out = ref_out.transpose(1,2).clone()
ref_out = ref_out.transpose(1, 2).clone()
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)


Expand Down

0 comments on commit bf51b69

Please sign in to comment.