Skip to content

Commit

Permalink
remove sorting for distributed training in memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger Waleffe authored and Roger Waleffe committed Nov 21, 2023
1 parent 8bbd9e6 commit 5674f3e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 31 deletions.
22 changes: 11 additions & 11 deletions src/cpp/src/data/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,17 +599,17 @@ void DataLoader::nodeSample(shared_ptr<Batch> batch, int worker_id) {

if (graph_storage_->useInMemorySubGraph()) {
// batch->root_node_indices_ = graph_storage_->current_subgraph_state_->global_to_local_index_map_.index_select(0, batch->root_node_indices_);
auto root_node_accessor = batch->root_node_indices_.accessor<int64_t, 1>();
auto buffer_offsets_accessor = graph_storage_->current_subgraph_state_->buffer_offsets_.accessor<int64_t, 1>();
int64_t partition_size = graph_storage_->getPartitionSize();

#pragma omp parallel for
for (int i = 0; i < batch->root_node_indices_.size(0); i++) {
int64_t global_id = root_node_accessor[i];
int64_t partition = global_id / partition_size;

root_node_accessor[i] = buffer_offsets_accessor[partition] + global_id - (partition * partition_size);
}
// auto root_node_accessor = batch->root_node_indices_.accessor<int64_t, 1>();
// auto buffer_offsets_accessor = graph_storage_->current_subgraph_state_->buffer_offsets_.accessor<int64_t, 1>();
// int64_t partition_size = graph_storage_->getPartitionSize();
//
// #pragma omp parallel for
// for (int i = 0; i < batch->root_node_indices_.size(0); i++) {
// int64_t global_id = root_node_accessor[i];
// int64_t partition = global_id / partition_size;
//
// root_node_accessor[i] = buffer_offsets_accessor[partition] + global_id - (partition * partition_size);
// }

}

Expand Down
42 changes: 22 additions & 20 deletions src/cpp/src/storage/graph_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ void GraphModelStorage::initializeInMemorySubGraph(torch::Tensor buffer_state, i
auto in_mem_edge_bucket_starts_accessor = in_mem_edge_bucket_starts.accessor<int64_t, 1>();

current_subgraph_state_->all_in_memory_edges_ = torch::empty({total_size, storage_ptrs_.edges->dim1_size_}, torch::kInt64);
torch::Tensor tmp_dst_sort = torch::empty({total_size, storage_ptrs_.edges->dim1_size_}, torch::kInt64);

#pragma omp parallel for
for (int i = 0; i < num_edge_buckets_in_mem; i++) {
Expand All @@ -493,6 +494,7 @@ void GraphModelStorage::initializeInMemorySubGraph(torch::Tensor buffer_state, i

current_subgraph_state_->all_in_memory_edges_.narrow(0, local_offset, edge_bucket_size) =
storage_ptrs_.edges->range(edge_bucket_start, edge_bucket_size);
tmp_dst_sort.narrow(0, local_offset, edge_bucket_size) = storage_ptrs_.train_edges_dst_sort->range(edge_bucket_start, edge_bucket_size);
}

int64_t partition_size = -1;
Expand All @@ -512,7 +514,7 @@ void GraphModelStorage::initializeInMemorySubGraph(torch::Tensor buffer_state, i

// torch::Tensor mapped_edges;
// torch::Tensor mapped_edges_dst_sort;
torch::Tensor mapped_edges = torch::empty({total_size, storage_ptrs_.edges->dim1_size_}, torch::kInt64);
torch::Tensor mapped_edges = current_subgraph_state_->all_in_memory_edges_; // torch::empty({total_size, storage_ptrs_.edges->dim1_size_}, torch::kInt64);
torch::Tensor mapped_edges_dst_sort;
// if (storage_ptrs_.edges->dim1_size_ == 3) {
// mapped_edges =
Expand All @@ -531,30 +533,30 @@ void GraphModelStorage::initializeInMemorySubGraph(torch::Tensor buffer_state, i
// std::runtime_error("Unexpected number of edge columns");
// }

int dst_index = storage_ptrs_.edges->dim1_size_ - 1;
auto all_in_memory_edges_accessor = current_subgraph_state_->all_in_memory_edges_.accessor<int64_t, 2>();
auto mapped_edges_accessor = mapped_edges.accessor<int64_t, 2>();
auto buffer_offsets_accessor = buffer_offsets.accessor<int64_t, 1>();

#pragma omp parallel for
for (int i = 0; i < total_size; i++) {
int64_t src_global_id = all_in_memory_edges_accessor[i][0];
int64_t dst_global_id = all_in_memory_edges_accessor[i][dst_index];

int64_t src_partition = src_global_id / partition_size;
int64_t dst_partition = dst_global_id / partition_size;

mapped_edges_accessor[i][0] = buffer_offsets_accessor[src_partition] + src_global_id - (src_partition * partition_size);
if (dst_index == 2) mapped_edges_accessor[i][1] = all_in_memory_edges_accessor[i][1];
mapped_edges_accessor[i][dst_index] = buffer_offsets_accessor[dst_partition] + dst_global_id - (dst_partition * partition_size);
}
// int dst_index = storage_ptrs_.edges->dim1_size_ - 1;
// auto all_in_memory_edges_accessor = current_subgraph_state_->all_in_memory_edges_.accessor<int64_t, 2>();
// auto mapped_edges_accessor = mapped_edges.accessor<int64_t, 2>();
// auto buffer_offsets_accessor = buffer_offsets.accessor<int64_t, 1>();
//
// #pragma omp parallel for
// for (int i = 0; i < total_size; i++) {
// int64_t src_global_id = all_in_memory_edges_accessor[i][0];
// int64_t dst_global_id = all_in_memory_edges_accessor[i][dst_index];
//
// int64_t src_partition = src_global_id / partition_size;
// int64_t dst_partition = dst_global_id / partition_size;
//
// mapped_edges_accessor[i][0] = buffer_offsets_accessor[src_partition] + src_global_id - (src_partition * partition_size);
// if (dst_index == 2) mapped_edges_accessor[i][1] = all_in_memory_edges_accessor[i][1];
// mapped_edges_accessor[i][dst_index] = buffer_offsets_accessor[dst_partition] + dst_global_id - (dst_partition * partition_size);
// }

// current_subgraph_state_->all_in_memory_edges_ = torch::Tensor();

current_subgraph_state_->all_in_memory_mapped_edges_ = mapped_edges;

mapped_edges = merge_sorted_edge_buckets(mapped_edges, in_mem_edge_bucket_starts, buffer_size, true);
mapped_edges_dst_sort = merge_sorted_edge_buckets(mapped_edges, in_mem_edge_bucket_starts, buffer_size, false);
mapped_edges = mapped_edges;//merge_sorted_edge_buckets(mapped_edges, in_mem_edge_bucket_starts, buffer_size, true);
mapped_edges_dst_sort = tmp_dst_sort;//merge_sorted_edge_buckets(mapped_edges, in_mem_edge_bucket_starts, buffer_size, false);

mapped_edges = mapped_edges.to(torch::kInt64);
mapped_edges_dst_sort = mapped_edges_dst_sort.to(torch::kInt64);
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/storage/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ std::tuple<shared_ptr<Storage>, shared_ptr<Storage>, shared_ptr<Storage>, shared
case StorageBackend::FLAT_FILE: {
if (num_train != -1) {
train_edge_storage = std::make_shared<FlatFile>(train_filename, num_train, num_columns, dtype);
train_edge_storage_dst_sort = std::make_shared<InMemory>(train_dst_sort_filename, num_train, num_columns, dtype, torch::kCPU); // ARMADA tmp
}
if (num_valid != -1) {
valid_edge_storage = std::make_shared<FlatFile>(valid_filename, num_valid, num_columns, dtype);
Expand Down

0 comments on commit 5674f3e

Please sign in to comment.