From 6961e105320d8127e0e1ddb30d0ecb88953271cc Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 30 May 2024 21:48:18 -0700 Subject: [PATCH 1/2] let's try upcasting rmsnorm --- src/levanter/models/llama.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index c0e1ca45a..60d5d7eaf 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -305,17 +305,20 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) return state_dict -class LlamaRMSNorm(hnn.LayerNorm): - """It is a modified version of LayerNorm. - The main changes are: - 1. The variance is defined as the average of square, versus the original - definition as the average of the squared deviations from the mean. - 2. The output is defined as x * inv, without minusing the mean. - 3. The default value of eps is set to 1e-6 and use_bias to False. +class LlamaRMSNorm(eqx.Module): """ + Similar to LayerNorm, but uses the RMS of the input along the specified axis (or axes) instead of variance. + """ + + axis: AxisSpec = eqx.static_field() + weight: Optional[NamedArray] + bias: Optional[NamedArray] + + eps: float = eqx.static_field(default=1e-5) + dtype: Optional[jnp.dtype] = eqx.static_field(default=jnp.float32) @staticmethod - def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False): + def init(axis: AxisSpec, eps: float = 1e-5, use_weight: bool = True, use_bias: bool = True, dtype=jnp.float32): if use_weight: weight = hax.ones(axis) else: @@ -325,20 +328,25 @@ def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: b else: bias = None - return LlamaRMSNorm(axis, weight, bias, eps) + return LlamaRMSNorm(axis, weight, bias, eps, dtype) def __call__(self, x: NamedArray) -> NamedArray: # This gives a different result than jnp.var(), which is # defined as the average of the squared deviations from the mean + in_dtype = x.dtype + x = x.astype(self.dtype) var = hax.mean(hax.square(x), axis=self.axis) inv = hax.rsqrt(var + self.eps) out = x * inv + out = out.astype(in_dtype) if self.weight is not None: out = self.weight * out if self.bias is not None: out = out + self.bias - return out + + # second cast in case params are in float32 + return out.astype(in_dtype) class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): From b6e6cef97eca3abfb5e0d91d1e3bb8cefaf5b6f5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 31 May 2024 15:05:28 -0700 Subject: [PATCH 2/2] Update src/levanter/models/llama.py --- src/levanter/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 60d5d7eaf..32b434865 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -318,7 +318,7 @@ class LlamaRMSNorm(eqx.Module): dtype: Optional[jnp.dtype] = eqx.static_field(default=jnp.float32) @staticmethod - def init(axis: AxisSpec, eps: float = 1e-5, use_weight: bool = True, use_bias: bool = True, dtype=jnp.float32): + def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = True, dtype=jnp.float32): if use_weight: weight = hax.ones(axis) else: