-
Notifications
You must be signed in to change notification settings - Fork 0
/
6_load_data.py
227 lines (186 loc) · 7.96 KB
/
6_load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
Make Your Own Dataset
=====================
This tutorial assumes that you already know :doc:`the basics of training a
GNN for node classification <1_introduction>` and :doc:`how to
create, load, and store a DGL graph <2_dglgraph>`.
By the end of this tutorial, you will be able to
- Create your own graph dataset for node classification, link
prediction, or graph classification.
(Time estimate: 15 minutes)
"""
######################################################################
# ``DGLDataset`` Object Overview
# ------------------------------
#
# Your custom graph dataset should inherit the ``dgl.data.DGLDataset``
# class and implement the following methods:
#
# - ``__getitem__(self, i)``: retrieve the ``i``-th example of the
# dataset. An example often contains a single DGL graph, and
# occasionally its label.
# - ``__len__(self)``: the number of examples in the dataset.
# - ``process(self)``: load and process raw data from disk.
#
######################################################################
# Creating a Dataset for Node Classification or Link Prediction from CSV
# ----------------------------------------------------------------------
#
# A node classification dataset often consists of a single graph, as well
# as its node and edge features.
#
# This tutorial takes a small dataset based on `Zachary’s Karate Club
# network <https://en.wikipedia.org/wiki/Zachary%27s_karate_club>`__. It
# contains
#
# * A ``members.csv`` file containing the attributes of all
# members, as well as their attributes.
#
# * An ``interactions.csv`` file
# containing the pair-wise interactions between two club members.
#
import urllib.request
import pandas as pd
urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/members.csv', './members.csv')
urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/interactions.csv', './interactions.csv')
members = pd.read_csv('./members.csv')
members.head()
interactions = pd.read_csv('./interactions.csv')
interactions.head()
######################################################################
# This tutorial treats the members as nodes and interactions as edges. It
# takes age as a numeric feature of the nodes, affiliated club as the label
# of the nodes, and edge weight as a numeric feature of the edges.
#
# .. note::
#
# The original Zachary’s Karate Club network does not have
# member ages. The ages in this tutorial are generated synthetically
# for demonstrating how to add node features into the graph for dataset
# creation.
#
# .. note::
#
# In practice, taking age directly as a numeric feature may
# not work well in machine learning; strategies like binning or
# normalizing the feature would work better. This tutorial directly
# takes the values as-is for simplicity.
#
import dgl
from dgl.data import DGLDataset
import torch
import os
class KarateClubDataset(DGLDataset):
def __init__(self):
super().__init__(name='karate_club')
def process(self):
nodes_data = pd.read_csv('./members.csv')
edges_data = pd.read_csv('./interactions.csv')
node_features = torch.from_numpy(nodes_data['Age'].to_numpy())
node_labels = torch.from_numpy(nodes_data['Club'].astype('category').cat.codes.to_numpy())
edge_features = torch.from_numpy(edges_data['Weight'].to_numpy())
edges_src = torch.from_numpy(edges_data['Src'].to_numpy())
edges_dst = torch.from_numpy(edges_data['Dst'].to_numpy())
self.graph = dgl.graph((edges_src, edges_dst), num_nodes=nodes_data.shape[0])
self.graph.ndata['feat'] = node_features
self.graph.ndata['label'] = node_labels
self.graph.edata['weight'] = edge_features
# If your dataset is a node classification dataset, you will need to assign
# masks indicating whether a node belongs to training, validation, and test set.
n_nodes = nodes_data.shape[0]
n_train = int(n_nodes * 0.6)
n_val = int(n_nodes * 0.2)
train_mask = torch.zeros(n_nodes, dtype=torch.bool)
val_mask = torch.zeros(n_nodes, dtype=torch.bool)
test_mask = torch.zeros(n_nodes, dtype=torch.bool)
train_mask[:n_train] = True
val_mask[n_train:n_train + n_val] = True
test_mask[n_train + n_val:] = True
self.graph.ndata['train_mask'] = train_mask
self.graph.ndata['val_mask'] = val_mask
self.graph.ndata['test_mask'] = test_mask
def __getitem__(self, i):
return self.graph
def __len__(self):
return 1
dataset = KarateClubDataset()
graph = dataset[0]
print(graph)
######################################################################
# Since a link prediction dataset only involves a single graph, preparing
# a link prediction dataset will have the same experience as preparing a
# node classification dataset.
#
######################################################################
# Creating a Dataset for Graph Classification from CSV
# ----------------------------------------------------
#
# Creating a graph classification dataset involves implementing
# ``__getitem__`` to return both the graph and its graph-level label.
#
# This tutorial demonstrates how to create a graph classification dataset
# with the following synthetic CSV data:
#
# - ``graph_edges.csv``: containing three columns:
#
# - ``graph_id``: the ID of the graph.
# - ``src``: the source node of an edge of the given graph.
# - ``dst``: the destination node of an edge of the given graph.
#
# - ``graph_properties.csv``: containing three columns:
#
# - ``graph_id``: the ID of the graph.
# - ``label``: the label of the graph.
# - ``num_nodes``: the number of nodes in the graph.
#
urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/graph_edges.csv', './graph_edges.csv')
urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/graph_properties.csv', './graph_properties.csv')
edges = pd.read_csv('./graph_edges.csv')
properties = pd.read_csv('./graph_properties.csv')
edges.head()
properties.head()
class SyntheticDataset(DGLDataset):
def __init__(self):
super().__init__(name='synthetic')
def process(self):
edges = pd.read_csv('./graph_edges.csv')
properties = pd.read_csv('./graph_properties.csv')
self.graphs = []
self.labels = []
# Create a graph for each graph ID from the edges table.
# First process the properties table into two dictionaries with graph IDs as keys.
# The label and number of nodes are values.
label_dict = {}
num_nodes_dict = {}
for _, row in properties.iterrows():
label_dict[row['graph_id']] = row['label']
num_nodes_dict[row['graph_id']] = row['num_nodes']
# For the edges, first group the table by graph IDs.
edges_group = edges.groupby('graph_id')
# For each graph ID...
for graph_id in edges_group.groups:
# Find the edges as well as the number of nodes and its label.
edges_of_id = edges_group.get_group(graph_id)
src = edges_of_id['src'].to_numpy()
dst = edges_of_id['dst'].to_numpy()
num_nodes = num_nodes_dict[graph_id]
label = label_dict[graph_id]
# Create a graph and add it to the list of graphs and labels.
g = dgl.graph((src, dst), num_nodes=num_nodes)
self.graphs.append(g)
self.labels.append(label)
# Convert the label list to tensor for saving.
self.labels = torch.LongTensor(self.labels)
def __getitem__(self, i):
return self.graphs[i], self.labels[i]
def __len__(self):
return len(self.graphs)
dataset = SyntheticDataset()
graph, label = dataset[0]
print(graph, label)
# Thumbnail Courtesy: (Un)common Use Cases for Graph Databases, Michal Bachman
# sphinx_gallery_thumbnail_path = '_static/blitz_6_load_data.png'