-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpointnet_baseline.py
71 lines (64 loc) · 3.21 KB
/
pointnet_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import warnings
warnings.filterwarnings("ignore")
from models.pointnet import PointNetLoading, PointNetPersistenceLoading, Pointnet_plus, PointNetSpaceGMLoading
from models.gnn import GCN, GIN, GAT, SAGE
from argparse import ArgumentParser
import torch
from torchinfo import summary
from tqdm import tqdm
import numpy as np
parser = ArgumentParser(description="KNN GNN")
parser.add_argument('--raw_dir', type=str, default = 'dfci', help="Directory where the raw data is stored")
parser.add_argument('--label_name', type=str, default = 'pTR_label', help="Directory where the raw data is stored")
parser.add_argument('--full', action='store_true', help="Directory where the raw data is stored")
parser.add_argument('--task', type=str, default = 'treatment', help="Task on PDO data")
parser.add_argument('--model', type=str, default = 'GCN', help="Directory where the raw data is stored")
parser.add_argument('--hidden_dim', type=int, default= 150, help="Hidden dim for the MLP")
parser.add_argument('--num_layers', type=int, default= 3, help="Number of MLP layers")
parser.add_argument('--batch_size', type=int, default= 32, help="Batch size")
parser.add_argument('--num_neighbors', type=int, default= 5, help="Number of neighbors for KNN graph")
parser.add_argument('--lr', type=float, default= 1e-3, help="Learnign Rate")
parser.add_argument('--wd', type=float, default= 3e-4, help="Weight decay")
parser.add_argument('--num_epochs', type=int, default= 200, help="Number of epochs")
parser.add_argument('--gpu', type=int, default= 0, help="GPU index")
def train(model, epochs):
best_acc = 0
best_val_acc = 0
opt = model.configure_optimizers()
with tqdm(range(epochs)) as tq:
for epoch in enumerate(tq):
t_loss = 0
for data in train_loader:
opt.zero_grad()
loss = model.training_step(data, data.batch)
loss.backward()
opt.step()
t_loss+=loss.item()
model.eval()
with torch.no_grad():
for data in val_loader:
model.validation_step(data, data.batch)
val_acc = model.on_validation_epoch_end()
for data in test_loader:
model.test_step(data, data.batch)
test_acc = model.on_test_epoch_end()
if(val_acc>= best_val_acc):
best_val_acc = val_acc
best_acc = test_acc
tq.set_description("Train loss = %.4f, Val acc = %.4f, Best val acc = %.4f, Best acc = %.4f" % (t_loss, test_acc, best_val_acc, best_acc))
return best_acc.item()
args = parser.parse_args()
if args.gpu != -1 and torch.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu)
else:
args.device = 'cpu'
if __name__ == '__main__':
print(args)
acc = []
for i in range(5):
train_loader, val_loader, test_loader, input_dim, num_classes = PointNetSpaceGMLoading(args.raw_dir, args.label_name, args.batch_size, args.device)
model = Pointnet_plus(1, input_dim, num_classes, args.lr).to(args.device)
model.train()
acc.append(train(model, args.num_epochs))
acc = np.array(acc)
print(f"Average performance: {acc.mean()}, std: {acc.std()}")