diff --git a/apex/contrib/test/fmha/test_fmha.py b/apex/contrib/test/fmha/test_fmha.py index 1bf1ad552..6e7360245 100644 --- a/apex/contrib/test/fmha/test_fmha.py +++ b/apex/contrib/test/fmha/test_fmha.py @@ -96,7 +96,7 @@ def run_test(self, s: int, b: int, zero_tensors: bool): ctx = ctx.view(b,s,h,d) ctx_ref = py_mha(qkv, amask, b,s,h,d) - self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3)) + torch.testing.assert_close(ctx_ref.float(), ctx.float(), atol=1e-3) labels = torch.randn_like(ctx_ref) diff = ctx_ref - labels @@ -114,7 +114,7 @@ def run_test(self, s: int, b: int, zero_tensors: bool): dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d) - self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3)) + torch.testing.assert_close(qkv.grad.float(), dqkv2.float(), atol=1e-3) def test_128(self): self.run_test(128, 32, False) diff --git a/apex/contrib/test/transducer/test_transducer_joint.py b/apex/contrib/test/transducer/test_transducer_joint.py index a20464884..a6029e818 100755 --- a/apex/contrib/test/transducer/test_transducer_joint.py +++ b/apex/contrib/test/transducer/test_transducer_joint.py @@ -1,167 +1,167 @@ -import unittest - -import torch - -SKIP_TEST = None -try: - from apex.contrib.transducer import TransducerJoint - from apex.contrib.transducer import _transducer_ref as transducer_ref -except ImportError as e: - SKIP_TEST = e - - -@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") -class TransducerJointTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - - def gen_input(self, for_vector_kernel): - self.B = 4 - T_min = 51 - T_max = 101 - U_min = 12 - U_max = 25 - if for_vector_kernel: - H = 512 - else: - H = 509 - dtype = torch.float16 - device = "cuda" - - self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device) - self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device) - self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device) - self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) - self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device) - self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max - self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max - self.dropout_prob = 0.5 - - # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by - # the loss function - for b in range(self.B): - self.h_grad[b, self.f_len[b]:, :, :] = 0 - self.h_grad[b, :, self.g_len[b]:, :] = 0 - self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len) - - - def _pack(self, x, f_len, g_len): - B = x.size(0) - list_x = [] - for b in range(B): - list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])] - x_row = torch.cat(list_x_row) - list_x.append(x_row) - x_packed = torch.cat(list_x).data.clone() - x_packed.requires_grad = True - batch_offset = torch.cumsum(f_len * g_len, dim=0) - return x_packed - - def _unpack(self, x, f_len, g_len): - batch_offset = torch.cumsum(f_len * g_len, dim=0) - x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8) - B = self.h_grad.size(0) - H = self.h_grad.size(-1) - for b in range(B): - my_batch_offset = 0 if b == 0 else batch_offset[b-1] - my_f_len = f_len[b] - my_g_len = g_len[b] - for t in range(my_f_len): - x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : - my_batch_offset + t*my_g_len + my_g_len] - return x_unpacked - - def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): - self.gen_input(for_vector_kernel=for_vector_kernel) - # Generate reference - f_ref = self.f_tst.data.clone() - g_ref = self.g_tst.data.clone() - f_ref.requires_grad = True - g_ref.requires_grad = True - - my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, - dropout_prob=self.dropout_prob, probe_mask=True) - if not pack_output: - h_tst = my_joint( f=self.f_tst, - g=self.g_tst, - f_len=self.f_len, - g_len=self.g_len) - h_tst.backward(self.h_grad) - if dropout: - mask = my_joint.mask_probe[0] - else: - batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0) - h_tst = my_joint( f=self.f_tst, - g=self.g_tst, - f_len=self.f_len, - g_len=self.g_len, - batch_offset=batch_offset, - packed_batch=batch_offset[-1]) - h_tst.backward(self.h_grad_packed) - if dropout: - mask_packed = my_joint.mask_probe[0] - mask = self._unpack(mask_packed, self.f_len, self.g_len) - - # reference - h_ref, f_grad_ref, g_grad_ref \ - = transducer_ref.transducer_joint_reference(f=f_ref, - g=g_ref, - h_grad=self.h_grad, - f_len=self.f_len, - g_len=self.g_len, - pack_output=pack_output, - relu=relu, - dropout=dropout, - dropout_prob=self.dropout_prob, - mask=mask if dropout else None) - - f_grad_tst = self.f_tst.grad - g_grad_tst = self.g_tst.grad - - self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4)) - - def test_transducer_joint(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) - - def test_transducer_joint_vec(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False) - - def test_transducer_joint_pack(self): - self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False) - - def test_transducer_joint_vec_pack(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) - - def test_transducer_joint_relu(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) - - def test_transducer_joint_vec_relu(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False) - - def test_transducer_joint_pack_relu(self): - self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False) - - def test_transducer_joint_vec_pack_relu(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) - - @unittest.expectedFailure - def test_transducer_joint_relu_dropout(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) - - @unittest.expectedFailure - def test_transducer_joint_vec_relu_dropout(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True) - - @unittest.expectedFailure - def test_transducer_joint_pack_relu_dropout(self): - self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True) - - @unittest.expectedFailure - def test_transducer_joint_vec_pack_relu_dropout(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) - - -if __name__ == '__main__': - unittest.main() +import unittest + +import torch + +SKIP_TEST = None +try: + from apex.contrib.transducer import TransducerJoint + from apex.contrib.transducer import _transducer_ref as transducer_ref +except ImportError as e: + SKIP_TEST = e + + +@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") +class TransducerJointTest(unittest.TestCase): + def setUp(self, seed=1234): + torch.manual_seed(seed) + + def gen_input(self, for_vector_kernel): + self.B = 4 + T_min = 51 + T_max = 101 + U_min = 12 + U_max = 25 + if for_vector_kernel: + H = 512 + else: + H = 509 + dtype = torch.float16 + device = "cuda" + + self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device) + self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device) + self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device) + self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) + self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device) + self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max + self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max + self.dropout_prob = 0.5 + + # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by + # the loss function + for b in range(self.B): + self.h_grad[b, self.f_len[b]:, :, :] = 0 + self.h_grad[b, :, self.g_len[b]:, :] = 0 + self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len) + + + def _pack(self, x, f_len, g_len): + B = x.size(0) + list_x = [] + for b in range(B): + list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])] + x_row = torch.cat(list_x_row) + list_x.append(x_row) + x_packed = torch.cat(list_x).data.clone() + x_packed.requires_grad = True + batch_offset = torch.cumsum(f_len * g_len, dim=0) + return x_packed + + def _unpack(self, x, f_len, g_len): + batch_offset = torch.cumsum(f_len * g_len, dim=0) + x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8) + B = self.h_grad.size(0) + H = self.h_grad.size(-1) + for b in range(B): + my_batch_offset = 0 if b == 0 else batch_offset[b-1] + my_f_len = f_len[b] + my_g_len = g_len[b] + for t in range(my_f_len): + x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : + my_batch_offset + t*my_g_len + my_g_len] + return x_unpacked + + def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): + self.gen_input(for_vector_kernel=for_vector_kernel) + # Generate reference + f_ref = self.f_tst.data.clone() + g_ref = self.g_tst.data.clone() + f_ref.requires_grad = True + g_ref.requires_grad = True + + my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, + dropout_prob=self.dropout_prob, probe_mask=True) + if not pack_output: + h_tst = my_joint( f=self.f_tst, + g=self.g_tst, + f_len=self.f_len, + g_len=self.g_len) + h_tst.backward(self.h_grad) + if dropout: + mask = my_joint.mask_probe[0] + else: + batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0) + h_tst = my_joint( f=self.f_tst, + g=self.g_tst, + f_len=self.f_len, + g_len=self.g_len, + batch_offset=batch_offset, + packed_batch=batch_offset[-1]) + h_tst.backward(self.h_grad_packed) + if dropout: + mask_packed = my_joint.mask_probe[0] + mask = self._unpack(mask_packed, self.f_len, self.g_len) + + # reference + h_ref, f_grad_ref, g_grad_ref \ + = transducer_ref.transducer_joint_reference(f=f_ref, + g=g_ref, + h_grad=self.h_grad, + f_len=self.f_len, + g_len=self.g_len, + pack_output=pack_output, + relu=relu, + dropout=dropout, + dropout_prob=self.dropout_prob, + mask=mask if dropout else None) + + f_grad_tst = self.f_tst.grad + g_grad_tst = self.g_tst.grad + + torch.testing.assert_close(h_ref, h_tst, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4) + + def test_transducer_joint(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) + + def test_transducer_joint_vec(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False) + + def test_transducer_joint_pack(self): + self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False) + + def test_transducer_joint_vec_pack(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) + + def test_transducer_joint_relu(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) + + def test_transducer_joint_vec_relu(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False) + + def test_transducer_joint_pack_relu(self): + self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False) + + def test_transducer_joint_vec_pack_relu(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) + + @unittest.expectedFailure + def test_transducer_joint_relu_dropout(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) + + @unittest.expectedFailure + def test_transducer_joint_vec_relu_dropout(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True) + + @unittest.expectedFailure + def test_transducer_joint_pack_relu_dropout(self): + self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True) + + @unittest.expectedFailure + def test_transducer_joint_vec_pack_relu_dropout(self): + self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/apex/contrib/test/transducer/test_transducer_loss.py b/apex/contrib/test/transducer/test_transducer_loss.py index 0e9117327..7bae4b815 100755 --- a/apex/contrib/test/transducer/test_transducer_loss.py +++ b/apex/contrib/test/transducer/test_transducer_loss.py @@ -1,139 +1,139 @@ -import unittest - -import torch - -SKIP_TEST = None -try: - from apex.contrib.transducer import TransducerLoss - from apex.contrib.transducer import _transducer_ref as transducer_ref -except ImportError as e: - SKIP_TEST = e - - -@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") -class TransducerLossTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - - def gen_input(self, scalar_t, for_vector_kernel): - self.B = 5 - T_min = 23 - T_max = 51 - U_min = 12 - U_max = 25 - V = 16 if for_vector_kernel else 14 - self.blank_idx = V - 1 - device = "cuda" - - self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, - device=device) - self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device) - self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) - self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device) - self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max - self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1 - self.x_tst_packed, self.batch_offset = self._pack(self.x_tst) - # Generate reference - x_ref = self.x_tst.data.clone() - x_ref.requires_grad = True - loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0) - _, _, self.grad_ref, self.loss_ref \ - = transducer_ref.transducer_loss_reference( x=x_ref, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, - blank_idx=self.blank_idx, - loss_grad=loss_grad) - - def _pack(self, x): - list_x = [] - for b in range(self.B): - list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])] - x_row = torch.cat(list_x_row) - list_x.append(x_row) - x_packed = torch.cat(list_x).data.clone() - x_packed.requires_grad = True - batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0) - return x_packed, batch_offset - - def _unpack(self, x): - x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1), - dtype=x.dtype, device=x.device) - for b in range(self.B): - my_batch_offset = 0 if b == 0 else self.batch_offset[b-1] - my_f_len = self.f_len[b] - my_g_len = self.y_len[b] + 1 - for t in range(my_f_len): - for u in range(my_g_len): - x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u] - return x_unpacked - - def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel): - self.gen_input(scalar_t, for_vector_kernel) - my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward, - packed_input=packed_input) - if not packed_input: - loss_tst = my_loss( x=self.x_tst, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, - blank_idx=self.blank_idx) - loss_tst.mean().backward() - grad_tst = self.x_tst.grad - else: - loss_tst = my_loss( x=self.x_tst_packed, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, - blank_idx=self.blank_idx, - batch_offset=self.batch_offset, - max_f_len=max(self.f_len)) - loss_tst.mean().backward() - grad_tst_packed = self.x_tst_packed.grad - grad_tst = self._unpack(grad_tst_packed) - - return loss_tst, grad_tst - - def test_transducer_loss_fp32(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32, - fuse_softmax_backward=False, - packed_input=False, - for_vector_kernel=False) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5)) - - def test_transducer_loss_fp16(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, - fuse_softmax_backward=False, - packed_input=False, - for_vector_kernel=False) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) - - def test_transducer_loss_fp16_backward_fusion(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, - fuse_softmax_backward=True, - packed_input=False, - for_vector_kernel=False) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) - - def test_transducer_loss_fp16_backward_fusion_packed(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, - fuse_softmax_backward=True, - packed_input=True, - for_vector_kernel=False) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) - - def test_transducer_loss_fp16_backward_fusion_packed_vec(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, - fuse_softmax_backward=True, - packed_input=True, - for_vector_kernel=True) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) - - -if __name__ == '__main__': - unittest.main() +import unittest + +import torch + +SKIP_TEST = None +try: + from apex.contrib.transducer import TransducerLoss + from apex.contrib.transducer import _transducer_ref as transducer_ref +except ImportError as e: + SKIP_TEST = e + + +@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") +class TransducerLossTest(unittest.TestCase): + def setUp(self, seed=1234): + torch.manual_seed(seed) + + def gen_input(self, scalar_t, for_vector_kernel): + self.B = 5 + T_min = 23 + T_max = 51 + U_min = 12 + U_max = 25 + V = 16 if for_vector_kernel else 14 + self.blank_idx = V - 1 + device = "cuda" + + self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, + device=device) + self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device) + self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) + self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device) + self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max + self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1 + self.x_tst_packed, self.batch_offset = self._pack(self.x_tst) + # Generate reference + x_ref = self.x_tst.data.clone() + x_ref.requires_grad = True + loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0) + _, _, self.grad_ref, self.loss_ref \ + = transducer_ref.transducer_loss_reference( x=x_ref, + label=self.y, + f_len=self.f_len, + y_len=self.y_len, + blank_idx=self.blank_idx, + loss_grad=loss_grad) + + def _pack(self, x): + list_x = [] + for b in range(self.B): + list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])] + x_row = torch.cat(list_x_row) + list_x.append(x_row) + x_packed = torch.cat(list_x).data.clone() + x_packed.requires_grad = True + batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0) + return x_packed, batch_offset + + def _unpack(self, x): + x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1), + dtype=x.dtype, device=x.device) + for b in range(self.B): + my_batch_offset = 0 if b == 0 else self.batch_offset[b-1] + my_f_len = self.f_len[b] + my_g_len = self.y_len[b] + 1 + for t in range(my_f_len): + for u in range(my_g_len): + x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u] + return x_unpacked + + def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel): + self.gen_input(scalar_t, for_vector_kernel) + my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward, + packed_input=packed_input) + if not packed_input: + loss_tst = my_loss( x=self.x_tst, + label=self.y, + f_len=self.f_len, + y_len=self.y_len, + blank_idx=self.blank_idx) + loss_tst.mean().backward() + grad_tst = self.x_tst.grad + else: + loss_tst = my_loss( x=self.x_tst_packed, + label=self.y, + f_len=self.f_len, + y_len=self.y_len, + blank_idx=self.blank_idx, + batch_offset=self.batch_offset, + max_f_len=max(self.f_len)) + loss_tst.mean().backward() + grad_tst_packed = self.x_tst_packed.grad + grad_tst = self._unpack(grad_tst_packed) + + return loss_tst, grad_tst + + def test_transducer_loss_fp32(self): + loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32, + fuse_softmax_backward=False, + packed_input=False, + for_vector_kernel=False) + torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5) + + def test_transducer_loss_fp16(self): + loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, + fuse_softmax_backward=False, + packed_input=False, + for_vector_kernel=False) + torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3) + + def test_transducer_loss_fp16_backward_fusion(self): + loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, + fuse_softmax_backward=True, + packed_input=False, + for_vector_kernel=False) + torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3) + + def test_transducer_loss_fp16_backward_fusion_packed(self): + loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, + fuse_softmax_backward=True, + packed_input=True, + for_vector_kernel=False) + torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3) + + def test_transducer_loss_fp16_backward_fusion_packed_vec(self): + loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, + fuse_softmax_backward=True, + packed_input=True, + for_vector_kernel=True) + torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/L0/run_amp/test_add_param_group.py b/tests/L0/run_amp/test_add_param_group.py index d3e90c433..6aa1cbbe9 100644 --- a/tests/L0/run_amp/test_add_param_group.py +++ b/tests/L0/run_amp/test_add_param_group.py @@ -139,8 +139,8 @@ def test_add_param_group(self): [param.data.clone() for param in model1.parameters()] for reference, final in zip(reference_params, final_params): - self.assertTrue(torch.allclose(reference.to(final.dtype), final), - "opt_level = {}, how_to_zero = {}, zero_before_add = {}".format( + torch.testing.assert_close(reference.to(final.dtype), final, + msg="opt_level = {}, how_to_zero = {}, zero_before_add = {}".format( opt_level, how_to_zero, zero_before_add)) diff --git a/tests/L0/run_amp/test_cache.py b/tests/L0/run_amp/test_cache.py index b58d2665f..53bd8802d 100644 --- a/tests/L0/run_amp/test_cache.py +++ b/tests/L0/run_amp/test_cache.py @@ -92,9 +92,9 @@ def training_step(): # Currently there's no difference in the allclose calls, so no need for branching, # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks. if model.weight.grad.type() == "torch.cuda.HalfTensor": - self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) + torch.testing.assert_close(model.weight.grad.float(), reference_grad) elif model.weight.grad.type() == "torch.cuda.FloatTensor": - self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) + torch.testing.assert_close(model.weight.grad.float(), reference_grad) else: raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type())) diff --git a/tests/L0/run_amp/test_fused_sgd.py b/tests/L0/run_amp/test_fused_sgd.py index 7f592128d..3940271b1 100644 --- a/tests/L0/run_amp/test_fused_sgd.py +++ b/tests/L0/run_amp/test_fused_sgd.py @@ -168,8 +168,8 @@ def test_2models2losses1optimizer(self): if opt_level == "O2" and not materialize_master_grads: continue else: - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()), - "opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers)) + torch.testing.assert_close(param.grad.float(), reference_grad.float(), + msg="opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers)) unskipped += 1 optimizer.step() @@ -178,8 +178,8 @@ def test_2models2losses1optimizer(self): model_params, amp.master_params(optimizer), final_params): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) + torch.testing.assert_close(model, reference) + torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() @@ -326,8 +326,8 @@ def test_3models2losses1optimizer(self): if opt_level == "O2" and not materialize_master_grads: continue else: - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()), - "opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers)) + torch.testing.assert_close(param.grad.float(), reference_grad.float(), + msg="opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers)) unskipped += 1 optimizer.step() @@ -339,8 +339,8 @@ def test_3models2losses1optimizer(self): model_params, amp.master_params(optimizer), final_params): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) + torch.testing.assert_close(model, reference) + torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() @@ -521,7 +521,7 @@ def what_got_skipped(which_iter, which_backward): if opt_level == "O2" and not materialize_master_grads: continue else: - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) + torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer0.step() @@ -534,8 +534,8 @@ def what_got_skipped(which_iter, which_backward): model_params, master_params, final_params[what_got_skipped(inject_inf, which_backward)]): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) + torch.testing.assert_close(model, reference) + torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() @@ -766,7 +766,7 @@ def what_got_skipped(which_iter, which_backward, which_model): if opt_level == "O2" and not materialize_master_grads: continue else: - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) + torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer0.step() @@ -784,8 +784,8 @@ def what_got_skipped(which_iter, which_backward, which_model): model_params, master_params, final_params[what_got_skipped(inject_inf, which_backward, which_model)]): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) + torch.testing.assert_close(model, reference) + torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() diff --git a/tests/L0/run_amp/test_multi_tensor_l2norm.py b/tests/L0/run_amp/test_multi_tensor_l2norm.py index ed3cbd195..f92c4ab03 100644 --- a/tests/L0/run_amp/test_multi_tensor_l2norm.py +++ b/tests/L0/run_amp/test_multi_tensor_l2norm.py @@ -50,9 +50,9 @@ def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_tensor): reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm() - self.assertTrue(torch.allclose(norm, reference)) + torch.testing.assert_close(norm, reference.broadcast_to(norm.shape)) if per_tensor: - self.assertTrue(torch.allclose(norm_per_tensor, normab)) + torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)) self.assertTrue(self.overflow_buf.item() == 0) @unittest.skipIf(disabled, "amp_C is unavailable") diff --git a/tests/L0/run_amp/test_multi_tensor_unscale_l2norm.py b/tests/L0/run_amp/test_multi_tensor_unscale_l2norm.py index 3b32af7fc..4a6c200d2 100644 --- a/tests/L0/run_amp/test_multi_tensor_unscale_l2norm.py +++ b/tests/L0/run_amp/test_multi_tensor_unscale_l2norm.py @@ -52,9 +52,9 @@ def unscale_l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_ten reference = torch.full([(sizea + sizeb)*repeat_tensors], self.val * self.inv_scale, dtype=torch.float32, device='cuda').norm() - self.assertTrue(torch.allclose(norm, reference)) + torch.testing.assert_close(norm, reference) if per_tensor: - self.assertTrue(torch.allclose(norm_per_tensor, normab)) + torch.testing.assert_close(norm_per_tensor, normab) self.assertTrue(self.overflow_buf.item() == 0) @unittest.skipIf(disabled, "amp_C is unavailable") diff --git a/tests/L0/run_amp/test_multiple_models_optimizers_losses.py b/tests/L0/run_amp/test_multiple_models_optimizers_losses.py index 068c84537..e8a767f9a 100644 --- a/tests/L0/run_amp/test_multiple_models_optimizers_losses.py +++ b/tests/L0/run_amp/test_multiple_models_optimizers_losses.py @@ -152,7 +152,7 @@ def test_2models2losses1optimizer(self): if i != inject_inf: for param, reference_grad in zip(amp.master_params(optimizer), reference_grads[unskipped]): - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) + torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer.step() @@ -161,8 +161,8 @@ def test_2models2losses1optimizer(self): model_params, amp.master_params(optimizer), final_params): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) + torch.testing.assert_close(model, reference) + torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() @@ -305,7 +305,7 @@ def test_3models2losses1optimizer(self): if i != inject_inf: for param, reference_grad in zip(amp.master_params(optimizer), reference_grads[unskipped]): - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) + torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer.step() @@ -317,8 +317,8 @@ def test_3models2losses1optimizer(self): model_params, amp.master_params(optimizer), final_params): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) + torch.testing.assert_close(model, reference) + torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() @@ -494,7 +494,7 @@ def what_got_skipped(which_iter, which_backward): list(amp.master_params(optimizer1)) for param, reference_grad in zip(master_params, reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]): - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) + torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer0.step() @@ -507,8 +507,8 @@ def what_got_skipped(which_iter, which_backward): model_params, master_params, final_params[what_got_skipped(inject_inf, which_backward)]): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) + torch.testing.assert_close(model, reference) + torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() @@ -734,7 +734,7 @@ def what_got_skipped(which_iter, which_backward, which_model): for param, reference_grad in zip(master_params, reference_grads[what_got_skipped(inject_inf, which_backward, which_model)][unskipped]): - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) + torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer0.step() @@ -752,8 +752,8 @@ def what_got_skipped(which_iter, which_backward, which_model): model_params, master_params, final_params[what_got_skipped(inject_inf, which_backward, which_model)]): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) + torch.testing.assert_close(model, reference) + torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate()