-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
80 lines (63 loc) · 2.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import logging
import torch
from torch.autograd import Variable
from models.cbhg import CBHGNet
from models.mgru import MinimalGRUNet
from run import Runner
from trainers.timit import TIMITTrainer
def get_logger(name):
# setup logger
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger
def get_loadable_checkpoint(checkpoint):
"""
If model is saved with DataParallel, checkpoint keys is started with 'module.' remove it and return new state dict
:param checkpoint:
:return: new checkpoint
"""
new_checkpoint = {}
for key, val in checkpoint.items():
new_key = key.replace('module.', '')
new_checkpoint[new_key] = val
return new_checkpoint
def to_variable(tensor, is_cuda=True):
result = Variable(tensor, requires_grad=False)
if is_cuda:
return result.cuda()
else:
return result
def get_trainer(name='cbhg'):
if name not in Runner.IMPLEMENTED_MODELS:
raise NotImplementedError('Trainer for %s is not implemented !! ' % name)
if name == 'cbhg':
return TIMITTrainer
else:
return None
def get_networks(name='cbhg', checkpoint_path='', is_cuda=True, is_multi_gpu=True):
"""
:param name: the name of network
:param checkpoint_path: checkpoint path if you want to load checkpoint
:param is_cuda: usage of cuda
:param is_multi_gpu: check multi gpu
:return: network, pretrained step
"""
if name == 'cbhg':
network = CBHGNet()
elif name == 'mgru':
network = MinimalGRUNet()
else:
raise NotImplementedError('Network %s is not implemented !! ' % name)
if checkpoint_path:
checkpoint = torch.load(checkpoint_path)
network.load_state_dict(get_loadable_checkpoint(checkpoint['net']))
if is_cuda:
network = network.cuda()
if is_multi_gpu:
network = torch.nn.DataParallel(network)
return network