-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathget_submodel.py
48 lines (43 loc) · 2.15 KB
/
get_submodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import torch
if __name__ == '__main__':
parser = argparse.ArgumentParser("Get (sub)model of VGGFace model")
parser.add_argument('--model', type=str, default='models/vggface.pth', help="input VGGFace model file")
parser.add_argument('--output', type=str, default='models/vggface_conv.pth', help="output VGGFace (sub)model file")
args = parser.parse_args()
# Load model state dict
model_state_dict = torch.load(args.model, map_location=lambda storage, loc: storage)
# Old-to-new model state dict key map
map_old2new_keys = {
'features.conv_1_1.weight': 'conv_1_1.weight',
'features.conv_1_1.bias': 'conv_1_1.bias',
'features.conv_1_2.weight': 'conv_1_2.weight',
'features.conv_1_2.bias': 'conv_1_2.bias',
'features.conv_2_1.weight': 'conv_2_1.weight',
'features.conv_2_1.bias': 'conv_2_1.bias',
'features.conv_2_2.weight': 'conv_2_2.weight',
'features.conv_2_2.bias': 'conv_2_2.bias',
'features.conv_3_1.weight': 'conv_3_1.weight',
'features.conv_3_1.bias': 'conv_3_1.bias',
'features.conv_3_2.weight': 'conv_3_2.weight',
'features.conv_3_2.bias': 'conv_3_2.bias',
'features.conv_3_3.weight': 'conv_3_3.weight',
'features.conv_3_3.bias': 'conv_3_3.bias',
'features.conv_4_1.weight': 'conv_4_1.weight',
'features.conv_4_1.bias': 'conv_4_1.bias',
'features.conv_4_2.weight': 'conv_4_2.weight',
'features.conv_4_2.bias': 'conv_4_2.bias',
'features.conv_4_3.weight': 'conv_4_3.weight',
'features.conv_4_3.bias': 'conv_4_3.bias',
'features.conv_5_1.weight': 'conv_5_1.weight',
'features.conv_5_1.bias': 'conv_5_1.bias',
'features.conv_5_2.weight': 'conv_5_2.weight',
'features.conv_5_2.bias': 'conv_5_2.bias',
'features.conv_5_3.weight': 'conv_5_3.weight',
'features.conv_5_3.bias': 'conv_5_3.bias'
}
new_model_state_dict = {}
for old_key, new_key in map_old2new_keys.items():
new_model_state_dict.update({new_key: model_state_dict[old_key]})
# Save output model state dict
torch.save(new_model_state_dict, args.output)