From a5abd3e504b8f1b66b50fbe62191e5b692caea96 Mon Sep 17 00:00:00 2001 From: Michael Schellenberger Costa Date: Wed, 24 Jul 2024 09:44:21 +0200 Subject: [PATCH 1/2] Avoid issues with ternary operator and proxy references Our beloved `tuple_of_iterator_references` does nto like ternary operators because that does not understand the underlying value txypes of the proxies. --- cub/cub/agent/agent_merge_sort.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cub/cub/agent/agent_merge_sort.cuh b/cub/cub/agent/agent_merge_sort.cuh index 380a6ca6f08..3dfecbacc06 100644 --- a/cub/cub/agent/agent_merge_sort.cuh +++ b/cub/cub/agent/agent_merge_sort.cuh @@ -383,7 +383,7 @@ gmem_to_reg(T (&output)[ITEMS_PER_THREAD], It1 input1, It2 input2, int count1, i for (int item = 0; item < ITEMS_PER_THREAD; ++item) { int idx = BLOCK_THREADS * item + threadIdx.x; - output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; + output[item] = (idx < count1) ? static_cast(input1[idx]) : static_cast(input2[idx - count1]); } } else @@ -394,7 +394,7 @@ gmem_to_reg(T (&output)[ITEMS_PER_THREAD], It1 input1, It2 input2, int count1, i int idx = BLOCK_THREADS * item + threadIdx.x; if (idx < count1 + count2) { - output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; + output[item] = (idx < count1) ? static_cast(input1[idx]) : static_cast(input2[idx - count1]); } } } From 5629abfbee67aae5bbcfa06063931619937612dd Mon Sep 17 00:00:00 2001 From: Michael Schellenberger Costa Date: Wed, 24 Jul 2024 09:50:33 +0200 Subject: [PATCH 2/2] Do not require merge to have identical value types for both inputs --- cub/cub/device/dispatch/dispatch_merge.cuh | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cub/cub/device/dispatch/dispatch_merge.cuh b/cub/cub/device/dispatch/dispatch_merge.cuh index 7a82eb61098..7b9ec6b4dd9 100644 --- a/cub/cub/device/dispatch/dispatch_merge.cuh +++ b/cub/cub/device/dispatch/dispatch_merge.cuh @@ -218,13 +218,8 @@ template , value_t>> struct dispatch_t { - using key_t = cub::detail::value_t; - using value_t = cub::detail::value_t; - // Cannot check output iterators, since they could be discard iterators, which do not have the right value_type - static_assert(::cuda::std::is_same, key_t>::value, ""); - static_assert(::cuda::std::is_same, value_t>::value, ""); - static_assert(::cuda::std::__invokable::value, + static_assert(::cuda::std::__invokable, cub::detail::value_t>::value, "Comparison operator cannot compare two keys"); void* d_temp_storage; @@ -351,7 +346,7 @@ struct dispatch_t { return error; } - dispatch_t dispatch{std::forward(args)...}; + dispatch_t dispatch{::cuda::std::forward(args)...}; error = CubDebug(PolicyHub::max_policy::Invoke(ptx_version, dispatch)); if (cudaSuccess != error) {