Skip to content

Commit

Permalink
unittest.main
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Apr 19, 2023
1 parent 768406d commit ecca2f7
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import apex
from apex.normalization import InstanceNorm3dNVFuser


class TestInstanceNormNVFuser(unittest.TestCase):
dtype = torch.float
track_running_stats = False
Expand All @@ -21,7 +22,7 @@ def init_modules(self):
self.reference_m = torch.nn.InstanceNorm3d(self.channel_size, affine=self.affine, track_running_stats=self.track_running_stats, device='cuda', dtype=self.dtype)

def check_same_output(self):
torch.manual_seed(42)
torch.manual_seed(42)
for i in range(2): # exercise JIT + caching
inp = torch.randint(0, 2, (self.batch_size, self.channel_size, self.spatial_size, self.spatial_size, self.spatial_size), device='cuda', requires_grad=True, dtype=self.dtype)
inp2 = inp.detach().clone()
Expand Down Expand Up @@ -78,7 +79,7 @@ def test_sweep(self):
self.channels_last = channels_last
self.affine = affine
self.init_modules()
self.check_same_output()
self.check_same_output()

@unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required")
def test_multigpu(self):
Expand All @@ -100,3 +101,7 @@ def forward(self, x):
pred = model(x)
loss = nn.functional.mse_loss(pred, y.float())
loss.backward()


if __name__ == "__main__":
unittest.main()

0 comments on commit ecca2f7

Please sign in to comment.