Skip to content

Commit

Permalink
[contrib] Use torch.testing.assert_close in test_index_mul_2d.py (#…
Browse files Browse the repository at this point in the history
…1693)

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored Jul 19, 2023
1 parent 7b2e71b commit f03c6fb
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions apex/contrib/test/index_mul_2d/test_index_mul_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def test_index_mul_float(self):
loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum()
loss.backward()

self.assertTrue(torch.allclose(self.input1_float, self.input1_float_, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input2_float, self.input2_float_, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input1_float.grad, self.input1_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input2_float.grad, self.input2_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
torch.testing.assert_close(self.input1_float, self.input1_float_, atol=1e-3, rtol=1e-3, equal_nan=True)
torch.testing.assert_close(self.input2_float, self.input2_float_, atol=1e-3, rtol=1e-3, equal_nan=True)
torch.testing.assert_close(self.input1_float.grad, self.input1_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
torch.testing.assert_close(self.input2_float.grad, self.input2_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)

def test_index_mul_half(self):
out = index_mul_2d(self.input1_half, self.input2_half, self.index1)
Expand All @@ -95,10 +95,10 @@ def test_index_mul_half(self):
loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum()
loss.backward()

self.assertTrue(torch.allclose(self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input1_half.grad, self.input1_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input2_half.grad, self.input2_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
torch.testing.assert_close(self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True)
torch.testing.assert_close(self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True)
torch.testing.assert_close(self.input1_half.grad, self.input1_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
torch.testing.assert_close(self.input2_half.grad, self.input2_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)

if __name__ == '__main__':
unittest.main()

0 comments on commit f03c6fb

Please sign in to comment.