-
Notifications
You must be signed in to change notification settings - Fork 17
/
dmn_loader.py
172 lines (143 loc) · 5.41 KB
/
dmn_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
''' This is dataset loader file to load the bAbI dataset. '''
import re
import numpy as np
from glob import glob
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import default_collate
class adict(dict):
def __init__(self, *arg, **karg):
dict.__init__(self, *arg, **karg)
self.__dict__ = self
def pad_collate(batch):
max_len_ques = float('-inf')
max_sen_len_context = float('-inf')
max_len_context = float('-inf')
for item in batch:
contexts, ques, _ = item
if len(contexts) > max_len_context:
max_len_context = len(contexts)
if len(ques) > max_len_ques:
max_len_ques = len(ques)
for sen in contexts:
if(len(sen) > max_sen_len_context):
max_sen_len_context = len(sen)
max_len_context = min(max_len_context, 70)
for idx, item in enumerate(batch): # Going through each example in the batch which contains their ow context, question and answer.
context_i, question, answer = item
context_i = context[-max_len_context:] #???
context = np.zeros((max_len_context, max_sen_len_context))
for i, sen in enumerate(context_i): # going through ith context containing max_len_context sentences and a question
context[i] = np.pad(sen, (0, max_sen_len_context-len(sen)), 'constant', constant_values=0)
question = np.pad(question, (0, max_len_ques-len(question)), 'constant', constant_values=0)
batch[idx] = (context, question, answer)
return default_collate(batch)
class BabiDataSet(Dataset):
def __init__(self, task_id, mode='train'):
self.mode = mode
self.vocab_path = 'dataset/babi{}_vocab.pkl'.format(task_id)
train_data, test_data = get_train_test(task_id) # Get raw train_data and test_data from babi dataset
self.QA = adict()
self.QA.VOCAB = {'<PAD>': 0, '<EOS>':1}
self.QA.IVOCAB = {0:'<PAD>', 1:'<EOS>'}
self.train = self.get_processed_data(train_data)
self.valid = [self.train[i][-int(len(self.train[i])/10):] for i in range(3)] # splitting into 90/10 train/val dataset
self.train = [self.train[i][:int(9*len(self.train[i])/10)] for i in range(3)] # splitting into 90/10 train/val dataset
self.test = self.get_processed_data(test_data)
def set_mode(self, mode):
self.mode = mode
def __len__(self):
if self.mode == 'train':
return len(self.train[0])
elif self.mode == 'valid':
return len(self.valid[0])
elif self.mode == 'test':
return len(self.test[0])
else:
print ("Invalid Mode!")
return
def __getdata__(self, index):
if self.mode == 'train':
contexts, questions, answers = self.train
elif self.mode == 'valid':
contexts, questions, answers = self.valid
elif self.mode == 'test':
contexts, questions, answers = self.test
return contexts[index], questions[index], answers[index]
def get_processed_data(self, raw_data):
unindexed= get_unprocessed_data(raw_data)
questions=[]
contexts= []
answers= []
for qa in unindexed:
context= [c.lower().split()+ ['<EOS>'] for c in qa['C']]
for con in context:
for token in con:
self.build_vocab(token)
context= [[self.QA.VOCAB[token] for token in sentence] for sentence in context]
question= qa['Q'].lower().split()+ ['<EOS>']
for token in question:
self.build_vocab(token)
question= [self.QA.VOCAB[token] for token in question]
self.build_vocab(qa['A'].lower())
answer= self.QA.VOCAB[qa['A'].lower()]
contexts.append(context)
questions.append(question)
answers.append(answer)
return (contexts, questions, answers)
def build_vocab(self, token):
if not token in self.QA.VOCAB:
next_index= len(self.QA.VOCAB)
self.QA.VOCAB[token]= next_index
self.QA.IVOCAB[next_index]= token
def get_train_test(task_id):
paths = glob('data/en-10k/qa{}_*'.format(task_id))
for path in paths:
if 'train' in path:
with open(path, 'r') as f:
train = f.read()
elif 'test' in path:
with open(path, 'r') as f:
test = f.read()
return train, test
def build_vocab(raw_data):
lowered= raw_data.lower()
tokens= re.findall('[a-zA-Z]+',lowered)
types= set(tokens)
return types
def get_unprocessed_data(raw_data):
tasks = []
task = None
data = raw_data.strip().split('\n')
for i, line in enumerate(data):
id = int(line[0:line.find(' ')])
if id == 1:
task = {"C": "", "Q": "", "A": "", "S": ""}
counter = 0
id_map = {}
line = line.strip()
line = line.replace('.', ' . ')
line = line[line.find(' ')+1:]
# if not a question
if line.find('?') == -1:
task["C"] += line + '<line>'
id_map[id] = counter
counter += 1
else:
idx = line.find('?')
tmp = line[idx+1:].split('\t')
task["Q"] = line[:idx]
task["A"] = tmp[1].strip()
task["S"] = [] # Supporting facts
for num in tmp[2].split():
task["S"].append(id_map[int(num.strip())])
tc = task.copy()
tc['C'] = tc['C'].split('<line>')[:-1]
tasks.append(tc)
return tasks
if __name__ == '__main__':
dataset_train= BabiDataSet(20, is_train= True) # Loading the dataset with task_id = 20
train_loader= DataLoader(dataset_train,batch_size=2, shuffle=True,collate_fn= pad_collate)
for batch_idx, data in enumerate(train_loader):
contexts, questions, answers= data
break