Skip to content

Commit

Permalink
fix run_drc
Browse files Browse the repository at this point in the history
  • Loading branch information
caradryanl committed May 14, 2024
1 parent e200c8c commit b4fa5ad
Show file tree
Hide file tree
Showing 3 changed files with 497 additions and 17 deletions.
5 changes: 2 additions & 3 deletions diffusers/scripts/exp_run_dino_drc.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# python scripts/run_dino_drc.py --image_path datasets/laion-aesthetic-2-5k/images/ --output_dir datasets/laion-aesthetic-2-5k/masks/
# python scripts/run_dino_drc.py --image_path datasets/coco2017-val-2-5k/images/ --output_dir datasets/coco2017-val-2-5k/masks/

python scripts/run_dino_drc.py --mask_size 512 512 --image_path datasets/laion-aesthetic-2-5k/images/ --output_dir datasets/laion-aesthetic-2-5k/masks/
python scripts/run_dino_drc.py --mask_size 512 512 --image_path datasets/coco2017-val-2-5k/images/ --output_dir datasets/coco2017-val-2-5k/masks/
python scripts/run_dino_drc.py --mask_size 256 256 --image_path datasets/ffhq-2-5k/images/ --output_dir datasets/ffhq-2-5k/masks/
python scripts/run_dino_drc.py --mask_size 256 256 --image_path datasets/celeba-hq-2-5k/images/ --output_dir datasets/celeba-hq-2-5k/masks/
27 changes: 13 additions & 14 deletions diffusers/scripts/run_dino_drc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
# parser.add_argument("--image_path", default='datasets/coco2017-val-2-5k/images/', type=str, help="Path of the image to load.")
parser.add_argument("--image_path", default='datasets/laion-aesthetic-2-5k/images/', type=str, help="Path of the image to load.")
parser.add_argument("--image_path", default='datasets/drc_mask_examples/images', type=str, help="Path of the image to load.")
parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.")
parser.add_argument("--mask_size", default=(512, 512), type=int, nargs="+", help="Resize image.")
parser.add_argument("--mask_size", default=(256, 256), type=int, nargs="+", help="Resize image.")
# parser.add_argument('--output_dir', default='datasets/coco2017-val-2-5k/masks/', help='Path where to save visualizations.')
parser.add_argument('--output_dir', default='datasets/laion-aesthetic-2-5k/masks/', help='Path where to save visualizations.')
parser.add_argument('--output_dir', default='datasets/drc_mask_examples/masks', help='Path where to save visualizations.')
parser.add_argument("--threshold", type=float, default=80, help="""We visualize masks
obtained by thresholding the self-attention maps to keep xx% of the mass.""")
args = parser.parse_args()
Expand Down Expand Up @@ -197,31 +197,30 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con

attentions = attentions.reshape(nh, w_featmap, h_featmap)
mask = nn.functional.interpolate(attentions.unsqueeze(0), size=(args.mask_size), mode="nearest")[0].cpu().numpy()
mask = sum(
mask[i] * 1 / mask.shape[0]
for i in range(mask.shape[0])
)
mask = mask[-1]
threshold = np.percentile(mask, args.threshold)
mask = np.where(mask >= threshold, 1, 0)
# print(mask.shape)
print(mask.size, mask.sum())
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()

# save attentions heatmaps
os.makedirs(args.output_dir, exist_ok=True)
# torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png"))
# for j in range(nh):
# fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png")
# plt.imsave(fname=fname, arr=attentions[j], format='png')
# print(f"{fname} saved.")
# if j == nh - 1:
# fname = os.path.join(args.output_dir, "attn-head" + str(j) + img_name[:-4] + ".png")
# plt.imsave(fname=fname, arr=attentions[j], format='png')
# print(f"{fname} saved.")

fname = os.path.join(args.output_dir, img_name[:-4]+'.npy')
np.save(fname, mask)
# plt.imsave(
# fname=fname,
# fname=fname[:-4]+'.jpg',
# arr=mask,
# cmap="plasma",
# cmap='Greys',
# format="jpg",
# )
np.save(fname, mask)


# if args.threshold is not None:
# image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
Expand Down
Loading

0 comments on commit b4fa5ad

Please sign in to comment.