diff --git a/utils/drag_utils.py b/utils/drag_utils.py index 90ca70c..342e644 100755 --- a/utils/drag_utils.py +++ b/utils/drag_utils.py @@ -104,6 +104,7 @@ def drag_diffusion_update(model, # prepare for point tracking and background regularization handle_points_init = copy.deepcopy(handle_points) interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest') + using_mask = interp_mask.sum() != 0.0 # prepare amp scaler for mixed-precision training scaler = torch.cuda.amp.GradScaler() @@ -145,7 +146,8 @@ def drag_diffusion_update(model, loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch) # masked region must stay unchanged - loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum() + if using_mask: + loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum() # loss += args.lam * ((init_code_orig-init_code)*(1.0-interp_mask)).abs().sum() print('loss total=%f'%(loss.item())) @@ -210,6 +212,7 @@ def drag_diffusion_update_gen(model, # prepare for point tracking and background regularization handle_points_init = copy.deepcopy(handle_points) interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest') + using_mask = interp_mask.sum() != 0.0 # prepare amp scaler for mixed-precision training scaler = torch.cuda.amp.GradScaler() @@ -267,7 +270,8 @@ def drag_diffusion_update_gen(model, loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch) # masked region must stay unchanged - loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum() + if using_mask: + loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum() # loss += args.lam * ((init_code_orig - init_code)*(1.0-interp_mask)).abs().sum() print('loss total=%f'%(loss.item()))