-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
70 lines (49 loc) · 1.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# %%
import pytorch_lightning as pl
from nemo.collections.asr.models import EncDecCTCModel
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
# from pytorch_lightning.loggers.wandb import WandbLogger
from utils.data import DataLoader, QASRDataset
from utils import get_config
# %%
config = get_config('./configs/train.yaml')
print(config)
# %%
model = EncDecCTCModel.restore_from(config.start_from, map_location='cpu')
model.cuda()
model._wer.use_cer = True
# %%
ds_train = QASRDataset(ds_fpath=config.training_labels, voc=True)
ds_val = QASRDataset(ds_fpath=config.validation_labels, voc=True)
dl_train = DataLoader(ds_train, batch_size=config.batch_size,
shuffle=False,
collate_fn=ds_train._collate_fn)
dl_val = DataLoader(ds_val, batch_size=config.val_batch_size,
collate_fn=ds_val._collate_fn)
model._train_dl = dl_train
model._validation_dl = dl_val
# %%
tb_logger = TensorBoardLogger(config.logs_dir, name=None, version='')
# wb_logger = WandbLogger(project='stt-quartznet-ar')
# %%
clb_last = ModelCheckpoint(config.checkpoint_dir,
every_n_train_steps=config.n_save_ckpt,
save_last=True, save_top_k=0)
clb_valid = ModelCheckpoint(config.checkpoint_dir,
filename="states_{val_loss:.5f}",
save_top_k=3, monitor='val_loss', mode='min')
# %%
trainer = pl.Trainer(
max_epochs=config.max_epochs,
log_every_n_steps=config.log_every_n_steps,
val_check_interval=config.val_check_interval,
logger=tb_logger,
# logger=wb_logger,
default_root_dir=config.checkpoint_dir,
callbacks=[clb_valid, clb_last]
)
model.set_trainer(trainer)
# %%
trainer.fit(model, ckpt_path=config.resume_from)
# %%