-
Notifications
You must be signed in to change notification settings - Fork 1
/
translate.py
171 lines (137 loc) · 6.63 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Translate sentences from the input stream.
# The model will be faster is sentences are sorted by length.
# Input sentences must have the same tokenization and BPE codes than the ones used in the model.
#
# Usage:
# cat source_sentences.bpe | \
# python translate.py --exp_name translate \
# --src_lang en --tgt_lang fr \
# --model_path trained_model.pth --output_path output
#
import os
import io
import sys
import argparse
import torch
from src.utils import AttrDict
from src.utils import bool_flag, initialize_exp
from src.data.dictionary import Dictionary
from src.model.transformer import TransformerModel
def get_parser():
"""
Generate a parameters parser.
"""
# parse parameters
parser = argparse.ArgumentParser(description="Translate sentences")
# main parameters
parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path")
parser.add_argument("--exp_name", type=str, default="", help="Experiment name")
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch")
parser.add_argument("-c", "--continue_translate", type=bool, default=False, help="whether continue to translate")
# model / output paths
parser.add_argument("--model_path", type=str, default="", help="Model path")
parser.add_argument("--output_path", type=str, default="", help="Output path")
# parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
# parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
# source language / target language
parser.add_argument("--src_lang", type=str, default="", help="Source language")
parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
return parser
def main(params):
# initialize the experiment
logger = initialize_exp(params)
# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
reloaded = torch.load(params.model_path)
model_params = AttrDict(reloaded['params'])
logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))
# update dictionary parameters
for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
setattr(params, name, getattr(model_params, name))
# build dictionary / build encoder / build decoder / reload weights
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
if all([k.startswith('module.') for k in reloaded['encoder'].keys()] + [k.startswith('module.') for k in reloaded['decoder'].keys()]):
encoder.load_state_dict({k[len('module.'):]: v for k, v in reloaded['encoder'].items()}, strict=False)
decoder.load_state_dict({k[len('module.'):]: v for k, v in reloaded['decoder'].items()})
else:
encoder.load_state_dict(reloaded['encoder'])
decoder.load_state_dict(reloaded['decoder'])
params.src_id = model_params.lang2id[params.src_lang]
params.tgt_id = model_params.lang2id[params.tgt_lang]
# read sentences from stdin
src_sent = []
for line in sys.stdin.readlines():
if len(line.strip().split()) > 0:
src_sent.append(line)
else:
src_sent.append('<UNK>')
logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent))
if params.continue_translate:
if not os.path.isfile(params.output_path):
with open(params.output_path, 'w', encoding='utf-8'):
logger.info(f"Creating {params.output_path}")
pass
f = io.open(params.output_path, 'r', encoding='utf-8')
num_translated = len(f.readlines())
f.close()
f = io.open(params.output_path, 'a+', encoding='utf-8')
logger.info(f"ct")
else:
f = io.open(params.output_path, 'w', encoding='utf-8')
num_translated = 0
logger.info(f"nct")
logger.info(f"Continue translating from line:{num_translated+1}")
for i in range(num_translated, len(src_sent), params.batch_size):
if i % (100 * params.batch_size) == 0:
logger.info(i)
# prepare batch
word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()])
for s in src_sent[i:i + params.batch_size]]
lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
batch[0] = params.eos_index
for j, s in enumerate(word_ids):
if lengths[j] > 2: # if sentence not empty
batch[1:lengths[j] - 1, j].copy_(s)
batch[lengths[j] - 1, j] = params.eos_index
langs = batch.clone().fill_(params.src_id)
# encode source batch and translate it
encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
encoded = encoded.transpose(0, 1)
decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
# convert sentences to words
for j in range(decoded.size(1)):
# remove delimiters
sent = decoded[:, j]
delimiters = (sent == params.eos_index).nonzero().view(-1)
assert len(delimiters) >= 1 and delimiters[0].item() == 0
sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]
# output translation
source = src_sent[i + j].strip()
target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
# sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
f.write(target + "\n")
f.close()
if __name__ == '__main__':
# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
# check parameters
assert os.path.isfile(params.model_path)
assert params.src_lang != '' and params.tgt_lang != ''
assert params.output_path
if not params.continue_translate:
assert not os.path.isfile(params.output_path)
# translate
with torch.no_grad():
main(params)