diff --git a/src/cpp/src/data/dataloader.cpp b/src/cpp/src/data/dataloader.cpp index a15e9447..f52f1ee5 100644 --- a/src/cpp/src/data/dataloader.cpp +++ b/src/cpp/src/data/dataloader.cpp @@ -411,7 +411,7 @@ void DataLoader::getBatchHelper(shared_ptr batch, int worker_id) { batch->dense_graph_.partition_size_ = graph_storage_->getPartitionSize(); } - loadCPUParameters(batch); +// loadCPUParameters(batch); } shared_ptr DataLoader::getBatch(at::optional device, bool perform_map, int worker_id) { @@ -480,6 +480,8 @@ shared_ptr DataLoader::getBatch(at::optional device, bool batch->sub_batches_ = sub_batches; +// loadCPUParameters(batch); + return batch; } @@ -645,7 +647,61 @@ void DataLoader::loadCPUParameters(shared_ptr batch) { if (only_root_features_) { batch->node_features_ = graph_storage_->getNodeFeatures(batch->root_node_indices_); } else { - batch->node_features_ = graph_storage_->getNodeFeatures(batch->unique_node_indices_); +// batch->node_features_ = graph_storage_->getNodeFeatures(batch->unique_node_indices_); + + + + if (batch->sub_batches_.size() > 0) { +// std::cout << "start\n"; + std::vector all_unique_nodes_vec(batch->sub_batches_.size()); +// int total_unique_nodes = 0; + +// #pragma omp parallel for # TODO + for (int i = 0; i < batch->sub_batches_.size(); i++) { + all_unique_nodes_vec[i] = batch->sub_batches_[i]->unique_node_indices_; +// total_unique_nodes += batch->sub_batches_[i]->unique_node_indices_.size(0); + +// std::cout << batch->sub_batches_[i]->unique_node_indices_.sizes() << " " +// << batch->sub_batches_[i]->unique_node_indices_.device() << "\n"; + } + +// std::cout << "cat\n"; + torch::Tensor all_unique_nodes = torch::cat({all_unique_nodes_vec}, 0); +// std::cout << all_unique_nodes.sizes() << "\n"; + auto unique_nodes = torch::_unique2(all_unique_nodes, true, true, false); + torch::Tensor unique_indices = std::get<0>(unique_nodes); + torch::Tensor inverse = std::get<1>(unique_nodes); + torch::Tensor unique_features = graph_storage_->getNodeFeatures(unique_indices); +// std::cout << unique_indices.sizes() << "\n"; +// std::cout << inverse.sizes() << " " << inverse.device() << "\n"; +// std::cout << unique_features.sizes() << " " << unique_features.device() << "\n"; + std::cout<sub_batches_.size()); + for (int i = 0; i < batch->sub_batches_.size(); i++) { + if (!batch->sub_batches_[i]->node_features_.defined()) { + batch->sub_batches_[i]->unique_node_indices_ = inverse.narrow(0, count, batch->sub_batches_[i]->unique_node_indices_.size(0)); + count += batch->sub_batches_[i]->unique_node_indices_.size(0); +// std::cout << batch->sub_batches_[i]->unique_node_indices_.sizes() << "\n"; + + int size = split_size; + if (count1 + split_size > unique_features.size(0)) size = unique_features.size(0) - count1; + batch->sub_batches_[i]->node_features_ = unique_features.narrow(0, count1, size); + count1 += size; +// std::cout << batch->sub_batches_[i]->node_features_.sizes() << "\n"; + } + } +// std::cout << "end\n"; + } else { + batch->node_features_ = graph_storage_->getNodeFeatures(batch->unique_node_indices_); + } + + + + } } } diff --git a/src/cpp/src/pipeline/pipeline_gpu.cpp b/src/cpp/src/pipeline/pipeline_gpu.cpp index 817d4049..58cd2bfe 100644 --- a/src/cpp/src/pipeline/pipeline_gpu.cpp +++ b/src/cpp/src/pipeline/pipeline_gpu.cpp @@ -151,17 +151,13 @@ void BatchToDeviceWorker::run() { } if (batch->sub_batches_.size() > 0) { -// #pragma omp parallel for # TODO - for (int i = 0; i < batch->sub_batches_.size(); i++) { - if (!batch->sub_batches_[i]->node_features_.defined()) - batch->sub_batches_[i]->node_features_ = pipeline_->dataloader_->graph_storage_->getNodeFeatures(batch->sub_batches_[i]->unique_node_indices_); -// batch->sub_batches_[i]->node_labels_ = pipeline_->dataloader_->graph_storage_->getNodeLabels( -// batch->sub_batches_[i]->dense_graph_.node_ids_.narrow(0, batch->sub_batches_[i]->dense_graph_.hop_offsets_[-2].item(), -// (batch->sub_batches_[i]->dense_graph_.node_ids_.size(0)-batch->sub_batches_[i]->dense_graph_.hop_offsets_[-2]).item())).flatten(0, 1); + if (!batch->sub_batches_[0]->node_features_.defined()) { + pipeline_->dataloader_->loadCPUParameters(batch); } } else { if (!batch->node_features_.defined()) - batch->node_features_ = pipeline_->dataloader_->graph_storage_->getNodeFeatures(batch->unique_node_indices_); + pipeline_->dataloader_->loadCPUParameters(batch); +// batch->node_features_ = pipeline_->dataloader_->graph_storage_->getNodeFeatures(batch->unique_node_indices_); // batch->node_labels_ = pipeline_->dataloader_->graph_storage_->getNodeLabels( // batch->dense_graph_.node_ids_.narrow(0, batch->dense_graph_.hop_offsets_[-2].item(), // (batch->dense_graph_.node_ids_.size(0)-batch->dense_graph_.hop_offsets_[-2]).item())).flatten(0, 1); @@ -205,11 +201,48 @@ void ComputeWorkerGPU::run() { streams_for_multi_guard.emplace_back(*(pipeline_->dataloader_->compute_streams_[i])); } + + int unique_size = 0; + int feat_dim = batch->sub_batches_[0]->node_features_.size(1); + for (int i = 0; i < batch->sub_batches_.size(); i++) { + unique_size += batch->sub_batches_[i]->node_features_.size(0); + } + std::vector unique_features_per_gpu(batch->sub_batches_.size()); + +// std::cout<<"start"<<"\n"; +// std::cout<sub_batches_.size(); i++) { CudaStreamGuard stream_guard(*(pipeline_->dataloader_->compute_streams_[i])); + auto device_options = torch::TensorOptions().dtype(torch::kFloat16).device(batch->sub_batches_[i]->node_features_.device()); + + torch::Tensor unique_node_features = torch::zeros({unique_size, feat_dim}, device_options); +// std::cout<sub_batches_.size(); j++) { + unique_node_features.narrow(0, count, batch->sub_batches_[j]->node_features_.size(0)).copy_(batch->sub_batches_[j]->node_features_); + count += batch->sub_batches_[j]->node_features_.size(0); +// std::cout<sub_batches_.size(); i++) { + CudaStreamGuard stream_guard(*(pipeline_->dataloader_->compute_streams_[i])); + auto device_options = torch::TensorOptions().dtype(torch::kFloat16).device(batch->sub_batches_[i]->node_features_.device()); + + batch->sub_batches_[i]->node_features_ = torch::zeros({batch->sub_batches_[i]->unique_node_indices_.size(0), feat_dim}, device_options); + torch::index_select_out(batch->sub_batches_[i]->node_features_, unique_features_per_gpu[i], 0, batch->sub_batches_[i]->unique_node_indices_); +// std::cout<sub_batches_[i]->node_features_.sizes()<<"\n"; + pipeline_->model_->device_models_[i]->clear_grad(); pipeline_->model_->device_models_[i]->train_batch(batch->sub_batches_[i], false); }