Skip to content

Commit

Permalink
Support custom positions for rotary positional embedding (#301)
Browse files Browse the repository at this point in the history
custom positions across heads for rotary
  • Loading branch information
Aceticia authored Nov 30, 2024
1 parent 8c993f0 commit c493eb4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
23 changes: 22 additions & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,30 @@ def test_custom_alibi(flash: bool):

logits = model(x, pos = pos)

def test_custom_rotary_pos_emb():
from einops import repeat

model = TransformerWrapper(
num_tokens = 20_000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 2,
heads = 8,
rotary_pos_emb = True
)
)

x = torch.randint(0, 20000, (4, 4))

pos = repeat(torch.arange(0, 4), "n -> b n", b=4)

logits1 = model(x, pos = pos)
logits2 = model(x)
assert torch.allclose(logits1, logits2)

@pytest.mark.parametrize('flash', (True, False))
def test_custom_alibi_across_heads(flash: bool):

model = Decoder(
dim = 512,
depth = 2,
Expand Down
9 changes: 6 additions & 3 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,10 @@ def forward_from_seq_len(self, seq_len):
def forward(self, t):
max_pos = t.max() + 1

freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
if t.ndim == 1:
t = rearrange(t, 'n -> 1 n')

freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
freqs = torch.stack((freqs, freqs), dim = -1)
freqs = rearrange(freqs, '... d r -> ... (d r)')

Expand All @@ -679,8 +682,8 @@ def rotate_half(x):
def apply_rotary_pos_emb(t, freqs, scale = 1):
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype

freqs = freqs[-seq_len:, :]
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
freqs = freqs[:, -seq_len:, :]
scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale

if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
Expand Down

0 comments on commit c493eb4

Please sign in to comment.