From 1f8b6c3f5c8c5553bea517ded515b962d8170ec5 Mon Sep 17 00:00:00 2001 From: yzhangcs Date: Fri, 20 Sep 2024 11:46:57 +0000 Subject: [PATCH] Clean the code --- fla/models/mamba2/modeling_mamba2.py | 4 ++-- fla/modules/layernorm_gated.py | 28 +++++++++++++++++++--------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/fla/models/mamba2/modeling_mamba2.py b/fla/models/mamba2/modeling_mamba2.py index 3d22d7414..63bb7efa2 100644 --- a/fla/models/mamba2/modeling_mamba2.py +++ b/fla/models/mamba2/modeling_mamba2.py @@ -412,7 +412,7 @@ def cuda_kernels_forward( seq_idx=None, return_final_states=True, dt_bias=self.dt_bias, - dt_softplus=True, + dt_softplus=True, **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: @@ -1075,4 +1075,4 @@ def forward( logits=logits, cache_params=mamba2_outputs.cache_params, hidden_states=mamba2_outputs.hidden_states, - ) \ No newline at end of file + ) diff --git a/fla/modules/layernorm_gated.py b/fla/modules/layernorm_gated.py index ac11736e5..74e1997b7 100644 --- a/fla/modules/layernorm_gated.py +++ b/fla/modules/layernorm_gated.py @@ -8,16 +8,13 @@ import torch import torch.nn.functional as F - import triton import triton.language as tl - from einops import rearrange def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True): dtype = x.dtype - N = x.shape[-1] weight = weight.float() bias = bias.float() if bias is not None else None if upcast: @@ -147,7 +144,6 @@ def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, nor return out, mean, rstd - @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @@ -356,7 +352,8 @@ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before weight = weight.contiguous() if bias is not None: bias = bias.contiguous() - y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm) + y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, + norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm) ctx.save_for_backward(x, weight, bias, mean, rstd, z) ctx.x_shape_og = x_shape_og ctx.eps = eps @@ -372,9 +369,22 @@ def backward(ctx, dy): if dy.stride(-1) != 1: dy = dy.contiguous() assert dy.shape == x.shape - dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size, - ctx.norm_before_gate, ctx.is_rms_norm) - return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None + dx, dw, db, dz = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + z, + ctx.group_size, + ctx.norm_before_gate, + ctx.is_rms_norm + ) + dx = dx.reshape(ctx.x_shape_og) + dx = dz.reshape(ctx.x_shape_og) if dz is not None else None + return dx, dw, db, dz, None, None, None, None def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): @@ -434,4 +444,4 @@ def forward(self, x, z=None): """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) """ return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size, - norm_before_gate=self.norm_before_gate) \ No newline at end of file + norm_before_gate=self.norm_before_gate)