Skip to content

Commit

Permalink
Merge pull request #36 from GavChap/fix-sdp-attention-mask
Browse files Browse the repository at this point in the history
* Fix SDP attention mask by creating a diagonal block mask in the right shape.
  • Loading branch information
city96 authored May 16, 2024
2 parents 1e5dc8e + 0f88894 commit 296411c
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions PixArt/models/PixArt_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,20 @@ def forward(self, x, cond, mask=None):
q, k, v = map(lambda t: t.permute(0, 2, 1, 3),(q, k, v),)
attn_mask = None
if mask is not None and len(mask) > 1:
# This is most definitely wrong, especially for B>1
attn_mask = torch.zeros(
[1, q.shape[1], q.shape[2], v.shape[2]],
dtype=q.dtype,

# Create equivalent of xformer diagonal block mask, still only correct for square masks
# But depth doesn't matter as tensors can expand in that dimension
attn_mask_template = torch.ones(
[q.shape[2] // B, mask[0]],
dtype=torch.bool,
device=q.device
)
attn_mask[:, :, (q.shape[2]//2):, mask[0]:] = True
attn_mask[:, :, :(q.shape[2]//2), :mask[1]] = True
attn_mask = torch.block_diag(attn_mask_template)

# create a mask on the diagonal for each mask in the batch
for n in range(B - 1):
attn_mask = torch.block_diag(attn_mask, attn_mask_template)

x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
Expand Down

0 comments on commit 296411c

Please sign in to comment.