diff --git a/test/test.py b/test/test.py index ec24b7d..52ffeed 100644 --- a/test/test.py +++ b/test/test.py @@ -74,11 +74,12 @@ def main(): inputs = inputs.cuda(gpu_id, non_blocking=True) # inference - out_x = model(inputs) - inputs_t = torch.transpose(inputs, 2, 3) - out_y_t = model(inputs_t) - out_y = torch.transpose(out_y_t, 2, 3) - outputs = torch.cat((out_x, out_y), 1) + with torch.no_grad(): + out_x = model(inputs) + inputs_t = torch.transpose(inputs, 2, 3) + out_y_t = model(inputs_t) + out_y = torch.transpose(out_y_t, 2, 3) + outputs = torch.cat((out_x, out_y), 1) # compute superpixels affinity = outputs[0].data.cpu().numpy()