diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index ca6f9f7db..3fa12c1ef 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -30,11 +30,14 @@ def liger_cross_entropy_kernel( X_stride, Y_ptr, Y_stride, + weight_ptr, loss_ptr, z_loss_ptr, loss_stride, n_cols, n_non_ignore, + sum_non_ignore_weight, + weight_sum, ignore_index, lse_square_scale: tl.constexpr, label_smoothing: tl.constexpr, @@ -42,6 +45,7 @@ def liger_cross_entropy_kernel( softcap, RETURN_Z_LOSS: tl.constexpr, BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, HAS_SOFTCAPPING: tl.constexpr, ): """ @@ -53,18 +57,22 @@ def liger_cross_entropy_kernel( X_stride (int): The stride of the input tensor. Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. loss_ptr: Pointer to tensor to store the loss. z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. + n_non_ignore (flaot): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. - RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. """ @@ -89,6 +97,9 @@ def liger_cross_entropy_kernel( loss_ptr += program_id * loss_stride z_loss_ptr += program_id * loss_stride + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 @@ -117,7 +128,11 @@ def liger_cross_entropy_kernel( block_max = tl.max(X_block) if label_smoothing > 0: # scale X beforehand to avoid overflow - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new @@ -153,18 +168,41 @@ def liger_cross_entropy_kernel( if HAS_SOFTCAPPING: intermediate = tanh(X_block / softcap) X_block = softcap * intermediate - # softmax(x_i) - X_block = tl.exp(X_block - m) / d - # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) - X_block += 2 * lse_square_scale * lse * X_block - # smoothing term - X_block += -eps - # special handle dx_y - X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) - # reduction scale - if reduction == "mean": - X_block = X_block / (n_non_ignore) - # chain rule + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: X_block = X_block * (1 - intermediate * intermediate) @@ -183,6 +221,8 @@ def liger_cross_entropy_kernel( # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) @@ -193,17 +233,24 @@ def liger_cross_entropy_kernel( # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * lse + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse loss = loss * (1 - label_smoothing) + smooth_loss # An auxiliary loss, z_loss # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html z_loss = lse_square_scale * lse * lse - loss += z_loss # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. z_loss = z_loss / n_non_ignore - loss = loss / n_non_ignore + loss += z_loss tl.store(loss_ptr, loss) if RETURN_Z_LOSS == _TRUE: @@ -225,6 +272,7 @@ def liger_cross_entropy_kernel( def cross_entropy_forward( _input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -250,7 +298,20 @@ def cross_entropy_forward( else: z_loss_1d = loss_1d # dummy ptr when return_z_loss == False - n_non_ignore = (target != ignore_index).sum().item() + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() + sum_non_ignore_weight = n_non_ignore + weight_sum = 0.0 + if weight is not None: + assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point( + weight + ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() + weight_sum = weight.sum().item() + # ensure weight is contiguous + if weight.stride(-1) != 1: + weight = weight.contiguous() # ensure _input and target are contiguous in the last dimension if _input.stride(-1) != 1: @@ -264,18 +325,22 @@ def cross_entropy_forward( X_stride=_input.stride(-2), Y_ptr=target, Y_stride=target.stride(-1), # always 1 + weight_ptr=weight if weight is not None else _input, # dummy if None loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, + sum_non_ignore_weight=sum_non_ignore_weight, ignore_index=ignore_index, + weight_sum=weight_sum, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=return_z_loss, BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps @@ -327,6 +392,7 @@ def forward( ctx, _input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.FloatTensor], ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -341,6 +407,7 @@ def forward( ctx : The context object. _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype ignore_index (int): The index to ignore in the target. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. @@ -354,6 +421,7 @@ def forward( loss, z_loss, _input = cross_entropy_forward( _input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -395,4 +463,5 @@ def backward(ctx, grad_output, grad_ouput2): None, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 83fcb108a..4df484135 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -17,6 +17,7 @@ def fused_linear_cross_entropy_forward( _input, weight, target, + ce_weight=None, bias=None, ignore_index=-100, lse_square_scale=0.0, @@ -47,8 +48,22 @@ def fused_linear_cross_entropy_forward( # we use fp32 for loss accumulator loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) - # NOTE: skip .item() here to avoid CUDA synchronization - total_n_non_ignore = (target != ignore_index).sum() + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point( + ce_weight + ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() for chunk_id in range(num_chunks): start_idx = chunk_id * chunk_size @@ -66,7 +81,6 @@ def fused_linear_cross_entropy_forward( # unreduced loss loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, - n_non_ignore = (target_chunk != ignore_index).sum().item() # ensure _input and target are contiguous logits_chunk = logits_chunk.contiguous() @@ -78,35 +92,28 @@ def fused_linear_cross_entropy_forward( X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None loss_ptr=loss_1d_slice, z_loss_ptr=loss_1d_slice, # dummy ptr, not used loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, - n_non_ignore=n_non_ignore, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=True if ce_weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, ) - # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V - # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H - # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only - # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. - # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients. - - if reduction == "mean": - alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0 - else: - alpha = 1.0 - - loss_1d[start_idx:end_idx] = loss_1d_slice * alpha - grad_logits_chunk = logits_chunk * alpha # chunk_size x V + loss_1d[start_idx:end_idx] = loss_1d_slice + grad_logits_chunk = logits_chunk # chunk_size x V grad_input[start_idx:end_idx] = grad_logits_chunk @ weight @@ -118,7 +125,7 @@ def fused_linear_cross_entropy_forward( ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. mat2=_input_chunk, out=grad_weight, - alpha=alpha, + alpha=1.0, beta=1.0, ) @@ -127,7 +134,7 @@ def fused_linear_cross_entropy_forward( input=grad_bias, other=logits_chunk.sum(dim=0), out=grad_bias, - alpha=alpha, + alpha=1.0, ) if reduction == "none": @@ -193,6 +200,7 @@ def forward( weight, target, bias=None, + ce_weight=None, ignore_index=-100, lse_square_scale=0.0, label_smoothing=0.0, @@ -212,21 +220,23 @@ def forward( target: (B*T) where each value is in [0, V-1] weight: (V, H) where V is the number of classes bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype ignore_index: the index to ignore in the target label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction: reduction to apply """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, - weight, - target, - bias, - ignore_index, - lse_square_scale, - label_smoothing, - reduction, - softcap, + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, ) # downcast to dtype and store for backward ctx.save_for_backward( @@ -243,4 +253,15 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ) - return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 9131c2e9e..d72fc3b00 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -8,6 +8,7 @@ class LigerCrossEntropyLoss(torch.nn.Module): def __init__( self, + weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -28,6 +29,7 @@ def __init__( "none", }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}" + self.weight = weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing @@ -39,6 +41,7 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor): loss, z_loss = LigerCrossEntropyFunction.apply( _input, target, + self.weight, self.ignore_index, self.lse_square_scale, self.label_smoothing, diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 32c0e3298..dd34fafb1 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -32,6 +32,7 @@ def liger_cross_entropy( loss, z_loss = LigerCrossEntropyFunction.apply( input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -49,6 +50,7 @@ def liger_fused_linear_cross_entropy( weight, target, bias=None, + ce_weight=None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -60,6 +62,7 @@ def liger_fused_linear_cross_entropy( weight, target, bias, + ce_weight, ignore_index, lse_square_scale, label_smoothing, diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 406be7737..1de352e6c 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -8,6 +8,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): def __init__( self, + ce_weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -24,6 +25,7 @@ def __init__( "none", }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}" + self.ce_weight = ce_weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing @@ -36,6 +38,7 @@ def forward(self, lin_weight, _input, target, bias=None): lin_weight, target, bias, + self.ce_weight, self.ignore_index, self.lse_square_scale, self.label_smoothing, diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 534637a87..b88033f2a 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -21,6 +21,7 @@ class CrossEntropyWithZLoss(torch.nn.Module): def __init__( self, + weight=None, lse_square_scale=0.0, reduction="mean", ignore_index=-100, @@ -29,6 +30,7 @@ def __init__( dtype=torch.float32, ): super().__init__() + self.weight = weight self.lse_square_scale = lse_square_scale self.reduction = reduction self.ignore_index = ignore_index @@ -39,10 +41,14 @@ def __init__( def forward(self, logits, targets): # Loss calculations are all in float32 logits = logits.to(torch.float32) + + target_mask = targets != self.ignore_index + # Standard cross entropy loss ce_loss = F.cross_entropy( logits, targets, + weight=self.weight, reduction=self.reduction, label_smoothing=self.label_smoothing, ignore_index=self.ignore_index, @@ -53,9 +59,9 @@ def forward(self, logits, targets): # Z-loss term z_loss = torch.where(targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0) - z_loss = z_loss.to(logits.dtype) + if self.reduction == "mean": - z_loss = z_loss.sum() / (targets != self.ignore_index).sum() + z_loss = z_loss.sum() / target_mask.sum() elif self.reduction == "sum": z_loss = z_loss.sum() else: @@ -284,6 +290,74 @@ def _test_correctness_with_z_loss_with_other_params_once( assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_weight_once(target_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss(weight=weight, reduction=reduction) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + +def _test_correctness_with_weight_with_other_params_once( + target_ce, + B, + T, + V, + reduction, + weight, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, +): + torch.manual_seed(0) + torch_ce = CrossEntropyWithZLoss( + weight=weight, + lse_square_scale=lse_square_scale, + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + dtype=dtype, + ) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + # upcasting to match liger's casting strategy + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices + target[indices_to_assign] = ignore_index + + output = torch_ce(softcap * torch.tanh(_input.to(torch.float32) / softcap), target).to(dtype) + output2 = target_ce(_input2, target) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) + + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) + assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once(target_ce, B, T, V, reduction, scalar, dtype, atol, rtol): torch_ce = CrossEntropyLoss(reduction=reduction) @@ -324,6 +398,7 @@ def _test_correctness_functional( y1, y1_z = liger_cross_entropy( x1, target, + None, ignore_index=0, lse_square_scale=1e-4, label_smoothing=0.1, @@ -331,7 +406,7 @@ def _test_correctness_functional( softcap=30.0, return_z_loss=True, ) - y2, y2_z = LigerCrossEntropyFunction.apply(x2, target, 0, 1e-4, 0.1, "mean", 30.0, True) + y2, y2_z = LigerCrossEntropyFunction.apply(x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol) @@ -660,6 +735,104 @@ def test_correctness_with_z_loss_with_other_params_once( (1.0, torch.float32, 1e-8, 1e-6), ], ) +def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, rtol): + weight = torch.rand(V, device=device, dtype=dtype) + test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction) + _test_correctness_with_weight_once(test_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol) + + +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 4096, 32000), # llama2, mistral + # # weird shapes + (3, 423, 32000), + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) +@pytest.mark.parametrize( + "ignore_index, lse_square_scale, label_smoothing, softcap", + [ + (-100, 1e-4, 0.1, 30.0), + (42, 1e-5, 0.2, 40.0), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) +def test_correctness_with_weight_with_other_params_once( + B, + T, + V, + reduction, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, +): + weight = torch.rand(V, device=device, dtype=torch.float32) # match softcap casting + test_ce = LigerCrossEntropyLoss( + weight=weight, + lse_square_scale=lse_square_scale, + reduction=reduction, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + softcap=softcap, + ) + _test_correctness_with_weight_with_other_params_once( + test_ce, + B, + T, + V, + reduction, + weight, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, + ) + + +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 4096, 32000), # llama2, mistral + # # weird shapes + (3, 423, 32000), + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol): liger_ce = LigerCrossEntropyLoss(reduction=reduction) _test_correctness_not_last_layer_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol) @@ -693,17 +866,21 @@ def test_float32_internal(): X_stride=X_bf16.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), + weight_ptr=X_bf16, # dummy ptr, not used z_loss_ptr=loss_bf16, # dummy ptr, not used loss_ptr=loss_bf16, loss_stride=loss_bf16.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, + sum_non_ignore_weight=n_non_ignore, # not used + weight_sum=0.0, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, @@ -717,17 +894,21 @@ def test_float32_internal(): X_stride=X_fp32.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), + weight_ptr=X_fp32, # dummy ptr, not used loss_ptr=loss_fp32, z_loss_ptr=loss_fp32, # dummy ptr, not used loss_stride=loss_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, + sum_non_ignore_weight=n_non_ignore, # not used + weight_sum=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index a4d6ba2ff..d4e919811 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -39,6 +39,7 @@ def __init__( V: int, dtype: torch.dtype, bias: bool = False, + ce_weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -48,6 +49,7 @@ def __init__( super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.ce_loss = CrossEntropyWithZLoss( + weight=ce_weight, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -68,6 +70,7 @@ def __init__( H: int, V: int, dtype: torch.dtype, + ce_weight: Optional[torch.FloatTensor] = None, bias: bool = False, ignore_index: int = -100, lse_square_scale: float = 0.0, @@ -78,6 +81,7 @@ def __init__( super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.ce_loss = LigerFusedLinearCrossEntropyLoss( + ce_weight=ce_weight, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -114,15 +118,11 @@ def forward(self, x, y): ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize( - "label_smoothing, ignore_index, lse_square_scale, softcap", + "has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap", [ - (0, -100, 0, None), - ( - 0.1, - 42, - 1e-4, - 30.0, - ), # Pass non-default values once to ensure all params work along + (False, 0, -100, 0, None), + # Pass non-default values once to ensure all params work along + (True, 0.1, 42, 1e-4, 30.0), ], ) def test_correctness( @@ -133,6 +133,7 @@ def test_correctness( scalar, dtype, bias, + has_ce_weight, lse_square_scale, label_smoothing, ignore_index, @@ -141,10 +142,15 @@ def test_correctness( atol, rtol, ): + if has_ce_weight: + ce_weight = torch.rand(V, device=device, dtype=torch.float32) + else: + ce_weight = None torch_lm_head_ce = TorchLMHeadCE( H=H, V=V, bias=bias, + ce_weight=ce_weight, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, ignore_index=ignore_index, @@ -156,6 +162,7 @@ def test_correctness( H=H, V=V, bias=bias, + ce_weight=ce_weight, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, ignore_index=ignore_index,