-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #205 from UNAOUN/add-glnn
Add model glnn
- Loading branch information
Showing
10 changed files
with
672 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Graph-less Neural Networks (GLNN) | ||
|
||
- Paper link: [https://arxiv.org/pdf/2110.08727](https://arxiv.org/pdf/2110.08727) | ||
- Author's code repo: [https://github.com/snap-research/graphless-neural-networks](https://github.com/snap-research/graphless-neural-networks) | ||
|
||
# Dataset Statics | ||
| Dataset | # Nodes | # Edges | # Classes | | ||
| -------- | ------- | ------- | --------- | | ||
| Cora | 2,708 | 10,556 | 7 | | ||
| Citeseer | 3,327 | 9,228 | 6 | | ||
| Pubmed | 19,717 | 88,651 | 3 | | ||
| Computers| 13,752 | 491,722 | 10 | | ||
| Photo | 7,650 | 238,162 | 8 | | ||
|
||
Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid), [Amazon](https://gammagl.readthedocs.io/en/latest/generated/gammagl.datasets.Amazon.html#gammagl.datasets.Amazon). | ||
|
||
# Results | ||
|
||
- Available dataset: "cora", "citeseer", "pubmed", "computers", "photo" | ||
- Available teacher: "SAGE", "GCN", "GAT", "APPNP", "MLP" | ||
|
||
```bash | ||
TL_BACKEND="tensorflow" python train_teacher.py --dataset cora --teacher SAGE | ||
TL_BACKEND="tensorflow" python train_student.py --dataset cora --teacher SAGE | ||
TL_BACKEND="torch" python train_teacher.py --dataset cora --teacher SAGE | ||
TL_BACKEND="torch" python train_student.py --dataset cora --teacher SAGE | ||
TL_BACKEND="paddle" python train_teacher.py --dataset cora --teacher SAGE | ||
TL_BACKEND="paddle" python train_student.py --dataset cora --teacher SAGE | ||
TL_BACKEND="mindspore" python train_teacher.py --dataset cora --teacher SAGE | ||
TL_BACKEND="mindspore" python train_student.py --dataset cora --teacher SAGE | ||
``` | ||
|
||
| Dataset | Paper | Our(tf) | Our(th) | Our(pd) | Our(ms) | | ||
| --------- | ---------- | ---------- | ---------- | ---------- | ---------- | | ||
| Cora | 80.54±1.35 | 80.94±0.31 | 80.84±0.30 | 80.90±0.21 | 81.04±0.30 | | ||
| Citeseer | 71.77±2.01 | 70.74±0.87 | 71.34±0.55 | 71.18±1.20 | 70.58±1.14 | | ||
| Pubmed | 75.42±2.31 | 77.90±0.07 | 77.88±0.23 | 77.78±0.19 | 77.78±0.13 | | ||
| Computers | 83.03±1.87 | 83.45±0.61 | 82.78±0.47 | 83.03±0.14 | 83.40±0.45 | | ||
| Photo | 92.11±1.08 | 91.93±0.16 | 91.91±0.24 | 91.89±0.27 | 91.88±0.21 | | ||
|
||
- The model performance is the average of 5 tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
global: | ||
num_layers: 2 | ||
hidden_dim: 128 | ||
learning_rate: 0.01 | ||
|
||
cora: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.01 | ||
weight_decay: 0.005 | ||
dropout_ratio: 0.6 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 | ||
|
||
|
||
citeseer: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.01 | ||
weight_decay: 0.001 | ||
dropout_ratio: 0.1 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 | ||
|
||
pubmed: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.005 | ||
weight_decay: 0 | ||
dropout_ratio: 0.4 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 | ||
|
||
computers: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.001 | ||
weight_decay: 0.002 | ||
dropout_ratio: 0.3 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 | ||
|
||
photo: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.005 | ||
weight_decay: 0.002 | ||
dropout_ratio: 0.3 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# !/usr/bin/env python | ||
# -*- encoding: utf-8 -*- | ||
|
||
import os | ||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
# os.environ['TL_BACKEND'] = 'torch' | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR | ||
|
||
import yaml | ||
import argparse | ||
import tensorlayerx as tlx | ||
from gammagl.datasets import Planetoid, Amazon | ||
from gammagl.models import MLP | ||
from gammagl.utils import mask_to_index | ||
from tensorlayerx.model import TrainOneStep, WithLoss | ||
|
||
|
||
class SemiSpvzLoss(WithLoss): | ||
def __init__(self, net, loss_fn): | ||
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) | ||
|
||
def forward(self, data, teacher_logits): | ||
student_logits = self.backbone_network(data['x']) | ||
train_y = tlx.gather(data['y'], data['t_idx']) | ||
train_teacher_logits = tlx.gather(teacher_logits, data['t_idx']) | ||
train_student_logits = tlx.gather(student_logits, data['t_idx']) | ||
loss = self._loss_fn(train_y, train_student_logits, train_teacher_logits, args.lamb) | ||
return loss | ||
|
||
|
||
def get_training_config(config_path, model_name, dataset): | ||
with open(config_path, "r") as conf: | ||
full_config = yaml.load(conf, Loader=yaml.FullLoader) | ||
dataset_specific_config = full_config["global"] | ||
model_specific_config = full_config[dataset][model_name] | ||
|
||
if model_specific_config is not None: | ||
specific_config = dict(dataset_specific_config, **model_specific_config) | ||
else: | ||
specific_config = dataset_specific_config | ||
|
||
specific_config["model_name"] = model_name | ||
return specific_config | ||
|
||
|
||
def calculate_acc(logits, y, metrics): | ||
metrics.update(logits, y) | ||
rst = metrics.result() | ||
metrics.reset() | ||
return rst | ||
|
||
|
||
def kl_divergence(teacher_logits, student_logits): | ||
# convert logits to probabilities | ||
teacher_probs = tlx.softmax(teacher_logits) | ||
student_probs = tlx.softmax(student_logits) | ||
# compute KL divergence | ||
kl_div = tlx.reduce_sum(teacher_probs * (tlx.log(teacher_probs+1e-10) - tlx.log(student_probs+1e-10)), axis=-1) | ||
return tlx.reduce_mean(kl_div) | ||
|
||
|
||
def cal_mlp_loss(labels, student_logits, teacher_logits, lamb): | ||
loss_l = tlx.losses.softmax_cross_entropy_with_logits(student_logits, labels) | ||
loss_t = kl_divergence(teacher_logits, student_logits) | ||
return lamb * loss_l + (1 - lamb) * loss_t | ||
|
||
|
||
def train_student(args): | ||
# load datasets | ||
if str.lower(args.dataset) not in ['cora','pubmed','citeseer','computers','photo']: | ||
raise ValueError('Unknown dataset: {}'.format(args.dataset)) | ||
if args.dataset in ['cora', 'pubmed', 'citeseer']: | ||
dataset = Planetoid(args.dataset_path, args.dataset) | ||
elif args.dataset == 'computers': | ||
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=200/13752, val_ratio=(200/13752)*1.5) | ||
elif args.dataset == 'photo': | ||
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=160/7650, val_ratio=(160/7650)*1.5) | ||
graph = dataset[0] | ||
|
||
# load teacher_logits from .npy file | ||
teacher_logits = tlx.files.load_npy_to_any(path = r'./', name = f'{args.dataset}_{args.teacher}_logits.npy') | ||
teacher_logits = tlx.ops.convert_to_tensor(teacher_logits) | ||
|
||
# for mindspore, it should be passed into node indices | ||
train_idx = mask_to_index(graph.train_mask) | ||
test_idx = mask_to_index(graph.test_mask) | ||
val_idx = mask_to_index(graph.val_mask) | ||
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0) | ||
|
||
net = MLP(in_channels=dataset.num_node_features, | ||
hidden_channels=conf["hidden_dim"], | ||
out_channels=dataset.num_classes, | ||
num_layers=conf["num_layers"], | ||
act=tlx.nn.ReLU(), | ||
norm=None, | ||
dropout=float(conf["dropout_ratio"])) | ||
|
||
optimizer = tlx.optimizers.Adam(lr=conf["learning_rate"], weight_decay=conf["weight_decay"]) | ||
metrics = tlx.metrics.Accuracy() | ||
train_weights = net.trainable_weights | ||
|
||
loss_func = SemiSpvzLoss(net, cal_mlp_loss) | ||
train_one_step = TrainOneStep(loss_func, optimizer, train_weights) | ||
|
||
data = { | ||
"x": graph.x, | ||
"y": graph.y, | ||
"train_idx": train_idx, | ||
"test_idx": test_idx, | ||
"val_idx": val_idx, | ||
"t_idx": t_idx | ||
} | ||
|
||
best_val_acc = 0 | ||
for epoch in range(args.n_epoch): | ||
net.set_train() | ||
train_loss = train_one_step(data, teacher_logits) | ||
net.set_eval() | ||
logits = net(data['x']) | ||
val_logits = tlx.gather(logits, data['val_idx']) | ||
val_y = tlx.gather(data['y'], data['val_idx']) | ||
val_acc = calculate_acc(val_logits, val_y, metrics) | ||
|
||
print("Epoch [{:0>3d}] ".format(epoch+1)\ | ||
+ " train loss: {:.4f}".format(train_loss.item())\ | ||
+ " val acc: {:.4f}".format(val_acc)) | ||
|
||
# save best model on evaluation set | ||
if val_acc > best_val_acc: | ||
best_val_acc = val_acc | ||
net.save_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict') | ||
|
||
net.load_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict') | ||
net.set_eval() | ||
logits = net(data['x']) | ||
test_logits = tlx.gather(logits, data['test_idx']) | ||
test_y = tlx.gather(data['y'], data['test_idx']) | ||
test_acc = calculate_acc(test_logits, test_y, metrics) | ||
print("Test acc: {:.4f}".format(test_acc)) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
# parameters setting | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_config_path",type=str,default="./train.conf.yaml",help="path to modelconfigeration") | ||
parser.add_argument("--teacher", type=str, default="SAGE", help="teacher model") | ||
parser.add_argument("--lamb", type=float, default=0, help="parameter balances loss from hard labels and teacher outputs") | ||
parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch") | ||
parser.add_argument('--dataset', type=str, default="cora", help="dataset") | ||
parser.add_argument("--dataset_path", type=str, default=r'./data', help="path to save dataset") | ||
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") | ||
parser.add_argument("--gpu", type=int, default=0) | ||
|
||
args = parser.parse_args() | ||
|
||
conf = {} | ||
if args.model_config_path is not None: | ||
conf = get_training_config(args.model_config_path, args.teacher, args.dataset) | ||
conf = dict(args.__dict__, **conf) | ||
|
||
if args.gpu >= 0: | ||
tlx.set_device("GPU", args.gpu) | ||
else: | ||
tlx.set_device("CPU") | ||
|
||
train_student(args) |
Oops, something went wrong.