Skip to content

Commit

Permalink
turn python values into scalar tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
christian-rauch committed Oct 11, 2021
1 parent a09fffc commit c78cf64
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def upsample_flow(self, flow, mask):
return up_flow.reshape(N, 2, 8*H, 8*W)


def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
def forward(self, image1, image2, iters=torch.tensor(12), flow_init=torch.tensor([]), upsample=torch.tensor(True), test_mode=torch.tensor(False)):
""" Estimate optical flow between pair of frames """

image1 = 2 * (image1 / 255.0) - 1.0
Expand Down Expand Up @@ -115,7 +115,7 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_

coords0, coords1 = self.initialize_flow(image1)

if flow_init is not None:
if flow_init is not None and flow_init.numel()>0:
coords1 = coords1 + flow_init

flow_predictions = []
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def demo(args):
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
flow_low, flow_up = model(image1, image2, iters=torch.tensor(20), test_mode=torch.tensor(True))
viz(image1, flow_up)


Expand Down

0 comments on commit c78cf64

Please sign in to comment.