Skip to content

Commit

Permalink
Fix Dtype Mismatch in torch.addmm within ops/fused_linear_cross_entro…
Browse files Browse the repository at this point in the history
…py.py in AMP training. (#502)

## Summary

This PR addresses a `dtype` mismatch error that I encountered while
using PyTorch AMP to train a Llama3 model. After reviewing previous
discussions, such as closed issue #305 and PR #318, conducting my own
tests, and performing a complete analysis of the problem, I found that
there is still a possibility of encountering a `dtype` mismatch if the
bias is `None` during FLCE computation. The detailed observation and
analysis of the issue can be found in issue #501.

This PR aims to:

1. Enhance the test cases to reproduce the mismatch error.
2. Resolve the bug by ensuring the correct `dtype` is used, without
affecting the behavior in other scenarios.

## Testing Done
- Hardware Type: RTX-4090-24G
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

```log
$ make test
python -m pytest --disable-warnings test/ --ignore=test/convergence
========================= test session starts =========================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /mnt/sda1/latest_liaw/open-source/Liger-Kernel
configfile: pyproject.toml
plugins: xdist-3.6.1, rerunfailures-15.0
collected 965 items                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            [ 99%]
test/transformers/test_transformers.py::test_import_from_root PASSED                                                                                                                                                                                                                                          [ 99%]
test/triton/test_triton_monkey_patch.py::test_import_from_root PASSED                                                                                                                                                                                                                                         [ 99%]
test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager PASSED                                                                                                                                                                                                                              [100%]

========================= 750 passed, 215 skipped, 41 warnings in 32.40s =========================

$ make test-convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
==================================================== test session starts =====================================================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /mnt/sda1/latest_liaw/open-source/Liger-Kernel
configfile: pyproject.toml
plugins: xdist-3.6.1, rerunfailures-15.0
collecting ... 
---------------------------------------------------- live log collection -----------------------------------------------------
INFO     datasets:config.py:54 PyTorch version 2.5.1 available.
collected 17 items                                                                                                           

test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype14-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED                                                                                                                                                                [ 88%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype15-0.001-0.01-0.1-0.01-0.01-0.01] PASSED                                                                                                                                                                       [ 94%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma2-32-0.0001-dtype16-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED                                                                                                                                                                  [100%]

========================= 17 passed, 1 warning in 60.39s (0:01:00) =========================


$ make checkstyle
ruff check . --fix; ruff_check_status=$?; \
ruff format .; ruff_format_status=$?; \
if [ $ruff_check_status -ne 0 ] || [ $ruff_format_status -ne 0 ]; then \
        exit 1; \
fi
All checks passed!
124 files left unchanged
```
  • Loading branch information
DandinPower authored Dec 29, 2024
1 parent 9875488 commit 174b191
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def fused_linear_cross_entropy_forward(
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias

target_chunk = target[start_idx:end_idx] # chunk_size,

n_rows = logits_chunk.shape[0]
Expand Down Expand Up @@ -112,7 +113,9 @@ def fused_linear_cross_entropy_forward(
if grad_weight is not None:
torch.addmm(
input=grad_weight,
mat1=logits_chunk.t(),
mat1=logits_chunk.t().to(
_input_chunk.dtype
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
mat2=_input_chunk,
out=grad_weight,
alpha=alpha,
Expand Down
14 changes: 8 additions & 6 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,26 +260,28 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
],
)
@pytest.mark.parametrize(
"cast_dtype, atol, rtol",
"bias, cast_dtype, atol, rtol",
[
(torch.bfloat16, 5e-3, 5e-2),
(torch.float16, 5e-3, 5e-2),
(True, torch.bfloat16, 5e-3, 5e-2),
(True, torch.float16, 5e-3, 5e-2),
(False, torch.bfloat16, 5e-3, 5e-2),
(False, torch.float16, 5e-3, 5e-2),
],
)
def test_amp(B, T, H, V, cast_dtype, atol, rtol):
def test_amp(B, T, H, V, bias, cast_dtype, atol, rtol):
dtype = torch.float32
torch_lm_head_ce = TorchLMHeadCE(
H=H,
V=V,
bias=True,
bias=bias,
label_smoothing=0.0,
reduction="mean",
dtype=dtype,
).to(device)
liger_lm_head_ce = LigerLMHeadCE(
H=H,
V=V,
bias=True,
bias=bias,
label_smoothing=0.0,
reduction="mean",
dtype=dtype,
Expand Down

0 comments on commit 174b191

Please sign in to comment.