Skip to content

Commit

Permalink
runtime fiddling
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger Waleffe authored and Roger Waleffe committed Nov 21, 2023
1 parent cd55db7 commit 8bbd9e6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/cpp/src/data/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void Batch::remoteTo(shared_ptr<c10d::ProcessGroupGloo> pg, int worker_id, int t
}

if (sub_batches_.size() > 0) {
#pragma omp parallel for // TODO: need to look at whether this works or not (e.g., parallel sending)
// #pragma omp parallel for // TODO: need to look at whether this works or not (e.g., parallel sending)
for (int i = 0; i < sub_batches_.size(); i++) {
sub_batches_[i]->remoteTo(pg, worker_id, tag+i, false);
}
Expand Down Expand Up @@ -164,7 +164,7 @@ void Batch::remoteReceive(shared_ptr<c10d::ProcessGroupGloo> pg, int worker_id,
}

if (sub_batches_.size() > 0) {
#pragma omp parallel for // TODO: need to look at whether this works or not (e.g., parallel sending)
// #pragma omp parallel for // TODO: need to look at whether this works or not (e.g., parallel sending)
for (int i = 0; i < sub_batches_.size(); i++) {
sub_batches_[i]->remoteReceive(pg, worker_id, tag + i, false);
}
Expand Down
21 changes: 17 additions & 4 deletions src/cpp/src/data/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,20 @@ void DENSEGraph::to(torch::Device device, CudaStream *compute_stream, CudaStream
}

void DENSEGraph::send(shared_ptr<c10d::ProcessGroupGloo> pg, int worker_id, int tag) {
send_tensor(node_ids_, pg, worker_id, tag);
send_tensor(hop_offsets_, pg, worker_id, tag);

send_tensor(node_ids_, pg, worker_id, tag);
send_tensor(out_offsets_, pg, worker_id, tag);

send_tensor(in_offsets_, pg, worker_id, tag);
// send_tensor(torch::cat({out_offsets_, in_offsets_}, 0), pg, worker_id, tag);
//
// torch::Tensor tmp = torch::cat({out_neighbors_vec_}, 0);
// out_neighbors_vec_ = {};
// out_neighbors_vec_.emplace_back(tmp);
//
// tmp = torch::cat({in_neighbors_vec_}, 0);
// in_neighbors_vec_ = {};
// in_neighbors_vec_.emplace_back(tmp);

int in_size = in_neighbors_vec_.size();
int out_size = out_neighbors_vec_.size();
Expand All @@ -326,12 +334,17 @@ void DENSEGraph::send(shared_ptr<c10d::ProcessGroupGloo> pg, int worker_id, int
}

void DENSEGraph::receive(shared_ptr<c10d::ProcessGroupGloo> pg, int worker_id, int tag) {
node_ids_ = receive_tensor(pg, worker_id, tag);
hop_offsets_ = receive_tensor(pg, worker_id, tag);

node_ids_ = receive_tensor(pg, worker_id, tag);
out_offsets_ = receive_tensor(pg, worker_id, tag);

in_offsets_ = receive_tensor(pg, worker_id, tag);
//
// torch::Tensor offsets = receive_tensor(pg, worker_id, tag);
// out_offsets_ = offsets.narrow(0, 0, offsets.size(0)/2);
// in_offsets_ = offsets.narrow(0, offsets.size(0)/2, offsets.size(0)/2);

// if (!torch::equal(out_offsets_, out_offsets1_)) throw MariusRuntimeException("");

torch::Tensor metadata = torch::tensor({-1, -1, -1, -1}, {torch::kInt64});
std::vector<torch::Tensor> transfer_vec;
Expand Down
8 changes: 7 additions & 1 deletion src/cpp/src/pipeline/pipeline_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,20 @@ void BatchToDeviceWorker::run() {
}

if (batch->sub_batches_.size() > 0) {
#pragma omp parallel for
// #pragma omp parallel for # TODO
for (int i = 0; i < batch->sub_batches_.size(); i++) {
if (!batch->sub_batches_[i]->node_features_.defined())
batch->sub_batches_[i]->node_features_ = pipeline_->dataloader_->graph_storage_->getNodeFeatures(batch->sub_batches_[i]->unique_node_indices_);
// batch->sub_batches_[i]->node_labels_ = pipeline_->dataloader_->graph_storage_->getNodeLabels(
// batch->sub_batches_[i]->dense_graph_.node_ids_.narrow(0, batch->sub_batches_[i]->dense_graph_.hop_offsets_[-2].item<int64_t>(),
// (batch->sub_batches_[i]->dense_graph_.node_ids_.size(0)-batch->sub_batches_[i]->dense_graph_.hop_offsets_[-2]).item<int64_t>())).flatten(0, 1);
}
} else {
if (!batch->node_features_.defined())
batch->node_features_ = pipeline_->dataloader_->graph_storage_->getNodeFeatures(batch->unique_node_indices_);
// batch->node_labels_ = pipeline_->dataloader_->graph_storage_->getNodeLabels(
// batch->dense_graph_.node_ids_.narrow(0, batch->dense_graph_.hop_offsets_[-2].item<int64_t>(),
// (batch->dense_graph_.node_ids_.size(0)-batch->dense_graph_.hop_offsets_[-2]).item<int64_t>())).flatten(0, 1);
}

batchToDevice(pipeline_, batch);
Expand Down

0 comments on commit 8bbd9e6

Please sign in to comment.