Each dataset in the data/ directory is pre-processed and prepared in a specific folder, for example data/molecules/ for the ZINC dataset. The graph dataset is prepared and saved in the DGL format (compatible with PyTorch), see e.g. prepare_molecules.ipynb
where the graph dataset is saved in ZINC.pkl file.
File data.py
contains function LoadData()
that loads any dataset by calling a specific data function, for instance MoleculeDataset()
that is defined in molecules.py.
The user will first prepare for each graph, the adjacency matrix, the node feature and the edge feature (if any). See prepare_molecules.ipynb that calls class MoleculeDatasetDGL() defined in file molecules.py.
Then, the user will convert the graph into the DGL format. See class MoleculeDGL() in file molecules.py. User will have to complete the _prepare()
method for the new dataset. A standard code is
class NewDatasetDGL(torch.utils.data.Dataset):
def __init__(self, name, **kwargs):
# other useful parameters, if needed
self.graph_labels = []
self.graph_lists = []
self._prepare()
def _prepare(self):
# write here the code for preparation
# of the new graph classification data
# Steps
# S1: initilize a dgl graph g = dgl.DGLGraph()
# S2: add nodes using g.add_nodes()
# S3: add edges using g.add_edges()
# S4: add node feat by assigning a torch tensor to g.ndata['feat']
# S5: add edge feat by assigning a torch tensor to g.edata['feat']
# S6: Append the dgl graph to self.graph_lists
# Repeat Steps S1 to S6 for 'n_samples' number of times
# See data/molecules.py file for example, or the following link in dgl docs:
# https://docs.dgl.ai/en/latest/_modules/dgl/data/minigc.html#MiniGCDataset
def __len__(self):
"""Return the number of graphs in the dataset."""
return self.n_samples
def __getitem__(self, idx):
"""
Get the idx^th sample.
Parameters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, int)
DGLGraph with node feature stored in `feat` field
And its label.
"""
return self.graph_lists[idx], self.graph_labels[idx]
At the next step, the user will define a class NewDataset()
that loads the DGL dataset and define a collate()
module to create mini-batches of graphs. Note that collate()
function is for the MP-GCNs which use batches of sparse graphs, and collate_dense_gnn()
is for the WL-GNNs which use dense graphs, with no batching of multiple graphs in one tensor.
class NewDataset(torch.utils.data.Dataset):
def __init__(self, name):
with open(name+'.pkl',"rb") as f:
f = pickle.load(f)
self.train = f[0]
self.val = f[1]
self.test = f[2]
def collate(self, samples):
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, labels
def collate_dense_gnn(self, samples):
"""
we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}.
Then we build a zero-initialized tensor, say X, in R^{(1 + d_n + d_e) x n x n}. X[0, :, :] is the adjacency matrix.
The diagonal X[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i.
The off diagonal X[1+d_n:, i, j] store edge features of edge(i, j).
"""
# prepare one dense tensor using above instruction
# store as x_will_all_info
return x_with_all_info, labels
The user will upgrade LoadData(DATASET_NAME)
in data.py
with the name of the new dataset and will return the dataset class NewDataset()
.
def LoadData(DATASET_NAME):
if DATASET_NAME == 'NEW_DATA':
return NewDataset(DATASET_NAME)
Eventually, the user will call function LoadData(DATASET_NAME)
to load the dataset and function DataLoader()
to create mini-batch of graphs. For example, this code loads the ZINC dataset and prepares mini-batch of 128 train graphs:
from data.data import LoadData
from data.molecules import MoleculeDataset
from torch.utils.data import DataLoader
DATASET_NAME = 'ZINC'
dataset = LoadData(DATASET_NAME)
train_loader = DataLoader(dataset.train, batch_size=128, shuffle=True, collate_fn=MoleculeDataset.collate)
Note that the batching approach for MP-GCNs is not applicable for WL-GNNs which operate on dense tensors. Therefore, we simply have the following code for WL-GNNs.
train_loader = DataLoader(dataset.train, shuffle=True, collate_fn=MoleculeDataset.collate_dense_gnn)
A data split for the TU dataset that preserves the class distribution across train-validation-test sets was prepared. The splits are stored in the TUs/ folder. Similarly, the split indices for CSL are stored in the CSL/ folder. We also store the split for the ZINC dataset in the molecules/ folder. For COLLAB, the dataset splits are automatically fetched from the OGB library.