-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbeam_search.py
58 lines (43 loc) · 1.82 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
def beam_search(model, src, bos, eos, k=10, max_len=100, filter_token=None, gpu=True):
"""
works if batch size = 1
"""
model.eval()
if gpu:
src = src.cuda()
if filter_token is None:
filter_token = []
# run encoders - batch_size = 1
encoded_src_t = model.encode_source_translate(src)
encoded_src_i = model.encode_source_infer(src)
# init
init_prob = 0
best_options = [(init_prob, [bos], None)] # beam
for length in range(max_len): # maximum target length
options = [] # candidates
for lprob, sentence, hidden in best_options:
# Prepare last word
last_word = sentence[-1]
# keep sentences ending in '</s>' as candidates
if last_word == eos:
options.append((lprob, sentence, hidden))
else:
last_word_input = torch.tensor([last_word], requires_grad=False).long().view(1, 1)
if gpu: last_word_input = last_word_input.cuda()
# Decode
lprobs, new_hidden = model.generate(last_word_input, src, encoded_src_t, encoded_src_i, hidden)
lprobs = lprobs.squeeze()
# filter out unwanted tokens such as <pad>
lprobs[filter_token] = float("-inf")
# Add top k candidates to options list for next word
for index in torch.topk(lprobs, k)[1]:
option = (lprobs[index].item() + lprob, sentence + [index], new_hidden)
options.append(option)
options.sort(key=lambda x: x[0], reverse=True) # sort by lprob
best_options = options[:k] # place top candidates in beam
return best_options