diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 1d87700c8632..d36caaf61952 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1095,7 +1095,6 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen @pytest.mark.parametrize('layout', ['bshd', 'bhsd']) def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): torch.manual_seed(20) - # TODO: Adapt test for bshd q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) if causal: input_metadata.need_causal()