Skip to content

Commit

Permalink
Add weight support for LigerCrossEntropy (#420)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Resolve #404.
Note: current implementation doesn't weight z loss.

Reference: [PyTorch's
CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
It hasn't fully tested with other params.
<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
Tcc0403 authored Dec 29, 2024
1 parent 174b191 commit 42ff02a
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 59 deletions.
107 changes: 88 additions & 19 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,22 @@ def liger_cross_entropy_kernel(
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
loss_ptr,
z_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
sum_non_ignore_weight,
weight_sum,
ignore_index,
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
):
"""
Expand All @@ -53,18 +57,22 @@ def liger_cross_entropy_kernel(
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
n_non_ignore (flaot): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
weight_sum (float): The sum of weight tensor.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""

Expand All @@ -89,6 +97,9 @@ def liger_cross_entropy_kernel(
loss_ptr += program_id * loss_stride
z_loss_ptr += program_id * loss_stride

if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)

# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867

Expand Down Expand Up @@ -117,7 +128,11 @@ def liger_cross_entropy_kernel(
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
if HAS_WEIGHT:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
else:
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
Expand Down Expand Up @@ -153,18 +168,41 @@ def liger_cross_entropy_kernel(
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / (n_non_ignore)
# chain rule

if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / n_non_ignore
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
dloss_ori = dloss_ori * weight_y
# derivative of smooth_loss
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
# derivative of z-loss
dz_loss = 2 * lse_square_scale * lse * softmax_X
# reduction scale
if reduction == "mean":
dloss_ori = dloss_ori / sum_non_ignore_weight
dloss_smooth = dloss_smooth / sum_non_ignore_weight
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
dz_loss = dz_loss / n_non_ignore
# derivative of total_loss
X_block = dloss_ori + dloss_smooth + dz_loss

# chain rule softcapping
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)
Expand All @@ -183,6 +221,8 @@ def liger_cross_entropy_kernel(
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = lse - ori_X_y
if HAS_WEIGHT:
loss = weight_y * loss

# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
Expand All @@ -193,17 +233,24 @@ def liger_cross_entropy_kernel(
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * lse
if HAS_WEIGHT:
smooth_loss = scaled_x_sum + eps * lse * weight_sum
else:
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss

# An auxiliary loss, z_loss
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
z_loss = lse_square_scale * lse * lse
loss += z_loss
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
if HAS_WEIGHT:
loss = loss / sum_non_ignore_weight
else:
loss = loss / n_non_ignore
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore
loss += z_loss

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
Expand All @@ -225,6 +272,7 @@ def liger_cross_entropy_kernel(
def cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand All @@ -250,7 +298,20 @@ def cross_entropy_forward(
else:
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False

n_non_ignore = (target != ignore_index).sum().item()
target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
sum_non_ignore_weight = n_non_ignore
weight_sum = 0.0
if weight is not None:
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(
weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
weight_sum = weight.sum().item()
# ensure weight is contiguous
if weight.stride(-1) != 1:
weight = weight.contiguous()

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
Expand All @@ -264,18 +325,22 @@ def cross_entropy_forward(
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight if weight is not None else _input, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
sum_non_ignore_weight=sum_non_ignore_weight,
ignore_index=ignore_index,
weight_sum=weight_sum,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=return_z_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
Expand Down Expand Up @@ -327,6 +392,7 @@ def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.FloatTensor],
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -341,6 +407,7 @@ def forward(
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Expand All @@ -354,6 +421,7 @@ def forward(
loss, z_loss, _input = cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -395,4 +463,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
Loading

0 comments on commit 42ff02a

Please sign in to comment.