Skip to content

Commit

Permalink
Merge pull request #205 from UNAOUN/add-glnn
Browse files Browse the repository at this point in the history
Add model glnn
  • Loading branch information
gyzhou2000 authored Jun 4, 2024
2 parents 2403822 + 9c9f644 commit c8d05e2
Show file tree
Hide file tree
Showing 10 changed files with 672 additions and 4 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ Now, GammaGL supports about 60 models, we welcome everyone to use or contribute
| [GGD [NeurIPS 2022]](./examples/ggd) | | :heavy_check_mark: | | :heavy_check_mark: |
| [LTD [WSDM 2022]](./examples/ltd) | | :heavy_check_mark: | | :heavy_check_mark: |
| [Graphormer [NeurIPS 2021]](./examples/graphormer) | | :heavy_check_mark: | | :heavy_check_mark: |
| [HiD-Net [AAAI 2023]](./examples/hid_net) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [HiD-Net [AAAI 2023]](./examples/hid_net) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [FusedGAT [MLSys 2022]](./examples/fusedgat) | | :heavy_check_mark: | | |
| [GLNN [ICLR 2022]](./examples/glnn) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |


| Contrastive Learning | TensorFlow | PyTorch | Paddle | MindSpore |
Expand Down
3 changes: 2 additions & 1 deletion docs/source/api/gammagl.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ gammagl.utils
gammagl.utils.negative_sampling
gammagl.utils.to_scipy_sparse_matrix
gammagl.utils.read_embeddings
gammagl.utils.homophily
gammagl.utils.homophily
gammagl.utils.get_train_val_test_split
41 changes: 41 additions & 0 deletions examples/glnn/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Graph-less Neural Networks (GLNN)

- Paper link: [https://arxiv.org/pdf/2110.08727](https://arxiv.org/pdf/2110.08727)
- Author's code repo: [https://github.com/snap-research/graphless-neural-networks](https://github.com/snap-research/graphless-neural-networks)

# Dataset Statics
| Dataset | # Nodes | # Edges | # Classes |
| -------- | ------- | ------- | --------- |
| Cora | 2,708 | 10,556 | 7 |
| Citeseer | 3,327 | 9,228 | 6 |
| Pubmed | 19,717 | 88,651 | 3 |
| Computers| 13,752 | 491,722 | 10 |
| Photo | 7,650 | 238,162 | 8 |

Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid), [Amazon](https://gammagl.readthedocs.io/en/latest/generated/gammagl.datasets.Amazon.html#gammagl.datasets.Amazon).

# Results

- Available dataset: "cora", "citeseer", "pubmed", "computers", "photo"
- Available teacher: "SAGE", "GCN", "GAT", "APPNP", "MLP"

```bash
TL_BACKEND="tensorflow" python train_teacher.py --dataset cora --teacher SAGE
TL_BACKEND="tensorflow" python train_student.py --dataset cora --teacher SAGE
TL_BACKEND="torch" python train_teacher.py --dataset cora --teacher SAGE
TL_BACKEND="torch" python train_student.py --dataset cora --teacher SAGE
TL_BACKEND="paddle" python train_teacher.py --dataset cora --teacher SAGE
TL_BACKEND="paddle" python train_student.py --dataset cora --teacher SAGE
TL_BACKEND="mindspore" python train_teacher.py --dataset cora --teacher SAGE
TL_BACKEND="mindspore" python train_student.py --dataset cora --teacher SAGE
```

| Dataset | Paper | Our(tf) | Our(th) | Our(pd) | Our(ms) |
| --------- | ---------- | ---------- | ---------- | ---------- | ---------- |
| Cora | 80.54±1.35 | 80.94±0.31 | 80.84±0.30 | 80.90±0.21 | 81.04±0.30 |
| Citeseer | 71.77±2.01 | 70.74±0.87 | 71.34±0.55 | 71.18±1.20 | 70.58±1.14 |
| Pubmed | 75.42±2.31 | 77.90±0.07 | 77.88±0.23 | 77.78±0.19 | 77.78±0.13 |
| Computers | 83.03±1.87 | 83.45±0.61 | 82.78±0.47 | 83.03±0.14 | 83.40±0.45 |
| Photo | 92.11±1.08 | 91.93±0.16 | 91.91±0.24 | 91.89±0.27 | 91.88±0.21 |

- The model performance is the average of 5 tests
140 changes: 140 additions & 0 deletions examples/glnn/train.conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
global:
num_layers: 2
hidden_dim: 128
learning_rate: 0.01

cora:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.01
weight_decay: 0.005
dropout_ratio: 0.6

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01


citeseer:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.01
weight_decay: 0.001
dropout_ratio: 0.1

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01

pubmed:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.005
weight_decay: 0
dropout_ratio: 0.4

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01

computers:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.001
weight_decay: 0.002
dropout_ratio: 0.3

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01

photo:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.005
weight_decay: 0.002
dropout_ratio: 0.3

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01
168 changes: 168 additions & 0 deletions examples/glnn/train_student.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# !/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR

import yaml
import argparse
import tensorlayerx as tlx
from gammagl.datasets import Planetoid, Amazon
from gammagl.models import MLP
from gammagl.utils import mask_to_index
from tensorlayerx.model import TrainOneStep, WithLoss


class SemiSpvzLoss(WithLoss):
def __init__(self, net, loss_fn):
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)

def forward(self, data, teacher_logits):
student_logits = self.backbone_network(data['x'])
train_y = tlx.gather(data['y'], data['t_idx'])
train_teacher_logits = tlx.gather(teacher_logits, data['t_idx'])
train_student_logits = tlx.gather(student_logits, data['t_idx'])
loss = self._loss_fn(train_y, train_student_logits, train_teacher_logits, args.lamb)
return loss


def get_training_config(config_path, model_name, dataset):
with open(config_path, "r") as conf:
full_config = yaml.load(conf, Loader=yaml.FullLoader)
dataset_specific_config = full_config["global"]
model_specific_config = full_config[dataset][model_name]

if model_specific_config is not None:
specific_config = dict(dataset_specific_config, **model_specific_config)
else:
specific_config = dataset_specific_config

specific_config["model_name"] = model_name
return specific_config


def calculate_acc(logits, y, metrics):
metrics.update(logits, y)
rst = metrics.result()
metrics.reset()
return rst


def kl_divergence(teacher_logits, student_logits):
# convert logits to probabilities
teacher_probs = tlx.softmax(teacher_logits)
student_probs = tlx.softmax(student_logits)
# compute KL divergence
kl_div = tlx.reduce_sum(teacher_probs * (tlx.log(teacher_probs+1e-10) - tlx.log(student_probs+1e-10)), axis=-1)
return tlx.reduce_mean(kl_div)


def cal_mlp_loss(labels, student_logits, teacher_logits, lamb):
loss_l = tlx.losses.softmax_cross_entropy_with_logits(student_logits, labels)
loss_t = kl_divergence(teacher_logits, student_logits)
return lamb * loss_l + (1 - lamb) * loss_t


def train_student(args):
# load datasets
if str.lower(args.dataset) not in ['cora','pubmed','citeseer','computers','photo']:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
if args.dataset in ['cora', 'pubmed', 'citeseer']:
dataset = Planetoid(args.dataset_path, args.dataset)
elif args.dataset == 'computers':
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=200/13752, val_ratio=(200/13752)*1.5)
elif args.dataset == 'photo':
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=160/7650, val_ratio=(160/7650)*1.5)
graph = dataset[0]

# load teacher_logits from .npy file
teacher_logits = tlx.files.load_npy_to_any(path = r'./', name = f'{args.dataset}_{args.teacher}_logits.npy')
teacher_logits = tlx.ops.convert_to_tensor(teacher_logits)

# for mindspore, it should be passed into node indices
train_idx = mask_to_index(graph.train_mask)
test_idx = mask_to_index(graph.test_mask)
val_idx = mask_to_index(graph.val_mask)
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0)

net = MLP(in_channels=dataset.num_node_features,
hidden_channels=conf["hidden_dim"],
out_channels=dataset.num_classes,
num_layers=conf["num_layers"],
act=tlx.nn.ReLU(),
norm=None,
dropout=float(conf["dropout_ratio"]))

optimizer = tlx.optimizers.Adam(lr=conf["learning_rate"], weight_decay=conf["weight_decay"])
metrics = tlx.metrics.Accuracy()
train_weights = net.trainable_weights

loss_func = SemiSpvzLoss(net, cal_mlp_loss)
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)

data = {
"x": graph.x,
"y": graph.y,
"train_idx": train_idx,
"test_idx": test_idx,
"val_idx": val_idx,
"t_idx": t_idx
}

best_val_acc = 0
for epoch in range(args.n_epoch):
net.set_train()
train_loss = train_one_step(data, teacher_logits)
net.set_eval()
logits = net(data['x'])
val_logits = tlx.gather(logits, data['val_idx'])
val_y = tlx.gather(data['y'], data['val_idx'])
val_acc = calculate_acc(val_logits, val_y, metrics)

print("Epoch [{:0>3d}] ".format(epoch+1)\
+ " train loss: {:.4f}".format(train_loss.item())\
+ " val acc: {:.4f}".format(val_acc))

# save best model on evaluation set
if val_acc > best_val_acc:
best_val_acc = val_acc
net.save_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict')

net.load_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict')
net.set_eval()
logits = net(data['x'])
test_logits = tlx.gather(logits, data['test_idx'])
test_y = tlx.gather(data['y'], data['test_idx'])
test_acc = calculate_acc(test_logits, test_y, metrics)
print("Test acc: {:.4f}".format(test_acc))



if __name__ == '__main__':
# parameters setting
parser = argparse.ArgumentParser()
parser.add_argument("--model_config_path",type=str,default="./train.conf.yaml",help="path to modelconfigeration")
parser.add_argument("--teacher", type=str, default="SAGE", help="teacher model")
parser.add_argument("--lamb", type=float, default=0, help="parameter balances loss from hard labels and teacher outputs")
parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch")
parser.add_argument('--dataset', type=str, default="cora", help="dataset")
parser.add_argument("--dataset_path", type=str, default=r'./data', help="path to save dataset")
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
parser.add_argument("--gpu", type=int, default=0)

args = parser.parse_args()

conf = {}
if args.model_config_path is not None:
conf = get_training_config(args.model_config_path, args.teacher, args.dataset)
conf = dict(args.__dict__, **conf)

if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")

train_student(args)
Loading

0 comments on commit c8d05e2

Please sign in to comment.