-
Notifications
You must be signed in to change notification settings - Fork 0
/
q2_5.py
135 lines (116 loc) · 5.86 KB
/
q2_5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
from torch.autograd import Variable
# from helpers import *
# from model import *
chars = ",.0123456789?abcdefghijklmnopqrstuvwxyz-"
codes = """--..-- .-.-.- ----- .---- ..--- ...-- ....- ..... -.... --... ---..
----. ..--.. .- -... -.-. -.. . ..-. --. .... .. .--- -.- .-.. --
-. --- .--. --.- .-. ... - ..- ...- .-- -..- -.-- --.. -....-"""
keys = dict(zip(chars, codes.split()))
text = """.-.-....-.-...--.-...-....--...-.-...-.--.------..-...-..-.-.---...-..-..---..-......--..-.--.-...-.--......-.........-..-.----.-.....-....--.-.-.--.-..---..-......-...-..-.--.-.----......-.--.-----..-------.-.-..---.-.-.--..-.-...............--...--....--..-....-.-----.....-...-------.-......-.........-..-..--.-....-...--....-.--.-.....--..-.....--..-.---.--...-.-.-..-.-.....---.-.-.-.----....-..-.....--..----......-...-.--.-...--.....--.....-.......-....---..-..--...-------.--....---..---.....-.-.-....-.-...--..-....---..--.--...-.-.-..-.-.....---.-.-.-.----....-..-.....--..----."""
text2 = '...---.........--.-...-.-.----......-....-...-.--...-.......-.---.---.--..-.-...-....--..-...-...-...........-.-.---..-.-..-....-.---.-.-.........-----...--.----.--.........-...-....-.....-..-.-.--....-.....--.-.--.--.-.....-..-.--...-...--.....-......-.-......--..-....-.....-.-..---..-.--.--.-..-.--..--...-.....-.---....--.-......-..-..-----..-.-..-..-.....-----.-.-....-.....-.-.---..--....-...---..-.--..--..-.--...-.--..--.-...--...-.-.-.-..-.--.-.....-...-............--.....-......--..-...-..-..-.--.-.--......-..-..--...-.....-.--..-.....--...-.-.-.-.--.....-..--.-.-..-..---....-...-.--......-.'
text3 = '..-.-.-...-.......--.-..---------.-------..-----....-.-...-.'
temptext = text
def word_tensor(word):
wordid = word_to_id[word]
return torch.tensor([wordid])
#返回匹配结果,候选词
def match(text):
matched = []
for k,v in words_to_morse.items():
if text.startswith(v):
matched.append((k,len(v)))
return matched
def search(decoder,text,beam=1):
hidden = decoder.init_hidden()
matched = match(text)
#初始化beamlist,不考虑beam大小,而是把第一步所有可能的字符都考虑进来
beamlist = []
beamlist_cache = []
for pr,start in matched:
beamlist.append(((pr,),word_tensor(pr).unsqueeze(0),start,hidden,0))
# print(word_tensor(pr))
#list中存放的元组为 预测出的文本,前一个字符的向量,剩下的文本在电码中的位置,隐藏层,总概率
search_end = 0
with torch.no_grad():
while(search_end < beam):
search_end = 0
while True:
for i in range(len(beamlist)):
print(i,'>',end='')
print_result(beamlist[i][0],beamlist[i][4],maxlength = 10)
inp = input('>')
if inp == '':
break
elif inp=='x':
if(len(beamlist_cache)==0):
print('No cache.')
else:
beamlist = beamlist_cache.pop()
else:
if(len(beamlist_cache)>5):beamlist_cache.pop(0)
beamlist_cache.append(beamlist)
beamlist = [beamlist[int(inp)]]
break
temp_list = []
for pretext,lastword,start,hidden,score in beamlist:
if start<len(text) :
matched = match(text[start:])
if len(matched)==0:
continue
else :
temp_list.append((pretext,lastword,start,hidden,score))
search_end += 1
continue
output, hidden = decoder(lastword,hidden)
output_dist = output.data.view(-1).exp()
output_dist = output_dist.div(output_dist.sum())
output_dist = torch.log(output_dist)
for m_word,m_len in matched:
# if(m_word=='sister'):
# m_word = m_word
m_word_id = word_to_id[m_word]
temp_list.append(((*pretext,m_word),word_tensor(m_word).unsqueeze(0),start+m_len,hidden,score+float(output_dist[m_word_id])))
# temp_list.sort(key=lambda k:k[4]/(len(k[0])-1),reverse=True)
temp_list.sort(key=lambda k:k[4]/(len(k[0])-1)+k[2]/len(k[0]),reverse=True)
# temp_list.sort(key=lambda k:k[4]*(k[2]/len(text)),reverse=True)
# temp_list.sort(key=lambda k:k[4]*(1+0.5*k[2]/len(k[0])),reverse=True)
# temp_list.sort(key=lambda k:k[4]*(1+k[2]/len(text))+k[2]/len(k[0]),reverse=True)
if search_end == len(temp_list):
break
beamlist = []
if len(temp_list)==0:
print('Nothing matched! Please go back.')
# maxscore = temp_list[0][4]
for i in range(len(temp_list)):
if(i>=beam): break
# if(maxscore-temp_list[i][4]>0.2):break
beamlist.append(temp_list[i])
return [(a,e/(len(a)-1)) for a,b,c,d,e in beamlist]
def print_result(guess,score,maxlength = 0):
print(score,':',end=' ')
if maxlength > 0 and len(guess)>maxlength:
guess = guess[-maxlength:]
for w in guess:
print(w,end=' ')
print()
# decoder = torch.load('model.pt')
decoder = torch.load('lm_model.pt')
with open('word_to_id','r') as f:
word_to_id = eval(f.read())
with open('id_to_word','r') as f:
id_to_word = eval(f.read())
words_to_morse = {}
for word in word_to_id.keys():
morse = ""
# if not word.isalpha():
# continue
for c in word:
morse = morse + keys[c.lower()]
words_to_morse[word]=morse
# print(decoder.batch_size,' ',decoder.num_steps)
r = search(decoder,text,beam=30)
# print(r)
for s,sc in r:
print_result(s,sc)