diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 7044200ba..b054f1972 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -806,10 +806,10 @@ def _tpu_splash_attention( if bias is not None: raise NotImplementedError("Splash attention does not support bias") - if attention_dtype is not None and attention_dtype != jnp.float32: - warnings.warn("Splash attention only supports float32. Switching to float32.") + # if attention_dtype is not None and attention_dtype != jnp.float32: + # warnings.warn("Splash attention only supports float32. Switching to float32.") - attention_dtype = jnp.float32 + # attention_dtype = jnp.float32 q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key)