-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
86 lines (70 loc) · 3.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import platform
import random
import sys
import time
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, RichModelSummary
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
from pytorch_lightning.strategies import DDPStrategy
from torch.cuda.amp import GradScaler
from utils.prepare import experiment_from_args
random.seed(1)
torch.manual_seed(1)
def define_args(parent_parser):
parser = parent_parser.add_argument_group('train.py')
parser.add_argument('--use-fp16',
help='sets models precision to FP16. Default is FP32',
action='store_true',
default=False)
parser.add_argument('--load-model-path',
help='load model from pth',
type=str,
default=None)
parser.add_argument('--wandb',
help='log to wandb',
type=bool,
default=True,
action=argparse.BooleanOptionalAction)
parser.add_argument('--tensorboard',
help='log to tensorboard',
type=bool,
default=False,
action=argparse.BooleanOptionalAction)
parser.add_argument('--ddp',
help='use DDP acceleration strategy',
type=bool,
default=False,
action=argparse.BooleanOptionalAction)
parser.add_argument('--name',
help='experiment name',
type=str,
default=None)
return parent_parser
def main():
data_module, model, args = experiment_from_args(sys.argv, add_argparse_args_fn=define_args)
plugins = []
if args.use_fp16:
grad_scaler = GradScaler()
plugins += [NativeMixedPrecisionPlugin(precision=16, device='cuda', scaler=grad_scaler)]
run_name = args.name
if run_name is None:
run_name = f'{time.strftime("%Y-%m-%d_%H:%M:%S")}-{platform.node()}'
print('Run name:', run_name)
loggers = []
if args.tensorboard:
loggers.append(TensorBoardLogger(save_dir='logs/', name=run_name))
if args.wandb:
loggers.append(WandbLogger(project='glimpse_mae', entity="ideas_cv", name=run_name))
checkpoint_callback = ModelCheckpoint(dirpath=f"checkpoints/{run_name}", monitor="val/loss")
trainer = Trainer(plugins=plugins, max_epochs=args.epochs, accelerator='auto', logger=loggers,
callbacks=[checkpoint_callback, RichProgressBar(leave=True), RichModelSummary(max_depth=3)],
enable_model_summary=False,
strategy=DDPStrategy(find_unused_parameters=False) if args.ddp else None)
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load_model_path)
if data_module.has_test_data:
trainer.test(ckpt_path='best', datamodule=data_module)
if __name__ == "__main__":
main()