Skip to content

Commit

Permalink
Fix LigerCrossEntropyLoss Reduction Behavior for "None" Mode (linkedi…
Browse files Browse the repository at this point in the history
…n#435)

## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Closes linkedin#421 

This pull request addresses an issue in the `cross_entropy_forward`
function where the `reduction="none"` mode did not behave as expected.

Previously, the function always returned a single scalar value, even
when reduction="none" was specified. This update ensures that when
reduction="none" is used, the function directly outputs the unreduced
loss array (loss_1d) instead of summing it.

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
### Changes Made:

- Added a condition to handle `reduction="none"`, ensuring the function
outputs loss_1d directly.
- Updated the computation of z_loss to respect the reduction="none"
mode.
- Add test for cases when `reduction="none"`


### Why we pass `gradient` to `output.backward()`?

#### Background on Gradients in PyTorch

- **Scalar Outputs**: When a tensor is a scalar (a single number),
PyTorch can compute gradients automatically by assuming the scalar has
an implicit gradient of 1.0.
- **Non-Scalar Outputs**: For tensors that are not scalars, gradients
must be provided explicitly because PyTorch cannot infer the shape or
distribution of gradients. Without this, it raises the error: "grad can
be implicitly created only for scalar outputs."

#### Why reduction="none" Needs Explicit Gradients

When `reduction="none"`, the loss function does not reduce the
per-example loss values into a single scalar. Instead, it outputs a
vector of losses, with one value per example in the batch. This means
that the loss tensor has multiple values, and PyTorch cannot assume what
the gradient for each of these values should be unless explicitly
provided.

#### The Fix
By passing `gradient=torch.ones_like(loss)` to `backward()`:

- **Gradient Tensor**: The `torch.ones_like(loss)` serves as the
gradient tensor. It specifies that each element in the loss tensor
contributes equally to the gradients during backpropagation.
- **Shape Match**: The gradient tensor's shape matches the loss tensor's
shape, fulfilling PyTorch's requirements for non-scalar outputs during
backward().

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
make test

`pytest
/home/jobuser/Liger-Kernel/test/transformers/test_cross_entropy.py`
shows:

```
=================================== 93 passed, 1 warning in 13.18s ===================================
```

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: NVIDIA A100-SXM4-80GB 
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
hebiao064 authored Dec 10, 2024
1 parent fdba493 commit d790b64
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
9 changes: 5 additions & 4 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,12 @@ def cross_entropy_forward(
num_warps=32 if not is_hip() else 16,
)

loss = torch.sum(loss_1d)
if return_z_loss == _TRUE.value:
z_loss = torch.sum(z_loss_1d)
if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss == _TRUE.value else None
else:
z_loss = None
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None

return loss, z_loss, _input

Expand Down
24 changes: 12 additions & 12 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, r
output2 = target_ce(_input2, target)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
output.backward(gradient=torch.ones_like(output))
output2.backward(gradient=torch.ones_like(output))
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


Expand Down Expand Up @@ -118,8 +118,8 @@ def _test_correctness_with_ignore_index_once(

assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
output.backward(gradient=torch.ones_like(output))
output2.backward(gradient=torch.ones_like(output))
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


Expand Down Expand Up @@ -199,8 +199,8 @@ def _test_correctness_with_softcap_once(

assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
output.backward(gradient=torch.ones_like(output))
output2.backward(gradient=torch.ones_like(output))

assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -325,8 +325,8 @@ def _test_correctness_not_last_layer_once(
loss1 = output * 3
loss2 = output2 * 3

loss1.backward()
loss2.backward()
loss1.backward(gradient=torch.ones_like(output))
loss2.backward(gradient=torch.ones_like(output))
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


Expand Down Expand Up @@ -384,7 +384,7 @@ def _test_correctness_functional(
(3, 423, 32000), # weird shapes
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize("reduction", ["sum", "mean", "none"])
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
Expand Down Expand Up @@ -432,7 +432,7 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
(3, 423, 32000, -123),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize("reduction", ["sum", "mean", "none"])
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
Expand Down Expand Up @@ -532,7 +532,7 @@ def test_correctness_with_label_smoothing_with_ignore_index_once(
(3, 423, 32000, 30.0),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize("reduction", ["sum", "mean", "none"])
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
Expand Down Expand Up @@ -700,7 +700,7 @@ def test_correctness_with_z_loss_with_other_params_once(
(3, 423, 32000),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize("reduction", ["sum", "mean", "none"])
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
Expand Down

0 comments on commit d790b64

Please sign in to comment.