diff --git a/diffusers/scripts/run_dino_drc.py b/diffusers/scripts/run_dino_drc.py index 984806b..80f2957 100644 --- a/diffusers/scripts/run_dino_drc.py +++ b/diffusers/scripts/run_dino_drc.py @@ -181,20 +181,6 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con # we keep only the output patch attention attentions = attentions[0, :, 0, 1:].reshape(nh, -1) - - if args.threshold is not None: - # we keep only a certain percentage of the mass - val, idx = torch.sort(attentions) - val /= torch.sum(val, dim=1, keepdim=True) - cumval = torch.cumsum(val, dim=1) - th_attn = cumval > (1 - args.threshold) - idx2 = torch.argsort(idx) - for head in range(nh): - th_attn[head] = th_attn[head][idx2[head]] - th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() - # interpolate - th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() - attentions = attentions.reshape(nh, w_featmap, h_featmap) mask = nn.functional.interpolate(attentions.unsqueeze(0), size=(args.mask_size), mode="nearest")[0].cpu().numpy() mask = mask[-1]