diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index b09d1ddbc..4eafd1736 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -13,6 +13,8 @@ def liger_cross_entropy_kernel( Y_stride, loss_ptr, loss_stride, + dX_ptr, + dX_stride, n_cols, n_non_ignore, ignore_index, @@ -49,12 +51,13 @@ def liger_cross_entropy_kernel( # 2. locate the start index X_ptr += program_id * X_stride + dX_ptr += program_id * dX_stride if y == ignore_index: # set all X_ptr as 0 for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + dX_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(dX_ptr + dX_offsets, 0.0, mask=dX_offsets < n_cols) return loss_ptr += program_id * loss_stride @@ -106,15 +109,15 @@ def liger_cross_entropy_kernel( for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load( + dX_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) if reduction == "mean": - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + dX_block = (tl.exp(dX_block - m) / d - eps) / (n_non_ignore) else: - X_block = tl.exp(X_block - m) / d - eps + dX_block = tl.exp(dX_block - m) / d - eps - tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + tl.store(dX_ptr + X_offsets, dX_block, mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 @@ -145,14 +148,14 @@ def liger_cross_entropy_kernel( loss = loss / n_non_ignore # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` - X_y = tl.load(X_ptr + y) + dX_y = tl.load(dX_ptr + y) if reduction == "mean": - X_y += -(1 - label_smoothing) / (n_non_ignore) + dX_y += -(1 - label_smoothing) / (n_non_ignore) else: - X_y += -(1 - label_smoothing) + dX_y += -(1 - label_smoothing) tl.store(loss_ptr, loss) - tl.store(X_ptr + y, X_y) + tl.store(dX_ptr + y, dX_y) # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 @@ -161,7 +164,9 @@ def liger_cross_entropy_kernel( MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning -def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction): +def cross_entropy_forward( + _input, target, ignore_index, label_smoothing, reduction, inplace +): BT, V = _input.shape n_rows = BT @@ -178,7 +183,8 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti if target.stride(-1) != 1: target = target.contiguous() - # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + dX = _input if inplace else torch.empty_like(_input) + liger_cross_entropy_kernel[(n_rows,)]( X_ptr=_input, X_stride=_input.stride(-2), @@ -186,6 +192,8 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti Y_stride=target.stride(-1), # always 1 loss_ptr=loss_1d, loss_stride=loss_1d.stride(-1), # always 1 + dX_ptr=dX, + dX_stride=dX.stride(-2), n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, @@ -198,7 +206,7 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti ) loss = torch.sum(loss_1d) - return loss, _input + return loss, dX def cross_entropy_backward(_input, grad_output): @@ -233,7 +241,13 @@ class LigerCrossEntropyFunction(torch.autograd.Function): @staticmethod def forward( - ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean" + ctx, + _input, + target, + ignore_index=-100, + label_smoothing=0.0, + reduction="mean", + inplace=True, ): """ The forward pass of the Liger Cross Entropy loss. @@ -250,16 +264,21 @@ def forward( tensor: The computed loss. """ loss, _input = cross_entropy_forward( - _input, target, ignore_index, label_smoothing, reduction + _input, target, ignore_index, label_smoothing, reduction, inplace ) # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location ctx.save_for_backward(_input.detach()) - return loss + + print(f"{inplace=}") + if inplace: + ctx.mark_dirty(_input) + ctx.mark_non_differentiable(_input) + return loss, _input @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output, grad_output2): """ The backward pass of the Liger Cross Entropy loss. @@ -270,6 +289,7 @@ def backward(ctx, grad_output): Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ + del grad_output2 (_input,) = ctx.saved_tensors _input = cross_entropy_backward(_input, grad_output) return ( @@ -278,4 +298,5 @@ def backward(ctx, grad_output): None, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index ac11fd173..7b68efe15 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -87,6 +87,8 @@ def fused_linear_cross_entropy_forward( Y_stride=target_chunk.stride(-1), # always 1 loss_ptr=loss_1d_slice, loss_stride=loss_1d_slice.stride(-1), # always 1 + dX_ptr=logits_chunk, + dX_stride=logits_chunk.stride(-2), n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index b2457481b..52d9d465d 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -1,21 +1,30 @@ -from torch.nn import CrossEntropyLoss +import torch from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction -class LigerCrossEntropyLoss(CrossEntropyLoss): - def __init__(self, *args, **kwargs): - super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) - assert (self.label_smoothing >= 0) and ( - self.label_smoothing <= 1 +class LigerCrossEntropyLoss(torch.nn.Module): + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="mean"): + super().__init__() + assert (label_smoothing >= 0) and ( + label_smoothing <= 1 ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" - assert self.reduction in { + assert reduction in { "mean", "sum", "none", }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}" + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction - def forward(self, _input, target): - return LigerCrossEntropyFunction.apply( - _input, target, self.ignore_index, self.label_smoothing, self.reduction + def forward(self, _input, target, inplace): + loss, _ = LigerCrossEntropyFunction.apply( + _input, + target, + ignore_index=self.ignore_index, + label_smoothing=self.label_smoothing, + reduction=self.reduction, + inplace=inplace, ) + return loss diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 1a970573e..58119de75 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -149,8 +149,8 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) - y1 = liger_cross_entropy(x1, target, 0) - y2 = LigerCrossEntropyFunction.apply(x2, target, 0) + y1, _ = liger_cross_entropy(x1, target, 0) + y2, _ = LigerCrossEntropyFunction.apply(x2, target, 0) assert torch.allclose(y1, y2, atol=atol, rtol=rtol)