-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinfer_gan.py
executable file
·65 lines (47 loc) · 1.57 KB
/
infer_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import logging
import hydra
from omegaconf import OmegaConf
logger = logging.getLogger(__name__)
def run(args):
import torch
from torch.utils.data import DataLoader
from body import LoaderSelector
from body import AugmentorSelector
from body import ModelSelector
from body import LossSelector
from body import OptimizerSelector
from body import SolverSelector
from body.solver.enhance_audio import enhance
# TODO: scheduler
torch.manual_seed(1000000007) # 1e+8 + 7
model = ModelSelector(args.model)()
# load from designated path into this model.
load_from = args.load_from
package = torch.load(load_from, 'cpu')
model.load_state_dict(package['model']['generator']['state'])
model.to(args.experiment['device'])
# now let's run inference
out_dir = args.out_dir
enhance(args, model, out_dir)
# required in args:
# - either args.data.noisy_json or args.data.noisy_dir, args.noisy_dir or args.noisy_json
# - args.experiment.num_workers
# - args.experiment.device
# - args.experiment.dry
def _main(args):
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
logger.info("For logs, checkpoints and samples check %s", os.getcwd())
logger.debug(args)
run(args)
@hydra.main(version_base=None,config_path="conf")
def main(args):
try:
os.makedirs(f"{args.experiment.output_path}",exist_ok=True)
_main(args)
except Exception:
logger.exception("Some error happened")
os._exit(1)
if __name__ == "__main__":
main()