-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathencoders.py
157 lines (124 loc) · 6.42 KB
/
encoders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# From EsVIT repo
from esvit.models import build_model
from esvit.config import config, update_config, save_config
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
# From EsVIT repo. Need to download full repo. https://github.com/microsoft/esvit
def load_encoder_esVIT(args, device):
# ============ building network ... ============
num_features = []
# if the network is a 4-stage vision transformer (i.e. swin)
if 'swin' in args.arch :
update_config(config, args)
model = build_model(config, is_teacher=True)
swin_spec = config.MODEL.SPEC
embed_dim=swin_spec['DIM_EMBED']
depths=swin_spec['DEPTHS']
num_heads=swin_spec['NUM_HEADS']
# For each stage, we have n stacked models (d)
# Each model takes embeddings of dimension embed_dim (the first param),
# And then the stage i, input dim is input dim(i-1)*2
for i, d in enumerate(depths):
num_features += [int(embed_dim * 2 ** i)] * d
# if the network is a 4-stage vision transformer (i.e. longformer)
elif 'vil' in args.arch :
update_config(config, args)
model = build_model(config, is_teacher=True)
msvit_spec = config.MODEL.SPEC
arch = msvit_spec.MSVIT.ARCH
layer_cfgs = model.layer_cfgs
num_stages = len(model.layer_cfgs)
depths = [cfg['n'] for cfg in model.layer_cfgs]
dims = [cfg['d'] for cfg in model.layer_cfgs]
out_planes = model.layer_cfgs[-1]['d']
Nglos = [cfg['g'] for cfg in model.layer_cfgs]
print(dims)
for i, d in enumerate(depths):
num_features += [ dims[i] ] * d
# if the network is a 4-stage vision transformer (i.e. CvT)
elif 'cvt' in args.arch :
update_config(config, args)
model = build_model(config, is_teacher=True)
cvt_spec = config.MODEL.SPEC
embed_dim=cvt_spec['DIM_EMBED']
depths=cvt_spec['DEPTH']
num_heads=cvt_spec['NUM_HEADS']
print(f'embed_dim {embed_dim} depths {depths}')
for i, d in enumerate(depths):
num_features += [int(embed_dim[i])] * int(d)
# if the network is a vanilla vision transformer (i.e. deit_tiny, deit_small, vit_base)
else:
raise ValueError(f'{args.arch} not supported yet.')
model.to(device)
# load weights to evaluate
state_dict = torch.load(args.checkpoint, map_location=device)
# Technically we can also load the weights of the student but in knowledge distillation, I think it's more common to take the teacher
# and in DINO paper, they show that the teacher learns better.
state_dict = state_dict['teacher']
#Line below was initally in the code but I think it's usefless in our case (swin-t)
#state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
#in trained model, you probably have the dense DINO head and in the loaded one a regular head. Those keys won't be matching.
#IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head_dense.mlp.0.weight', 'head_dense.mlp.0.bias', 'head_dense.mlp.2.weight', 'head_dense.mlp.2.bias', 'head_dense.mlp.4.weight', 'head_dense.mlp.4.bias', 'head_dense.last_layer.weight_g', 'head_dense.last_layer.weight_v', 'head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
#in any case, we do not use the heads but the out features of each stage.
msg = model.load_state_dict(state_dict, strict=False)
print(msg)
model.eval()
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built with pretrained weigths {args.checkpoint}.")
##a choice, 4 will take the last 2 stages for instance
#if n>1, they are just stacked features.
#paper says : For all transformers architecture, we use the concatenation of view-level features
# in the last layers (results are similar to the use of 3 or 5 layers in our initial experiments)
num_features_linear = sum(num_features[-args.n_last_blocks:])
print(f'num_features_linear {num_features_linear}')
return model, num_features_linear, depths
# Regular resnet encoder.
def load_encoder_resnet(backbone, checkpoint_file, use_imagenet_weights, device):
import torch.nn as nn
import torchvision.models as models
class DecapitatedResnet(nn.Module):
def __init__(self, base_encoder, pretrained):
super(DecapitatedResnet, self).__init__()
self.encoder = base_encoder(pretrained=pretrained)
def forward(self, x):
# Same forward pass function as used in the torchvision 'stock' ResNet code
# but with the final FC layer removed.
x = self.encoder.conv1(x)
x = self.encoder.bn1(x)
x = self.encoder.relu(x)
x = self.encoder.maxpool(x)
x = self.encoder.layer1(x)
x = self.encoder.layer2(x)
x = self.encoder.layer3(x)
x = self.encoder.layer4(x)
x = self.encoder.avgpool(x)
x = torch.flatten(x, 1)
return x
model = DecapitatedResnet(models.__dict__[backbone], use_imagenet_weights)
if use_imagenet_weights:
if checkpoint_file is not None:
raise Exception(
"Either provide a weights checkpoint or the --imagenet flag, not both."
)
print(f"Created encoder with Imagenet weights")
else:
checkpoint = torch.load(checkpoint_file, map_location="cpu")
state_dict = checkpoint["state_dict"]
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith("module.encoder_q") and not k.startswith(
"module.encoder_q.fc"
):
# remove prefix from key names
state_dict[k[len("module.encoder_q.") :]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
# Verify that the checkpoint did not contain data for the final FC layer
msg = model.encoder.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
print(f"Loaded checkpoint {checkpoint_file}")
model = model.to(device)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
model.eval()
return model