-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
executable file
·101 lines (77 loc) · 2.79 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from helper import *
from torch.utils.data import Dataset
class TrainDataset(Dataset):
"""
Training Dataset class.
Parameters
----------
triples: The triples used for training the model
params: Parameters for the experiments
Returns
-------
A training Dataset class instance used by DataLoader
"""
def __init__(self, triples, params):
self.triples = triples
self.p = params
self.entities = np.arange(self.p.num_ent, dtype=np.int32)
def __len__(self):
return len(self.triples)
def __getitem__(self, idx):
ele = self.triples[idx]
triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(ele['label']), np.float32(ele['sub_samp'])
trp_label = self.get_label(label)
if self.p.lbl_smooth != 0.0:
trp_label = (1.0 - self.p.lbl_smooth) * trp_label + (1.0/self.p.num_ent)
return triple, trp_label, None, None
@staticmethod
def collate_fn(data):
triple = torch.stack([_[0] for _ in data], dim=0)
trp_label = torch.stack([_[1] for _ in data], dim=0)
return triple, trp_label
def get_neg_ent(self, triple, label):
def get(triple, label):
pos_obj = label
mask = np.ones([self.p.num_ent], dtype=np.bool)
mask[label] = 0
neg_ent = np.int32(np.random.choice(self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1])
neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent))
return neg_ent
neg_ent = get(triple, label)
return neg_ent
def get_label(self, label):
y = np.zeros([self.p.num_ent], dtype=np.float32)
for e2 in label:
y[e2] = 1.0
return torch.FloatTensor(y)
class TestDataset(Dataset):
"""
Evaluation Dataset class.
Parameters
----------
triples: The triples used for evaluating the model
params: Parameters for the experiments
Returns
-------
An evaluation Dataset class instance used by DataLoader for model evaluation
"""
def __init__(self, triples, params):
self.triples = triples
self.p = params
def __len__(self):
return len(self.triples)
def __getitem__(self, idx):
ele = self.triples[idx]
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
label = self.get_label(label)
return triple, label
@staticmethod
def collate_fn(data):
triple = torch.stack([_[0] for _ in data], dim=0)
label = torch.stack([_[1] for _ in data], dim=0)
return triple, label
def get_label(self, label):
y = np.zeros([self.p.num_ent], dtype=np.float32)
for e2 in label:
y[e2] = 1.0
return torch.FloatTensor(y)