Skip to content

Commit

Permalink
Replace SelfAttention with MultiHeadDotProductAttention to silence wa…
Browse files Browse the repository at this point in the history
…rnings.

PiperOrigin-RevId: 680333531
  • Loading branch information
Scenic Authors committed Sep 29, 2024
1 parent b712494 commit f9bd9e6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion scenic/projects/baselines/clip/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class ResidualAttentionBlock(nn.Module):
@nn.compact
def __call__(self, x: jnp.ndarray, attn_mask=None) -> jnp.ndarray:
xn = LayerNorm(name='ln_1')(x)
x = x + nn.SelfAttention(
x = x + nn.MultiHeadDotProductAttention(
self.num_heads, name='attn', deterministic=True)(xn, attn_mask)
xn = LayerNorm(name='ln_2')(x)
x = x + MLP(name='mlp')(xn)
Expand Down

0 comments on commit f9bd9e6

Please sign in to comment.