-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcliprnex.py
259 lines (193 loc) · 8.11 KB
/
cliprnex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# -*- coding: utf-8 -*-
"""CLIP GradCAM Visualization
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/github/kevinzakka/clip_playground/blob/main/CLIP_GradCAM_Visualization.ipynb
# CLIP GradCAM Colab
This Colab notebook uses [GradCAM](https://arxiv.org/abs/1610.02391) on OpenAI's [CLIP](https://openai.com/blog/clip/) model to produce a heatmap highlighting which regions in an image activate the most to a given caption.
**Note:** Currently only works with the ResNet variants of CLIP. ViT support coming soon.
"""
import urllib.request
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import clip
from PIL import Image
from scipy.ndimage import filters
from torch import nn
import argparse
import warnings
import imageio
import os
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
heatmap_folder = 'clipapp'
# Parse command line arguments
parser = argparse.ArgumentParser(description='Process an image and its corresponding token.')
parser.add_argument('img_name', type=str, help='The name of the image file (with extension)')
parser.add_argument('token_path', type=str, help='The path to the text token file')
parser.add_argument("clipmodel", type=str, default='ViT-B/32', help="CLIP model to use")
parser.add_argument('--roi_x', type=int, default=0, help='X coordinate of the ROI')
parser.add_argument('--roi_y', type=int, default=0, help='Y coordinate of the ROI')
parser.add_argument('--roi_width', type=int, default=100, help='Width of the ROI')
parser.add_argument('--roi_height', type=int, default=100, help='Height of the ROI')
args = parser.parse_args()
# Use the arguments
image_name = args.img_name
token_path = args.token_path
clipmodel = args.clipmodel
img_file = args.img_name # Directly use the image file name from command line args
token_file = args.token_path # Directly use the token file path from command line args
with open(token_file, 'r') as f:
tokens = f.read().split()
with open(token_file, 'r') as f:
token = f.read().strip() # .strip() removes any leading/trailing whitespace or newline characters
print(token)
def normalize(x: np.ndarray) -> np.ndarray:
# Normalize to [0, 1].
x = x - x.min()
if x.max() > 0:
x = x / x.max()
return x
# Modified from: https://github.com/salesforce/ALBEF/blob/main/visualization.ipynb
def getAttMap(img, attn_map, blur=True):
if blur:
attn_map = filters.gaussian_filter(attn_map, 0.02*max(img.shape[:2]))
attn_map = normalize(attn_map)
cmap = plt.get_cmap('jet')
attn_map_c = np.delete(cmap(attn_map), 3, 2)
attn_map = 1*(1-attn_map**0.7).reshape(attn_map.shape + (1,))*img + \
(attn_map**0.7).reshape(attn_map.shape+(1,)) * attn_map_c
return attn_map
def viz_attn(img, attn_map, blur=True):
_, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img)
axes[1].imshow(getAttMap(img, attn_map, blur))
for ax in axes:
ax.axis("off")
#plt.show()
image_path = args.img_name
img = Image.open(image_path)
#img = img.convert('RGB')
def load_image(img_path, resize=None):
image = img.convert("RGB")
if resize is not None:
image = image.resize((resize, resize))
return np.asarray(image).astype(np.float32) / 255.
class Hook:
"""Attaches to a module and records its activations and gradients."""
def __init__(self, module: nn.Module):
self.data = None
self.hook = module.register_forward_hook(self.save_grad)
def save_grad(self, module, input, output):
self.data = output
output.requires_grad_(True)
output.retain_grad()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.hook.remove()
@property
def activation(self) -> torch.Tensor:
return self.data
@property
def gradient(self) -> torch.Tensor:
return self.data.grad
def gradCAM(
model: nn.Module,
input: torch.Tensor,
target: torch.Tensor,
layer: nn.Module
) -> torch.Tensor:
# Zero out any gradients at the input.
if input.grad is not None:
input.grad.data.zero_()
# Disable gradient settings.
requires_grad = {}
for name, param in model.named_parameters():
requires_grad[name] = param.requires_grad
param.requires_grad_(False)
# Attach a hook to the model at the desired layer.
assert isinstance(layer, nn.Module)
with Hook(layer) as hook:
# Do a forward and backward pass.
output = model(input)
output.backward(target)
grad = hook.gradient.float()
act = hook.activation.float()
# Global average pool gradient across spatial dimension
# to obtain importance weights.
alpha = grad.mean(dim=(2, 3), keepdim=True)
# Weighted combination of activation maps over channel
# dimension.
gradcam = torch.sum(act * alpha, dim=1, keepdim=True)
# We only want neurons with positive influence so we
# clamp any negative ones.
gradcam = torch.clamp(gradcam, min=0)
# Resize gradcam to input resolution.
gradcam = F.interpolate(
gradcam,
input.shape[2:],
mode='bicubic',
align_corners=False)
# Restore gradient settings.
for name, param in model.named_parameters():
param.requires_grad_(requires_grad[name])
return gradcam
#@title Run
#@markdown #### Image & Caption settings
image_url = args.img_name
image_caption = f'{tokens}' #@param {type:"string"}
print(tokens)
#@markdown ---
#@markdown #### CLIP model settings
clip_model = args.clipmodel
saliency_layer = "layer4" #@param ["layer4", "layer3", "layer2", "layer1"]
#@markdown ---
#@markdown #### Visualization settings
blur = True #@param {type:"boolean"}
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clip_model, device=device, jit=False)
# Download the image from the web.
#img = img.convert('RGB')
image_input = preprocess(img).unsqueeze(0).to(device)
image_np = load_image(img, model.visual.input_resolution)
text_input = clip.tokenize([image_caption]).to(device)
attn_map = gradCAM(
model.visual,
image_input,
model.encode_text(text_input).float(),
getattr(model.visual, saliency_layer)
)
attn_map = attn_map.squeeze().detach().cpu().numpy()
viz_attn(image_np, attn_map, blur)
the_img_file = args.img_name
def save_attn_map(img, attn_map, blur=True, file_path='heatmap.png'):
# Generate the attention map visualization
attn_visualization = getAttMap(img, attn_map, blur)
# Convert the visualization to an image (PIL) for resizing
attn_visualization_image = Image.fromarray((attn_visualization * 255).astype(np.uint8))
# Resize the image to 224x224 pixels
attn_visualization_resized = attn_visualization_image.resize((224, 224), Image.LANCZOS)
# Save the resized visualization to a file
attn_visualization_resized.save(file_path)
# Assuming attn_map is your GradCAM output and image_np is the original image data
save_attn_map(image_np, attn_map, blur, file_path=f'{heatmap_folder}/{os.path.splitext(os.path.basename(the_img_file))[0]}_{token}.png')
def normalize_and_threshold(attn_map, threshold=0.4):
# Ensure attn_map is in float format for processing
attn_map = attn_map.astype(np.float32)
# Normalize the attention map to have values between 0 and 1
attn_map_norm = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
# Apply threshold to create a binary mask
binary_mask = attn_map_norm >= threshold
return binary_mask
# Assuming attn_map is your GradCAM attention map
binary_mask = normalize_and_threshold(attn_map, threshold=0.4)
# Convert the binary mask to an image format (PIL Image) for resizing
mask_image = Image.fromarray(np.uint8(binary_mask * 255), 'L')
# Resize the binary mask image to 224x224 pixels
mask_image_resized = mask_image.resize((224, 224), Image.LANCZOS)
# Save the resized binary mask image
binary_mask_filename = f"{heatmap_folder}/tmp/binary_mask_{os.path.splitext(os.path.basename(args.img_name))[0]}.png"
mask_image_resized.save(binary_mask_filename)