diff --git a/src/cpp/src/data/samplers/neighbor.cpp b/src/cpp/src/data/samplers/neighbor.cpp index bce72230..7ed6bf5b 100644 --- a/src/cpp/src/data/samplers/neighbor.cpp +++ b/src/cpp/src/data/samplers/neighbor.cpp @@ -496,7 +496,7 @@ DENSEGraph LayeredNeighborSampler::getNeighbors(torch::Tensor node_ids, shared_p } if (outgoing_offsets.defined()) { - if (delta_outgoing_edges.size(0) > 0) { + if (delta_outgoing_offsets.size(0) > 0) { outgoing_offsets = outgoing_offsets + delta_outgoing_edges.size(0); outgoing_offsets = torch::cat({delta_outgoing_offsets, outgoing_offsets}, 0); } @@ -632,7 +632,7 @@ torch::Tensor LayeredNeighborSampler::computeDeltaIdsHelperMethod1(torch::Tensor auto device_options = torch::TensorOptions().dtype(torch::kInt64).device(node_ids.device()); std::vector sub_deltas = std::vector(num_threads); - int64_t upper_bound = (int64_t)(delta_incoming_edges.size(0) + delta_outgoing_edges.size(0)) / num_threads; + int64_t upper_bound = (int64_t)(delta_incoming_edges.size(0) + delta_outgoing_edges.size(0)) / num_threads + 1; std::vector sub_counts = std::vector(num_threads, 0); std::vector sub_offsets = std::vector(num_threads, 0); diff --git a/src/cpp/src/nn/layers/gnn/graph_sage_layer.cpp b/src/cpp/src/nn/layers/gnn/graph_sage_layer.cpp index 8809c59e..08417e1e 100644 --- a/src/cpp/src/nn/layers/gnn/graph_sage_layer.cpp +++ b/src/cpp/src/nn/layers/gnn/graph_sage_layer.cpp @@ -81,9 +81,14 @@ torch::Tensor GraphSageLayer::forward(torch::Tensor inputs, DENSEGraph dense_gra a_i = a_i / (total_num_neighbors + 1).unsqueeze(-1); outputs = torch::matmul(w1_, a_i.transpose(0, -1)).transpose(0, -1); } else if (options_->aggregator == GraphSageAggregator::MEAN) { - torch::Tensor denominator = torch::where(torch::not_equal(total_num_neighbors, 0), total_num_neighbors, 1).to(a_i.dtype()).unsqueeze(-1); - a_i = a_i / denominator; - outputs = (torch::matmul(w1_, self_embs.transpose(0, -1)) + torch::matmul(w2_, a_i.transpose(0, -1))).transpose(0, -1); + if (total_num_neighbors.defined()) { + torch::Tensor denominator = torch::where(torch::not_equal(total_num_neighbors, 0), total_num_neighbors, 1).to(a_i.dtype()).unsqueeze(-1); + a_i = a_i / denominator; + outputs = (torch::matmul(w1_, self_embs.transpose(0, -1)) + torch::matmul(w2_, a_i.transpose(0, -1))).transpose(0, -1); + } else { + outputs = torch::matmul(w1_, self_embs.transpose(0, -1)).transpose(0, -1); + } + } else { throw std::runtime_error("Unrecognized aggregator"); }