Skip to content

Commit

Permalink
Allow (somewhat) different input value types for merge (NVIDIA#2075)
Browse files Browse the repository at this point in the history
* Add a cuDF inspired test for merge_by_key
* Allow CUB MergePath to support iterators with different value types
* Allow different input value types for merge, as long as they are convertible to the value type of the first iterator. This weakens the publicly documented guarantees of equal value types to restore the old behavior of the thrust implementation replaced in NVIDIA#1817.
  • Loading branch information
bernhardmgruber authored and pciolkosz committed Aug 4, 2024
1 parent cdbc3e9 commit 72d5c40
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 26 deletions.
1 change: 1 addition & 0 deletions cub/cub/agent/agent_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct agent_t
{
using policy = Policy;

// key and value type are taken from the first input sequence (consistent with old Thrust behavior)
using key_type = typename ::cuda::std::iterator_traits<KeysIt1>::value_type;
using item_type = typename ::cuda::std::iterator_traits<ItemsIt1>::value_type;

Expand Down
9 changes: 5 additions & 4 deletions cub/cub/agent/agent_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -382,19 +382,20 @@ gmem_to_reg(T (&output)[ITEMS_PER_THREAD], It1 input1, It2 input2, int count1, i
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = BLOCK_THREADS * item + threadIdx.x;
output[item] = (idx < count1) ? input1[idx] : input2[idx - count1];
const int idx = BLOCK_THREADS * item + threadIdx.x;
// It1 and It2 could have different value types. Convert after load.
output[item] = (idx < count1) ? static_cast<T>(input1[idx]) : static_cast<T>(input2[idx - count1]);
}
}
else
{
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = BLOCK_THREADS * item + threadIdx.x;
const int idx = BLOCK_THREADS * item + threadIdx.x;
if (idx < count1 + count2)
{
output[item] = (idx < count1) ? input1[idx] : input2[idx - count1];
output[item] = (idx < count1) ? static_cast<T>(input1[idx]) : static_cast<T>(input2[idx - count1]);
}
}
}
Expand Down
8 changes: 3 additions & 5 deletions cub/cub/block/block_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,15 @@ template <typename KeyIt1, typename KeyIt2, typename OffsetT, typename BinaryPre
_CCCL_DEVICE _CCCL_FORCEINLINE OffsetT
MergePath(KeyIt1 keys1, KeyIt2 keys2, OffsetT keys1_count, OffsetT keys2_count, OffsetT diag, BinaryPred binary_pred)
{
using key_t = typename ::cuda::std::iterator_traits<KeyIt1>::value_type;
static_assert(::cuda::std::is_same<key_t, typename ::cuda::std::iterator_traits<KeyIt2>::value_type>::value, "");

OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count;
OffsetT keys1_end = (cub::min)(diag, keys1_count);

while (keys1_begin < keys1_end)
{
const OffsetT mid = cub::MidPoint<OffsetT>(keys1_begin, keys1_end);
const key_t key1 = keys1[mid];
const key_t key2 = keys2[diag - 1 - mid];
// pull copies of the keys before calling binary_pred so proxy references are unwrapped
const detail::value_t<KeyIt1> key1 = keys1[mid];
const detail::value_t<KeyIt2> key2 = keys2[diag - 1 - mid];
if (binary_pred(key2, key1))
{
keys1_end = mid;
Expand Down
23 changes: 6 additions & 17 deletions cub/cub/device/dispatch/dispatch_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void device_partition_merge_path_kernel(
Offset* merge_partitions,
CompareOp compare_op)
{
static_assert(
::cuda::std::is_convertible<typename ::cuda::std::__invoke_of<CompareOp, value_t<KeyIt1>, value_t<KeyIt1>>::type,
bool>::value,
"Comparison operator must be convertible to bool");

// items_per_tile must be the same of the merge kernel later, so we have to consider whether a fallback agent will be
// selected for the merge agent that changes the tile size
constexpr int items_per_tile =
Expand Down Expand Up @@ -121,9 +116,12 @@ __launch_bounds__(
Offset* merge_partitions,
vsmem_t global_temp_storage)
{
// the merge agent loads keys into a local array of KeyIt1::value_type, on which the comparisons are performed
using key_t = value_t<KeyIt1>;
static_assert(::cuda::std::__invokable<CompareOp, key_t, key_t>::value,
"Comparison operator cannot compare two keys");
static_assert(
::cuda::std::is_convertible<typename ::cuda::std::__invoke_of<CompareOp, value_t<KeyIt1>, value_t<KeyIt1>>::type,
bool>::value,
::cuda::std::is_convertible<typename ::cuda::std::__invoke_of<CompareOp, key_t, key_t>::type, bool>::value,
"Comparison operator must be convertible to bool");

using MergeAgent = typename choose_merge_agent<
Expand Down Expand Up @@ -218,15 +216,6 @@ template <typename KeyIt1,
typename PolicyHub = device_merge_policy_hub<value_t<KeyIt1>, value_t<ValueIt1>>>
struct dispatch_t
{
using key_t = cub::detail::value_t<KeyIt1>;
using value_t = cub::detail::value_t<ValueIt1>;

// Cannot check output iterators, since they could be discard iterators, which do not have the right value_type
static_assert(::cuda::std::is_same<cub::detail::value_t<KeyIt2>, key_t>::value, "");
static_assert(::cuda::std::is_same<cub::detail::value_t<ValueIt2>, value_t>::value, "");
static_assert(::cuda::std::__invokable<CompareOp, key_t, key_t>::value,
"Comparison operator cannot compare two keys");

void* d_temp_storage;
std::size_t& temp_storage_bytes;
KeyIt1 d_keys1;
Expand Down Expand Up @@ -351,7 +340,7 @@ struct dispatch_t
{
return error;
}
dispatch_t dispatch{std::forward<Args>(args)...};
dispatch_t dispatch{::cuda::std::forward<Args>(args)...};
error = CubDebug(PolicyHub::max_policy::Invoke(ptx_version, dispatch));
if (cudaSuccess != error)
{
Expand Down
66 changes: 66 additions & 0 deletions thrust/testing/merge_by_key.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <thrust/functional.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/retag.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/merge.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
Expand Down Expand Up @@ -253,3 +255,67 @@ void TestMergeByKeyDescending(size_t n)
TestMergeByKey<T, thrust::greater<T>>(n);
}
DECLARE_VARIABLE_UNITTEST(TestMergeByKeyDescending);

struct def_level_fn
{
_CCCL_DEVICE std::uint32_t operator()(int i) const
{
return static_cast<uint32_t>(i + 10);
}
};

struct offset_transform
{
_CCCL_DEVICE int operator()(int i) const
{
return i + 1;
}
};

// Tests the use of thrust::merge_by_key similar to cuDF in
// https://github.com/rapidsai/cudf/blob/branch-24.08/cpp/src/lists/dremel.cu#L413
void TestMergeByKeyFromCuDFDremel()
{
// TODO(bgruber): I have no idea what this code is actually computing, but I tried to replicate the types/iterators
constexpr std::ptrdiff_t empties_size = 123;
constexpr int max_vals_size = 225;
constexpr int level = 4;
constexpr int curr_rep_values_size = 0;

thrust::device_vector<int> empties(empties_size, 42);
thrust::device_vector<int> empties_idx(empties_size, 13);

thrust::device_vector<std::uint8_t> temp_rep_vals(max_vals_size);
thrust::device_vector<std::uint8_t> temp_def_vals(max_vals_size);
thrust::device_vector<std::uint8_t> rep_level(max_vals_size);
thrust::device_vector<std::uint8_t> def_level(max_vals_size);

auto offset_transformer = offset_transform{};
auto transformed_empties = thrust::make_transform_iterator(empties.begin(), offset_transformer);

auto input_parent_rep_it = thrust::make_constant_iterator(level);
auto input_parent_def_it = thrust::make_transform_iterator(empties_idx.begin(), def_level_fn{});
auto input_parent_zip_it = thrust::make_zip_iterator(input_parent_rep_it, input_parent_def_it);
auto input_child_zip_it = thrust::make_zip_iterator(temp_rep_vals.begin(), temp_def_vals.begin());
auto output_zip_it = thrust::make_zip_iterator(rep_level.begin(), def_level.begin());

thrust::merge_by_key(
transformed_empties,
transformed_empties + empties_size,
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(curr_rep_values_size),
input_parent_zip_it,
input_child_zip_it,
thrust::make_discard_iterator(),
output_zip_it);

thrust::device_vector<std::uint8_t> reference_rep_level(max_vals_size);
thrust::fill(reference_rep_level.begin(), reference_rep_level.begin() + empties_size, level);

thrust::device_vector<std::uint8_t> reference_def_level(max_vals_size);
thrust::fill(reference_def_level.begin(), reference_def_level.begin() + empties_size, 13 + 10);

ASSERT_EQUAL(reference_rep_level, rep_level);
ASSERT_EQUAL(reference_def_level, def_level);
}
DECLARE_UNITTEST(TestMergeByKeyFromCuDFDremel);

0 comments on commit 72d5c40

Please sign in to comment.