Skip to content

Commit

Permalink
Fix bug in bias test
Browse files Browse the repository at this point in the history
  • Loading branch information
vgokhale committed May 20, 2024
1 parent bf51b69 commit 4c06dc8
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,18 +1170,14 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor
torch.manual_seed(20)
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.max_seqlens_q = N_CTX_Q
input_metadata.max_seqlens_k = N_CTX_K
q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd')
if causal:
input_metadata.need_causal()
if use_bias:
bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda")
input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K)
else:
bias = None
q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
o = torch.empty_like(q)

# triton implementation
Expand Down

0 comments on commit 4c06dc8

Please sign in to comment.