Skip to content

Commit

Permalink
Custom pos alibi flash attn fix (#300)
Browse files Browse the repository at this point in the history
* handle custom pos for alibi in flash attn

* test for custom pos alibi+flash attn
  • Loading branch information
Aceticia authored Nov 29, 2024
1 parent 57efd77 commit 4c3e62a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def test_neo_mlp():
out = mlp(x)
assert out.shape == (3, 7)

def test_custom_alibi():
@pytest.mark.parametrize('flash', (True, False))
def test_custom_alibi(flash: bool):

model = TransformerWrapper(
num_tokens = 20_000,
Expand All @@ -397,7 +398,8 @@ def test_custom_alibi():
dim = 512,
depth = 2,
heads = 8,
alibi_pos_bias = True
alibi_pos_bias = True,
attn_flash = flash
)
)

Expand All @@ -407,7 +409,8 @@ def test_custom_alibi():

logits = model(x, pos = pos)

def test_custom_alibi_across_heads():
@pytest.mark.parametrize('flash', (True, False))
def test_custom_alibi_across_heads(flash: bool):

model = Decoder(
dim = 512,
Expand All @@ -417,6 +420,7 @@ def test_custom_alibi_across_heads():
rel_pos_kwargs = dict(
slopes = [1, 1]
),
attn_flash = flash
)

x = torch.randn(2, 4, 512)
Expand Down
2 changes: 1 addition & 1 deletion x_transformers/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def flash_attn(
# convert from bool to float

if exists(attn_bias):
attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
attn_bias = attn_bias.expand(batch, heads, -1, -1)

# if mask given, the mask would already contain the causal mask from above logic
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
Expand Down

0 comments on commit 4c3e62a

Please sign in to comment.