Skip to content

Commit

Permalink
Call contiguous() after permute() in pytorch model
Browse files Browse the repository at this point in the history
This avoids an error in some versions (e.g., 1.4) of PyTorch, and also brings sizable speed improvement.
  • Loading branch information
hqucms authored Apr 12, 2020
1 parent 2f31857 commit 98b744a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_graph_feature(x, k=20, idx=None):
feature = feature.view(batch_size, num_points, k, num_dims)
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2)
feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()

return feature

Expand Down

0 comments on commit 98b744a

Please sign in to comment.