Skip to content

Commit

Permalink
Update flce
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Dec 22, 2024
1 parent ec134fc commit 7ed4dd9
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 39 deletions.
83 changes: 52 additions & 31 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def fused_linear_cross_entropy_forward(
_input,
weight,
target,
ce_weight=None,
bias=None,
ignore_index=-100,
lse_square_scale=0.0,
Expand Down Expand Up @@ -54,7 +55,25 @@ def fused_linear_cross_entropy_forward(
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)

# NOTE: skip .item() here to avoid CUDA synchronization
total_n_non_ignore = (target != ignore_index).sum()
target_mask = target != ignore_index
total_n_non_ignore = target_mask.sum().item()
total_sum_non_ignore_ce_weight = total_n_non_ignore
ce_weight_sum = 0.0
if ce_weight is not None:
assert (
ce_weight.shape[0] == V
), f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
assert torch.is_floating_point(
ce_weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
total_sum_non_ignore_ce_weight = (
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask))
.sum()
.item()
)
ce_weight_sum = ce_weight.sum().item()
if ce_weight.stride(-1) != 1:
ce_weight = ce_weight.contiguous()

for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
Expand All @@ -71,7 +90,6 @@ def fused_linear_cross_entropy_forward(

# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
n_non_ignore = (target_chunk != ignore_index).sum().item()

# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
Expand All @@ -83,14 +101,14 @@ def fused_linear_cross_entropy_forward(
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
weight_ptr=_input, # dummy ptr, not used
weight_stride=0,
weight_ptr=ce_weight, # dummy ptr, not used
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
sum_of_non_ignore_weight=n_non_ignore,
n_non_ignore=total_n_non_ignore,
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
weight_sum=ce_weight_sum,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
Expand All @@ -103,19 +121,8 @@ def fused_linear_cross_entropy_forward(
num_warps=32 if not is_hip() else 16,
)

# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.

if reduction == "mean":
alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
else:
alpha = 1.0

loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
grad_logits_chunk = logits_chunk * alpha # chunk_size x V
loss_1d[start_idx:end_idx] = loss_1d_slice
grad_logits_chunk = logits_chunk # chunk_size x V

grad_input[start_idx:end_idx] = grad_logits_chunk @ weight

Expand All @@ -125,7 +132,7 @@ def fused_linear_cross_entropy_forward(
mat1=logits_chunk.t(),
mat2=_input_chunk,
out=grad_weight,
alpha=alpha,
alpha=1.0,
beta=1.0,
)

Expand All @@ -134,7 +141,7 @@ def fused_linear_cross_entropy_forward(
input=grad_bias,
other=logits_chunk.sum(dim=0),
out=grad_bias,
alpha=alpha,
alpha=1.0,
)

loss = torch.sum(loss_1d)
Expand Down Expand Up @@ -199,6 +206,7 @@ def forward(
weight,
target,
bias=None,
ce_weight=None,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
Expand All @@ -218,21 +226,23 @@ def forward(
target: (B*T) where each value is in [0, V-1]
weight: (V, H) where V is the number of classes
bias: (V) where V is the number of classes
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index: the index to ignore in the target
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction: reduction to apply
"""

loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input,
weight,
target,
bias,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
_input=_input,
weight=weight,
target=target,
bias=bias,
ce_weight=ce_weight,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
Expand All @@ -249,4 +259,15 @@ def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
)
return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
return (
grad_input,
grad_weight,
None,
grad_bias,
None,
None,
None,
None,
None,
None,
)
2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def liger_fused_linear_cross_entropy(
weight,
target,
bias=None,
ce_weight=None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -63,6 +64,7 @@ def liger_fused_linear_cross_entropy(
weight,
target,
bias,
ce_weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
ce_weight: Optional[torch.FloatTensor] = None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -28,6 +29,7 @@ def __init__(
assert (
softcap is None or softcap > 0
), f"softcap must greater than 0.0 or None. Got: {softcap}"
self.ce_weight = ce_weight
self.ignore_index = ignore_index
self.lse_square_scale = lse_square_scale
self.label_smoothing = label_smoothing
Expand All @@ -40,6 +42,7 @@ def forward(self, lin_weight, _input, target, bias=None):
lin_weight,
target,
bias,
self.ce_weight,
self.ignore_index,
self.lse_square_scale,
self.label_smoothing,
Expand Down
23 changes: 15 additions & 8 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
V: int,
dtype: torch.dtype,
bias: bool = False,
ce_weight: Optional[torch.FloatTensor] = None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -52,6 +53,7 @@ def __init__(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ce_loss = CrossEntropyWithZLoss(
weight=ce_weight,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
Expand All @@ -72,6 +74,7 @@ def __init__(
H: int,
V: int,
dtype: torch.dtype,
ce_weight: Optional[torch.FloatTensor] = None,
bias: bool = False,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
Expand All @@ -84,6 +87,7 @@ def __init__(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
ce_weight=ce_weight,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
Expand Down Expand Up @@ -118,15 +122,11 @@ def forward(self, x, y):
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize(
"label_smoothing, ignore_index, lse_square_scale, softcap",
"has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap",
[
(0, -100, 0, None),
(
0.1,
42,
1e-4,
30.0,
), # Pass non-default values once to ensure all params work along
(False, 0, -100, 0, None),
# Pass non-default values once to ensure all params work along
(True, 0.1, 42, 1e-4, 30.0),
],
)
def test_correctness(
Expand All @@ -137,6 +137,7 @@ def test_correctness(
scalar,
dtype,
bias,
has_ce_weight,
lse_square_scale,
label_smoothing,
ignore_index,
Expand All @@ -145,10 +146,15 @@ def test_correctness(
atol,
rtol,
):
if has_ce_weight:
ce_weight = torch.rand(V, device=device, dtype=torch.float32)
else:
ce_weight = None
torch_lm_head_ce = TorchLMHeadCE(
H=H,
V=V,
bias=bias,
ce_weight=ce_weight,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
ignore_index=ignore_index,
Expand All @@ -160,6 +166,7 @@ def test_correctness(
H=H,
V=V,
bias=bias,
ce_weight=ce_weight,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
ignore_index=ignore_index,
Expand Down

0 comments on commit 7ed4dd9

Please sign in to comment.