diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 53dc7d54dd..f095bf7de7 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -1116,12 +1116,12 @@ def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): device, dtype, _, - _, + B, q_len, kv_len, - _, + H, k, - _, + Kv, ) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv torch.manual_seed(q_len + kv_len + k) if device != "cuda": @@ -1134,7 +1134,7 @@ def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): query=query, key=key, value=value, attn_bias=attn_bias, scale=scale ) op_fw = sample_random_supported_fw(inputs, seed=q_len * k + kv_len * k) - grad_out = torch.ones_like(query) + grad_out = query.new_ones(B * H, q_len, Kv) query.requires_grad_(True) key.requires_grad_(True) value.requires_grad_(True)