forked from JusperLee/Apollo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
138 lines (115 loc) · 5.62 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json
import os
import argparse
import pytorch_lightning as pl
import torch
import warnings
import hydra
from omegaconf import OmegaConf
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from omegaconf import DictConfig
from look2hear.utils import print_only
from typing import Any, Dict, List, Tuple
def train(cfg: DictConfig, model_path: str = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# instantiate datamodule
print_only(f"Instantiating datamodule <{cfg.datas._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datas, expdir=os.path.join(cfg.exp.dir, cfg.exp.name))
# instantiate model
print_only(f"Instantiating AudioNet <{cfg.model._target_}>")
model: torch.nn.Module = hydra.utils.instantiate(cfg.model)
# Load the model state if continue training
if model_path:
print_only(f"Loading model state dict from <{model_path}>")
checkpoint = torch.load(model_path, weights_only=False)
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
if "audio_model" in k:
new_k = k.replace("audio_model.", "")
state_dict[new_k] = state_dict[k]
del state_dict[k]
model.load_state_dict(state_dict)
print_only(f"Instantiating Discriminator <{cfg.discriminator._target_}>")
discriminator: torch.nn.Module = hydra.utils.instantiate(cfg.discriminator)
# instantiate optimizer
print_only(f"Instantiating optimizer <{cfg.optimizer_g._target_}>")
optimizer_g: torch.optim = hydra.utils.instantiate(cfg.optimizer_g, params=model.parameters())
optimizer_d: torch.optim = hydra.utils.instantiate(cfg.optimizer_d, params=discriminator.parameters())
# instantiate scheduler
print_only(f"Instantiating scheduler <{cfg.scheduler_g._target_}>")
scheduler_g: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_g, optimizer=optimizer_g)
scheduler_d: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_d, optimizer=optimizer_d)
# instantiate loss
print_only(f"Instantiating loss <{cfg.loss_g._target_}>")
loss_g: torch.nn.Module = hydra.utils.instantiate(cfg.loss_g)
loss_d: torch.nn.Module = hydra.utils.instantiate(cfg.loss_d)
losses = {"g": loss_g,"d": loss_d}
# instantiate metrics
print_only(f"Instantiating metrics <{cfg.metrics._target_}>")
metrics: torch.nn.Module = hydra.utils.instantiate(cfg.metrics)
# instantiate system
print_only(f"Instantiating system <{cfg.system._target_}>")
system: LightningModule = hydra.utils.instantiate(
cfg.system,
model=model,
discriminator=discriminator,
loss_func=losses,
metrics=metrics,
optimizer=[optimizer_g, optimizer_d],
scheduler=[scheduler_g, scheduler_d]
)
# Load the model state if continue training
if model_path:
print_only(f"Loading system state dict from <{model_path}>")
checkpoint = torch.load(model_path, weights_only=False)
system.load_state_dict(checkpoint['state_dict'])
# instantiate callbacks
callbacks: List[Callback] = []
if cfg.get("early_stopping"):
print_only(f"Instantiating early_stopping <{cfg.early_stopping._target_}>")
callbacks.append(hydra.utils.instantiate(cfg.early_stopping))
if cfg.get("checkpoint"):
print_only(f"Instantiating checkpoint <{cfg.checkpoint._target_}>")
checkpoint: pl.callbacks.ModelCheckpoint = hydra.utils.instantiate(cfg.checkpoint)
callbacks.append(checkpoint)
# instantiate logger
print_only(f"Instantiating logger <{cfg.logger._target_}>")
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "logs"), exist_ok=True)
logger = hydra.utils.instantiate(cfg.logger)
# instantiate trainer
print_only(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config=cfg.trainer,
callbacks=callbacks,
logger=logger,
strategy=DDPStrategy(find_unused_parameters=True),
)
trainer.fit(system, datamodule=datamodule)
print_only("Training finished!")
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
with open(os.path.join(cfg.exp.dir, cfg.exp.name, "best_k_models.json"), "w") as f:
json.dump(best_k, f, indent=0)
state_dict = torch.load(checkpoint.best_model_path)
system.load_state_dict(state_dict=state_dict["state_dict"])
system.cpu()
to_save = system.audio_model.serialize()
torch.save(to_save, os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="configs/apollo.yaml",)
parser.add_argument("-m", "--model", type=str, default=None, help="Path to the checkpoint model for resuming training")
args = parser.parse_args()
warnings.filterwarnings("ignore")
torch.set_float32_matmul_precision("high")
cfg = OmegaConf.load(args.config)
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name), exist_ok=True)
OmegaConf.save(cfg, os.path.join(cfg.exp.dir, cfg.exp.name, "config.yaml"))
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
devices = cfg.trainer.get("devices", [0])
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8001'
os.environ["RANK"] = '0'
os.environ["WORLD_SIZE"] = str(len(devices))
torch.distributed.init_process_group(backend='gloo' if os.name == "nt" else 'nccl', init_method="env://")
train(cfg, model_path=args.model)