-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
executable file
·98 lines (79 loc) · 2.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python
""" PointAR training script
"""
import os
import fire
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from model.pointconv import PointAR
from trainer.utils import train_valid_test_split
from datasets.pointar.loader import PointARTestDataset
from datasets.pointar.loader import PointARTrainDataset
# from datasets.pointar.loader import PointARTrainD10Dataset
class ModelSavingCallback(pl.Callback):
def __init__(self, sample_input):
self.sample_input = sample_input
def on_epoch_end(self, trainer, pl_module):
dump_path = f'./dist/model_dumps'
os.system(f'mkdir -p {dump_path}')
trainer.save_checkpoint(f'{dump_path}/{trainer.current_epoch}.ckpt')
def train(debug=False,
use_hdr=True,
normalize=False,
n_points=1280,
num_workers=16,
batch_size=32):
"""Train PointAR model
Parameters
----------
debug : bool
Set debugging flag
use_hdr : bool
Use HDR SH coefficients data for training
normalize : bool
Normalize SH coefficients
n_points : int
Number of model input points, default 1280
num_workers : int
Number of workers for loading data, default 16
batch_size : int
Training batch size
"""
# Specify dataset
TestDataset = PointARTestDataset
TrainDataset = TestDataset if debug else PointARTrainDataset
# Get loaders ready
loader_param = {'use_hdr': use_hdr}
loaders, scaler = train_valid_test_split(
TrainDataset, loader_param,
TestDataset, loader_param,
normalize=normalize, num_workers=num_workers, batch_size=batch_size)
train_loader, valid_loader, test_loader = loaders
# Get model ready
model = PointAR(hparams={
'n_shc': 27,
'n_points': n_points,
'min': torch.from_numpy(scaler.min_) if normalize else torch.zeros((27)),
'scale': torch.from_numpy(scaler.scale_) if normalize else torch.ones((27))
})
# Train
sample_input = (
torch.zeros((1, 3, n_points)).float().cuda(),
torch.zeros((1, 3, n_points)).float().cuda())
trainer = pl.Trainer(
gpus=1,
check_val_every_n_epoch=1,
callbacks=[
ModelSavingCallback(
sample_input=sample_input
),
EarlyStopping(monitor='valid_shc_mse')
])
# Start training
trainer.fit(
model,
train_dataloader=train_loader,
val_dataloaders=[valid_loader, test_loader])
if __name__ == '__main__':
fire.Fire(train)