-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain.py
114 lines (96 loc) · 5.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
@Origin : main.py by Yue Wang
@Contact: [email protected]
@Time: 2018/10/13 10:39 PM
modified by {Sanghyeok Lee, Sihyeon Kim}
@Contact: {cat0626, sh_bs15}@korea.ac.kr
@File: main.py
@Time: 2021.09.30
"""
from __future__ import print_function
import os
import argparse
import torch
from util import IOStream
from train import train_vanilla, train_AugTune, test
def _init_():
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
if not os.path.exists('checkpoints/'+args.exp_name):
os.makedirs('checkpoints/'+args.exp_name)
if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'):
os.makedirs('checkpoints/'+args.exp_name+'/'+'models')
os.system('cp main.py checkpoints'+'/'+args.exp_name+'/'+'main.py.backup')
os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
os.system('cp util.py checkpoints' + '/' + args.exp_name + '/' + 'util.py.backup')
os.system('cp data.py checkpoints' + '/' + args.exp_name + '/' + 'data.py.backup')
os.system('cp PointWOLF.py checkpoints' + '/' + args.exp_name + '/' + 'PointWOLF.py.backup')
os.system('cp train.py checkpoints' + '/' + args.exp_name + '/' + 'train.py.backup')
if __name__ == "__main__":
# Training settings
parser = argparse.ArgumentParser(description='Point Cloud Recognition')
parser.add_argument('--exp_name', type=str, default='exp_PointWOLF', metavar='N',
help='Name of the experiment')
parser.add_argument('--model', type=str, default='dgcnn', metavar='N',
choices=['pointnet', 'dgcnn'],
help='Model to use, [pointnet, dgcnn]')
parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N',
choices=['modelnet40'])
parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size',
help='Size of batch)')
parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size',
help='Size of batch)')
parser.add_argument('--epochs', type=int, default=250, metavar='N',
help='number of episode to train ')
parser.add_argument('--use_sgd', type=bool, default=True,
help='Use SGD')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001, 0.1 if using sgd)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--no_cuda', type=bool, default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--eval', type=bool, default=False,
help='evaluate the model')
parser.add_argument('--num_points', type=int, default=1024,
help='num of points to use')
parser.add_argument('--dropout', type=float, default=0.5,
help='dropout rate')
parser.add_argument('--emb_dims', type=int, default=1024, metavar='N',
help='Dimension of embeddings')
parser.add_argument('--k', type=int, default=20, metavar='N',
help='Num of nearest neighbors to use')
parser.add_argument('--model_path', type=str, default='', metavar='N',
help='Pretrained model path')
# PointWOLF settings
parser.add_argument('--PointWOLF', action='store_true', help='Use PointWOLF')
parser.add_argument('--w_num_anchor', type=int, default=4, help='Num of anchor point' )
parser.add_argument('--w_sample_type', type=str, default='fps', help='Sampling method for anchor point, option : (fps, random)')
parser.add_argument('--w_sigma', type=float, default=0.5, help='Kernel bandwidth')
parser.add_argument('--w_R_range', type=float, default=10, help='Maximum rotation range of local transformation')
parser.add_argument('--w_S_range', type=float, default=3, help='Maximum scailing range of local transformation')
parser.add_argument('--w_T_range', type=float, default=0.25, help='Maximum translation range of local transformation')
# AugTune settings
parser.add_argument('--AugTune', action='store_true', help='Use AugTune')
parser.add_argument('--l', type=float, default=0.1, help='Difficulty parameter lambda')
args = parser.parse_args()
_init_()
io = IOStream('checkpoints/' + args.exp_name + '/run.log')
io.cprint(str(args))
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
io.cprint(
'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices')
torch.cuda.manual_seed(args.seed)
else:
io.cprint('Using CPU')
if not args.eval:
if args.AugTune:
train_AugTune(args, io)
else:
train_vanilla(args, io)
else:
test(args, io)