diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index a3d0406f1..341ed3199 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -19,6 +19,7 @@ def fused_linear_cross_entropy_forward( _input, weight, target, + ce_weight=None, bias=None, ignore_index=-100, lse_square_scale=0.0, @@ -54,7 +55,25 @@ def fused_linear_cross_entropy_forward( loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) # NOTE: skip .item() here to avoid CUDA synchronization - total_n_non_ignore = (target != ignore_index).sum() + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ( + ce_weight.shape[0] == V + ), f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point( + ce_weight + ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)) + .sum() + .item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() for chunk_id in range(num_chunks): start_idx = chunk_id * chunk_size @@ -71,7 +90,6 @@ def fused_linear_cross_entropy_forward( # unreduced loss loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, - n_non_ignore = (target_chunk != ignore_index).sum().item() # ensure _input and target are contiguous logits_chunk = logits_chunk.contiguous() @@ -83,14 +101,14 @@ def fused_linear_cross_entropy_forward( X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 - weight_ptr=_input, # dummy ptr, not used - weight_stride=0, + weight_ptr=ce_weight, # dummy ptr, not used loss_ptr=loss_1d_slice, z_loss_ptr=loss_1d_slice, # dummy ptr, not used loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, - n_non_ignore=n_non_ignore, - sum_of_non_ignore_weight=n_non_ignore, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -103,19 +121,8 @@ def fused_linear_cross_entropy_forward( num_warps=32 if not is_hip() else 16, ) - # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V - # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H - # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only - # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. - # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients. - - if reduction == "mean": - alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0 - else: - alpha = 1.0 - - loss_1d[start_idx:end_idx] = loss_1d_slice * alpha - grad_logits_chunk = logits_chunk * alpha # chunk_size x V + loss_1d[start_idx:end_idx] = loss_1d_slice + grad_logits_chunk = logits_chunk # chunk_size x V grad_input[start_idx:end_idx] = grad_logits_chunk @ weight @@ -125,7 +132,7 @@ def fused_linear_cross_entropy_forward( mat1=logits_chunk.t(), mat2=_input_chunk, out=grad_weight, - alpha=alpha, + alpha=1.0, beta=1.0, ) @@ -134,7 +141,7 @@ def fused_linear_cross_entropy_forward( input=grad_bias, other=logits_chunk.sum(dim=0), out=grad_bias, - alpha=alpha, + alpha=1.0, ) loss = torch.sum(loss_1d) @@ -199,6 +206,7 @@ def forward( weight, target, bias=None, + ce_weight=None, ignore_index=-100, lse_square_scale=0.0, label_smoothing=0.0, @@ -218,21 +226,23 @@ def forward( target: (B*T) where each value is in [0, V-1] weight: (V, H) where V is the number of classes bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype ignore_index: the index to ignore in the target label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction: reduction to apply """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, - weight, - target, - bias, - ignore_index, - lse_square_scale, - label_smoothing, - reduction, - softcap, + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, ) # downcast to dtype and store for backward ctx.save_for_backward( @@ -249,4 +259,15 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ) - return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 5d6086caa..60d472129 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -52,6 +52,7 @@ def liger_fused_linear_cross_entropy( weight, target, bias=None, + ce_weight=None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -63,6 +64,7 @@ def liger_fused_linear_cross_entropy( weight, target, bias, + ce_weight, ignore_index, lse_square_scale, label_smoothing, diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 7df79d309..c13148f91 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -10,6 +10,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): def __init__( self, + ce_weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -28,6 +29,7 @@ def __init__( assert ( softcap is None or softcap > 0 ), f"softcap must greater than 0.0 or None. Got: {softcap}" + self.ce_weight = ce_weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing @@ -40,6 +42,7 @@ def forward(self, lin_weight, _input, target, bias=None): lin_weight, target, bias, + self.ce_weight, self.ignore_index, self.lse_square_scale, self.label_smoothing, diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index a6bcd4d8b..8909d9337 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -41,6 +41,7 @@ def __init__( V: int, dtype: torch.dtype, bias: bool = False, + ce_weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -52,6 +53,7 @@ def __init__( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.ce_loss = CrossEntropyWithZLoss( + weight=ce_weight, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -72,6 +74,7 @@ def __init__( H: int, V: int, dtype: torch.dtype, + ce_weight: Optional[torch.FloatTensor] = None, bias: bool = False, ignore_index: int = -100, lse_square_scale: float = 0.0, @@ -84,6 +87,7 @@ def __init__( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.ce_loss = LigerFusedLinearCrossEntropyLoss( + ce_weight=ce_weight, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -118,15 +122,11 @@ def forward(self, x, y): ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize( - "label_smoothing, ignore_index, lse_square_scale, softcap", + "has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap", [ - (0, -100, 0, None), - ( - 0.1, - 42, - 1e-4, - 30.0, - ), # Pass non-default values once to ensure all params work along + (False, 0, -100, 0, None), + # Pass non-default values once to ensure all params work along + (True, 0.1, 42, 1e-4, 30.0), ], ) def test_correctness( @@ -137,6 +137,7 @@ def test_correctness( scalar, dtype, bias, + has_ce_weight, lse_square_scale, label_smoothing, ignore_index, @@ -145,10 +146,15 @@ def test_correctness( atol, rtol, ): + if has_ce_weight: + ce_weight = torch.rand(V, device=device, dtype=torch.float32) + else: + ce_weight = None torch_lm_head_ce = TorchLMHeadCE( H=H, V=V, bias=bias, + ce_weight=ce_weight, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, ignore_index=ignore_index, @@ -160,6 +166,7 @@ def test_correctness( H=H, V=V, bias=bias, + ce_weight=ce_weight, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, ignore_index=ignore_index,