-
Notifications
You must be signed in to change notification settings - Fork 2
/
vocabulary.py
308 lines (232 loc) · 11.5 KB
/
vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
from bidict import bidict
import pickle
import logging
logger = logging.getLogger(__name__)
class Vocabulary():
"""This class maps strings to integers, which also allow many namespaces
"""
DEFAULT_PAD_TOKEN = '*@PAD@*'
DEFAULT_UNK_TOKEN = '*@UNK@*'
def __init__(self,
counters=dict(),
min_count=dict(),
pretrained_vocab=dict(),
intersection_namespace=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""initialize vocabulary
Keyword Arguments:
counters {dict} -- multiple counter (default: {dict()})
min_count {dict} -- min count dict (default: {dict()})
pretrained_vocab {dict} -- pretrained vocabulary (default: {dict()})
intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.min_count = dict(min_count)
self.intersection_namespace = dict(intersection_namespace)
self.no_pad_namespace = set(no_pad_namespace)
self.no_unk_namespace = set(no_unk_namespace)
self.contain_pad_namespace = dict(contain_pad_namespace)
self.contain_unk_namespace = dict(contain_unk_namespace)
self.vocab = dict()
self.extend_from_counter(counters, self.min_count, self.no_pad_namespace,
self.no_unk_namespace)
self.extend_from_pretrained_vocab(pretrained_vocab, self.intersection_namespace,
self.no_pad_namespace, self.no_unk_namespace)
logger.info("Initialize vocabulary successfully.")
def extend_from_pretrained_vocab(self,
pretrained_vocab,
intersection_namespace=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""extend vocabulary from pretrained vocab
Arguments:
pretrained_vocab {dict} -- pretrained vocabulary
Keyword Arguments:
intersection_namespace {dict} -- intersection namespace correspond to pretrained vocabulary in case of too large pretrained vocabulary (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.intersection_namespace.update(dict(intersection_namespace))
self.no_pad_namespace.update(set(no_pad_namespace))
self.no_unk_namespace.update(set(no_unk_namespace))
self.contain_pad_namespace.update(dict(contain_pad_namespace))
self.contain_unk_namespace.update(dict(contain_unk_namespace))
for namespace, vocab in pretrained_vocab.items():
self.__namespace_init(namespace)
is_intersection = namespace in self.intersection_namespace
intersection_vocab = self.vocab[
self.intersection_namespace[namespace]] if is_intersection else []
for key, value in vocab.items():
if not is_intersection or key in intersection_vocab:
self.vocab[namespace][key] = value
logger.info(
"Vocabulay {} (size: {}) was constructed successfully from pretrained_vocab.".
format(namespace, len(self.vocab[namespace])))
def extend_from_counter(self,
counters,
min_count=dict(),
no_pad_namespace=list(),
no_unk_namespace=list(),
contain_pad_namespace=dict(),
contain_unk_namespace=dict()):
"""extend vocabulary from counter
Arguments:
counters {dict} -- multiply counter
Keyword Arguments:
min_count {dict} -- min count dict (default: {dict()})
no_pad_namespace {list} -- no paddding namespace (default: {list()})
no_unk_namespace {list} -- no unknown namespace (default: {list()})
contain_pad_namespace {dict} -- contain padding token namespace (default: {dict()})
contain_unk_namespace {dict} -- contain unknown token namespace (default: {dict()})
"""
self.no_pad_namespace.update(set(no_pad_namespace))
self.no_unk_namespace.update(set(no_unk_namespace))
self.contain_pad_namespace.update(dict(contain_pad_namespace))
self.contain_unk_namespace.update(dict(contain_unk_namespace))
self.min_count.update(dict(min_count))
for namespace, counter in counters.items():
self.__namespace_init(namespace)
for key in counter:
minc = min_count[namespace] \
if min_count and namespace in min_count else 1
if counter[key] >= minc:
self.vocab[namespace][key] = len(self.vocab[namespace])
logger.info("Vocabulay {} (size: {}) was constructed successfully from counter.".format(
namespace, len(self.vocab[namespace])))
def add_tokens_to_namespace(self, tokens, namespace):
"""This function adds tokens to one namespace for extending vocabulary
Arguments:
tokens {list} -- token list
namespace {str} -- namespace name
"""
if namespace not in self.vocab:
self.__namespace_init(namespace)
logger.error('Add Namespace {} into vocabulary.'.format(namespace))
for token in tokens:
if token not in self.vocab[namespace]:
self.vocab[namespace][token] = len(self.vocab[namespace])
def get_token_index(self, token, namespace):
"""This function gets token index in one namespace of vocabulary
Arguments:
token {str} -- token
namespace {str} -- namespace name
Raises:
RuntimeError: namespace not exists
Returns:
int -- token index
"""
if token in self.vocab[namespace]:
return self.vocab[namespace][token]
if namespace not in self.no_unk_namespace:
return self.get_unknown_index(namespace)
logger.error("Can not find the index of {} from a no unknown token namespace {}.".format(
token, namespace))
raise RuntimeError(
"Can not find the index of {} from a no unknown token namespace {}.".format(
token, namespace))
def get_token_from_index(self, index, namespace):
"""This function gets token using index in vocabulary
Arguments:
index {int} -- index
namespace {str} -- namespace name
Raises:
RuntimeError: index out of range
Returns:
str -- token
"""
if index < len(self.vocab[namespace]):
return self.vocab[namespace].inv[index]
logger.error("The index {} is out of vocabulary {} range.".format(index, namespace))
raise RuntimeError("The index {} is out of vocabulary {} range.".format(index, namespace))
def get_vocab_size(self, namespace):
"""This function gets the size of one namespace in vocabulary
Arguments:
namespace {str} -- namespace name
Returns:
int -- vocabulary size
"""
return len(self.vocab[namespace])
def get_all_namespaces(self):
"""This function gets all namespaces
Returns:
list -- all namespaces vocabulary contained
"""
return set(self.vocab)
def get_padding_index(self, namespace):
"""This function gets padding token index in one namespace of vocabulary
Arguments:
namespace {str} -- namespace name
Raises:
RuntimeError: no padding
Returns:
int -- padding index
"""
if namespace not in self.vocab:
raise RuntimeError("Namespace {} doesn't exist.".format(namespace))
if namespace not in self.no_pad_namespace:
if namespace not in self.contain_pad_namespace:
return self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN]
return self.vocab[namespace][self.contain_pad_namespace[namespace]]
logger.error("Namespace {} doesn't has paddding token.".format(namespace))
raise RuntimeError("Namespace {} doesn't has paddding token.".format(namespace))
def get_unknown_index(self, namespace):
"""This function gets unknown token index in one namespace of vocabulary
Arguments:
namespace {str} -- namespace name
Raises:
RuntimeError: no unknown
Returns:
int -- unknown index
"""
if namespace not in self.vocab:
raise RuntimeError("Namespace {} doesn't exist.".format(namespace))
if namespace not in self.no_unk_namespace:
if namespace not in self.contain_unk_namespace:
return self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN]
return self.vocab[namespace][self.contain_unk_namespace[namespace]]
logger.error("Namespace {} doesn't has unknown token.".format(namespace))
raise RuntimeError("Namespace {} doesn't has unknown token.".format(namespace))
def get_namespace_tokens(self, namesapce):
"""This function returns all tokens in one namespace
Arguments:
namesapce {str} -- namespce name
Returns:
dict_keys -- all tokens
"""
return self.vocab[namesapce]
def save(self, file_path):
"""This function saves vocabulary into file
Arguments:
file_path {str} -- file path
"""
pickle.dump(self, open(file_path, 'wb'))
@classmethod
def load(cls, file_path):
"""This function loads vocabulary from file
Arguments:
file_path {str} -- file path
Returns:
Vocabulary -- vocabulary
"""
return pickle.load(open(file_path, 'rb'), encoding='utf-8')
def __namespace_init(self, namespace):
"""This function initializes a namespace,
adds pad and unk token to one namespace of vacabulary
Arguments:
namespace {str} -- namespace
"""
self.vocab[namespace] = bidict()
if namespace not in self.no_pad_namespace and namespace not in self.contain_pad_namespace:
self.vocab[namespace][Vocabulary.DEFAULT_PAD_TOKEN] = len(self.vocab[namespace])
if namespace not in self.no_unk_namespace and namespace not in self.contain_unk_namespace:
self.vocab[namespace][Vocabulary.DEFAULT_UNK_TOKEN] = len(self.vocab[namespace])