From 17fa3bbc3d08e5f5837e602d23392d8dfaa7f142 Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Fri, 26 Jul 2024 16:29:14 +0200 Subject: [PATCH] Allow (somewhat) different input value types for merge (#2075) * Add a cuDF inspired test for merge_by_key * Allow CUB MergePath to support iterators with different value types * Allow different input value types for merge, as long as they are convertible to the value type of the first iterator. This weakens the publicly documented guarantees of equal value types to restore the old behavior of the thrust implementation replaced in #1817. --- cub/cub/agent/agent_merge.cuh | 1 + cub/cub/agent/agent_merge_sort.cuh | 9 +-- cub/cub/block/block_merge_sort.cuh | 8 +-- cub/cub/device/dispatch/dispatch_merge.cuh | 23 ++------ thrust/testing/merge_by_key.cu | 66 ++++++++++++++++++++++ 5 files changed, 81 insertions(+), 26 deletions(-) diff --git a/cub/cub/agent/agent_merge.cuh b/cub/cub/agent/agent_merge.cuh index f6403c4e93f..adf75535172 100644 --- a/cub/cub/agent/agent_merge.cuh +++ b/cub/cub/agent/agent_merge.cuh @@ -59,6 +59,7 @@ struct agent_t { using policy = Policy; + // key and value type are taken from the first input sequence (consistent with old Thrust behavior) using key_type = typename ::cuda::std::iterator_traits::value_type; using item_type = typename ::cuda::std::iterator_traits::value_type; diff --git a/cub/cub/agent/agent_merge_sort.cuh b/cub/cub/agent/agent_merge_sort.cuh index 380a6ca6f08..123abb2b986 100644 --- a/cub/cub/agent/agent_merge_sort.cuh +++ b/cub/cub/agent/agent_merge_sort.cuh @@ -382,8 +382,9 @@ gmem_to_reg(T (&output)[ITEMS_PER_THREAD], It1 input1, It2 input2, int count1, i #pragma unroll for (int item = 0; item < ITEMS_PER_THREAD; ++item) { - int idx = BLOCK_THREADS * item + threadIdx.x; - output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; + const int idx = BLOCK_THREADS * item + threadIdx.x; + // It1 and It2 could have different value types. Convert after load. + output[item] = (idx < count1) ? static_cast(input1[idx]) : static_cast(input2[idx - count1]); } } else @@ -391,10 +392,10 @@ gmem_to_reg(T (&output)[ITEMS_PER_THREAD], It1 input1, It2 input2, int count1, i #pragma unroll for (int item = 0; item < ITEMS_PER_THREAD; ++item) { - int idx = BLOCK_THREADS * item + threadIdx.x; + const int idx = BLOCK_THREADS * item + threadIdx.x; if (idx < count1 + count2) { - output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; + output[item] = (idx < count1) ? static_cast(input1[idx]) : static_cast(input2[idx - count1]); } } } diff --git a/cub/cub/block/block_merge_sort.cuh b/cub/cub/block/block_merge_sort.cuh index 1d3637a095a..5ca9500550f 100644 --- a/cub/cub/block/block_merge_sort.cuh +++ b/cub/cub/block/block_merge_sort.cuh @@ -57,17 +57,15 @@ template ::value_type; - static_assert(::cuda::std::is_same::value_type>::value, ""); - OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count; OffsetT keys1_end = (cub::min)(diag, keys1_count); while (keys1_begin < keys1_end) { const OffsetT mid = cub::MidPoint(keys1_begin, keys1_end); - const key_t key1 = keys1[mid]; - const key_t key2 = keys2[diag - 1 - mid]; + // pull copies of the keys before calling binary_pred so proxy references are unwrapped + const detail::value_t key1 = keys1[mid]; + const detail::value_t key2 = keys2[diag - 1 - mid]; if (binary_pred(key2, key1)) { keys1_end = mid; diff --git a/cub/cub/device/dispatch/dispatch_merge.cuh b/cub/cub/device/dispatch/dispatch_merge.cuh index 7a82eb61098..2c16d851448 100644 --- a/cub/cub/device/dispatch/dispatch_merge.cuh +++ b/cub/cub/device/dispatch/dispatch_merge.cuh @@ -64,11 +64,6 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void device_partition_merge_path_kernel( Offset* merge_partitions, CompareOp compare_op) { - static_assert( - ::cuda::std::is_convertible, value_t>::type, - bool>::value, - "Comparison operator must be convertible to bool"); - // items_per_tile must be the same of the merge kernel later, so we have to consider whether a fallback agent will be // selected for the merge agent that changes the tile size constexpr int items_per_tile = @@ -121,9 +116,12 @@ __launch_bounds__( Offset* merge_partitions, vsmem_t global_temp_storage) { + // the merge agent loads keys into a local array of KeyIt1::value_type, on which the comparisons are performed + using key_t = value_t; + static_assert(::cuda::std::__invokable::value, + "Comparison operator cannot compare two keys"); static_assert( - ::cuda::std::is_convertible, value_t>::type, - bool>::value, + ::cuda::std::is_convertible::type, bool>::value, "Comparison operator must be convertible to bool"); using MergeAgent = typename choose_merge_agent< @@ -218,15 +216,6 @@ template , value_t>> struct dispatch_t { - using key_t = cub::detail::value_t; - using value_t = cub::detail::value_t; - - // Cannot check output iterators, since they could be discard iterators, which do not have the right value_type - static_assert(::cuda::std::is_same, key_t>::value, ""); - static_assert(::cuda::std::is_same, value_t>::value, ""); - static_assert(::cuda::std::__invokable::value, - "Comparison operator cannot compare two keys"); - void* d_temp_storage; std::size_t& temp_storage_bytes; KeyIt1 d_keys1; @@ -351,7 +340,7 @@ struct dispatch_t { return error; } - dispatch_t dispatch{std::forward(args)...}; + dispatch_t dispatch{::cuda::std::forward(args)...}; error = CubDebug(PolicyHub::max_policy::Invoke(ptx_version, dispatch)); if (cudaSuccess != error) { diff --git a/thrust/testing/merge_by_key.cu b/thrust/testing/merge_by_key.cu index 411db0794f3..d7e34676956 100644 --- a/thrust/testing/merge_by_key.cu +++ b/thrust/testing/merge_by_key.cu @@ -1,6 +1,8 @@ #include +#include #include #include +#include #include #include #include @@ -253,3 +255,67 @@ void TestMergeByKeyDescending(size_t n) TestMergeByKey>(n); } DECLARE_VARIABLE_UNITTEST(TestMergeByKeyDescending); + +struct def_level_fn +{ + _CCCL_DEVICE std::uint32_t operator()(int i) const + { + return static_cast(i + 10); + } +}; + +struct offset_transform +{ + _CCCL_DEVICE int operator()(int i) const + { + return i + 1; + } +}; + +// Tests the use of thrust::merge_by_key similar to cuDF in +// https://github.com/rapidsai/cudf/blob/branch-24.08/cpp/src/lists/dremel.cu#L413 +void TestMergeByKeyFromCuDFDremel() +{ + // TODO(bgruber): I have no idea what this code is actually computing, but I tried to replicate the types/iterators + constexpr std::ptrdiff_t empties_size = 123; + constexpr int max_vals_size = 225; + constexpr int level = 4; + constexpr int curr_rep_values_size = 0; + + thrust::device_vector empties(empties_size, 42); + thrust::device_vector empties_idx(empties_size, 13); + + thrust::device_vector temp_rep_vals(max_vals_size); + thrust::device_vector temp_def_vals(max_vals_size); + thrust::device_vector rep_level(max_vals_size); + thrust::device_vector def_level(max_vals_size); + + auto offset_transformer = offset_transform{}; + auto transformed_empties = thrust::make_transform_iterator(empties.begin(), offset_transformer); + + auto input_parent_rep_it = thrust::make_constant_iterator(level); + auto input_parent_def_it = thrust::make_transform_iterator(empties_idx.begin(), def_level_fn{}); + auto input_parent_zip_it = thrust::make_zip_iterator(input_parent_rep_it, input_parent_def_it); + auto input_child_zip_it = thrust::make_zip_iterator(temp_rep_vals.begin(), temp_def_vals.begin()); + auto output_zip_it = thrust::make_zip_iterator(rep_level.begin(), def_level.begin()); + + thrust::merge_by_key( + transformed_empties, + transformed_empties + empties_size, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(curr_rep_values_size), + input_parent_zip_it, + input_child_zip_it, + thrust::make_discard_iterator(), + output_zip_it); + + thrust::device_vector reference_rep_level(max_vals_size); + thrust::fill(reference_rep_level.begin(), reference_rep_level.begin() + empties_size, level); + + thrust::device_vector reference_def_level(max_vals_size); + thrust::fill(reference_def_level.begin(), reference_def_level.begin() + empties_size, 13 + 10); + + ASSERT_EQUAL(reference_rep_level, rep_level); + ASSERT_EQUAL(reference_def_level, def_level); +} +DECLARE_UNITTEST(TestMergeByKeyFromCuDFDremel);