Skip to content

Commit

Permalink
solve the problem of repeated definition of dropout in forward of MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
gyzhou2000 committed Jun 2, 2024
1 parent 6ee25a3 commit 2403822
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
3 changes: 2 additions & 1 deletion gammagl/models/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ def __init__(self, in_channels,

self.mlp = MLP([hidden_channels, hidden_channels, out_channels],
norm=None, dropout=0.5)
self.relu = tlx.ReLU()

def forward(self, x, edge_index, batch):
if x is None:
# x = tlx.ones((batch.shape[0], 1), dtype=tlx.float32)
x = tlx.random_normal((batch.shape[0], 1), dtype=tlx.float32)

for conv in self.convs:
x = tlx.relu(conv(x, edge_index))
x = self.relu(conv(x, edge_index))

x = global_sum_pool(x, batch)
return self.mlp(x)
9 changes: 7 additions & 2 deletions gammagl/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(self,
dropout[-1] = 0.
assert len(dropout) == len(channel_list) - 1
self.dropout = dropout
self.dropouts = tlx.nn.ModuleList()
for i in range(len(dropout)):
self.dropouts.append(tlx.nn.Dropout(p=dropout[i]))

if isinstance(bias, bool):
bias = [bias] * (len(channel_list) - 1)
Expand Down Expand Up @@ -89,12 +92,14 @@ def forward(self, x, return_emb=None):
if self.act is not None and not self.act_first:
x = self.act(x)

x = tlx.nn.Dropout(p=self.dropout[i])(x)
# x = tlx.nn.Dropout(p=self.dropout[i])(x)
x = self.dropouts[i](x)
emb = x

if self.plain_last:
x = self.lins[-1](x)
x = tlx.nn.Dropout(p=self.dropout[-1])(x)
# x = tlx.nn.Dropout(p=self.dropout[-1])(x)
x = self.dropouts[-1](x)

return (x, emb) if isinstance(return_emb, bool) else x

Expand Down
2 changes: 1 addition & 1 deletion gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::tuple<torch::Tensor, torch::Tensor> segment_max_cpu_forward(
auto index_data = index.data_ptr<int64_t>();
auto arg_out_data = arg_out.data_ptr<int64_t>();

AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_max_cpu_forward", [&]() {
out.fill_(std::numeric_limits<scalar_t>::lowest());
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
Expand Down

0 comments on commit 2403822

Please sign in to comment.