Skip to content

Commit

Permalink
let's try upcasting rmsnorm (#606)
Browse files Browse the repository at this point in the history
* let's try upcasting rmsnorm

* Update src/levanter/models/llama.py
  • Loading branch information
dlwh authored May 31, 2024
1 parent d9eb6cb commit dae12f6
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,20 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None)
return state_dict


class LlamaRMSNorm(hnn.LayerNorm):
"""It is a modified version of LayerNorm.
The main changes are:
1. The variance is defined as the average of square, versus the original
definition as the average of the squared deviations from the mean.
2. The output is defined as x * inv, without minusing the mean.
3. The default value of eps is set to 1e-6 and use_bias to False.
class LlamaRMSNorm(eqx.Module):
"""
Similar to LayerNorm, but uses the RMS of the input along the specified axis (or axes) instead of variance.
"""

axis: AxisSpec = eqx.static_field()
weight: Optional[NamedArray]
bias: Optional[NamedArray]

eps: float = eqx.static_field(default=1e-5)
dtype: Optional[jnp.dtype] = eqx.static_field(default=jnp.float32)

@staticmethod
def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False):
def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = True, dtype=jnp.float32):
if use_weight:
weight = hax.ones(axis)
else:
Expand All @@ -325,20 +328,25 @@ def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: b
else:
bias = None

return LlamaRMSNorm(axis, weight, bias, eps)
return LlamaRMSNorm(axis, weight, bias, eps, dtype)

def __call__(self, x: NamedArray) -> NamedArray:
# This gives a different result than jnp.var(), which is
# defined as the average of the squared deviations from the mean
in_dtype = x.dtype
x = x.astype(self.dtype)
var = hax.mean(hax.square(x), axis=self.axis)
inv = hax.rsqrt(var + self.eps)
out = x * inv
out = out.astype(in_dtype)

if self.weight is not None:
out = self.weight * out
if self.bias is not None:
out = out + self.bias
return out

# second cast in case params are in float32
return out.astype(in_dtype)


class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module):
Expand Down

0 comments on commit dae12f6

Please sign in to comment.