Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Dec 22, 2024
1 parent cbaf88f commit ec134fc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 27 deletions.
25 changes: 13 additions & 12 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def liger_cross_entropy_kernel(
loss_stride,
n_cols,
n_non_ignore,
n_sum_non_ignore_weight,
sum_non_ignore_weight,
weight_sum,
ignore_index,
lse_square_scale: tl.constexpr,
Expand All @@ -55,19 +55,19 @@ def liger_cross_entropy_kernel(
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
weight_stride (int): The stride of the weight tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (flaot): The number of non-ignored elements or the sum of non-ignored target's weights in the batch
weight_sum (float): The sum of weigh tensor
n_non_ignore (flaot): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
weight_sum (float): The sum of weight tensor.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
Expand Down Expand Up @@ -200,8 +200,8 @@ def liger_cross_entropy_kernel(
dz_loss = 2 * lse_square_scale * lse * softmax_X
# reduction scale
if reduction == "mean":
dloss_ori = dloss_ori / n_sum_non_ignore_weight
dloss_smooth = dloss_smooth / n_sum_non_ignore_weight
dloss_ori = dloss_ori / sum_non_ignore_weight
dloss_smooth = dloss_smooth / sum_non_ignore_weight
dz_loss = dz_loss / n_non_ignore
# derivative of total_loss
X_block = dloss_ori + dloss_smooth + dz_loss
Expand Down Expand Up @@ -249,7 +249,7 @@ def liger_cross_entropy_kernel(
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
if HAS_WEIGHT:
loss = loss / n_sum_non_ignore_weight
loss = loss / sum_non_ignore_weight
else:
loss = loss / n_non_ignore
z_loss = z_loss / n_non_ignore
Expand Down Expand Up @@ -307,20 +307,21 @@ def cross_entropy_forward(

target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
n_sum_non_ignore_weight = n_non_ignore
weight_sum = weight.sum().item() if weight is not None else 0.0
sum_non_ignore_weight = n_non_ignore
weight_sum = 0.0
if weight is not None:
assert (
weight.shape[0] == V
), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(
weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
n_sum_non_ignore_weight = (
sum_non_ignore_weight = (
torch.gather(weight, dim=0, index=target.masked_select(target_mask))
.sum()
.item()
)
weight_sum = weight.sum().item()
if weight.stride(-1) != 1:
weight = weight.contiguous()

Expand All @@ -342,7 +343,7 @@ def cross_entropy_forward(
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
n_sum_non_ignore_weight=n_sum_non_ignore_weight,
sum_non_ignore_weight=sum_non_ignore_weight,
ignore_index=ignore_index,
weight_sum=weight_sum,
lse_square_scale=lse_square_scale,
Expand Down
30 changes: 15 additions & 15 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,31 +808,31 @@ def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, r
@pytest.mark.parametrize(
"B, T, V",
[
(2, 4096, 3200), # llama2, mistral
(2, 4096, 32000), # llama2, mistral
# # weird shapes
(3, 423, 3200),
(3, 423, 32000),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean", "none"])
@pytest.mark.parametrize(
"ignore_index, lse_square_scale, label_smoothing, softcap",
[
(-100, 0, 0.1, 30.0),
# (42, 1e-5, 0.2, 40.0),
(-100, 1e-4, 0.1, 30.0),
(42, 1e-5, 0.2, 40.0),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
# pytest.param(
# 1.0,
# torch.bfloat16,
# 1e-8,
# 5e-2,
# marks=pytest.mark.skipif(
# not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
# ),
# ),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(1.0, torch.float32, 1e-8, 1e-6),
],
)
Expand Down Expand Up @@ -942,7 +942,7 @@ def test_float32_internal():
loss_stride=loss_bf16.stride(-1),
n_cols=n_cols,
n_non_ignore=n_non_ignore,
n_sum_non_ignore_weight=n_non_ignore, # not used
sum_non_ignore_weight=n_non_ignore, # not used
weight_sum=0.0, # not used
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
Expand Down Expand Up @@ -970,7 +970,7 @@ def test_float32_internal():
loss_stride=loss_fp32.stride(-1),
n_cols=n_cols,
n_non_ignore=n_non_ignore,
n_sum_non_ignore_weight=n_non_ignore, # not used
sum_non_ignore_weight=n_non_ignore, # not used
weight_sum=n_non_ignore, # not used
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
Expand Down

0 comments on commit ec134fc

Please sign in to comment.