Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
gather with sorted indices.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 committed May 29, 2024
1 parent 7352f1c commit 2a183f4
Show file tree
Hide file tree
Showing 13 changed files with 401 additions and 31 deletions.
81 changes: 80 additions & 1 deletion cpp/src/wholememory_ops/functions/gather_func.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -32,6 +34,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -40,6 +44,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -48,6 +54,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand Down Expand Up @@ -76,6 +84,75 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t,
void* indices,
wholememory_array_description_t,
bool,
void*,
void*,
wholememory_matrix_description_t,
cudaStream_t,
int) = nullptr;
if (embedding_is_float) {
if (indices_desc.dtype == WHOLEMEMORY_DT_INT) {
p_gather_func = gather_floating_int32_func;
} else {
p_gather_func = gather_floating_int64_func;
}
} else {
if (indices_desc.dtype == WHOLEMEMORY_DT_INT) {
p_gather_func = gather_integer_int32_func;
} else {
p_gather_func = gather_integer_int64_func;
}
}
return p_gather_func(embedding_gref,
embedding_desc,
indices,
indices_desc,
false,
nullptr,
output,
output_desc,
stream,
gather_sms);
} catch (const wholememory::cuda_error& rle) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (const wholememory::logic_error& le) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (...) {
return WHOLEMEMORY_LOGIC_ERROR;
}
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t gather_with_sorted_ids_func(
wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
void* raw_indices,
wholememory_array_description_t raw_indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
try {
bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype);
WHOLEMEMORY_CHECK(embedding_is_float ||
wholememory_dtype_is_integer_number(embedding_desc.dtype));
bool output_is_float = wholememory_dtype_is_floating_number(output_desc.dtype);
WHOLEMEMORY_CHECK(output_is_float || wholememory_dtype_is_integer_number(output_desc.dtype));
WHOLEMEMORY_EXPECTS(
embedding_is_float == output_is_float,
"embedding and output should be same number type, e.g. floating number or integer number.");
if (indices_desc.size == 0) { return WHOLEMEMORY_SUCCESS; }
WHOLEMEMORY_CHECK(indices_desc.size == raw_indices_desc.size);
WHOLEMEMORY_CHECK(indices_desc.dtype == raw_indices_desc.dtype);
wholememory_error_code_t (*p_gather_func)(wholememory_gref_t,
wholememory_matrix_description_t,
void* indices,
wholememory_array_description_t,
bool,
void*,
void*,
wholememory_matrix_description_t,
cudaStream_t,
Expand All @@ -97,6 +174,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
embedding_desc,
indices,
indices_desc,
true,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,13 +27,23 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int32_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
gather_temp_func<EmbeddingT, int32_t, OutputT>(embedding_gref,
embedding_desc,
indices,
indice_count,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt32,
Expand All @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,13 +27,23 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int64_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
gather_temp_func<EmbeddingT, int64_t, OutputT>(embedding_gref,
embedding_desc,
indices,
indice_count,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt64,
Expand All @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,13 +27,23 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int32_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
gather_temp_func<EmbeddingT, int32_t, OutputT>(embedding_gref,
embedding_desc,
indices,
indice_count,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt32,
Expand All @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,13 +27,23 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int64_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
gather_temp_func<EmbeddingT, int64_t, OutputT>(embedding_gref,
embedding_desc,
indices,
indice_count,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt64,
Expand All @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Loading

0 comments on commit 2a183f4

Please sign in to comment.