Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Merge pull request #178 from rapidsai/branch-24.06
Browse files Browse the repository at this point in the history
Forward-merge branch-24.06 into branch-24.08
  • Loading branch information
GPUtester authored May 29, 2024
2 parents c5bb685 + ae3748a commit 00365ee
Show file tree
Hide file tree
Showing 8 changed files with 566 additions and 13 deletions.
39 changes: 38 additions & 1 deletion cpp/src/wholegraph_ops/sample_comm.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,4 +57,41 @@ __global__ void sample_all_kernel(wholememory_gref_t wm_csr_row_ptr,
}
}
}

__device__ __forceinline__ int log2_up_device(int x)
{
if (x <= 2) return x - 1;
return 32 - __clz(x - 1);
}
template <typename IdType>
struct ExpandWithOffsetFunc {
const IdType* indptr;
IdType* indptr_shift;
int length;
__host__ __device__ auto operator()(int64_t tIdx)
{
indptr_shift[tIdx] = indptr[tIdx % length] + tIdx / length;
}
};

template <typename WMIdType, typename DegreeType>
struct ReduceForDegrees {
WMIdType* rowoffsets;
DegreeType* in_degree_ptr;
int length;
__host__ __device__ auto operator()(int64_t tIdx)
{
in_degree_ptr[tIdx] = rowoffsets[tIdx + length] - rowoffsets[tIdx];
}
};

template <typename DegreeType>
struct MinInDegreeFanout {
int max_sample_count;
__host__ __device__ auto operator()(DegreeType degree)
{
return min(static_cast<int>(degree), max_sample_count);
}
};

} // namespace wholegraph_ops
42 changes: 39 additions & 3 deletions cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,7 +41,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
}
WHOLEMEMORY_EXPECTS_NOTHROW(!csr_row_ptr_has_handle ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS,
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED,
"Memory type not supported.");
bool const csr_col_ptr_has_handle = wholememory_tensor_has_handle(wm_csr_col_ptr_tensor);
wholememory_memory_type_t csr_col_ptr_memory_type = WHOLEMEMORY_MT_NONE;
Expand All @@ -51,7 +52,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
}
WHOLEMEMORY_EXPECTS_NOTHROW(!csr_col_ptr_has_handle ||
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED ||
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS,
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED,
"Memory type not supported.");

auto csr_row_ptr_tensor_description =
Expand Down Expand Up @@ -108,6 +110,40 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
void* center_nodes = wholememory_tensor_get_data_pointer(center_nodes_tensor);
void* output_sample_offset = wholememory_tensor_get_data_pointer(output_sample_offset_tensor);

if (csr_col_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED &&
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED) {
wholememory_distributed_backend_t distributed_backend_row = wholememory_get_distributed_backend(
wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor));
wholememory_distributed_backend_t distributed_backend_col = wholememory_get_distributed_backend(
wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor));
if (distributed_backend_col == WHOLEMEMORY_DB_NCCL &&
distributed_backend_row == WHOLEMEMORY_DB_NCCL) {
wholememory_handle_t wm_csr_row_ptr_handle =
wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor);
wholememory_handle_t wm_csr_col_ptr_handle =
wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor);
return wholegraph_ops::wholegraph_csr_unweighted_sample_without_replacement_nccl(
wm_csr_row_ptr_handle,
wm_csr_col_ptr_handle,
csr_row_ptr_tensor_description,
csr_col_ptr_tensor_description,
center_nodes,
center_nodes_desc,
max_sample_count,
output_sample_offset,
output_sample_offset_desc,
output_dest_memory_context,
output_center_localid_memory_context,
output_edge_gid_memory_context,
random_seed,
p_env_fns,
static_cast<cudaStream_t>(stream));
} else {
WHOLEMEMORY_ERROR("Only NCCL communication backend is supported for sampling.");
return WHOLEMEMORY_INVALID_INPUT;
}
}

wholememory_gref_t wm_csr_row_ptr_gref, wm_csr_col_ptr_gref;
WHOLEMEMORY_RETURN_ON_FAIL(
wholememory_tensor_get_global_reference(wm_csr_row_ptr_tensor, &wm_csr_row_ptr_gref));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ __global__ void large_sample_kernel(
}
}

__device__ __forceinline__ int log2_up_device(int x)
{
if (x <= 2) return x - 1;
return 32 - __clz(x - 1);
}

template <typename IdType,
typename LocalIdType,
typename WMIdType,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,4 +37,21 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_ma
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);

wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_nccl(
wholememory_handle_t csr_row_wholememory_handle,
wholememory_handle_t csr_col_wholememory_handle,
wholememory_tensor_description_t wm_csr_row_ptr_desc,
wholememory_tensor_description_t wm_csr_col_ptr_desc,
void* center_nodes,
wholememory_array_description_t center_nodes_desc,
int max_sample_count,
void* output_sample_offset,
wholememory_array_description_t output_sample_offset_desc,
void* output_dest_memory_context,
void* output_center_localid_memory_context,
void* output_edge_gid_memory_context,
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);
} // namespace wholegraph_ops
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime_api.h>

#include <wholememory/env_func_ptrs.h>
#include <wholememory/wholememory.h>

#include "unweighted_sample_without_replacement_nccl_func.cuh"
#include "wholememory_ops/register.hpp"

namespace wholegraph_ops {

REGISTER_DISPATCH_TWO_TYPES(UnweightedSampleWithoutReplacementCSRNCCL,
wholegraph_csr_unweighted_sample_without_replacement_nccl_func,
SINT3264,
SINT3264)

wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_nccl(
wholememory_handle_t csr_row_wholememory_handle,
wholememory_handle_t csr_col_wholememory_handle,
wholememory_tensor_description_t wm_csr_row_ptr_desc,
wholememory_tensor_description_t wm_csr_col_ptr_desc,
void* center_nodes,
wholememory_array_description_t center_nodes_desc,
int max_sample_count,
void* output_sample_offset,
wholememory_array_description_t output_sample_offset_desc,
void* output_dest_memory_context,
void* output_center_localid_memory_context,
void* output_edge_gid_memory_context,
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
try {
DISPATCH_TWO_TYPES(center_nodes_desc.dtype,
wm_csr_col_ptr_desc.dtype,
UnweightedSampleWithoutReplacementCSRNCCL,
csr_row_wholememory_handle,
csr_col_wholememory_handle,
wm_csr_row_ptr_desc,
wm_csr_col_ptr_desc,
center_nodes,
center_nodes_desc,
max_sample_count,
output_sample_offset,
output_sample_offset_desc,
output_dest_memory_context,
output_center_localid_memory_context,
output_edge_gid_memory_context,
random_seed,
p_env_fns,
stream);

} catch (const wholememory::cuda_error& rle) {
// WHOLEMEMORY_FAIL_NOTHROW("%s", rle.what());
return WHOLEMEMORY_LOGIC_ERROR;
} catch (const wholememory::logic_error& le) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (...) {
return WHOLEMEMORY_LOGIC_ERROR;
}
return WHOLEMEMORY_SUCCESS;
}

} // namespace wholegraph_ops
Loading

0 comments on commit 00365ee

Please sign in to comment.