-
Notifications
You must be signed in to change notification settings - Fork 7
/
dataset.py
39 lines (27 loc) · 1.05 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from itertools import repeat
import torch
from torch.utils.data import Dataset
class KnowledgeGraphDataset(Dataset):
def __init__(self, x, y, e_to_index, r_to_index):
self.x = x
self.y = y
self.e_to_index = e_to_index
self.r_to_index = r_to_index
assert len(x) == len(y)
def __len__(self):
return len(self.x)
def __getitem__(self, item):
s, r = self.x[item]
os = self.y[item]
indices = [self.e_to_index[o] for o in os]
return self.e_to_index[s], self.r_to_index[r], indices
def collate_train(batch):
max_len = max(map(lambda x: len(x[2]), batch))
# each object index list must have same length (to use torch.scatter_), therefore we pad with the first index
for _, _, indices in batch:
indices.extend(repeat(indices[0], max_len - len(indices)))
s, o, i = zip(*batch)
return torch.LongTensor(s), torch.LongTensor(o), torch.LongTensor(i)
def collate_valid(batch):
s, o, i = zip(*batch)
return torch.LongTensor(s), torch.LongTensor(o), list(i)