Skip to content

Commit

Permalink
no mask for default case
Browse files Browse the repository at this point in the history
  • Loading branch information
Yujun-Shi committed Dec 11, 2023
1 parent 1714ba1 commit e86e49a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions utils/drag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def drag_diffusion_update(model,
# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest')
using_mask = interp_mask.sum() != 0.0

# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
Expand Down Expand Up @@ -145,7 +146,8 @@ def drag_diffusion_update(model,
loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch)

# masked region must stay unchanged
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
if using_mask:
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
# loss += args.lam * ((init_code_orig-init_code)*(1.0-interp_mask)).abs().sum()
print('loss total=%f'%(loss.item()))

Expand Down Expand Up @@ -210,6 +212,7 @@ def drag_diffusion_update_gen(model,
# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest')
using_mask = interp_mask.sum() != 0.0

# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
Expand Down Expand Up @@ -267,7 +270,8 @@ def drag_diffusion_update_gen(model,
loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch)

# masked region must stay unchanged
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
if using_mask:
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
# loss += args.lam * ((init_code_orig - init_code)*(1.0-interp_mask)).abs().sum()
print('loss total=%f'%(loss.item()))

Expand Down

0 comments on commit e86e49a

Please sign in to comment.