Skip to content

Commit

Permalink
Try harder to unwrap nested thrust::tuple_of_iterator_references (#…
Browse files Browse the repository at this point in the history
…1469)

* Ensure that we play nicely with std::iterators

We were defining our own set of iterator categories

That meant that an interator that is a `std::random_access_iterator` would not be a `cuda::std::random_access_iterator`

To ensure that we are playing nicely with iterators that use `std::` iterator categories just pull in the standard ones.

Fixes [BUG]: cuda::std::iterator_traits does not expose proper member types in old C++ dialects #1509

* Ensure that we provide our own `std::contiguous_iterator_tag` if needed

* Disable MSVC warning

* Try to appease MSVC2017

* Try harder to unwrap nested `thrust::tuple_of_iterator_references`

We tried to simply unpack the `tuple_of_iterator_references`, however, if it contained nested `tuple_of_iterator_references` then that would break down. Instead recursively apply the unwrapping when possible

* Make `tuple` constructible from `tuple_of_iterator_references`
  • Loading branch information
miscco authored Mar 7, 2024
1 parent 97f59d2 commit ca11fef
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 101 deletions.
13 changes: 13 additions & 0 deletions libcudacxx/include/cuda/std/detail/libcxx/include/tuple
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ template <class... Types>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template<class>
struct __is_tuple_of_iterator_references : false_type
{};

// __tuple_leaf
struct __tuple_leaf_default_constructor_tag {};

Expand Down Expand Up @@ -851,6 +855,15 @@ public:
_Tp...>::template __tuple_like_constraints<_Tuple>,
__invalid_tuple_constraints>;

// Horrible hack to make tuple_of_iterator_references work
template <class _TupleOfIteratorReferences,
__enable_if_t<__is_tuple_of_iterator_references<_TupleOfIteratorReferences>::value, int> = 0,
__enable_if_t<(tuple_size<_TupleOfIteratorReferences>::value == sizeof...(_Tp)), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 tuple(_TupleOfIteratorReferences&& __t)
: tuple(_CUDA_VSTD::forward<_TupleOfIteratorReferences>(__t).template __to_tuple<_Tp...>(
__make_tuple_indices_t<sizeof...(_Tp)>()))
{}

template <
class _Tuple, class _Constraints = __tuple_like_constraints<_Tuple>,
__enable_if_t<!_PackExpandsToThisTuple<_Tuple>::value, int> = 0,
Expand Down
126 changes: 110 additions & 16 deletions thrust/testing/zip_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@

#if _CCCL_STD_VER >= 2011 && !defined(THRUST_LEGACY_GCC)

#include <unittest/unittest.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/transform.h>
#include <thrust/zip_function.h>
# include <thrust/device_vector.h>
# include <thrust/iterator/zip_iterator.h>
# include <thrust/remove.h>
# include <thrust/sort.h>
# include <thrust/transform.h>
# include <thrust/zip_function.h>

#include <iostream>
# include <iostream>

# include <unittest/unittest.h>

using namespace unittest;

struct SumThree
{
template <typename T1, typename T2, typename T3>
__host__ __device__
auto operator()(T1 x, T2 y, T3 z) const
THRUST_DECLTYPE_RETURNS(x + y + z)
__host__ __device__ auto operator()(T1 x, T2 y, T3 z) const THRUST_DECLTYPE_RETURNS(x + y + z)
}; // end SumThree

struct SumThreeTuple
{
template <typename Tuple>
__host__ __device__
auto operator()(Tuple x) const
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
__host__ __device__ auto operator()(Tuple x) const
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
}; // end SumThreeTuple

template <typename T>
Expand All @@ -42,22 +43,22 @@ struct TestZipFunctionTransform
device_vector<T> d_data1 = h_data1;
device_vector<T> d_data2 = h_data2;

host_vector<T> h_result_tuple(n);
host_vector<T> h_result_zip(n);
host_vector<T> h_result_tuple(n);
host_vector<T> h_result_zip(n);
device_vector<T> d_result_zip(n);

// Tuple base case
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
h_result_tuple.begin(),
SumThreeTuple{});
// Zip Function
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
h_result_zip.begin(),
make_zip_function(SumThree{}));
transform(make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin(), d_data2.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end(), d_data2.end())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end(), d_data2.end())),
d_result_zip.begin(),
make_zip_function(SumThree{}));

Expand All @@ -67,4 +68,97 @@ struct TestZipFunctionTransform
};
VariableUnitTest<TestZipFunctionTransform, ThirtyTwoBitTypes> TestZipFunctionTransformInstance;

struct RemovePred
{
__host__ __device__ bool operator()(const thrust::tuple<uint32_t, uint32_t>& ele1, const float&)
{
return thrust::get<0>(ele1) == thrust::get<1>(ele1);
}
};
template <typename T>
struct TestZipFunctionMixed
{
void operator()()
{
thrust::device_vector<uint32_t> vecA{0, 0, 2, 0};
thrust::device_vector<uint32_t> vecB{0, 2, 2, 2};
thrust::device_vector<float> vecC{88.0f, 88.0f, 89.0f, 89.0f};
thrust::device_vector<float> expected{88.0f, 89.0f};

auto inputKeyItBegin =
thrust::make_zip_iterator(thrust::make_zip_iterator(vecA.begin(), vecB.begin()), vecC.begin());
auto endIt =
thrust::remove_if(inputKeyItBegin, inputKeyItBegin + vecA.size(), thrust::make_zip_function(RemovePred{}));
auto numEle = endIt - inputKeyItBegin;
vecA.resize(numEle);
vecB.resize(numEle);
vecC.resize(numEle);

ASSERT_EQUAL(numEle, 2);
ASSERT_EQUAL(vecC, expected);
}
};
SimpleUnitTest<TestZipFunctionMixed, type_list<int, float> > TestZipFunctionMixedInstance;

struct NestedFunctionCall
{
__host__ __device__ bool
operator()(const thrust::tuple<uint32_t, thrust::tuple<thrust::tuple<int, int>, thrust::tuple<int, int>>>& idAndPt)
{
thrust::tuple<thrust::tuple<int, int>, thrust::tuple<int, int>> ele1 = thrust::get<1>(idAndPt);
thrust::tuple<int, int> p1 = thrust::get<0>(ele1);
thrust::tuple<int, int> p2 = thrust::get<1>(ele1);
return thrust::get<0>(p1) == thrust::get<0>(p2) || thrust::get<1>(p1) == thrust::get<1>(p2);
}
};

template <typename T>
struct TestNestedZipFunction
{
void operator()()
{
thrust::device_vector<int> PX{0, 1, 2, 3};
thrust::device_vector<int> PY{0, 1, 2, 2};
thrust::device_vector<uint32_t> SS{0, 1, 2};
thrust::device_vector<uint32_t> ST{1, 2, 3};
thrust::device_vector<float> vecC{88.0f, 88.0f, 89.0f, 89.0f};

auto segIt = thrust::make_zip_iterator(
thrust::make_zip_iterator(thrust::make_permutation_iterator(PX.begin(), SS.begin()),
thrust::make_permutation_iterator(PY.begin(), SS.begin())),
thrust::make_zip_iterator(thrust::make_permutation_iterator(PX.begin(), ST.begin()),
thrust::make_permutation_iterator(PY.begin(), ST.begin())));
auto idAndSegIt = thrust::make_zip_iterator(thrust::make_counting_iterator(0u), segIt);

thrust::device_vector<bool> isMH{false, false, false};
thrust::device_vector<bool> expected{false, false, true};
thrust::transform(idAndSegIt, idAndSegIt + SS.size(), isMH.begin(), NestedFunctionCall{});
ASSERT_EQUAL(isMH, expected);
}
};
SimpleUnitTest<TestNestedZipFunction, type_list<int, float> > TestNestedZipFunctionInstance;

struct SortPred {
__device__ __forceinline__
bool operator()(const thrust::tuple<thrust::tuple<int, int>, int>& a,
const thrust::tuple<thrust::tuple<int, int>, int>& b) {
return thrust::get<1>(a) < thrust::get<1>(b);
}
};
template <typename T>
struct TestNestedZipFunction2
{
void operator()()
{
thrust::device_vector<int> A(5);
thrust::device_vector<int> B(5);
thrust::device_vector<int> C(5);
auto n = A.size();

auto tupleIt = thrust::make_zip_iterator(cuda::std::begin(A), cuda::std::begin(B));
auto nestedTupleIt = thrust::make_zip_iterator(tupleIt, cuda::std::begin(C));
thrust::sort(nestedTupleIt, nestedTupleIt + n, SortPred{});
}
};
SimpleUnitTest<TestNestedZipFunction2, type_list<int, float> > TestNestedZipFunctionInstance2;
#endif // _CCCL_STD_VER
Loading

0 comments on commit ca11fef

Please sign in to comment.