-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathmain.py
206 lines (165 loc) · 7.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import argparse
import os
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from psbody.mesh import Mesh, MeshViewers
import mesh_operations
from config_parser import read_config
from data import ComaDataset
from model import Coma
from transform import Normalize
def scipy_to_torch_sparse(scp_matrix):
values = scp_matrix.data
indices = np.vstack((scp_matrix.row, scp_matrix.col))
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
shape = scp_matrix.shape
sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
return sparse_tensor
def adjust_learning_rate(optimizer, lr_decay):
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * lr_decay
def save_model(coma, optimizer, epoch, train_loss, val_loss, checkpoint_dir):
checkpoint = {}
checkpoint['state_dict'] = coma.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
checkpoint['epoch_num'] = epoch
checkpoint['train_loss'] = train_loss
checkpoint['val_loss'] = val_loss
torch.save(checkpoint, os.path.join(checkpoint_dir, 'checkpoint_'+ str(epoch)+'.pt'))
def main(args):
if not os.path.exists(args.conf):
print('Config not found' + args.conf)
config = read_config(args.conf)
print('Initializing parameters')
template_file_path = config['template_fname']
template_mesh = Mesh(filename=template_file_path)
if args.checkpoint_dir:
checkpoint_dir = args.checkpoint_dir
else:
checkpoint_dir = config['checkpoint_dir']
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
visualize = config['visualize']
output_dir = config['visual_output_dir']
if visualize is True and not output_dir:
print('No visual output directory is provided. Checkpoint directory will be used to store the visual results')
output_dir = checkpoint_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
eval_flag = config['eval']
lr = config['learning_rate']
lr_decay = config['learning_rate_decay']
weight_decay = config['weight_decay']
total_epochs = config['epoch']
workers_thread = config['workers_thread']
opt = config['optimizer']
batch_size = config['batch_size']
val_losses, accs, durations = [], [], []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Generating transforms')
M, A, D, U = mesh_operations.generate_transform_matrices(template_mesh, config['downsampling_factors'])
D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
num_nodes = [len(M[i].v) for i in range(len(M))]
print('Loading Dataset')
if args.data_dir:
data_dir = args.data_dir
else:
data_dir = config['data_dir']
normalize_transform = Normalize()
dataset = ComaDataset(data_dir, dtype='train', split=args.split, split_term=args.split_term, pre_transform=normalize_transform)
dataset_test = ComaDataset(data_dir, dtype='test', split=args.split, split_term=args.split_term, pre_transform=normalize_transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers_thread)
test_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=workers_thread)
print('Loading model')
start_epoch = 1
coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
if opt == 'adam':
optimizer = torch.optim.Adam(coma.parameters(), lr=lr, weight_decay=weight_decay)
elif opt == 'sgd':
optimizer = torch.optim.SGD(coma.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
else:
raise Exception('No optimizer provided')
checkpoint_file = config['checkpoint_file']
print(checkpoint_file)
if checkpoint_file:
checkpoint = torch.load(checkpoint_file)
start_epoch = checkpoint['epoch_num']
coma.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
#To find if this is fixed in pytorch
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
coma.to(device)
if eval_flag:
val_loss = evaluate(coma, output_dir, test_loader, dataset_test, template_mesh, device, visualize)
print('val loss', val_loss)
return
best_val_loss = float('inf')
val_loss_history = []
for epoch in range(start_epoch, total_epochs + 1):
print("Training for epoch ", epoch)
train_loss = train(coma, train_loader, len(dataset), optimizer, device)
val_loss = evaluate(coma, output_dir, test_loader, dataset_test, template_mesh, device, visualize=visualize)
print('epoch ', epoch,' Train loss ', train_loss, ' Val loss ', val_loss)
if val_loss < best_val_loss:
save_model(coma, optimizer, epoch, train_loss, val_loss, checkpoint_dir)
best_val_loss = val_loss
val_loss_history.append(val_loss)
val_losses.append(best_val_loss)
if opt=='sgd':
adjust_learning_rate(optimizer, lr_decay)
if torch.cuda.is_available():
torch.cuda.synchronize()
def train(coma, train_loader, len_dataset, optimizer, device):
coma.train()
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = coma(data)
loss = F.l1_loss(out, data.y)
total_loss += data.num_graphs * loss.item()
loss.backward()
optimizer.step()
return total_loss / len_dataset
def evaluate(coma, output_dir, test_loader, dataset, template_mesh, device, visualize=False):
coma.eval()
total_loss = 0
meshviewer = MeshViewers(shape=(1, 2))
for i, data in enumerate(test_loader):
data = data.to(device)
with torch.no_grad():
out = coma(data)
loss = F.l1_loss(out, data.y)
total_loss += data.num_graphs * loss.item()
if visualize and i % 100 == 0:
save_out = out.detach().cpu().numpy()
save_out = save_out*dataset.std.numpy()+dataset.mean.numpy()
expected_out = (data.y.detach().cpu().numpy())*dataset.std.numpy()+dataset.mean.numpy()
result_mesh = Mesh(v=save_out, f=template_mesh.f)
expected_mesh = Mesh(v=expected_out, f=template_mesh.f)
meshviewer[0][0].set_dynamic_meshes([result_mesh])
meshviewer[0][1].set_dynamic_meshes([expected_mesh])
meshviewer[0][0].save_snapshot(os.path.join(output_dir, 'file'+str(i)+'.png'), blocking=False)
return total_loss/len(dataset)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Pytorch Trainer for Convolutional Mesh Autoencoders')
parser.add_argument('-c', '--conf', help='path of config file')
parser.add_argument('-s', '--split', default='sliced', help='split can be sliced, expression or identity ')
parser.add_argument('-st', '--split_term', default='sliced', help='split term can be sliced, expression name '
'or identity name')
parser.add_argument('-d', '--data_dir', help='path where the downloaded data is stored')
parser.add_argument('-cp', '--checkpoint_dir', help='path where checkpoints file need to be stored')
args = parser.parse_args()
if args.conf is None:
args.conf = os.path.join(os.path.dirname(__file__), 'default.cfg')
print('configuration file not specified, trying to load '
'it from current directory', args.conf)
main(args)