Skip to content

Commit

Permalink
fixed bugs that appear with small batch sizes or small neighborhoods (#…
Browse files Browse the repository at this point in the history
…147)

Co-authored-by: Roger Waleffe <[email protected]>
  • Loading branch information
rogerwaleffe and Roger Waleffe authored Oct 28, 2023
1 parent e10cade commit 841145c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/cpp/src/data/samplers/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ DENSEGraph LayeredNeighborSampler::getNeighbors(torch::Tensor node_ids, shared_p
}

if (outgoing_offsets.defined()) {
if (delta_outgoing_edges.size(0) > 0) {
if (delta_outgoing_offsets.size(0) > 0) {
outgoing_offsets = outgoing_offsets + delta_outgoing_edges.size(0);
outgoing_offsets = torch::cat({delta_outgoing_offsets, outgoing_offsets}, 0);
}
Expand Down Expand Up @@ -632,7 +632,7 @@ torch::Tensor LayeredNeighborSampler::computeDeltaIdsHelperMethod1(torch::Tensor

auto device_options = torch::TensorOptions().dtype(torch::kInt64).device(node_ids.device());
std::vector<torch::Tensor> sub_deltas = std::vector<torch::Tensor>(num_threads);
int64_t upper_bound = (int64_t)(delta_incoming_edges.size(0) + delta_outgoing_edges.size(0)) / num_threads;
int64_t upper_bound = (int64_t)(delta_incoming_edges.size(0) + delta_outgoing_edges.size(0)) / num_threads + 1;

std::vector<int> sub_counts = std::vector<int>(num_threads, 0);
std::vector<int> sub_offsets = std::vector<int>(num_threads, 0);
Expand Down
11 changes: 8 additions & 3 deletions src/cpp/src/nn/layers/gnn/graph_sage_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ torch::Tensor GraphSageLayer::forward(torch::Tensor inputs, DENSEGraph dense_gra
a_i = a_i / (total_num_neighbors + 1).unsqueeze(-1);
outputs = torch::matmul(w1_, a_i.transpose(0, -1)).transpose(0, -1);
} else if (options_->aggregator == GraphSageAggregator::MEAN) {
torch::Tensor denominator = torch::where(torch::not_equal(total_num_neighbors, 0), total_num_neighbors, 1).to(a_i.dtype()).unsqueeze(-1);
a_i = a_i / denominator;
outputs = (torch::matmul(w1_, self_embs.transpose(0, -1)) + torch::matmul(w2_, a_i.transpose(0, -1))).transpose(0, -1);
if (total_num_neighbors.defined()) {
torch::Tensor denominator = torch::where(torch::not_equal(total_num_neighbors, 0), total_num_neighbors, 1).to(a_i.dtype()).unsqueeze(-1);
a_i = a_i / denominator;
outputs = (torch::matmul(w1_, self_embs.transpose(0, -1)) + torch::matmul(w2_, a_i.transpose(0, -1))).transpose(0, -1);
} else {
outputs = torch::matmul(w1_, self_embs.transpose(0, -1)).transpose(0, -1);
}

} else {
throw std::runtime_error("Unrecognized aggregator");
}
Expand Down

0 comments on commit 841145c

Please sign in to comment.