diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index cd1be33d..18c03b48 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -388,7 +388,8 @@ def test_neo_mlp(): out = mlp(x) assert out.shape == (3, 7) -def test_custom_alibi(): +@pytest.mark.parametrize('flash', (True, False)) +def test_custom_alibi(flash: bool): model = TransformerWrapper( num_tokens = 20_000, @@ -397,7 +398,8 @@ def test_custom_alibi(): dim = 512, depth = 2, heads = 8, - alibi_pos_bias = True + alibi_pos_bias = True, + attn_flash = flash ) ) @@ -407,7 +409,8 @@ def test_custom_alibi(): logits = model(x, pos = pos) -def test_custom_alibi_across_heads(): +@pytest.mark.parametrize('flash', (True, False)) +def test_custom_alibi_across_heads(flash: bool): model = Decoder( dim = 512, @@ -417,6 +420,7 @@ def test_custom_alibi_across_heads(): rel_pos_kwargs = dict( slopes = [1, 1] ), + attn_flash = flash ) x = torch.randn(2, 4, 512) diff --git a/x_transformers/attend.py b/x_transformers/attend.py index d354f915..c2bad988 100644 --- a/x_transformers/attend.py +++ b/x_transformers/attend.py @@ -370,7 +370,7 @@ def flash_attn( # convert from bool to float if exists(attn_bias): - attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1) + attn_bias = attn_bias.expand(batch, heads, -1, -1) # if mask given, the mask would already contain the causal mask from above logic # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number