From 174b191cfa3c97e60b45874d94794ed4846a826c Mon Sep 17 00:00:00 2001 From: "LIAW,YONG-CHENG" Date: Sun, 29 Dec 2024 11:53:50 +0800 Subject: [PATCH] Fix Dtype Mismatch in torch.addmm within ops/fused_linear_cross_entropy.py in AMP training. (#502) ## Summary This PR addresses a `dtype` mismatch error that I encountered while using PyTorch AMP to train a Llama3 model. After reviewing previous discussions, such as closed issue #305 and PR #318, conducting my own tests, and performing a complete analysis of the problem, I found that there is still a possibility of encountering a `dtype` mismatch if the bias is `None` during FLCE computation. The detailed observation and analysis of the issue can be found in issue #501. This PR aims to: 1. Enhance the test cases to reproduce the mismatch error. 2. Resolve the bug by ensuring the correct `dtype` is used, without affecting the behavior in other scenarios. ## Testing Done - Hardware Type: RTX-4090-24G - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ```log $ make test python -m pytest --disable-warnings test/ --ignore=test/convergence ========================= test session starts ========================= platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /mnt/sda1/latest_liaw/open-source/Liger-Kernel configfile: pyproject.toml plugins: xdist-3.6.1, rerunfailures-15.0 collected 965 items [ 99%] test/transformers/test_transformers.py::test_import_from_root PASSED [ 99%] test/triton/test_triton_monkey_patch.py::test_import_from_root PASSED [ 99%] test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager PASSED [100%] ========================= 750 passed, 215 skipped, 41 warnings in 32.40s ========================= $ make test-convergence HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py ==================================================== test session starts ===================================================== platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /mnt/sda1/latest_liaw/open-source/Liger-Kernel configfile: pyproject.toml plugins: xdist-3.6.1, rerunfailures-15.0 collecting ... ---------------------------------------------------- live log collection ----------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.5.1 available. collected 17 items test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype14-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 88%] test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype15-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 94%] test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma2-32-0.0001-dtype16-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [100%] ========================= 17 passed, 1 warning in 60.39s (0:01:00) ========================= $ make checkstyle ruff check . --fix; ruff_check_status=$?; \ ruff format .; ruff_format_status=$?; \ if [ $ruff_check_status -ne 0 ] || [ $ruff_format_status -ne 0 ]; then \ exit 1; \ fi All checks passed! 124 files left unchanged ``` --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 5 ++++- .../test_fused_linear_cross_entropy.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index ba6fbef81..83fcb108a 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -59,6 +59,7 @@ def fused_linear_cross_entropy_forward( logits_chunk = _input_chunk @ weight.t() # chunk_size x V if bias is not None: logits_chunk = logits_chunk + bias + target_chunk = target[start_idx:end_idx] # chunk_size, n_rows = logits_chunk.shape[0] @@ -112,7 +113,9 @@ def fused_linear_cross_entropy_forward( if grad_weight is not None: torch.addmm( input=grad_weight, - mat1=logits_chunk.t(), + mat1=logits_chunk.t().to( + _input_chunk.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. mat2=_input_chunk, out=grad_weight, alpha=alpha, diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 9f0d5d5e1..a4d6ba2ff 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -260,18 +260,20 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol): ], ) @pytest.mark.parametrize( - "cast_dtype, atol, rtol", + "bias, cast_dtype, atol, rtol", [ - (torch.bfloat16, 5e-3, 5e-2), - (torch.float16, 5e-3, 5e-2), + (True, torch.bfloat16, 5e-3, 5e-2), + (True, torch.float16, 5e-3, 5e-2), + (False, torch.bfloat16, 5e-3, 5e-2), + (False, torch.float16, 5e-3, 5e-2), ], ) -def test_amp(B, T, H, V, cast_dtype, atol, rtol): +def test_amp(B, T, H, V, bias, cast_dtype, atol, rtol): dtype = torch.float32 torch_lm_head_ce = TorchLMHeadCE( H=H, V=V, - bias=True, + bias=bias, label_smoothing=0.0, reduction="mean", dtype=dtype, @@ -279,7 +281,7 @@ def test_amp(B, T, H, V, cast_dtype, atol, rtol): liger_lm_head_ce = LigerLMHeadCE( H=H, V=V, - bias=True, + bias=bias, label_smoothing=0.0, reduction="mean", dtype=dtype,