diff --git a/README.md b/README.md index 17a09cf2..719c2dc5 100644 --- a/README.md +++ b/README.md @@ -436,7 +436,9 @@ Now, GammaGL supports about 60 models, we welcome everyone to use or contribute | [GGD [NeurIPS 2022]](./examples/ggd) | | :heavy_check_mark: | | :heavy_check_mark: | | [LTD [WSDM 2022]](./examples/ltd) | | :heavy_check_mark: | | :heavy_check_mark: | | [Graphormer [NeurIPS 2021]](./examples/graphormer) | | :heavy_check_mark: | | :heavy_check_mark: | -| [HiD-Net [AAAI 2023]](./examples/hid_net) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| [HiD-Net [AAAI 2023]](./examples/hid_net) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| [FusedGAT [MLSys 2022]](./examples/fusedgat) | | :heavy_check_mark: | | | +| [GLNN [ICLR 2022]](./examples/glnn) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Contrastive Learning | TensorFlow | PyTorch | Paddle | MindSpore | diff --git a/docs/source/api/gammagl.utils.rst b/docs/source/api/gammagl.utils.rst index ffae0afe..9f988f14 100644 --- a/docs/source/api/gammagl.utils.rst +++ b/docs/source/api/gammagl.utils.rst @@ -26,4 +26,5 @@ gammagl.utils gammagl.utils.negative_sampling gammagl.utils.to_scipy_sparse_matrix gammagl.utils.read_embeddings - gammagl.utils.homophily \ No newline at end of file + gammagl.utils.homophily + gammagl.utils.get_train_val_test_split \ No newline at end of file diff --git a/examples/glnn/readme.md b/examples/glnn/readme.md new file mode 100644 index 00000000..ecd9670f --- /dev/null +++ b/examples/glnn/readme.md @@ -0,0 +1,41 @@ +# Graph-less Neural Networks (GLNN) + +- Paper link: [https://arxiv.org/pdf/2110.08727](https://arxiv.org/pdf/2110.08727) +- Author's code repo: [https://github.com/snap-research/graphless-neural-networks](https://github.com/snap-research/graphless-neural-networks) + +# Dataset Statics +| Dataset | # Nodes | # Edges | # Classes | +| -------- | ------- | ------- | --------- | +| Cora | 2,708 | 10,556 | 7 | +| Citeseer | 3,327 | 9,228 | 6 | +| Pubmed | 19,717 | 88,651 | 3 | +| Computers| 13,752 | 491,722 | 10 | +| Photo | 7,650 | 238,162 | 8 | + +Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid), [Amazon](https://gammagl.readthedocs.io/en/latest/generated/gammagl.datasets.Amazon.html#gammagl.datasets.Amazon). + +# Results + +- Available dataset: "cora", "citeseer", "pubmed", "computers", "photo" +- Available teacher: "SAGE", "GCN", "GAT", "APPNP", "MLP" + +```bash +TL_BACKEND="tensorflow" python train_teacher.py --dataset cora --teacher SAGE +TL_BACKEND="tensorflow" python train_student.py --dataset cora --teacher SAGE +TL_BACKEND="torch" python train_teacher.py --dataset cora --teacher SAGE +TL_BACKEND="torch" python train_student.py --dataset cora --teacher SAGE +TL_BACKEND="paddle" python train_teacher.py --dataset cora --teacher SAGE +TL_BACKEND="paddle" python train_student.py --dataset cora --teacher SAGE +TL_BACKEND="mindspore" python train_teacher.py --dataset cora --teacher SAGE +TL_BACKEND="mindspore" python train_student.py --dataset cora --teacher SAGE +``` + +| Dataset | Paper | Our(tf) | Our(th) | Our(pd) | Our(ms) | +| --------- | ---------- | ---------- | ---------- | ---------- | ---------- | +| Cora | 80.54±1.35 | 80.94±0.31 | 80.84±0.30 | 80.90±0.21 | 81.04±0.30 | +| Citeseer | 71.77±2.01 | 70.74±0.87 | 71.34±0.55 | 71.18±1.20 | 70.58±1.14 | +| Pubmed | 75.42±2.31 | 77.90±0.07 | 77.88±0.23 | 77.78±0.19 | 77.78±0.13 | +| Computers | 83.03±1.87 | 83.45±0.61 | 82.78±0.47 | 83.03±0.14 | 83.40±0.45 | +| Photo | 92.11±1.08 | 91.93±0.16 | 91.91±0.24 | 91.89±0.27 | 91.88±0.21 | + +- The model performance is the average of 5 tests \ No newline at end of file diff --git a/examples/glnn/train.conf.yaml b/examples/glnn/train.conf.yaml new file mode 100644 index 00000000..8700cad5 --- /dev/null +++ b/examples/glnn/train.conf.yaml @@ -0,0 +1,140 @@ +global: + num_layers: 2 + hidden_dim: 128 + learning_rate: 0.01 + +cora: + SAGE: + fan_out: 5,5 + learning_rate: 0.01 + dropout_ratio: 0 + weight_decay: 0.0005 + + GCN: + hidden_dim: 64 + dropout_ratio: 0.8 + weight_decay: 0.001 + + MLP: + learning_rate: 0.01 + weight_decay: 0.005 + dropout_ratio: 0.6 + + GAT: + dropout_ratio: 0.6 + weight_decay: 0.01 + num_heads: 8 + attn_dropout_ratio: 0.3 + + APPNP: + dropout_ratio: 0.5 + weight_decay: 0.01 + + +citeseer: + SAGE: + fan_out: 5,5 + learning_rate: 0.01 + dropout_ratio: 0 + weight_decay: 0.0005 + + GCN: + hidden_dim: 64 + dropout_ratio: 0.8 + weight_decay: 0.001 + + MLP: + learning_rate: 0.01 + weight_decay: 0.001 + dropout_ratio: 0.1 + + GAT: + dropout_ratio: 0.6 + weight_decay: 0.01 + num_heads: 8 + attn_dropout_ratio: 0.3 + + APPNP: + dropout_ratio: 0.5 + weight_decay: 0.01 + +pubmed: + SAGE: + fan_out: 5,5 + learning_rate: 0.01 + dropout_ratio: 0 + weight_decay: 0.0005 + + GCN: + hidden_dim: 64 + dropout_ratio: 0.8 + weight_decay: 0.001 + + MLP: + learning_rate: 0.005 + weight_decay: 0 + dropout_ratio: 0.4 + + GAT: + dropout_ratio: 0.6 + weight_decay: 0.01 + num_heads: 8 + attn_dropout_ratio: 0.3 + + APPNP: + dropout_ratio: 0.5 + weight_decay: 0.01 + +computers: + SAGE: + fan_out: 5,5 + learning_rate: 0.01 + dropout_ratio: 0 + weight_decay: 0.0005 + + GCN: + hidden_dim: 64 + dropout_ratio: 0.8 + weight_decay: 0.001 + + MLP: + learning_rate: 0.001 + weight_decay: 0.002 + dropout_ratio: 0.3 + + GAT: + dropout_ratio: 0.6 + weight_decay: 0.01 + num_heads: 8 + attn_dropout_ratio: 0.3 + + APPNP: + dropout_ratio: 0.5 + weight_decay: 0.01 + +photo: + SAGE: + fan_out: 5,5 + learning_rate: 0.01 + dropout_ratio: 0 + weight_decay: 0.0005 + + GCN: + hidden_dim: 64 + dropout_ratio: 0.8 + weight_decay: 0.001 + + MLP: + learning_rate: 0.005 + weight_decay: 0.002 + dropout_ratio: 0.3 + + GAT: + dropout_ratio: 0.6 + weight_decay: 0.01 + num_heads: 8 + attn_dropout_ratio: 0.3 + + APPNP: + dropout_ratio: 0.5 + weight_decay: 0.01 \ No newline at end of file diff --git a/examples/glnn/train_student.py b/examples/glnn/train_student.py new file mode 100644 index 00000000..53ad1214 --- /dev/null +++ b/examples/glnn/train_student.py @@ -0,0 +1,168 @@ +# !/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' +# os.environ['TL_BACKEND'] = 'torch' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR + +import yaml +import argparse +import tensorlayerx as tlx +from gammagl.datasets import Planetoid, Amazon +from gammagl.models import MLP +from gammagl.utils import mask_to_index +from tensorlayerx.model import TrainOneStep, WithLoss + + +class SemiSpvzLoss(WithLoss): + def __init__(self, net, loss_fn): + super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, teacher_logits): + student_logits = self.backbone_network(data['x']) + train_y = tlx.gather(data['y'], data['t_idx']) + train_teacher_logits = tlx.gather(teacher_logits, data['t_idx']) + train_student_logits = tlx.gather(student_logits, data['t_idx']) + loss = self._loss_fn(train_y, train_student_logits, train_teacher_logits, args.lamb) + return loss + + +def get_training_config(config_path, model_name, dataset): + with open(config_path, "r") as conf: + full_config = yaml.load(conf, Loader=yaml.FullLoader) + dataset_specific_config = full_config["global"] + model_specific_config = full_config[dataset][model_name] + + if model_specific_config is not None: + specific_config = dict(dataset_specific_config, **model_specific_config) + else: + specific_config = dataset_specific_config + + specific_config["model_name"] = model_name + return specific_config + + +def calculate_acc(logits, y, metrics): + metrics.update(logits, y) + rst = metrics.result() + metrics.reset() + return rst + + +def kl_divergence(teacher_logits, student_logits): + # convert logits to probabilities + teacher_probs = tlx.softmax(teacher_logits) + student_probs = tlx.softmax(student_logits) + # compute KL divergence + kl_div = tlx.reduce_sum(teacher_probs * (tlx.log(teacher_probs+1e-10) - tlx.log(student_probs+1e-10)), axis=-1) + return tlx.reduce_mean(kl_div) + + +def cal_mlp_loss(labels, student_logits, teacher_logits, lamb): + loss_l = tlx.losses.softmax_cross_entropy_with_logits(student_logits, labels) + loss_t = kl_divergence(teacher_logits, student_logits) + return lamb * loss_l + (1 - lamb) * loss_t + + +def train_student(args): + # load datasets + if str.lower(args.dataset) not in ['cora','pubmed','citeseer','computers','photo']: + raise ValueError('Unknown dataset: {}'.format(args.dataset)) + if args.dataset in ['cora', 'pubmed', 'citeseer']: + dataset = Planetoid(args.dataset_path, args.dataset) + elif args.dataset == 'computers': + dataset = Amazon(args.dataset_path, args.dataset, train_ratio=200/13752, val_ratio=(200/13752)*1.5) + elif args.dataset == 'photo': + dataset = Amazon(args.dataset_path, args.dataset, train_ratio=160/7650, val_ratio=(160/7650)*1.5) + graph = dataset[0] + + # load teacher_logits from .npy file + teacher_logits = tlx.files.load_npy_to_any(path = r'./', name = f'{args.dataset}_{args.teacher}_logits.npy') + teacher_logits = tlx.ops.convert_to_tensor(teacher_logits) + + # for mindspore, it should be passed into node indices + train_idx = mask_to_index(graph.train_mask) + test_idx = mask_to_index(graph.test_mask) + val_idx = mask_to_index(graph.val_mask) + t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0) + + net = MLP(in_channels=dataset.num_node_features, + hidden_channels=conf["hidden_dim"], + out_channels=dataset.num_classes, + num_layers=conf["num_layers"], + act=tlx.nn.ReLU(), + norm=None, + dropout=float(conf["dropout_ratio"])) + + optimizer = tlx.optimizers.Adam(lr=conf["learning_rate"], weight_decay=conf["weight_decay"]) + metrics = tlx.metrics.Accuracy() + train_weights = net.trainable_weights + + loss_func = SemiSpvzLoss(net, cal_mlp_loss) + train_one_step = TrainOneStep(loss_func, optimizer, train_weights) + + data = { + "x": graph.x, + "y": graph.y, + "train_idx": train_idx, + "test_idx": test_idx, + "val_idx": val_idx, + "t_idx": t_idx + } + + best_val_acc = 0 + for epoch in range(args.n_epoch): + net.set_train() + train_loss = train_one_step(data, teacher_logits) + net.set_eval() + logits = net(data['x']) + val_logits = tlx.gather(logits, data['val_idx']) + val_y = tlx.gather(data['y'], data['val_idx']) + val_acc = calculate_acc(val_logits, val_y, metrics) + + print("Epoch [{:0>3d}] ".format(epoch+1)\ + + " train loss: {:.4f}".format(train_loss.item())\ + + " val acc: {:.4f}".format(val_acc)) + + # save best model on evaluation set + if val_acc > best_val_acc: + best_val_acc = val_acc + net.save_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict') + + net.load_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict') + net.set_eval() + logits = net(data['x']) + test_logits = tlx.gather(logits, data['test_idx']) + test_y = tlx.gather(data['y'], data['test_idx']) + test_acc = calculate_acc(test_logits, test_y, metrics) + print("Test acc: {:.4f}".format(test_acc)) + + + +if __name__ == '__main__': + # parameters setting + parser = argparse.ArgumentParser() + parser.add_argument("--model_config_path",type=str,default="./train.conf.yaml",help="path to modelconfigeration") + parser.add_argument("--teacher", type=str, default="SAGE", help="teacher model") + parser.add_argument("--lamb", type=float, default=0, help="parameter balances loss from hard labels and teacher outputs") + parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch") + parser.add_argument('--dataset', type=str, default="cora", help="dataset") + parser.add_argument("--dataset_path", type=str, default=r'./data', help="path to save dataset") + parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") + parser.add_argument("--gpu", type=int, default=0) + + args = parser.parse_args() + + conf = {} + if args.model_config_path is not None: + conf = get_training_config(args.model_config_path, args.teacher, args.dataset) + conf = dict(args.__dict__, **conf) + + if args.gpu >= 0: + tlx.set_device("GPU", args.gpu) + else: + tlx.set_device("CPU") + + train_student(args) diff --git a/examples/glnn/train_teacher.py b/examples/glnn/train_teacher.py new file mode 100644 index 00000000..2c54dbdd --- /dev/null +++ b/examples/glnn/train_teacher.py @@ -0,0 +1,211 @@ +# !/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' +# os.environ['TL_BACKEND'] = 'torch' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR + +import yaml +import argparse +import tensorlayerx as tlx +from gammagl.datasets import Planetoid, Amazon +from gammagl.models import GCNModel, GraphSAGE_Full_Model, GATModel, APPNPModel, MLP +from gammagl.utils import mask_to_index, calc_gcn_norm +from tensorlayerx.model import TrainOneStep, WithLoss + + +class SemiSpvzLoss(WithLoss): + def __init__(self, net, loss_fn): + super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, y): + if args.teacher == "GCN": + logits = self.backbone_network(data['x'], data['edge_index'], None, data['num_nodes']) + elif args.teacher == "SAGE": + logits = self.backbone_network(data['x'], data['edge_index']) + elif args.teacher == "GAT": + logits = self.backbone_network(data['x'], data['edge_index'], data['num_nodes']) + elif args.teacher == "APPNP": + logits = self.backbone_network(data['x'], data['edge_index'], data['edge_weight'], data['num_nodes']) + elif args.teacher == "MLP": + logits = self.backbone_network(data['x']) + train_logits = tlx.gather(logits, data['train_idx']) + train_y = tlx.gather(data['y'], data['train_idx']) + loss = self._loss_fn(train_logits, train_y) + return loss + + +def get_training_config(config_path, model_name, dataset): + with open(config_path, "r") as conf: + full_config = yaml.load(conf, Loader=yaml.FullLoader) + dataset_specific_config = full_config["global"] + model_specific_config = full_config[dataset][model_name] + + if model_specific_config is not None: + specific_config = dict(dataset_specific_config, **model_specific_config) + else: + specific_config = dataset_specific_config + + specific_config["model_name"] = model_name + return specific_config + + +def calculate_acc(logits, y, metrics): + metrics.update(logits, y) + rst = metrics.result() + metrics.reset() + return rst + + +def train_teacher(args): + # load datasets + if str.lower(args.dataset) not in ['cora','pubmed','citeseer','computers','photo']: + raise ValueError('Unknown dataset: {}'.format(args.dataset)) + if args.dataset in ['cora', 'pubmed', 'citeseer']: + dataset = Planetoid(args.dataset_path, args.dataset) + elif args.dataset == 'computers': + dataset = Amazon(args.dataset_path, args.dataset, train_ratio=200/13752, val_ratio=(200/13752)*1.5) + elif args.dataset == 'photo': + dataset = Amazon(args.dataset_path, args.dataset, train_ratio=160/7650, val_ratio=(160/7650)*1.5) + graph = dataset[0] + edge_index = graph.edge_index + edge_weight = tlx.convert_to_tensor(calc_gcn_norm(edge_index, graph.num_nodes)) + + # for mindspore, it should be passed into node indices + train_idx = mask_to_index(graph.train_mask) + test_idx = mask_to_index(graph.test_mask) + val_idx = mask_to_index(graph.val_mask) + + if args.teacher == "GCN": + net = GCNModel(feature_dim=dataset.num_node_features, + hidden_dim=conf["hidden_dim"], + num_class=dataset.num_classes, + drop_rate=conf["dropout_ratio"], + num_layers=conf["num_layers"]) + + elif args.teacher == "SAGE": + net = GraphSAGE_Full_Model(in_feats=dataset.num_node_features, + n_hidden=conf["hidden_dim"], + n_classes=dataset.num_classes, + n_layers=conf["num_layers"], + activation=tlx.nn.ReLU(), + dropout=conf["dropout_ratio"], + aggregator_type='gcn') + + elif args.teacher == "GAT": + net = GATModel(feature_dim=dataset.num_node_features, + hidden_dim=conf["hidden_dim"], + num_class=dataset.num_classes, + heads=conf["num_heads"], + drop_rate=conf["dropout_ratio"], + num_layers=conf["num_layers"]) + + elif args.teacher == "APPNP": + net = APPNPModel(feature_dim=dataset.num_node_features, + num_class=dataset.num_classes, + iter_K=10, + alpha=0.1, + drop_rate=conf["dropout_ratio"]) + + elif args.teacher == "MLP": + net = MLP(in_channels=dataset.num_node_features, + hidden_channels=conf["hidden_dim"], + out_channels=dataset.num_classes, + num_layers=conf["num_layers"], + act=tlx.nn.ReLU(), + dropout=conf["dropout_ratio"]) + + optimizer = tlx.optimizers.Adam(lr=conf["learning_rate"], weight_decay=conf["weight_decay"]) + metrics = tlx.metrics.Accuracy() + train_weights = net.trainable_weights + + loss_func = SemiSpvzLoss(net, tlx.losses.softmax_cross_entropy_with_logits) + train_one_step = TrainOneStep(loss_func, optimizer, train_weights) + + data = { + "x": graph.x, + "y": graph.y, + "edge_index": edge_index, + "edge_weight": edge_weight, + "train_idx": train_idx, + "test_idx": test_idx, + "val_idx": val_idx, + "num_nodes": graph.num_nodes + } + + best_val_acc = 0 + for epoch in range(args.n_epoch): + net.set_train() + train_loss = train_one_step(data, graph.y) + net.set_eval() + if args.teacher == "GCN": + logits = net(data['x'], data['edge_index'], None, data['num_nodes']) + elif args.teacher == "SAGE": + logits = net(data['x'], data['edge_index']) + elif args.teacher == "GAT": + logits = net(data['x'], data['edge_index'], data['num_nodes']) + elif args.teacher == "APPNP": + logits = net(data['x'], data['edge_index'], data['edge_weight'], data['num_nodes']) + elif args.teacher == "MLP": + logits = net(data['x']) + val_logits = tlx.gather(logits, data['val_idx']) + val_y = tlx.gather(data['y'], data['val_idx']) + val_acc = calculate_acc(val_logits, val_y, metrics) + + print("Epoch [{:0>3d}] ".format(epoch+1)\ + + " train loss: {:.4f}".format(train_loss.item())\ + + " val acc: {:.4f}".format(val_acc)) + + # save best model on evaluation set + if val_acc > best_val_acc: + best_val_acc = val_acc + net.save_weights(args.best_model_path+args.dataset+'_'+args.teacher+".npz", format='npz_dict') + if tlx.BACKEND == 'torch': + tlx.files.save_any_to_npy(tlx.convert_to_numpy(logits), args.best_model_path+args.dataset+'_'+args.teacher+'_logits.npy') + else: + tlx.files.save_any_to_npy(tlx.convert_to_numpy(logits), args.best_model_path+args.dataset+'_'+args.teacher+'_logits.npy') + + net.load_weights(args.best_model_path+args.dataset+'_'+args.teacher+".npz", format='npz_dict') + net.set_eval() + if args.teacher == "GCN": + logits = net(data['x'], data['edge_index'], None, data['num_nodes']) + elif args.teacher == "SAGE": + logits = net(data['x'], data['edge_index']) + elif args.teacher == "GAT": + logits = net(data['x'], data['edge_index'], data['num_nodes']) + elif args.teacher == "APPNP": + logits = net(data['x'], data['edge_index'], data['edge_weight'], data['num_nodes']) + elif args.teacher == "MLP": + logits = net(data['x']) + test_logits = tlx.gather(logits, data['test_idx']) + test_y = tlx.gather(data['y'], data['test_idx']) + test_acc = calculate_acc(test_logits, test_y, metrics) + print("Test acc: {:.4f}".format(test_acc)) + + +if __name__ == '__main__': + # parameters setting + parser = argparse.ArgumentParser() + parser.add_argument("--model_config_path",type=str,default="./train.conf.yaml",help="path to modelconfigeration") + parser.add_argument("--teacher", type=str, default="SAGE", help="teacher model") + parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch") + parser.add_argument('--dataset', type=str, default="cora", help="dataset") + parser.add_argument("--dataset_path", type=str, default=r'./data', help="path to save dataset") + parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") + parser.add_argument("--gpu", type=int, default=0) + + args = parser.parse_args() + + conf = {} + if args.model_config_path is not None: + conf = get_training_config(args.model_config_path, args.teacher, args.dataset) + conf = dict(args.__dict__, **conf) + + if args.gpu >= 0: + tlx.set_device("GPU", args.gpu) + else: + tlx.set_device("CPU") + + train_teacher(args) diff --git a/gammagl/datasets/amazon.py b/gammagl/datasets/amazon.py index bb0e4907..c365b804 100644 --- a/gammagl/datasets/amazon.py +++ b/gammagl/datasets/amazon.py @@ -3,6 +3,7 @@ import tensorlayerx as tlx from gammagl.data import InMemoryDataset, download_url from gammagl.io.npz import read_npz +from gammagl.utils import get_train_val_test_split class Amazon(InMemoryDataset): @@ -31,8 +32,15 @@ class Amazon(InMemoryDataset): an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) - force_reload (bool, optional): Whether to re-process the dataset. + force_reload : bool, optional + Whether to re-process the dataset. (default: :obj:`False`) + train_ratio : float, optional + Ratio of training samples. + (default: :obj:`0.1`) + val_ratio : float, optional + Ratio of validation samples. + (default: :obj:`0.15`) Stats: .. list-table:: @@ -61,12 +69,18 @@ class Amazon(InMemoryDataset): def __init__(self, root: str = None, name: str = 'computers', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, - force_reload: bool = False): + force_reload: bool = False, + train_ratio: float = 0.1, + val_ratio: float = 0.15): self.name = name.lower() assert self.name in ['computers', 'photo'] super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) + data = self.get(0) + data.train_mask, data.val_mask, data.test_mask = get_train_val_test_split(self.data, train_ratio, val_ratio) + self.data, self.slices = self.collate([data]) + @property def raw_dir(self) -> str: return osp.join(self.root, self.name.capitalize(), 'raw') diff --git a/gammagl/utils/__init__.py b/gammagl/utils/__init__.py index 99a6f022..f57d8d09 100644 --- a/gammagl/utils/__init__.py +++ b/gammagl/utils/__init__.py @@ -18,6 +18,7 @@ from .to_dense_adj import to_dense_adj from .smiles import from_smiles from .shortest_path import shortest_path_distance, batched_shortest_path_distance +from .get_split import get_train_val_test_split from .get_laplacian import get_laplacian __all__ = [ @@ -44,6 +45,7 @@ 'from_smiles', 'shortest_path_distance', 'batched_shortest_path_distance', + 'get_train_val_test_split', 'get_laplacian' ] diff --git a/gammagl/utils/get_split.py b/gammagl/utils/get_split.py new file mode 100644 index 00000000..fc90494d --- /dev/null +++ b/gammagl/utils/get_split.py @@ -0,0 +1,57 @@ +import tensorlayerx as tlx +import numpy as np +from sklearn.model_selection import train_test_split + + +def get_train_val_test_split(graph, train_ratio, val_ratio): + """ + Split the dataset into train, validation, and test sets. + + Parameters + ---------- + graph : + The graph to split. + train_ratio : float + The proportion of the dataset to include in the train split. + val_ratio : float + The proportion of the dataset to include in the validation split. + + Returns + ------- + :class:`tuple` of :class:`tensor` + """ + + random_state = np.random.RandomState(0) + num_samples = graph.num_nodes + all_indices = np.arange(num_samples) + + # split into train and (val + test) + train_indices, val_test_indices = train_test_split( + all_indices, train_size=train_ratio, random_state=random_state + ) + + # calculate the ratio of validation and test splits in the remaining data + test_ratio = 1.0 - train_ratio - val_ratio + val_size_ratio = val_ratio / (val_ratio + test_ratio) + + # split val + test into validation and test sets + val_indices, test_indices = train_test_split( + val_test_indices, train_size=val_size_ratio, random_state=random_state + ) + + return generate_masks(num_samples, train_indices, val_indices, test_indices) + + +def generate_masks(num_nodes, train_indices, val_indices, test_indices): + np_train_mask = np.zeros(num_nodes, dtype=bool) + np_train_mask[train_indices] = 1 + np_val_mask = np.zeros(num_nodes, dtype=bool) + np_val_mask[val_indices] = 1 + np_test_mask = np.zeros(num_nodes, dtype=bool) + np_test_mask[test_indices] = 1 + + train_mask = tlx.ops.convert_to_tensor(np_train_mask, dtype=tlx.bool) + val_mask = tlx.ops.convert_to_tensor(np_val_mask, dtype=tlx.bool) + test_mask = tlx.ops.convert_to_tensor(np_test_mask, dtype=tlx.bool) + + return train_mask, val_mask, test_mask \ No newline at end of file diff --git a/tests/utils/test_get_split.py b/tests/utils/test_get_split.py new file mode 100644 index 00000000..b60ca7d5 --- /dev/null +++ b/tests/utils/test_get_split.py @@ -0,0 +1,32 @@ +import tensorlayerx as tlx +from gammagl.utils import get_train_val_test_split +import numpy as np + + +class Graph: + def __init__(self, num_nodes): + self.num_nodes = num_nodes + +def test_get_split(): + num_nodes = 1000 + graph = Graph(num_nodes) + + train_ratio = 0.6 + val_ratio = 0.2 + + train_mask, val_mask, test_mask = get_train_val_test_split(graph, train_ratio, val_ratio) + + assert tlx.ops.is_tensor(train_mask) + assert tlx.ops.is_tensor(val_mask) + assert tlx.ops.is_tensor(test_mask) + + train_mask = tlx.convert_to_numpy(train_mask) + val_mask = tlx.convert_to_numpy(val_mask) + test_mask = tlx.convert_to_numpy(test_mask) + + assert np.sum(train_mask) == int(num_nodes * train_ratio) + assert np.sum(val_mask) == int(num_nodes * val_ratio) + assert np.sum(test_mask) == num_nodes - int(num_nodes * train_ratio) - int(num_nodes * val_ratio) + + assert np.all(train_mask + val_mask + test_mask == 1) + \ No newline at end of file