-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert_to_pytorch.py
145 lines (105 loc) · 4.68 KB
/
convert_to_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
from pprint import pprint
import tensorflow as tf
from inception import Inception3
import torch
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
model = Inception3()
#CUB model trained in tensorflow
tf_path = os.path.abspath('/home/m.bharti/svn/FineGrainedClassification/cvpr18-inaturalist-transfer/checkpoints/cub_200/auxlogits_aug_resize/model.ckpt-2810') # Path to our TensorFlow checkpoint
## iNaturalist path, does not work
#tf_path = os.path.abspath('./checkpoints/inception/inception_v3_iNat_299.ckpt')
#print_tensors_in_checkpoint_file(file_name='./checkpoints/inception/inception_v3_iNat_299.ckpt', tensor_name='', all_tensors=True)
#reader = pywrap_tensorflow.NewCheckpointReader('./checkpoints/inception/inception_v3_iNat_299.ckpt')
#init_vars = reader.get_variable_to_shape_map()
#
init_vars = tf.train.list_variables(tf_path)
#pprint(init_vars)
#print(len(init_vars))
tf_vars = []
for key, value in init_vars:
#print("Loading TF weight {} with shape {}".format(name, shape))
print(key)
#array = reader.get_tensor(key)
array = tf.train.load_variable(tf_path, key)
tf_vars.append((key, array))
print("Total vars {}".format(len(tf_vars)))
count =0
total_aux_variables =0
total_logits_variables = 0
# FOr each variable in the PyTorch model
for full_name, array in tf_vars:
# skip the prefix ('model/') and split the path-like variable name in a list of sub-path
if full_name == 'global_step':
continue
name = full_name[12:].split('/')
if full_name == 'InceptionV3/Logits/Conv2d_1c_1x1/biases/Momentum':
continue
if full_name == 'InceptionV3/Logits/Conv2d_1c_1x1/weights/Momentum':
continue
if full_name == 'InceptionV3/AuxLogits/Conv2d_1b_1x1/BatchNorm/beta/Momentum':
continue
if full_name == 'InceptionV3/AuxLogits/Conv2d_1b_1x1/weights/Momentum':
continue
if full_name == 'InceptionV3/AuxLogits/Conv2d_2a_5x5/BatchNorm/beta/Momentum':
continue
if full_name == 'InceptionV3/AuxLogits/Conv2d_2a_5x5/weights/Momentum':
continue
if full_name == 'InceptionV3/AuxLogits/Conv2d_2b_1x1/biases/Momentum':
continue
if full_name == 'InceptionV3/AuxLogits/Conv2d_2b_1x1/weights/Momentum':
continue
print(full_name)
# Initiate the pointer from the main model class
pointer = model
if name[0] == 'AuxLogits':
total_aux_variables = total_aux_variables + 1
pprint(full_name)
if name[0] == 'Logits':
total_logits_variables = total_logits_variables + 1
# pprint(full_name)
# continue
# We iterate along the scopes and move our pointer accordingly
for m_name in name:
l = [m_name]
# Convert parameters final names to the PyTorch modules equivalent names
if l[0] == 'weights':
pointer = getattr(pointer, 'conv')
pointer = getattr(pointer, 'weight')
array = np.transpose(array, (3, 2, 0, 1))
assert pointer.shape == array.shape
#print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
count = count + 1
elif l[0] == 'biases': #Batch Normalisation
pointer = getattr(pointer, 'conv')
pointer = getattr(pointer, 'bias')
assert pointer.shape == array.shape
pointer.data = torch.from_numpy(array)
count = count + 1
elif l[0] == 'beta': # Batch Normalisation
pointer = getattr(pointer, 'bias')
assert pointer.shape == array.shape
pointer.data = torch.from_numpy(array)
count = count + 1
elif l[0] == 'gamma': #Batch Normalisation not present for this model
pointer = getattr(pointer, 'weight')
pointer.data = torch.from_numpy(array)
count = count + 1
elif l[0] == 'moving_mean': #Batch Normalisation
assert getattr(pointer, 'running_mean').shape == array.shape
pointer.__setattr__('running_mean', torch.from_numpy(array))
#pointer = torch.from_numpy(array)
count = count + 1
elif l[0] == 'moving_variance': #Batch Normalisation
assert getattr(pointer, 'running_var').shape == array.shape
pointer.__setattr__('running_var', torch.from_numpy(array))
count = count + 1
else:
pointer = getattr(pointer, l[0])
#print("Moving forward {}".format(l[0]))
pprint("Updated {} parameters".format(count))
torch.save(model.state_dict(), './cub_inceptionv3.pth')
#torch.save(model.state_dict(), './iNat_inceptionv3.pth')