Skip to content

Commit

Permalink
Clean the code
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Sep 20, 2024
1 parent 7be5c3d commit 1f8b6c3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
4 changes: 2 additions & 2 deletions fla/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def cuda_kernels_forward(
seq_idx=None,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
dt_softplus=True,
**dt_limit_kwargs,
)
if ssm_state is not None and cache_params is not None:
Expand Down Expand Up @@ -1075,4 +1075,4 @@ def forward(
logits=logits,
cache_params=mamba2_outputs.cache_params,
hidden_states=mamba2_outputs.hidden_states,
)
)
28 changes: 19 additions & 9 deletions fla/modules/layernorm_gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@

import torch
import torch.nn.functional as F

import triton
import triton.language as tl

from einops import rearrange


def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
dtype = x.dtype
N = x.shape[-1]
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
Expand Down Expand Up @@ -147,7 +144,6 @@ def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, nor
return out, mean, rstd



@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
Expand Down Expand Up @@ -356,7 +352,8 @@ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size,
norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
Expand All @@ -372,9 +369,22 @@ def backward(ctx, dy):
if dy.stride(-1) != 1:
dy = dy.contiguous()
assert dy.shape == x.shape
dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
ctx.norm_before_gate, ctx.is_rms_norm)
return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
dx, dw, db, dz = _layer_norm_bwd(
dy,
x,
weight,
bias,
ctx.eps,
mean,
rstd,
z,
ctx.group_size,
ctx.norm_before_gate,
ctx.is_rms_norm
)
dx = dx.reshape(ctx.x_shape_og)
dx = dz.reshape(ctx.x_shape_og) if dz is not None else None
return dx, dw, db, dz, None, None, None, None


def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
Expand Down Expand Up @@ -434,4 +444,4 @@ def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
norm_before_gate=self.norm_before_gate)
norm_before_gate=self.norm_before_gate)

0 comments on commit 1f8b6c3

Please sign in to comment.