Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not require merge to have identical value types for both inputs #2054

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cub/cub/agent/agent_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ gmem_to_reg(T (&output)[ITEMS_PER_THREAD], It1 input1, It2 input2, int count1, i
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = BLOCK_THREADS * item + threadIdx.x;
output[item] = (idx < count1) ? input1[idx] : input2[idx - count1];
output[item] = (idx < count1) ? static_cast<T>(input1[idx]) : static_cast<T>(input2[idx - count1]);
}
}
else
Expand All @@ -394,7 +394,7 @@ gmem_to_reg(T (&output)[ITEMS_PER_THREAD], It1 input1, It2 input2, int count1, i
int idx = BLOCK_THREADS * item + threadIdx.x;
if (idx < count1 + count2)
{
output[item] = (idx < count1) ? input1[idx] : input2[idx - count1];
output[item] = (idx < count1) ? static_cast<T>(input1[idx]) : static_cast<T>(input2[idx - count1]);
}
}
}
Expand Down
9 changes: 2 additions & 7 deletions cub/cub/device/dispatch/dispatch_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,8 @@ template <typename KeyIt1,
typename PolicyHub = device_merge_policy_hub<value_t<KeyIt1>, value_t<ValueIt1>>>
struct dispatch_t
{
using key_t = cub::detail::value_t<KeyIt1>;
using value_t = cub::detail::value_t<ValueIt1>;

// Cannot check output iterators, since they could be discard iterators, which do not have the right value_type
static_assert(::cuda::std::is_same<cub::detail::value_t<KeyIt2>, key_t>::value, "");
static_assert(::cuda::std::is_same<cub::detail::value_t<ValueIt2>, value_t>::value, "");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: @miscco we require this in our docs:

InputIterator1 and InputIterator2 have the same value_type

If we don't want to enforce that, we should relax requirements in the docs (both Thrust and CUB) and add tests. I think it's fine to merge this PR without doing that, but we should at least file an issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So people are already doing it as we have seen with cuDF.

I believe we should think about something in the direction of is_assignable<value_t<OutIt>&, value_t<InIt1>> with the obvious caveat that it needs to work for discard iterators.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue it should be is_assignable<reference_t<OutIt>, value_t<InIt1>>. That would handle proxy references.

static_assert(::cuda::std::__invokable<CompareOp, key_t, key_t>::value,
static_assert(::cuda::std::__invokable<CompareOp, cub::detail::value_t<KeyIt1>, cub::detail::value_t<KeyIt1>>::value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bernhardmgruber as discussed in #1817 (comment), this is a breaking change and should go into device code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You seem to have added the static assert in the device code but haven't removed it from the dispatch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this static assertion about the return type:

::cuda::std::is_convertible<typename ::cuda::std::__invoke_of<CompareOp, key_t, key_t>::type, bool>::value

which I have removed. The offending static assert here was not commented at in PR #1817.

I thought querying whether a comparison operator is invokable with a given set of arguments should be fine in host code. Only asking for the return type is problematic. My reasoning was that it's much more user friendly to have a compile error here if the user passed a wrong operator, than many stacks deeper in the kernel code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only asking for the return type is problematic.

Any form of introspection of the signature is problematic, including:

Introspecting the parameter type of operator() is only supported in device code.

Let's move the check to device code.

Copy link
Collaborator

@gevtushenko gevtushenko Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical: see #1817 (comment)

Suggested change
static_assert(::cuda::std::__invokable<CompareOp, cub::detail::value_t<KeyIt1>, cub::detail::value_t<KeyIt1>>::value,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, you just marked the second static_assert on the review, but your textual comment applied to both. I understand now.

"Comparison operator cannot compare two keys");

void* d_temp_storage;
Expand Down Expand Up @@ -351,7 +346,7 @@ struct dispatch_t
{
return error;
}
dispatch_t dispatch{std::forward<Args>(args)...};
dispatch_t dispatch{::cuda::std::forward<Args>(args)...};
error = CubDebug(PolicyHub::max_policy::Invoke(ptx_version, dispatch));
if (cudaSuccess != error)
{
Expand Down
Loading