From f9bd9e61182dd5e0503a304ad66ecbd90605bc08 Mon Sep 17 00:00:00 2001 From: Scenic Authors Date: Sun, 29 Sep 2024 15:48:34 -0700 Subject: [PATCH] Replace SelfAttention with MultiHeadDotProductAttention to silence warnings. PiperOrigin-RevId: 680333531 --- scenic/projects/baselines/clip/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scenic/projects/baselines/clip/layers.py b/scenic/projects/baselines/clip/layers.py index 6865e2423..3f8a05448 100644 --- a/scenic/projects/baselines/clip/layers.py +++ b/scenic/projects/baselines/clip/layers.py @@ -251,7 +251,7 @@ class ResidualAttentionBlock(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray, attn_mask=None) -> jnp.ndarray: xn = LayerNorm(name='ln_1')(x) - x = x + nn.SelfAttention( + x = x + nn.MultiHeadDotProductAttention( self.num_heads, name='attn', deterministic=True)(xn, attn_mask) xn = LayerNorm(name='ln_2')(x) x = x + MLP(name='mlp')(xn)