Skip to content

Commit

Permalink
parallel load
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger Waleffe authored and Roger Waleffe committed Nov 21, 2023
1 parent 8f17391 commit 6c55243
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/cpp/src/data/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,17 +691,21 @@ void DataLoader::loadCPUParameters(shared_ptr<Batch> batch, int id, bool load) {

if (load) {
std::cout<<"load\n";
torch::Tensor unique_features = graph_storage_->getNodeFeatures(unique_indices);
// torch::Tensor unique_features = graph_storage_->getNodeFeatures(unique_indices);

int count = 0;
int split_size = (int) ceil((float) unique_features.size(0) / batch->sub_batches_.size());

#pragma omp parallel for
for (int i = 0; i < batch->sub_batches_.size(); i++) {
int size = split_size;
if (count + split_size > unique_features.size(0))
size = unique_features.size(0) - count;
int start = i*split_size;

if (start + size > unique_indices.size(0)) {
size = unique_indices.size(0) - start;
}

batch->sub_batches_[i]->node_features_ = unique_features.narrow(0, count, size);
count += size;
// batch->sub_batches_[i]->node_features_ = unique_features.narrow(0, count, size);
batch->sub_batches_[i]->node_features_ = graph_storage_->getNodeFeatures(unique_indices.narrow(0, count, size));
}
}

Expand Down

0 comments on commit 6c55243

Please sign in to comment.