diff --git a/gammagl/models/gin.py b/gammagl/models/gin.py index 2f9f9470..6ca89fe9 100644 --- a/gammagl/models/gin.py +++ b/gammagl/models/gin.py @@ -49,6 +49,7 @@ def __init__(self, in_channels, self.mlp = MLP([hidden_channels, hidden_channels, out_channels], norm=None, dropout=0.5) + self.relu = tlx.ReLU() def forward(self, x, edge_index, batch): if x is None: @@ -56,7 +57,7 @@ def forward(self, x, edge_index, batch): x = tlx.random_normal((batch.shape[0], 1), dtype=tlx.float32) for conv in self.convs: - x = tlx.relu(conv(x, edge_index)) + x = self.relu(conv(x, edge_index)) x = global_sum_pool(x, batch) return self.mlp(x) diff --git a/gammagl/models/mlp.py b/gammagl/models/mlp.py index 3c67ff3f..0cb263c7 100644 --- a/gammagl/models/mlp.py +++ b/gammagl/models/mlp.py @@ -46,6 +46,9 @@ def __init__(self, dropout[-1] = 0. assert len(dropout) == len(channel_list) - 1 self.dropout = dropout + self.dropouts = tlx.nn.ModuleList() + for i in range(len(dropout)): + self.dropouts.append(tlx.nn.Dropout(p=dropout[i])) if isinstance(bias, bool): bias = [bias] * (len(channel_list) - 1) @@ -89,12 +92,14 @@ def forward(self, x, return_emb=None): if self.act is not None and not self.act_first: x = self.act(x) - x = tlx.nn.Dropout(p=self.dropout[i])(x) + # x = tlx.nn.Dropout(p=self.dropout[i])(x) + x = self.dropouts[i](x) emb = x if self.plain_last: x = self.lins[-1](x) - x = tlx.nn.Dropout(p=self.dropout[-1])(x) + # x = tlx.nn.Dropout(p=self.dropout[-1])(x) + x = self.dropouts[-1](x) return (x, emb) if isinstance(return_emb, bool) else x diff --git a/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp b/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp index db265bde..fade765b 100644 --- a/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp @@ -34,7 +34,7 @@ std::tuple segment_max_cpu_forward( auto index_data = index.data_ptr(); auto arg_out_data = arg_out.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_max_cpu_forward", [&]() { out.fill_(std::numeric_limits::lowest()); auto x_data = x.data_ptr(); auto out_data = out.data_ptr();