From fbb9b4b63b16c757b37db64d57f12b8a56ee52b6 Mon Sep 17 00:00:00 2001 From: Qiwei Ye Date: Tue, 26 Jul 2022 17:24:37 +0800 Subject: [PATCH 1/2] align with paper --- flash_pytorch/flash_pytorch.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/flash_pytorch/flash_pytorch.py b/flash_pytorch/flash_pytorch.py index 197499f..975d86e 100644 --- a/flash_pytorch/flash_pytorch.py +++ b/flash_pytorch/flash_pytorch.py @@ -264,7 +264,8 @@ def forward( j - sequence dimension (target) """ - b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size + b, n, device, c = x.shape[0], x.shape[-2], x.device, self.group_size + g = math.ceil(n / c) # prenorm @@ -299,24 +300,23 @@ def forward( # padding for groups - padding = padding_to_multiple_of(n, g) + padding = padding_to_multiple_of(n, c) if padding > 0: quad_q, quad_k, lin_q, lin_k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v)) - mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool)) mask = F.pad(mask, (0, padding), value = False) # group along sequence - quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v)) + quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (g c) d -> b g c d', c = c), (quad_q, quad_k, lin_q, lin_k, v)) if exists(mask): - mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g) + mask = rearrange(mask, 'b (g c) -> b g 1 c', c = c) # calculate quadratic attention output - sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g + sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / c sim = sim + self.rel_pos_bias(sim) @@ -327,7 +327,7 @@ def forward( attn = attn.masked_fill(~mask, 0.) if self.causal: - causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1) + causal_mask = torch.ones((c,c), dtype = torch.bool, device = device).triu(1) attn = attn.masked_fill(causal_mask, 0.) quad_out = einsum('... i j, ... j d -> ... i d', attn, v) @@ -335,7 +335,7 @@ def forward( # calculate linear attention output if self.causal: - lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g + lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / c # exclusive cumulative sum along group dimension @@ -350,6 +350,7 @@ def forward( # fold back groups into full sequence, and excise out padding + # import pdb; pdb.set_trace() quad_attn_out, lin_attn_out = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out, lin_out)) # gate From 6ecc99f9e96deb3be33666b54c6d0cf96c8d9c57 Mon Sep 17 00:00:00 2001 From: Qiwei Ye Date: Tue, 26 Jul 2022 22:38:11 +0800 Subject: [PATCH 2/2] delete unesscary change --- flash_pytorch/flash_pytorch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash_pytorch/flash_pytorch.py b/flash_pytorch/flash_pytorch.py index 975d86e..8359345 100644 --- a/flash_pytorch/flash_pytorch.py +++ b/flash_pytorch/flash_pytorch.py @@ -265,7 +265,6 @@ def forward( """ b, n, device, c = x.shape[0], x.shape[-2], x.device, self.group_size - g = math.ceil(n / c) # prenorm @@ -350,7 +349,6 @@ def forward( # fold back groups into full sequence, and excise out padding - # import pdb; pdb.set_trace() quad_attn_out, lin_attn_out = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out, lin_out)) # gate