-
Notifications
You must be signed in to change notification settings - Fork 147
/
EdgeClassificationDataset.py
144 lines (129 loc) · 6.5 KB
/
EdgeClassificationDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch as th
from . import BaseDataset, register_dataset
from . import Mg2vecDataSet
@register_dataset('edge_classification')
class EdgeClassificationDataset(BaseDataset):
r"""
The class *EdgeClassificationDataset* is a base class for datasets which can be used in task *edge classification*.
So its subclass should contain attributes such as graph, category, num_classes and so on.
Besides, it should implement the functions *get_labels()* and *get_split()*.
Attributes
-------------
g : dgl.DGLHeteroGraph
The heterogeneous graph.
category : str
The category(or target) node type need to be predict. In general, we predict only one node type.
num_classes : int
The target node will be classified into num_classes categories.
has_feature : bool
Whether the dataset has feature. Default ``False``.
multi_label : bool
Whether the node has multi label. Default ``False``. For now, only HGBn-IMDB has multi-label.
"""
def __init__(self, *args, **kwargs):
super(EdgeClassificationDataset, self).__init__(*args, **kwargs)
self.g = None
self.category = None
self.num_classes = None
self.has_feature = False
self.multi_label = False
def get_labels(self):
r"""
The subclass of dataset should overwrite the function. We can get labels of target nodes through it.
Notes
------
In general, the labels are th.LongTensor.
But for multi-label dataset, they should be th.FloatTensor. Or it will raise
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 target' in call to _thnn_nll_loss_forward
return
-------
labels : torch.Tensor
"""
if 'labels' in self.g.edges[self.category].data:
labels = self.g.edges[self.category].data.pop('labels').long()
elif 'label' in self.g.edges[self.category].data:
labels = self.g.edges[self.category].data.pop('label').long()
else:
raise ValueError('Labels of nodes are not in the hg.edges[category].data.')
labels = labels.float() if self.multi_label else labels
return labels
def get_split(self, validation=True):
r"""
Parameters
----------
validation : bool
Whether to split dataset. Default ``True``. If it is False, val_idx will be same with train_idx.
We can get idx of train, validation and test through it.
return
-------
train_idx, val_idx, test_idx : torch.Tensor, torch.Tensor, torch.Tensor
"""
if 'train_mask' not in self.g.edges[self.category].data:
self.logger.dataset_info("The dataset has no train mask. "
"So split the category nodes randomly. And the ratio of train/test is 8:2.")
num_nodes = self.g.number_of_nodes(self.category)
n_test = int(num_nodes * 0.2)
n_train = num_nodes - n_test
train, test = th.utils.data.random_split(range(num_nodes), [n_train, n_test])
train_idx = th.tensor(train.indices)
test_idx = th.tensor(test.indices)
if validation:
self.logger.dataset_info("Split train into train/valid with the ratio of 8:2 ")
random_int = th.randperm(len(train_idx))
valid_idx = train_idx[random_int[:len(train_idx) // 5]]
train_idx = train_idx[random_int[len(train_idx) // 5:]]
else:
self.logger.dataset_info("Set valid set with train set.")
valid_idx = train_idx
train_idx = train_idx
else:
train_mask = self.g.edges[self.category].data.pop('train_mask')
test_mask = self.g.edges[self.category].data.pop('test_mask')
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
if validation:
if 'val_mask' in self.g.edges[self.category].data:
val_mask = self.g.edges[self.category].data.pop('val_mask')
valid_idx = th.nonzero(val_mask, as_tuple=False).squeeze()
elif 'valid_mask' in self.g.edges[self.category].data:
val_mask = self.g.edges[self.category].data.pop('valid_mask').squeeze()
valid_idx = th.nonzero(val_mask, as_tuple=False).squeeze()
else:
self.logger.dataset_info("Split train into train/valid with the ratio of 8:2 ")
random_int = th.randperm(len(train_idx))
valid_idx = train_idx[random_int[:len(train_idx) // 5]]
train_idx = train_idx[random_int[len(train_idx) // 5:]]
else:
self.logger.dataset_info("Set valid set with train set.")
valid_idx = train_idx
train_idx = train_idx
self.train_idx = train_idx
self.valid_idx = valid_idx
self.test_idx = test_idx
return self.train_idx, self.valid_idx, self.test_idx
@register_dataset('hin_edge_classification')
class HIN_EdgeClassification(EdgeClassificationDataset):
r"""
The HIN dataset are all used in different papers. So we preprocess them and store them as form of dgl.DGLHeteroGraph.
The dataset name combined with paper name through 4(for).
Dataset Name :
dblp4Mg2vec/ ...
"""
def __init__(self, dataset_name, *args, **kwargs):
super(HIN_EdgeClassification, self).__init__(*args, **kwargs)
self.g, self.category, self.num_classes = self.load_HIN(dataset_name)
def load_HIN(self, name_dataset):
if name_dataset == 'dblp4Mg2vec_4':
# which is used in MG2VEC with size=4
dataset = Mg2vecDataSet(name='dblp4Mg2vec_4', raw_dir='')
g = dataset[0].long()
category = 'relation'
num_classes = 3
return g, category, num_classes
if name_dataset == 'dblp4Mg2vec_5':
# which is used in MG2VEC with size=5
dataset = Mg2vecDataSet(name='dblp4Mg2vec_5', raw_dir='')
g = dataset[0].long()
category = 'relation'
num_classes = 3
return g, category, num_classes