This repository has been archived by the owner on May 18, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathtrain.py
145 lines (122 loc) · 4.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import logging
from typing import Callable
import torch
import os.path as osp
import signal
import sys
import traceback
from argparse import Namespace
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import log
from callbacks import (
CheckpointCustomFilename,
SaveOnKeyboardInterrupt,
CheckpointEveryNSteps,
)
from models import find_model_using_name
from options.test_options import TestOptions
from options.train_options import TrainOptions
from util import str2num
logger = log.setup_custom_logger("logger")
# DDP requires setting the manual seed
# https://pytorch-lightning.readthedocs.io/en/latest/multi_gpu.html#distributed-data-parallel
torch.manual_seed(420)
def main(train=True):
""" Runs train or test """
options_obj = TrainOptions() if train else TestOptions()
opt = options_obj.parse()
logger.setLevel(getattr(logging, opt.loglevel.upper()))
model_class = find_model_using_name(opt.model)
if opt.checkpoint:
model = model_class.load_from_checkpoint(
# TODO: we have to manually override all TestOptions for hparams in
# __init__, because they're not present in the checkpoint's train options.
# We should find a better solution
opt.checkpoint
)
logger.info(f"RESUMED {model_class.__name__} from checkpoint: {opt.checkpoint}")
else:
model = model_class(opt)
logger.info(f"INITIALIZED new {model_class.__name__}")
model.override_hparams(opt)
trainer = Trainer(
resume_from_checkpoint=opt.checkpoint if opt.checkpoint else None,
**get_hardware_kwargs(opt),
**get_train_kwargs(opt),
profiler=True,
)
if train:
save_on_interrupt = make_save_on_interrupt(trainer)
try:
trainer.fit(model)
except Exception as e:
logger.warning(f"Caught a {type(e)}!")
logger.error(traceback.format_exc())
save_on_interrupt(name=e.__class__.__name__)
else:
print("Testing........")
print(opt)
trainer.test(model)
logger.info(f"Finished {opt.model}, named {opt.name}!")
def get_hardware_kwargs(opt):
""" Hardware kwargs for the Trainer """
hardware_kwargs = vars(
Namespace(
gpus=opt.gpu_ids,
distributed_backend=opt.distributed_backend,
precision=opt.precision,
)
)
return hardware_kwargs
def get_train_kwargs(opt):
"""
Return Trainer kwargs specific to training if opt.is_train is True.
Otherwise return an empty dict.
"""
if not opt.is_train:
return {}
train_kwargs = vars(
Namespace(
# Checkpointing
checkpoint_callback=ModelCheckpoint(save_top_k=5, verbose=True),
callbacks=[
# CheckpointCustomFilename(),
CheckpointEveryNSteps(opt.save_count, prefix=opt.model, verbose=True),
],
default_root_dir=osp.join(opt.experiments_dir, opt.name),
log_save_interval=opt.display_count,
# Training and data
accumulate_grad_batches=opt.accumulated_batches,
max_epochs=opt.keep_epochs + opt.decay_epochs,
val_check_interval=str2num(opt.val_check_interval),
# see https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#replace-sampler-ddp
replace_sampler_ddp=False,
limit_train_batches=str2num(opt.limit_train_batches),
limit_val_batches=str2num(opt.limit_val_batches),
# Debug
fast_dev_run=opt.fast_dev_run,
)
)
return train_kwargs
def make_save_on_interrupt(trainer: Trainer) -> Callable:
""" On interrupt, will save checkpoint """
def save_on_interrupt(*args, name=""):
name = f"interrupted_by_{name}" if name else "interrupted_by_Ctrl-C"
try:
ckpt_path = osp.join(trainer.checkpoint_callback.dirpath, f"{name}.ckpt")
logger.warning(
"Training stopped prematurely. "
f"Saving Trainer checkpoint to: {ckpt_path}"
)
trainer.save_checkpoint(ckpt_path)
except:
logger.warning(
"No checkpoint to save. Either training didn't start, or I'm a "
"child process."
)
exit()
signal.signal(signal.SIGINT, save_on_interrupt)
return save_on_interrupt
if __name__ == "__main__":
main(train=True)