Skip to content

Commit

Permalink
faster uniques
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger Waleffe authored and Roger Waleffe committed Nov 21, 2023
1 parent 0edaaa0 commit edb983f
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/cpp/include/data/dataloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class DataLoader {
*/
void unloadStorage(bool write = false) { graph_storage_->unload(write); }

torch::Tensor computeUniques(torch::Tensor node_ids, int64_t num_nodes_in_memory);

/**
* Gets the number of edges from the graph storage.
* @return Number of edges in the graph
Expand Down
117 changes: 110 additions & 7 deletions src/cpp/src/data/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ void DataLoader::loadCPUParameters(shared_ptr<Batch> batch) {
std::vector<torch::Tensor> all_unique_nodes_vec(batch->sub_batches_.size());
// int total_unique_nodes = 0;

// #pragma omp parallel for # TODO
// #pragma omp parallel for
for (int i = 0; i < batch->sub_batches_.size(); i++) {
all_unique_nodes_vec[i] = batch->sub_batches_[i]->unique_node_indices_;
// total_unique_nodes += batch->sub_batches_[i]->unique_node_indices_.size(0);
Expand All @@ -670,13 +670,17 @@ void DataLoader::loadCPUParameters(shared_ptr<Batch> batch) {
// std::cout << "cat\n";
torch::Tensor all_unique_nodes = torch::cat({all_unique_nodes_vec}, 0);
// std::cout << all_unique_nodes.sizes() << "\n";
auto unique_nodes = torch::_unique2(all_unique_nodes, true, true, false);
torch::Tensor unique_indices = std::get<0>(unique_nodes);
torch::Tensor inverse = std::get<1>(unique_nodes);
torch::Tensor unique_features = graph_storage_->getNodeFeatures(unique_indices);
// auto unique_nodes = torch::_unique2(all_unique_nodes, true, true, false);
// torch::Tensor unique_indices = std::get<0>(unique_nodes);
// torch::Tensor inverse = std::get<1>(unique_nodes);
// torch::Tensor unique_features = graph_storage_->getNodeFeatures(unique_indices);
// std::cout << unique_indices.sizes() << "\n";
// std::cout << inverse.sizes() << " " << inverse.device() << "\n";
// std::cout << unique_features.sizes() << " " << unique_features.device() << "\n";

torch::Tensor unique_indices = computeUniques(all_unique_nodes, graph_storage_->getNumNodesInMemory());
torch::Tensor unique_features = graph_storage_->getNodeFeatures(unique_indices);

std::cout<<unique_indices.size(0)<<" vs " <<all_unique_nodes.size(0)<<"\n";
t.stop();
std::cout<<"uniques: "<<t.getDuration()<<"\n";
Expand All @@ -687,9 +691,10 @@ void DataLoader::loadCPUParameters(shared_ptr<Batch> batch) {
int split_size = (int) ceil((float) unique_features.size(0) / batch->sub_batches_.size());
for (int i = 0; i < batch->sub_batches_.size(); i++) {
if (!batch->sub_batches_[i]->node_features_.defined()) {
batch->sub_batches_[i]->unique_node_indices_ = inverse.narrow(0, count, batch->sub_batches_[i]->unique_node_indices_.size(0));
count += batch->sub_batches_[i]->unique_node_indices_.size(0);
// batch->sub_batches_[i]->unique_node_indices_ = inverse.narrow(0, count, batch->sub_batches_[i]->unique_node_indices_.size(0));
// count += batch->sub_batches_[i]->unique_node_indices_.size(0);
// std::cout << batch->sub_batches_[i]->unique_node_indices_.sizes() << "\n";
batch->sub_batches_[i]->root_node_indices_ = unique_indices;

int size = split_size;
if (count1 + split_size > unique_features.size(0)) size = unique_features.size(0) - count1;
Expand Down Expand Up @@ -796,4 +801,102 @@ void DataLoader::loadStorage() {
}
}
}
}

torch::Tensor DataLoader::computeUniques(torch::Tensor node_ids, int64_t num_nodes_in_memory) {
unsigned int num_threads = 1;

#ifdef MARIUS_OMP
#pragma omp parallel
{
#pragma omp single
num_threads = omp_get_num_threads();
}
#endif

int64_t chunk_size = ceil((double)num_nodes_in_memory / num_threads);

auto bool_device_options = torch::TensorOptions().dtype(torch::kBool).device(node_ids.device());
torch::Tensor hash_map = torch::zeros({num_nodes_in_memory}, bool_device_options);

auto hash_map_accessor = hash_map.accessor<bool, 1>();
auto nodes_accessor = node_ids.accessor<int64_t, 1>();

#pragma omp parallel default(none) shared(hash_map_accessor, hash_map, node_ids, nodes_accessor) num_threads(num_threads)
{

#pragma omp for
for (int64_t j = 0; j < node_ids.size(0); j++) {
if (!hash_map_accessor[nodes_accessor[j]]) {
hash_map_accessor[nodes_accessor[j]] = 1;
}
}
}

auto device_options = torch::TensorOptions().dtype(torch::kInt64).device(node_ids.device());
std::vector<torch::Tensor> sub_deltas = std::vector<torch::Tensor>(num_threads);
int64_t upper_bound = (int64_t)(node_ids.size(0)) / num_threads + 1;

std::vector<int> sub_counts = std::vector<int>(num_threads, 0);
std::vector<int> sub_offsets = std::vector<int>(num_threads, 0);

#pragma omp parallel num_threads(num_threads)
{
#ifdef MARIUS_OMP
int tid = omp_get_thread_num();
#else
int tid = 0;
#endif

// if (tid == 30)
// std::cout<<"omp: "<<tid<<"\n";

sub_deltas[tid] = torch::empty({upper_bound}, device_options);
auto delta_ids_accessor = sub_deltas[tid].accessor<int64_t, 1>();

int64_t start = chunk_size * tid;
int64_t end = start + chunk_size;

if (end > num_nodes_in_memory) {
end = num_nodes_in_memory;
}

int private_count = 0;
int grow_count = 0;

#pragma unroll
for (int64_t j = start; j < end; j++) {
if (hash_map_accessor[j]) {
delta_ids_accessor[private_count++] = j;
// hash_map_accessor[j] = 0;
grow_count++;

if (grow_count == upper_bound) {
sub_deltas[tid] = torch::cat({sub_deltas[tid], torch::empty({upper_bound}, device_options)}, 0);
delta_ids_accessor = sub_deltas[tid].accessor<int64_t, 1>();
grow_count = 0;
}
}
}
sub_counts[tid] = private_count;
}

int count = 0;
for (auto c : sub_counts) {
count += c;
}

for (int k = 0; k < num_threads - 1; k++) {
sub_offsets[k + 1] = sub_offsets[k] + sub_counts[k];
}

torch::Tensor delta_ids = torch::empty({count}, device_options);

#pragma omp parallel for num_threads(num_threads)
for (int k = 0; k < num_threads; k++) {
if (sub_deltas[k].size(0) > 0)
delta_ids.narrow(0, sub_offsets[k], sub_counts[k]) = sub_deltas[k].narrow(0, 0, sub_counts[k]);
}

return delta_ids;
}
2 changes: 2 additions & 0 deletions src/cpp/src/pipeline/pipeline_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ void ComputeWorkerGPU::run() {
CudaStreamGuard stream_guard(*(pipeline_->dataloader_->compute_streams_[i]));
auto device_options = torch::TensorOptions().dtype(torch::kFloat16).device(batch->sub_batches_[i]->node_features_.device());

batch->sub_batches_[i]->unique_node_indices_ = torch::searchsorted(batch->sub_batches_[i]->root_node_indices_, batch->sub_batches_[i]->unique_node_indices_);

batch->sub_batches_[i]->node_features_ = torch::zeros({batch->sub_batches_[i]->unique_node_indices_.size(0), feat_dim}, device_options);
torch::index_select_out(batch->sub_batches_[i]->node_features_, unique_features_per_gpu[i], 0, batch->sub_batches_[i]->unique_node_indices_);
// std::cout<<batch->sub_batches_[i]->node_features_.sizes()<<"\n";
Expand Down

0 comments on commit edb983f

Please sign in to comment.