Skip to content

Commit

Permalink
line 77: add torch.no_grad()
Browse files Browse the repository at this point in the history
Add torch.no_grad() to reduce memory consumption during inference
  • Loading branch information
wctu authored Apr 4, 2019
1 parent 17584dc commit 66317a9
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,12 @@ def main():
inputs = inputs.cuda(gpu_id, non_blocking=True)

# inference
out_x = model(inputs)
inputs_t = torch.transpose(inputs, 2, 3)
out_y_t = model(inputs_t)
out_y = torch.transpose(out_y_t, 2, 3)
outputs = torch.cat((out_x, out_y), 1)
with torch.no_grad():
out_x = model(inputs)
inputs_t = torch.transpose(inputs, 2, 3)
out_y_t = model(inputs_t)
out_y = torch.transpose(out_y_t, 2, 3)
outputs = torch.cat((out_x, out_y), 1)

# compute superpixels
affinity = outputs[0].data.cpu().numpy()
Expand Down

0 comments on commit 66317a9

Please sign in to comment.