-
Notifications
You must be signed in to change notification settings - Fork 0
/
alphabet.py
112 lines (89 loc) · 3.78 KB
/
alphabet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
Copyright (C) 2017 Pierpaolo Basile, Pierluigi Cassotti
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import json
import os
import utils
class Alphabet:
def __init__(self, name, keep_growing=True):
self.__name = name
self.instance2index = {}
self.instances = []
self.keep_growing = keep_growing
# Index 0 is occupied by default, all else following.
self.default_index = 0
self.next_index = 1
self.logger = utils.get_logger('Alphabet')
def add(self, instance):
if instance not in self.instance2index:
self.instances.append(instance)
self.instance2index[instance] = self.next_index
self.next_index += 1
def get_index(self, instance):
try:
return self.instance2index[instance]
except KeyError:
if self.keep_growing:
index = self.next_index
self.add(instance)
return index
else:
return self.default_index
def get_instance(self, index):
if index == 0:
# First index is occupied by the wildcard element.
return 'PAD'
try:
return self.instances[index - 1]
except IndexError:
self.logger.warn('unknown instance, return the first label.')
return self.instances[0]
def size(self):
return len(self.instances) + 1
def iteritems(self):
return iter(sorted(self.instance2index.items(), key=operator.itemgetter(1)))
def enumerate_items(self, start=1):
if start < 1 or start >= self.size():
raise IndexError("Enumerate is allowed between [1 : size of the alphabet)")
return zip(range(start, len(self.instances) + 1), self.instances[start - 1:])
def close(self):
self.keep_growing = False
def open(self):
self.keep_growing = True
def get_content(self):
return {'instance2index': self.instance2index, 'instances': self.instances}
def from_json(self, data):
self.instances = data["instances"]
self.instance2index = data["instance2index"]
def save(self, output_directory, name=None):
"""
Save both alhpabet records to the given directory.
:param output_directory: Directory to save model and weights.
:param name: The alphabet saving name, optional.
:return:
"""
saving_name = name if name else self.__name
try:
json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w'))
except Exception as e:
self.logger.warn("Alphabet is not saved: " % repr(e))
def load(self, input_directory, name=None):
"""
Load model architecture and weights from the give directory. This allow we use old models even the structure
changes.
:param input_directory: Directory to save model and weights
:return:
"""
loading_name = name if name else self.__name
self.from_json(json.load(open(os.path.join(input_directory, loading_name + ".json"))))