diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index c0e1ca45a..32b434865 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -305,17 +305,20 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) return state_dict -class LlamaRMSNorm(hnn.LayerNorm): - """It is a modified version of LayerNorm. - The main changes are: - 1. The variance is defined as the average of square, versus the original - definition as the average of the squared deviations from the mean. - 2. The output is defined as x * inv, without minusing the mean. - 3. The default value of eps is set to 1e-6 and use_bias to False. +class LlamaRMSNorm(eqx.Module): """ + Similar to LayerNorm, but uses the RMS of the input along the specified axis (or axes) instead of variance. + """ + + axis: AxisSpec = eqx.static_field() + weight: Optional[NamedArray] + bias: Optional[NamedArray] + + eps: float = eqx.static_field(default=1e-5) + dtype: Optional[jnp.dtype] = eqx.static_field(default=jnp.float32) @staticmethod - def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False): + def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = True, dtype=jnp.float32): if use_weight: weight = hax.ones(axis) else: @@ -325,20 +328,25 @@ def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: b else: bias = None - return LlamaRMSNorm(axis, weight, bias, eps) + return LlamaRMSNorm(axis, weight, bias, eps, dtype) def __call__(self, x: NamedArray) -> NamedArray: # This gives a different result than jnp.var(), which is # defined as the average of the squared deviations from the mean + in_dtype = x.dtype + x = x.astype(self.dtype) var = hax.mean(hax.square(x), axis=self.axis) inv = hax.rsqrt(var + self.eps) out = x * inv + out = out.astype(in_dtype) if self.weight is not None: out = self.weight * out if self.bias is not None: out = out + self.bias - return out + + # second cast in case params are in float32 + return out.astype(in_dtype) class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module):