Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AMD] [ROCm] Pick
num_warps
based on platform (#326)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This is a PR to enable the kernel to run on AMD GPUs through the initial changes to the `num_warps`. This change is proposed by @Edenzzzz and @DocShotgun in this issue #266 ## Details <!--- This is an optional section; is there anything specific that reviewers should be aware of? ---> I have updated the `transformers` version from `4.44.0` to `4.46.0` requirement and all unit tests passed on A100 and MI300X. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: AMD Instinct MI300X - [x] run `make test` to ensure correctness - There are some test failed due to numerical precision issue. Passed by relaxing the condition by 1 order of magnitude (following the advice in the Liger-Kernel technical report https://arxiv.org/pdf/[2410.10989](https://arxiv.org/pdf/2410.10989) **Footnote 12:** _Note that in practice, the tolerance may need further relaxation in some cases by one or two orders of magnitude, even for exact kernels. We use convergence tests to ensure exactness in cases where the tolerance for correctness needs to be loose._ ) - The test that the tolerance are relaxed involves `kl_div` and `jsd` in `float32` tests - The relax conditions are described by the following code snippet ``` _DTYPE_PARAMS = ( "dtype, atol, rtol", [ pytest.param( torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), (torch.float32, 1e-8 if not is_hip() else 1e-7, 1e-6), (torch.float16, 1e-3, 1e-3), ], ) ``` - To pass the test, the triton must not be installed from source, it must be installed through pypi `pip install triton==3.0.0`. This issue will be tracked with an issue at triton triton-lang/triton#5013 . - ~~Something is weird as well, if I just run the failed test `test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]`, the test passed. By running `pytest test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]`. However it will failed if there are other tests running before this test.~~ - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence <details> <summary> <s>Failure Test Logs (Click to expand/collapse) </s> </summary> ```bash ============================================================= FAILURES ============================================================= ________________________ test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] _________________________ B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0, dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize( "B, T, V, ignore_index", [ (2, 4096, 32000, -100), # llama2, mistral (2, 4096, 32000, 2), # llama2, mistral (1, 4096, 128256, -300), # llama3 # weird shapes (3, 423, 32000, -123), ], ) @pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ pytest.param( 0.1, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), pytest.param( 1.0, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), pytest.param( 10.0, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), (10.0, torch.float32, 1e-8, 1e-6), ], ) @pytest.mark.skipif( torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, reason="Needs 16GB+ GPU memory.", ) def test_correctness_with_ignore_index( B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ): liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) > _test_correctness_with_ignore_index_once( liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ) test/transformers/test_cross_entropy.py:302: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ target_ce = LigerCrossEntropyLoss(), B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0 dtype = torch.float32, atol = 1e-08, rtol = 1e-06 def _test_correctness_with_ignore_index_once( target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ): torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index indices_to_assign = torch.randperm(B * T)[ :num_elements_to_assign ] # Randomly select indices target[indices_to_assign] = ignore_index output = torch_ce(_input, target) output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) output.backward() output2.backward() > assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) E AssertionError: assert False E + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06) E + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose E + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[ 6.0503, 3.7258, -0.3530, ..., 11.8853, 20.5071, -9.9739],\n [ 15.2597, -0.5924, 6.6471, ..., -9.3584, 3.0466, -2.5966],\n [-17.9122, 31.2363, -1.4114, ..., -5.5268, 17.4033, -3.3372],\n ...,\n [ 4.3242, -7.8904, 10.2973, ..., -17.3829, -1.2789, 6.6447],\n [-10.9055, 10.4553, -5.2270, ..., -12.5100, 5.0782, 11.1050],\n [ -5.8922, 15.0620, 5.5783, ..., -5.3107, 6.2329, -13.0452]],\n device='cuda:0', requires_grad=True).grad E + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0', requires_grad=True).grad test/transformers/test_cross_entropy.py:61: AssertionError _________________________________ test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] _________________________________ B = 1, T = 4096, V = 128256, beta = 0.1, dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize(*_DTYPE_PARAMS) @pytest.mark.parametrize("beta", [0.1, 0.5, 0.9]) def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol): liger_jsd = LigerJSD(beta=beta) > _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol) test/transformers/test_jsd.py:269: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor1 = tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>) tensor2 = tensor(0.0805, device='cuda:0', grad_fn=<LigerJSDFunctionBackward>), rtol = 1e-06, atol = 1e-08, max_print = 5 def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5): """ Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. Parameters: tensor1 (torch.Tensor): First tensor to compare. tensor2 (torch.Tensor): Second tensor to compare. rtol (float): Relative tolerance. atol (float): Absolute tolerance. max_print (int): Maximum number of mismatched elements to print. Raises: AssertionError: If the tensors are not all close within the given tolerance. """ # Check if the shapes of the tensors match if tensor1.shape != tensor2.shape: raise AssertionError("Input tensors must have the same shape.") # Calculate the difference between the tensors diff = torch.abs(tensor1 - tensor2) # Determine the tolerance tolerance = atol + rtol * torch.abs(tensor2) # Find tolerance mismatched elements tol_mismatched = diff > tolerance # Find nan mismatched elements nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) # Find +inf mismatched elements posinf_mismatched = torch.logical_xor( torch.isposinf(tensor1), torch.isposinf(tensor2) ) # Find -inf mismatched elements neginf_mismatched = torch.logical_xor( torch.isneginf(tensor1), torch.isneginf(tensor2) ) # Find all mismatched elements mismatched = torch.logical_or( torch.logical_or(tol_mismatched, nan_mismatched), torch.logical_or(posinf_mismatched, neginf_mismatched), ) mismatched_indices = torch.nonzero(mismatched) # Count the number of mismatched elements num_mismatched = mismatched.sum().item() # Check if all elements are close all_close = num_mismatched == 0 # Raise AssertionError with detailed information if there are mismatches if not all_close and num_mismatched >= 1: mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] print_count = min(max_print, num_mismatched) for index in mismatched_indices[:print_count]: i = tuple(index.tolist()) mismatch_details.append( f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}" ) if num_mismatched > max_print: mismatch_details.append( f"... and {num_mismatched - max_print} more mismatched elements." ) > raise AssertionError("\n".join(mismatch_details)) E AssertionError: Number of mismatched elements: 1 E Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767 test/utils.py:106: AssertionError _________________________________ test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] _________________________________ B = 1, T = 4096, V = 128256, beta = 0.9, dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize(*_DTYPE_PARAMS) @pytest.mark.parametrize("beta", [0.1, 0.5, 0.9]) def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol): liger_jsd = LigerJSD(beta=beta) > _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol) test/transformers/test_jsd.py:269: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor1 = tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>) tensor2 = tensor(0.0805, device='cuda:0', grad_fn=<LigerJSDFunctionBackward>), rtol = 1e-06, atol = 1e-08, max_print = 5 def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5): """ Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. Parameters: tensor1 (torch.Tensor): First tensor to compare. tensor2 (torch.Tensor): Second tensor to compare. rtol (float): Relative tolerance. atol (float): Absolute tolerance. max_print (int): Maximum number of mismatched elements to print. Raises: AssertionError: If the tensors are not all close within the given tolerance. """ # Check if the shapes of the tensors match if tensor1.shape != tensor2.shape: raise AssertionError("Input tensors must have the same shape.") # Calculate the difference between the tensors diff = torch.abs(tensor1 - tensor2) # Determine the tolerance tolerance = atol + rtol * torch.abs(tensor2) # Find tolerance mismatched elements tol_mismatched = diff > tolerance # Find nan mismatched elements nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) # Find +inf mismatched elements posinf_mismatched = torch.logical_xor( torch.isposinf(tensor1), torch.isposinf(tensor2) ) # Find -inf mismatched elements neginf_mismatched = torch.logical_xor( torch.isneginf(tensor1), torch.isneginf(tensor2) ) # Find all mismatched elements mismatched = torch.logical_or( torch.logical_or(tol_mismatched, nan_mismatched), torch.logical_or(posinf_mismatched, neginf_mismatched), ) mismatched_indices = torch.nonzero(mismatched) # Count the number of mismatched elements num_mismatched = mismatched.sum().item() # Check if all elements are close all_close = num_mismatched == 0 # Raise AssertionError with detailed information if there are mismatches if not all_close and num_mismatched >= 1: mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] print_count = min(max_print, num_mismatched) for index in mismatched_indices[:print_count]: i = tuple(index.tolist()) mismatch_details.append( f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}" ) if num_mismatched > max_print: mismatch_details.append( f"... and {num_mismatched - max_print} more mismatched elements." ) > raise AssertionError("\n".join(mismatch_details)) E AssertionError: Number of mismatched elements: 1 E Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344 test/utils.py:106: AssertionError ___________________________________ test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] ___________________________________ B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize("log_target", [True, False]) @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @pytest.mark.parametrize(*_DTYPE_PARAMS) def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol): liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) > _test_correctness_once( liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target ) test/transformers/test_kl_div.py:97: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none' log_target = False, is_last_layer = True, device = 'cuda' def _test_correctness_once( target_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target, is_last_layer=True, device="cuda", ): torch.manual_seed(0) torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target) input = torch.randn( B * T, V, device=device, dtype=dtype, requires_grad=True ).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) with torch.no_grad(): target = torch.randn(B * T, V, device=device).softmax(dim=-1) output = torch_kldiv(x1, target) output2 = target_kldiv(x2, target) > assert torch.allclose(output, output2, atol=atol, rtol=rtol) E AssertionError: assert False E + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06) E + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose test/transformers/test_kl_div.py:75: AssertionError ______________________________ test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] _______________________________ B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize("log_target", [True, False]) @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @pytest.mark.parametrize(*_DTYPE_PARAMS) def test_correctness_not_last(B, T, V, log_target, reduction, dtype, atol, rtol): liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) > _test_correctness_once( liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target, is_last_layer=False, ) test/transformers/test_kl_div.py:108: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none' log_target = False, is_last_layer = False, device = 'cuda' def _test_correctness_once( target_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target, is_last_layer=True, device="cuda", ): torch.manual_seed(0) torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target) input = torch.randn( B * T, V, device=device, dtype=dtype, requires_grad=True ).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) with torch.no_grad(): target = torch.randn(B * T, V, device=device).softmax(dim=-1) output = torch_kldiv(x1, target) output2 = target_kldiv(x2, target) > assert torch.allclose(output, output2, atol=atol, rtol=rtol) E AssertionError: assert False E + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06) E + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose test/transformers/test_kl_div.py:75: AssertionError _________________________________________________ test_import_custom_cache_manager _________________________________________________ def test_import_custom_cache_manager(): from triton.runtime.cache import get_cache_manager from liger_kernel.triton import apply_liger_triton_cache_manager apply_liger_triton_cache_manager() > cache_manager = get_cache_manager(key="test_hash") test/triton/test_triton_monkey_patch.py:17: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ /opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:277: in get_cache_manager return __cache_cls(_base64(key)) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ key = 'test_hash' def _base64(key): # Assume key is a hex string. > return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") E ValueError: non-hexadecimal number found in fromhex() arg at position 0 /opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:261: ValueError ===================================================== short test summary info ====================================================== FAILED test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] - AssertionError: assert False + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06) + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[ 6.0503, 3.7258, -0.3530, ..., 11.8853, 20.5071, -9.9739],\n [ 15.2597, -0.5924, 6.6471, ..., -9.3584, 3.0466, -2.5966],\n [-17.9122, 31.2363, -1.4114, ..., -5.5268, 17.4033, -3.3372],\n ...,\n [ 4.3242, -7.8904, 10.2973, ..., -17.3829, -1.2789, 6.6447],\n [-10.9055, 10.4553, -5.2270, ..., -12.5100, 5.0782, 11.1050],\n [ -5.8922, 15.0620, 5.5783, ..., -5.3107, 6.2329, -13.0452]],\n device='cuda:0', requires_grad=True).grad + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0', requires_grad=True).grad FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1 Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767 FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1 Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344 FAILED test/transformers/test_kl_div.py::test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06) + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose FAILED test/transformers/test_kl_div.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False + where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06) + where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose FAILED test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager - ValueError: non-hexadecimal number found in fromhex() arg at position 0 ================================ 6 failed, 1012 passed, 8 skipped, 72 warnings in 630.02s (0:10:30) ================================ make: *** [Makefile:8: test] Error 1 ``` </details> --------- Co-authored-by: tjtanaa <[email protected]> Co-authored-by: root <tjtanaa>
- Loading branch information