-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_visualization.py
210 lines (166 loc) · 9.19 KB
/
run_visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import clip
import torch
from torch import nn as nn
from torch.nn import functional as F
import torch.optim as optim
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import torchvision.utils
import numpy as np
import random
import pdb
import collections
from datetime import datetime
import datetime
import time
from typing import Any
import argparse
from argparse import Namespace
import os
import sys
from clip.model import QuickGELU
import warnings
warnings.filterwarnings("ignore")
# Custom imports
from image_net import TotalVariation, CrossEntropyLoss, MatchBatchNorm, BaseFakeBN, LayerActivationNorm
from image_net import ActivationNorm, NormalVariation, ColorVariation, fix_random_seed
from image_net import NetworkPass
from image_net import LossArray, TotalVariation
from image_net import ViTFeatHook, ViTEnsFeatHook
from regularizers import TotalVariation as BaseTotalVariation, FakeColorDistribution as AbstractColorDistribution
from regularizers import FakeBatchNorm as BaseFakeBN, NormalVariation as BaseNormalVariation
from regularizers import ColorVariation as BaseColorVariation
from hooks import ViTAttHookHolder, ViTGeLUHook, ClipGeLUHook, SpecialSaliencyClipGeLUHook
from prepost import Clip, Tile, Jitter, RepeatBatch, ColorJitter, fix_random_seed
from prepost import GaussianNoise
from util import ClipWrapper
from util import new_init, save_intermediate_step, save_image, fix_random_seed
_nums = '0123456789'
steps_folder = 'steps'
os.makedirs(steps_folder, exist_ok=True)
steps_folder = 'finals'
os.makedirs(steps_folder, exist_ok=True)
class ImageNetVisualizer:
def __init__(self, loss_array: LossArray, pre_aug: nn.Module = None,
post_aug: nn.Module = None, steps: int = 2000, lr: float = 0.1, save_every: int = 200, saver: bool = True,
print_every: int = 5, **_):
self.loss = loss_array
self.saver = saver #None #saver
print(self.saver)
self.pre_aug = pre_aug
self.post_aug = post_aug
self.save_every = save_every
self.print_every = print_every
self.steps = steps
self.lr = lr
def __call__(self, img: torch.tensor = None, optimizer: optim.Optimizer = None, layer: int = None, feature: int = None, clipname: str = None):
if not img.is_cuda or img.device != torch.device('cuda:0'):
img = img.to('cuda:0')
if not img.requires_grad:
img.requires_grad_()
# Mess with the optimizer here:
# ['ASGD', 'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'LBFGS', 'NAdam', 'RAdam', 'RMSprop', 'Rprop', 'SGD', 'SparseAdam']
# Default:
# optimizer = optimizer if optimizer is not None else optim.Adam([img], lr=self.lr, betas=(0.5, 0.99), eps=1e-8)
optimizer = optimizer if optimizer is not None else optim.Adamax([img], lr=self.lr, betas=(0.5, 0.99), eps=1e-8)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, self.steps, 0.)
print(f'#i\t{self.loss.header()}', flush=True)
for i in range(self.steps + 1):
optimizer.zero_grad()
augmented = self.pre_aug(img) if self.pre_aug is not None else img
loss = self.loss(augmented)
if i % self.print_every == 0:
print(f'{i}\t{self.loss}', flush=True)
if i % self.save_every == 0 and self.saver is True:
save_intermediate_step(img, i, layer, feature, clipname)
loss.backward()
optimizer.step()
lr_scheduler.step()
img.data = (self.post_aug(img) if self.post_aug is not None else img).data
self.loss.reset()
optimizer.state = collections.defaultdict(dict)
return img
def get_clip_dimensions(clipmodel):
model, preprocess = clip.load(clipmodel)
model = model.eval()
for transform in preprocess.transforms:
if isinstance(transform, Resize):
input_dims = transform.size
break
num_layers = None
num_features = None
if hasattr(model, 'visual') and hasattr(model.visual, 'transformer'):
num_layers = len(model.visual.transformer.resblocks)
last_block = model.visual.transformer.resblocks[-1]
if hasattr(last_block, 'mlp'):
c_proj_layer = last_block.mlp.c_proj
num_features = c_proj_layer.in_features
return input_dims, num_layers, num_features
def load_clip_model(device: str = 'cuda') -> torch.nn.Module:
model, _ = clip.load(clipmodel, device=device)
model = ClipWrapper(model).to(device)
return model
def parse_range(range_str):
if '-' in range_str:
start, end = map(int, range_str.split('-'))
return list(range(start, end + 1))
else:
return list(map(int, range_str.split(',')))
def generate_visualizations(model, clipname, layer_range_str, feature_range_str, image_size, tv, lr, steps, print_every, save_every, saver, coefficient, use_fixed_random_seed):
layer_range = parse_range(layer_range_str)
feature_range = parse_range(feature_range_str)
for layer in layer_range:
for feature in feature_range:
print(f"Generating visualization for Layer {layer}, Feature {feature}...")
loss = LossArray()
loss += ViTEnsFeatHook(ClipGeLUHook(model, sl=slice(layer, layer + 1)), key='high', feat=feature, coefficient=1)
loss += TotalVariation(2, image_size, coefficient * tv)
pre, post = torch.nn.Sequential(RepeatBatch(8), ColorJitter(8, shuffle_every=True),
GaussianNoise(8, True, 0.5, 400), Tile(image_size // image_size), Jitter()), Clip()
image = new_init(image_size, 1, use_fixed_random_seed=use_fixed_random_seed)
visualizer = ImageNetVisualizer(loss_array=loss, pre_aug=pre, post_aug=post, print_every=print_every, lr=lr, steps=steps, save_every=save_every, saver=saver, coefficient=coefficient, use_fixed_random_seed=use_fixed_random_seed)
image.data = visualizer(image, layer=layer, feature=feature, clipname=clipname)
save_image(image, f'finals/{clipname}_L{layer}_F{feature}.png')
def generate_single(model, clipname, layer, feature, image_size, tv, lr, steps, print_every, save_every, saver, coefficient, use_fixed_random_seed):
loss = LossArray()
loss += ViTEnsFeatHook(ClipGeLUHook(model, sl=slice(layer, layer + 1)), key='high', feat=feature, coefficient=1)
loss += TotalVariation(2, image_size, coefficient * tv)
pre, post = torch.nn.Sequential(RepeatBatch(8), ColorJitter(8, shuffle_every=True),
GaussianNoise(8, True, 0.5, 400), Tile(image_size // image_size), Jitter()), Clip()
image = new_init(image_size, 1, use_fixed_random_seed=use_fixed_random_seed)
visualizer = ImageNetVisualizer(loss_array=loss, pre_aug=pre, post_aug=post, print_every=print_every, lr=lr, steps=steps, save_every=save_every, saver=saver, coefficient=coefficient, use_fixed_random_seed=use_fixed_random_seed)
image.data = visualizer(image, layer=layer, feature=feature, clipname=clipname)
save_image(image, f'finals/{clipname}_L{layer}_F{feature}.png')
# Choose a CLIP ViT model here: ['ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
clipmodel = "ViT-L/14@336px"
clipname = clipmodel.replace("/", "-").replace("@", "-")
input_dims, num_layers, num_features = get_clip_dimensions(clipmodel)
print(f"Selected input dimension for {clipmodel}: {input_dims}")
# This is the valid range for the selected model, alas pay attention to this output when you run the code:
print(f"Number of Layers: {num_layers} with {num_features} Features / Layer\n")
def main():
model = load_clip_model()
image_size = input_dims
# "True" for all in layer / feature RANGE set below. "False" to generate single layer, feature.
generate_multi = False
layer, feature = 20, 3169 # Single layer and feature
layer_range_str = "17-19" # continuous range ("5-10") or discrete value ("5,6,8")
feature_range_str = "301,302,2742" # continuous range ("50-90") or discrete value ("500,1000,1555")
# coefficient=0.0005 -> sharp and noisy features; coefficient=0.005 -> balanced; coefficient=0.05 -> soft, blurry, muddy
tv = 1.0
coefficient=0.005
lr = 1.0
steps = 400
print_every = 10
save_every = 10
# Set to "False" to disable saving intermediate steps:
saver = True
# Set to "True" to make the results deterministic via torch.backends.cudnn.deterministic; False to get randomly different image of same feature.
# CURRENTLY DOESNT WORK! LEAVE AT FALSE!
use_fixed_random_seed = False
if generate_multi:
generate_visualizations(model, clipname, layer_range_str, feature_range_str, image_size, tv, lr, steps, print_every, save_every, saver, coefficient, use_fixed_random_seed)
else:
generate_single(model, clipname, layer, feature, image_size, tv, lr, steps, print_every, save_every, saver, coefficient, use_fixed_random_seed)
if __name__ == '__main__':
main()