-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model & Dataset] facebook & sp2gcl #201
Conversation
gammagl/datasets/facebook.py
Outdated
x = tlx.convert_to_tensor(data['features'], dtype=tlx.float32) | ||
y = tlx.convert_to_tensor(data['target'], dtype=tlx.int64) | ||
edge_index = tlx.convert_to_tensor(data['edges'], dtype=tlx.int64) | ||
edge_index = edge_index.T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tried if this can work in the other backend like 'mindspore'?
gammagl/datasets/facebook.py
Outdated
|
||
def __init__( | ||
self, | ||
root: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, this argument can be optional, as we have a cached mechanism compared to PyG
.
examples/sp2_gcl/node_main.py
Outdated
loss = 0.5 * tlx.losses.softmax_cross_entropy_with_logits(logits, labels) + 0.5 * tlx.losses.softmax_cross_entropy_with_logits(logits.transpose(-2, -1), labels) | ||
return loss | ||
def main(args): | ||
global edge, e, u, test_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this line doing?
examples/sp2_gcl/node_main.py
Outdated
val_idx = tlx.where(data.val_mask)[0] | ||
test_idx = tlx.where(data.test_mask)[0] | ||
else: | ||
train_idx, val_idx, test_idx = split(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think this is useful. Usually, the train, valid, test split should be done in the dataset. You may directly use data.train_mask
.etc to get the idx instead of add a new function in the util
.
gammagl/models/sp2gcl.py
Outdated
import tensorlayerx as tlx | ||
import tensorlayerx.nn as nn | ||
from gammagl.layers.conv import GCNConv | ||
class Encoder(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember there are many Encoder
in the gammagl. I do not recommend you to place this function here.
gammagl/models/sp2gcl.py
Outdated
return x | ||
|
||
|
||
class MLP(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
|
||
|
||
|
||
class EigenMLP(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
gammagl/utils/split.py
Outdated
import tensorlayerx as tlx | ||
import numpy as np | ||
|
||
def split(y): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This util is useless, remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from sklearn.model_selection import train_test_split
examples/sp2_gcl/readme.md
Outdated
@ -0,0 +1,40 @@ | ||
# Graph Contrastive Learning with Stable and Scalable | ||
|
||
- Paper link: [https://proceedings.neurips.cc/paper_files/paper/2023/file/8e9a6582caa59fda0302349702965171-Paper-Conference.pdf](https://arxiv.org/abs/2201.11349) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
链接不对
gammagl/models/sp2gcl.py
Outdated
period_e = e.unsqueeze(1) * tlx.pow(2, period_term) | ||
# period_e = period_e.to(u.device) | ||
fourier_e = tlx.concat([tlx.sin(period_e), tlx.cos(period_e)], axis=-1) | ||
h = u @ fourier_e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tlx.unsqueeze,矩阵乘用tlx.matmul
examples/sp2_gcl/readme.md
Outdated
| PubMed | 82.3±0.3 | OOM | | ||
| Wiki-CS | 79.42±0.19 | 76.79 ± 0.61 | | ||
| Facebook | 90.43±0.13 | 85.35±0.26 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调参
gammagl/models/sp2gcl.py
Outdated
self.phi = nn.Sequential(nn.Linear(1, 16), nn.ReLU(), nn.Linear(16, 16)) | ||
self.psi = nn.Sequential(nn.Linear(16, 16), nn.ReLU(), nn.Linear(16, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete
tests/datasets/test_facebook.py
Outdated
dataset = FacebookPagePage(root='data/facebook') | ||
g = dataset[0] | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
节点数量,边数量,节点特征维度,类别数量都判断一下
examples/sp2_gcl/sp2gcl_trainer.py
Outdated
parser.add_argument('--seed', type=int, default=0) | ||
parser.add_argument('--cuda', type=int, default=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed去点,设置device参考其他trainer写法
examples/sp2_gcl/sp2gcl_trainer.py
Outdated
def compute_laplacian(data): | ||
|
||
edge_index = data.edge_index | ||
num_nodes = data.num_nodes | ||
row, col = edge_index | ||
data_adj = csr_matrix((np.ones(len(row)), (row, col)), shape=(num_nodes, num_nodes)) | ||
degree = np.array(data_adj.sum(axis=1)).flatten() | ||
deg_inv_sqrt = 1.0 / np.sqrt(degree) | ||
deg_inv_sqrt[np.isinf(deg_inv_sqrt)] = 0 | ||
I = csr_matrix(np.eye(num_nodes)) | ||
D_inv_sqrt = csr_matrix((deg_inv_sqrt, (np.arange(num_nodes), np.arange(num_nodes)))) | ||
L = I - D_inv_sqrt.dot(data_adj).dot(D_inv_sqrt) | ||
e, u = scipy.sparse.linalg.eigsh(L, k=100, which='SM', tol=1e-3) | ||
data.e = tlx.convert_to_tensor(e, dtype=tlx.float32) | ||
data.u = tlx.convert_to_tensor(u, dtype=tlx.float32) | ||
|
||
return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
试试用 get_laplacian
接口替换
Description
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes