-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsupernet_load_init_weight.py
164 lines (147 loc) · 7.62 KB
/
supernet_load_init_weight.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import time
import torch.nn as nn
import math
from pathlib import Path
from model.supernet_transformer import Vision_TransformerSuper
def load_weight(img_size=240, patch_size=20, embed_dim=256, depth=14, num_heads=4,
mlp_ratio=4, qkv_bias=True, drop_rate=0.0, drop_path_rate=0.1, attn_drop_rate=0.,
gp=True, num_classes=1000, max_relative_position=14,
relative_position=True, change_qkv=True, abs_pos=True,
rank_ratio=0.9, pretrained_output_dir=None, output_path=None):
# create the modified model
model_modify = Vision_TransformerSuper(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim,
depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop_rate=drop_rate, drop_path_rate=drop_path_rate, attn_drop_rate=attn_drop_rate,
gp=gp, num_classes=num_classes, max_relative_position=max_relative_position,
relative_position=relative_position, change_qkv=change_qkv,
abs_pos=abs_pos, rank_ratio=rank_ratio)
# print(model_modify)
# model_modify = deit_tiny_patch16_224(pretrained=False)
# torch.save(model_modify, "./models/supernet/model_modify.pth")
# get the pretrained model weight
output_dir_path = Path(pretrained_output_dir)
pretrained_model_path = output_dir_path / 'supernet-tiny.pth'
pretext_model = torch.load(pretrained_model_path, map_location='cpu')
pretext_model_dict = pretext_model["model"]
# pretext_model_dict = pretext_model.state_dict()
# print(pretext_model_dict.keys())
# get the matchable layers
model_modified_dict = model_modify.state_dict()
matchable_layer = {k:v for k,v in pretext_model_dict.items()
if (k in model_modified_dict.keys())
and (k != "patch_embed_super.proj.weight")
and (k != "patch_embed_super.proj.bias")
and (k != "pos_embed")}
# modify the unmatchable layers
unmatchable_layer = {k:v for k,v in pretext_model_dict.items() if not k in model_modified_dict.keys()}
modified_layer = {}
token_num = int(round(img_size/patch_size,0) * round(img_size/patch_size,0)) + 1
patch_embed_super_proj_weight_value = nn.Parameter(torch.Tensor(embed_dim, 3, patch_size, patch_size))
pos_embed_value = nn.Parameter(torch.Tensor(1, token_num, embed_dim))
patch_embed_super_proj_bias_value = nn.Parameter(torch.Tensor(embed_dim))
nn.init.kaiming_uniform_(patch_embed_super_proj_weight_value, a=math.sqrt(5))
nn.init.kaiming_uniform_(pos_embed_value, a=math.sqrt(5))
nn.init.uniform_(patch_embed_super_proj_bias_value, 0, 5)
# nn.init.kaiming_uniform_(patch_embed_super_proj_bias_value, a=math.sqrt(5))
k_new = "patch_embed_super.proj.weight"
v_new = patch_embed_super_proj_weight_value
modified_layer.update({k_new:v_new})
k_new1 = "pos_embed"
v_new1 = pos_embed_value
modified_layer.update({k_new1:v_new1})
k_new2 = "patch_embed_super.proj.bias"
v_new2 = patch_embed_super_proj_bias_value
modified_layer.update({k_new2:v_new2})
# print(unmatchable_layer.keys())
# modify the fc layers in the attention
i = 0
for k,v in unmatchable_layer.items():
if k == "blocks." + str(int(i)) + ".attn.qkv.weight":
v_qkv1,v_qkv2_diag,v_qkv3 = torch.svd_lowrank(v, q = int(rank_ratio*embed_dim))
k_qkv1 = "blocks." + str(int(i)) + ".attn.qkv1.weight"
modified_layer.update({k_qkv1:v_qkv3.t()})
k_qkv2 = "blocks." + str(int(i)) + ".attn.qkv2.weight"
v_qkv2 = torch.diag(v_qkv2_diag)
modified_layer.update({k_qkv2:v_qkv2})
k_qkv3 = "blocks." + str(int(i)) + ".attn.qkv3.weight"
modified_layer.update({k_qkv3:v_qkv1})
elif k == "blocks." + str(int(i-0.25)) + ".attn.qkv.bias":
k_bias = "blocks." + str(int(i-0.25)) + ".attn.qkv3.bias"
v_bias = v
modified_layer.update({k_bias:v_bias})
elif k == "blocks." + str(int(i-0.5)) + ".attn.proj.weight":
v_proj1,v_proj2_diag,v_proj3 = torch.svd_lowrank(v, q = int(rank_ratio*embed_dim))
k_proj1 = "blocks." + str(int(i-0.5)) + ".attn.proj1.weight"
modified_layer.update({k_proj1:v_proj3.t()})
k_proj2 = "blocks." + str(int(i-0.5)) + ".attn.proj2.weight"
v_fc12 = torch.diag(v_proj2_diag)
modified_layer.update({k_proj2:v_fc12})
k_proj3 = "blocks." + str(int(i-0.5)) + ".attn.proj3.weight"
modified_layer.update({k_proj3:v_proj1})
elif k == "blocks." + str(int(i-0.75)) + ".attn.proj.bias":
k_bias = "blocks." + str(int(i-0.75)) + ".attn.proj3.bias"
v_bias = v
modified_layer.update({k_bias:v_bias})
else:
continue
i += 0.25
# modify the fc layers in the mlp
i = 0
for k,v in unmatchable_layer.items():
if k == "blocks." + str(int(i)) + ".fc1.weight":
v_fc11,v_fc12_diag,v_fc13 = torch.svd_lowrank(v, q = int(rank_ratio*embed_dim))
k_fc11 = "blocks." + str(int(i)) + ".fc11.weight"
modified_layer.update({k_fc11:v_fc13.t()})
k_fc12 = "blocks." + str(int(i)) + ".fc12.weight"
v_fc12 = torch.diag(v_fc12_diag)
modified_layer.update({k_fc12:v_fc12})
k_fc13 = "blocks." + str(int(i)) + ".fc13.weight"
modified_layer.update({k_fc13:v_fc11})
elif k == "blocks." + str(int(i-0.25)) + ".fc1.bias":
k_bias = "blocks." + str(int(i-0.25)) + ".fc13.bias"
v_bias = v
modified_layer.update({k_bias:v_bias})
elif k == "blocks." + str(int(i-0.5)) + ".fc2.weight":
v_fc11,v_fc12_diag,v_fc13 = torch.svd_lowrank(v, q = int(rank_ratio*embed_dim))
k_fc11 = "blocks." + str(int(i-0.5)) + ".fc21.weight"
modified_layer.update({k_fc11:v_fc13.t()})
k_fc12 = "blocks." + str(int(i-0.5)) + ".fc22.weight"
v_fc12 = torch.diag(v_fc12_diag)
modified_layer.update({k_fc12:v_fc12})
k_fc13 = "blocks." + str(int(i-0.5)) + ".fc23.weight"
modified_layer.update({k_fc13:v_fc11})
elif k == "blocks." + str(int(i-0.75)) + ".fc2.bias":
k_bias = "blocks." + str(int(i-0.75)) + ".fc23.bias"
v_bias = v
modified_layer.update({k_bias:v_bias})
else:
continue
i += 0.25
# print("####",modified_layer.keys())
# # update weight
model_modified_dict.update(matchable_layer)
model_modified_dict.update(modified_layer)
model_modify.load_state_dict(model_modified_dict)
# # check
# a=[]
# for i in model_modified_dict.keys():
# a.append(i)
# for i in matchable_layer.keys():
# a.remove(i)
# for i in modified_layer.keys():
# a.remove(i)
# print(a)
# asd
model_dict = model_modify.state_dict()
initial_model_path = output_path / 'supernet_tiny_weight.pth'
torch.save(model_dict, initial_model_path)
print("Successfully generate pre-trained model")
time.sleep(3)
path = initial_model_path
return path
# load_weight(img_size=240, patch_size=20, embed_dim=256, depth=14, num_heads=4,
# mlp_ratio=4, qkv_bias=True, drop_rate=0.0, drop_path_rate=0.1, attn_drop_rate=0.,
# gp=True, num_classes=1000, max_relative_position=14,
# relative_position=True, change_qkv=True, abs_pos=True,
# rank_ratio=0.9, output_dir="./result")