-
Notifications
You must be signed in to change notification settings - Fork 5
/
eval.py
79 lines (65 loc) · 3.11 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import tensorflow as tf
import os
import utils
from model_helper import las_model_fn
def parse_args():
parser = argparse.ArgumentParser(
description='Run model evaluation.')
parser.add_argument('--data', type=str,
help='data in TFRecord format')
parser.add_argument('--vocab', type=str,
help='vocabulary table, listing vocabulary line by line')
parser.add_argument('--norm', type=str, default=None,
help='normalization params')
parser.add_argument('--t2t_format', action='store_true',
help='Use dataset in the format of ASR problems of Tensor2Tensor framework. --train param should be directory')
parser.add_argument('--t2t_problem_name', type=str,
help='Problem name for data in T2T format.')
parser.add_argument('--mapping', type=str,
help='additional mapping when evaluation')
parser.add_argument('--model_dir', type=str, required=True,
help='path of saving model')
parser.add_argument('--batch_size', type=int, default=8,
help='batch size')
parser.add_argument('--num_channels', type=int, default=39,
help='number of input channels')
parser.add_argument('--binf_map', type=str, default='misc/binf_map.csv',
help='Path to CSV with phonemes to binary features map')
parser.add_argument('--t2t_features_hparams_override', type=str, default='',
help='String with overrided parameters used by Tensor2Tensor problem.')
return parser.parse_args()
def main(args):
eval_name = str(os.path.basename(args.data).split('.')[0])
config = tf.estimator.RunConfig(model_dir=args.model_dir)
hparams = utils.create_hparams(args)
vocab_name = args.vocab if not args.t2t_format else os.path.join(args.data, 'vocab.txt')
vocab_list = utils.load_vocab(vocab_name)
binf2phone_np = None
binf2phone = None
if hparams.decoder.binary_outputs:
binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
binf2phone_np = binf2phone.values
def model_fn(features, labels,
mode, config, params):
return las_model_fn(features, labels, mode, config, params,
binf2phone=binf2phone_np, run_name=eval_name)
model = tf.estimator.Estimator(
model_fn=model_fn,
config=config,
params=hparams)
tf.logging.info('Evaluating on {}'.format(eval_name))
if args.t2t_format:
input_fn = lambda: utils.input_fn_t2t(
args.data, tf.estimator.ModeKeys.EVAL, hparams,
args.t2t_problem_name, batch_size=args.batch_size,
features_hparams_override=args.t2t_features_hparams_override)
else:
input_fn = lambda: utils.input_fn(
args.data, args.vocab, args.norm, num_channels=args.num_channels,
batch_size=args.batch_size)
model.evaluate(input_fn, name=eval_name)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
args = parse_args()
main(args)