Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

let's try upcasting rmsnorm #606

Merged
merged 2 commits into from
May 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,20 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None)
return state_dict


class LlamaRMSNorm(hnn.LayerNorm):
"""It is a modified version of LayerNorm.
The main changes are:
1. The variance is defined as the average of square, versus the original
definition as the average of the squared deviations from the mean.
2. The output is defined as x * inv, without minusing the mean.
3. The default value of eps is set to 1e-6 and use_bias to False.
class LlamaRMSNorm(eqx.Module):
"""
Similar to LayerNorm, but uses the RMS of the input along the specified axis (or axes) instead of variance.
"""

axis: AxisSpec = eqx.static_field()
weight: Optional[NamedArray]
bias: Optional[NamedArray]

eps: float = eqx.static_field(default=1e-5)
dtype: Optional[jnp.dtype] = eqx.static_field(default=jnp.float32)

@staticmethod
def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False):
def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = True, dtype=jnp.float32):
if use_weight:
weight = hax.ones(axis)
else:
Expand All @@ -325,20 +328,25 @@ def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: b
else:
bias = None

return LlamaRMSNorm(axis, weight, bias, eps)
return LlamaRMSNorm(axis, weight, bias, eps, dtype)

def __call__(self, x: NamedArray) -> NamedArray:
# This gives a different result than jnp.var(), which is
# defined as the average of the squared deviations from the mean
in_dtype = x.dtype
x = x.astype(self.dtype)
var = hax.mean(hax.square(x), axis=self.axis)
inv = hax.rsqrt(var + self.eps)
out = x * inv
out = out.astype(in_dtype)

if self.weight is not None:
out = self.weight * out
if self.bias is not None:
out = out + self.bias
return out

# second cast in case params are in float32
return out.astype(in_dtype)


class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module):
Expand Down
Loading