-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
29 lines (23 loc) · 1.07 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from LightingModel import LitModel
import pytorch_lightning as pl
from maskData import MaskDataModule
import config as cfg
from pytorch_lightning.callbacks import ModelCheckpoint
pl.seed_everything(cfg.random_seed)
if __name__ == '__main__':
model = LitModel()
data_module = MaskDataModule(train_dir=cfg.train_dir,
val_dir=cfg.val_dir,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers)
checkpoint_callback = ModelCheckpoint(monitor='val_sum',
save_top_k=10,
mode='max')
trainer = pl.Trainer(max_epochs=cfg.max_epochs,
gpus=1,
num_sanity_val_steps=2,
callbacks=[checkpoint_callback],
limit_train_batches=cfg.limit_train_batches,
limit_val_batches=cfg.limit_val_batches,
gradient_clip_val=10)
trainer.fit(model, datamodule=data_module)