Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Dtype Mismatch in torch.addmm within ops/fused_linear_cross_entro…
…py.py in AMP training. (#502) ## Summary This PR addresses a `dtype` mismatch error that I encountered while using PyTorch AMP to train a Llama3 model. After reviewing previous discussions, such as closed issue #305 and PR #318, conducting my own tests, and performing a complete analysis of the problem, I found that there is still a possibility of encountering a `dtype` mismatch if the bias is `None` during FLCE computation. The detailed observation and analysis of the issue can be found in issue #501. This PR aims to: 1. Enhance the test cases to reproduce the mismatch error. 2. Resolve the bug by ensuring the correct `dtype` is used, without affecting the behavior in other scenarios. ## Testing Done - Hardware Type: RTX-4090-24G - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ```log $ make test python -m pytest --disable-warnings test/ --ignore=test/convergence ========================= test session starts ========================= platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /mnt/sda1/latest_liaw/open-source/Liger-Kernel configfile: pyproject.toml plugins: xdist-3.6.1, rerunfailures-15.0 collected 965 items [ 99%] test/transformers/test_transformers.py::test_import_from_root PASSED [ 99%] test/triton/test_triton_monkey_patch.py::test_import_from_root PASSED [ 99%] test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager PASSED [100%] ========================= 750 passed, 215 skipped, 41 warnings in 32.40s ========================= $ make test-convergence HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py ==================================================== test session starts ===================================================== platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /mnt/sda1/latest_liaw/open-source/Liger-Kernel configfile: pyproject.toml plugins: xdist-3.6.1, rerunfailures-15.0 collecting ... ---------------------------------------------------- live log collection ----------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.5.1 available. collected 17 items test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype14-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 88%] test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype15-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 94%] test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma2-32-0.0001-dtype16-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [100%] ========================= 17 passed, 1 warning in 60.39s (0:01:00) ========================= $ make checkstyle ruff check . --fix; ruff_check_status=$?; \ ruff format .; ruff_format_status=$?; \ if [ $ruff_check_status -ne 0 ] || [ $ruff_format_status -ne 0 ]; then \ exit 1; \ fi All checks passed! 124 files left unchanged ```
- Loading branch information