-
Notifications
You must be signed in to change notification settings - Fork 147
/
oag_dataset.py
81 lines (70 loc) · 2.64 KB
/
oag_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import dgl
import torch as th
from dgl.data import extract_archive, download
from . import BaseDataset, register_dataset
@register_dataset("oag_dataset")
class OAGDataset(BaseDataset):
def __init__(self, *args, **kwargs):
super(OAGDataset, self).__init__(*args, **kwargs)
self.name = "oag_cs"
self.valid_idx = None
self.test_idx = None
self.train_idx = None
self.dims = None
self.g = None
self.category = None
self.num_classes = None
self.has_feature = True
self.data_path = "./openhgnn/dataset/data/oag_cs.tgz"
self.raw_dir = "./openhgnn/dataset/data"
self.g_path = "./openhgnn/dataset/data/oag_cs/oag_cs.bin"
self.url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/oag_cs.tgz"
if not self.has_cache():
self.download()
self.load_graph_from_disk(self.g_path)
self.meta_paths_dict = {
"P-A": [("paper", "paper-author", "author")],
"A-P": [("author", "author-paper", "paper")],
"V-A": [
("venue", "venue-paper", "paper"),
("paper", "paper-author", "author"),
],
"A-V": [
("author", "author-paper", "paper"),
("paper", "paper-venue", "venue"),
],
}
def load_graph_from_disk(self, file_path):
glist, dims = dgl.load_graphs(file_path)
self.g = glist[0]
self.dims = dims
def get_labels(self, task_type, node_type):
assert task_type in ["L1", "L2"]
return self.g.ndata[task_type].pop(node_type)
def get_split(self, node_type, device="cpu"):
train_mask = self.g.nodes[node_type].data["train_mask"]
test_mask = self.g.nodes[node_type].data["test_mask"]
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
valid_idx = train_idx
self.train_idx = train_idx.to(device)
self.test_idx = test_idx.to(device)
self.valid_idx = valid_idx.to(device)
return self.train_idx, self.valid_idx, self.test_idx
def get_feature(
self,
):
return self.g.ndata.pop("feat")
def to(self, device):
self.g = self.g.to(device)
return self
def download(self):
if os.path.exists(self.data_path):
pass
else:
file_path = os.path.join(self.raw_dir)
download(self.url, path=file_path)
extract_archive(self.data_path, os.path.join(self.raw_dir))
def has_cache(self):
return os.path.exists(self.g_path)