-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodels.py
132 lines (104 loc) · 4.8 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from layers import GCN, MVPool
class GraphClassificationModel(torch.nn.Module):
def __init__(self, args):
super(GraphClassificationModel, self).__init__()
self.args = args
self.num_features = args.num_features
self.nhid = args.nhid
self.num_classes = args.num_classes
self.pooling_ratio = args.pooling_ratio
self.dropout_ratio = args.dropout_ratio
self.sample = args.sample_neighbor
self.sparse = args.sparse_attention
self.sl = args.structure_learning
self.lamb = args.lamb
self.conv1 = GCNConv(self.num_features, self.nhid)
self.conv2 = GCN(self.nhid, self.nhid)
self.conv3 = GCN(self.nhid, self.nhid)
self.pool1 = MVPool(self.nhid, self.pooling_ratio, args)
self.pool2 = MVPool(self.nhid, self.pooling_ratio, args)
self.pool3 = MVPool(self.nhid, self.pooling_ratio, args)
self.lin1 = torch.nn.Linear(self.nhid * 2, self.nhid)
self.lin2 = torch.nn.Linear(self.nhid, self.nhid // 2)
self.lin3 = torch.nn.Linear(self.nhid // 2, self.num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
edge_attr = None
x = F.relu(self.conv1(x, edge_index, edge_attr))
x, edge_index, edge_attr, batch, _ = self.pool1(x, edge_index, edge_attr, batch)
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = F.relu(self.conv2(x, edge_index, edge_attr))
x, edge_index, edge_attr, batch, _ = self.pool2(x, edge_index, edge_attr, batch)
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = F.relu(self.conv3(x, edge_index, edge_attr))
x, edge_index, edge_attr, batch, _ = self.pool3(x, edge_index, edge_attr, batch)
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = F.relu(x1) + F.relu(x2) + F.relu(x3)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = F.relu(self.lin2(x))
x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = F.log_softmax(self.lin3(x), dim=-1)
return x
class NodeClassificationModel(torch.nn.Module):
def __init__(self, args, sum_res=False, act=F.relu):
super(NodeClassificationModel, self).__init__()
assert args.depth >= 1
self.in_channels = args.num_features
self.hidden_channels = args.nhid
self.out_channels = args.num_classes
self.depth = args.depth
self.pool_ratios = [args.pool1, args.pool2, args.pool3, args.pool4, args.pool5]
self.act = act
self.sum_res = sum_res
channels = self.hidden_channels
self.down_convs = torch.nn.ModuleList()
self.pools = torch.nn.ModuleList()
self.down_convs.append(GCNConv(self.in_channels, channels))
for i in range(self.depth):
self.pools.append(MVPool(channels, self.pool_ratios[i], args))
self.down_convs.append(GCN(channels, channels))
in_channels = channels if sum_res else 2 * channels
self.up_convs = torch.nn.ModuleList()
for i in range(self.depth):
self.up_convs.append(GCN(in_channels, channels))
self.up_convs.append(GCN(channels, self.out_channels))
def forward(self, x, edge_index, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
edge_weight = x.new_ones(edge_index.size(1))
x = F.dropout(x, p=0.92, training=self.training)
x = self.down_convs[0](x, edge_index, edge_weight)
x = self.act(x)
xs = [x]
edge_indices = [edge_index]
edge_weights = [edge_weight]
perms = []
for i in range(1, self.depth + 1):
x, edge_index, edge_weight, batch, perm = self.pools[i - 1](x, edge_index, edge_weight, batch)
x = self.down_convs[i](x, edge_index, edge_weight)
x = self.act(x)
if i < self.depth:
xs += [x]
edge_indices += [edge_index]
edge_weights += [edge_weight]
perms += [perm]
for i in range(self.depth):
j = self.depth - 1 - i
res = xs[j]
edge_index = edge_indices[j]
edge_weight = edge_weights[j]
perm = perms[j]
up = torch.zeros_like(res)
up[perm] = x
x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
x = self.up_convs[i](x, edge_index, edge_weight)
x = self.act(x)
x = self.up_convs[-1](x, edge_index, edge_weight)
return x