Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FusedRMSNorm #78

Merged
merged 9 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apex/normalization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm
219 changes: 219 additions & 0 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@
fused_layer_norm_cuda = None


# Reference implementation from Huggingface
def manual_rms_norm(input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)

if weight is None:
return input

# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(self.weight.dtype)

return weight * input


class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
Expand Down Expand Up @@ -39,6 +56,31 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None


class FusedRMSNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward_affine(
input_, ctx.normalized_shape, weight_, ctx.eps)
ctx.save_for_backward(input_, weight_, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight_, invvar = ctx.saved_tensors
grad_input = grad_weight = None
grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine(
grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps
)
return grad_input, grad_weight, None, None


class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):

@staticmethod
Expand All @@ -58,6 +100,25 @@ def forward(ctx, input, weight, bias, normalized_shape, eps):
return output


class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction):

@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes(
input_, ctx.normalized_shape, weight_, ctx.eps
)

ctx.save_for_backward(input_, weight_, invvar)
return output


class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
Expand All @@ -81,6 +142,29 @@ def backward(ctx, grad_output):
return grad_input, None, None


class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.rms_backward(
grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps
)
return grad_input, None, None


def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
Expand All @@ -99,6 +183,24 @@ def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, e
return FusedLayerNormAffineMixedDtypesFunction.apply(*args)


def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormAffineFunction.apply(*args)


def fused_rms_norm(input, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormFunction.apply(*args)


def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormAffineMixedDtypesFunction.apply(*args)


class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
Expand Down Expand Up @@ -195,6 +297,100 @@ def extra_repr(self):
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)


class FusedRMSNorm(torch.nn.Module):
r"""Applies RMS Normalization over a mini-batch of inputs

Currently only runs on cuda() tensors.

.. math::
y = \frac{x}{\mathrm{RMS}[x]} * \gamma

The root-mean-square is calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` is a learnable affine transform parameter of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
`epsilon` is added to the mean-square, then the root of the sum is taken.

.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, RMS Normalization applies per-element scale
with :attr:`elementwise_affine`.

This layer uses statistics computed from input data in both training and
evaluation modes.

Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size

.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]

If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.

Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)

Examples::

>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedRMSNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedRMSNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedRMSNorm(10)
>>> # Activating the module
>>> output = m(input)

.. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf
"""

def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__()

global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()

def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)

def forward(self, input):
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)

if self.elementwise_affine:
return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
else:
return fused_rms_norm(input, self.normalized_shape, self.eps)

def extra_repr(self):
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)


# NOTE (mkozuki): Why "mixed"?
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
Expand All @@ -216,3 +412,26 @@ def forward(self, input: torch.Tensor):
if not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)


# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
class MixedFusedRMSNorm(FusedRMSNorm):

def __init__(self, normalized_shape, eps=1e-5, **kwargs):
if "elementwise_affine" in kwargs:
import warnings
warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument")
elementwise_affine = kwargs.pop("elementwise_affine")
if not elementwise_affine:
raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`")

super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True)

def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
# TODO Manual RMS Norm Implementation Here
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
Loading