Skip to content

Commit

Permalink
Enable FusedRMSNorm (#78)
Browse files Browse the repository at this point in the history
* FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm (NVIDIA#1274)

* FusedRMSNorm based on FusedLayerNorm

* refactor duplicated kernels

* delete comments

* delete comments

* cleanup

* cleanup

* cleanup, fixed clobbering forward_affine_mixed_dtypes

* fix pybind naming and add MixedFused test

* undo skipping

* check elementwise_affine

* Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Oof, nice catch, thanks

Co-authored-by: Masaki Kozuki <[email protected]>

Co-authored-by: Masaki Kozuki <[email protected]>

* fix and generate docs for FusedRMSNorm (NVIDIA#1285)

* [FusedRMSNorm doc] document where epsilon is added (NVIDIA#1295)

* [FusedRMSNorm doc] add epsilon to formula

* correct

* better wording

* Fix some bugs

* Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs

* Fix NaN issues in FusedRMSNorm

* Update test_fused_layer_norm.py

* Skip test_fused_layer_norm.TestAutocastFusedRMSNorm on ROCm

* Use at::cuda::warp_size() instead of at::cuda::getCurrentDeviceProperties()->warpSize

Co-authored-by: eqy <[email protected]>
Co-authored-by: Masaki Kozuki <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
  • Loading branch information
4 people authored Aug 5, 2022
1 parent cf77e9b commit c97ebfa
Show file tree
Hide file tree
Showing 6 changed files with 1,074 additions and 136 deletions.
2 changes: 1 addition & 1 deletion apex/normalization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm
219 changes: 219 additions & 0 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@
fused_layer_norm_cuda = None


# Reference implementation from Huggingface
def manual_rms_norm(input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)

if weight is None:
return input

# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(self.weight.dtype)

return weight * input


class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
Expand Down Expand Up @@ -39,6 +56,31 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None


class FusedRMSNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward_affine(
input_, ctx.normalized_shape, weight_, ctx.eps)
ctx.save_for_backward(input_, weight_, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight_, invvar = ctx.saved_tensors
grad_input = grad_weight = None
grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine(
grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps
)
return grad_input, grad_weight, None, None


class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):

@staticmethod
Expand All @@ -58,6 +100,25 @@ def forward(ctx, input, weight, bias, normalized_shape, eps):
return output


class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction):

@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes(
input_, ctx.normalized_shape, weight_, ctx.eps
)

ctx.save_for_backward(input_, weight_, invvar)
return output


class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
Expand All @@ -81,6 +142,29 @@ def backward(ctx, grad_output):
return grad_input, None, None


class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.rms_backward(
grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps
)
return grad_input, None, None


def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
Expand All @@ -99,6 +183,24 @@ def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, e
return FusedLayerNormAffineMixedDtypesFunction.apply(*args)


def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormAffineFunction.apply(*args)


def fused_rms_norm(input, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormFunction.apply(*args)


def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormAffineMixedDtypesFunction.apply(*args)


class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
Expand Down Expand Up @@ -195,6 +297,100 @@ def extra_repr(self):
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)


class FusedRMSNorm(torch.nn.Module):
r"""Applies RMS Normalization over a mini-batch of inputs
Currently only runs on cuda() tensors.
.. math::
y = \frac{x}{\mathrm{RMS}[x]} * \gamma
The root-mean-square is calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` is a learnable affine transform parameter of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
`epsilon` is added to the mean-square, then the root of the sum is taken.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, RMS Normalization applies per-element scale
with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedRMSNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedRMSNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedRMSNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf
"""

def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__()

global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()

def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)

def forward(self, input):
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)

if self.elementwise_affine:
return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
else:
return fused_rms_norm(input, self.normalized_shape, self.eps)

def extra_repr(self):
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)


# NOTE (mkozuki): Why "mixed"?
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
Expand All @@ -216,3 +412,26 @@ def forward(self, input: torch.Tensor):
if not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)


# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
class MixedFusedRMSNorm(FusedRMSNorm):

def __init__(self, normalized_shape, eps=1e-5, **kwargs):
if "elementwise_affine" in kwargs:
import warnings
warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument")
elementwise_affine = kwargs.pop("elementwise_affine")
if not elementwise_affine:
raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`")

super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True)

def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
# TODO Manual RMS Norm Implementation Here
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
Loading

0 comments on commit c97ebfa

Please sign in to comment.