Skip to content

Commit

Permalink
Ensure that {cr}begin works with types that pull in namespace std via…
Browse files Browse the repository at this point in the history
… ADL

Those types must be host only so we can safely SFINAE away our own free functions that would be ambiguous
  • Loading branch information
miscco committed May 1, 2024
1 parent 465d7f3 commit 3206dce
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 49 deletions.
66 changes: 60 additions & 6 deletions libcudacxx/include/cuda/std/__iterator/access.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,96 @@
# pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/__utility/declval.h>
#include <cuda/std/cstddef>
#include <cuda/std/initializer_list>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

// We need to detect whether there is already a free function begin / end that would end up being ambiguous.
// This can happen when a type pulls in both namespace std and namespace cuda::std via ADL.
// In that case we are always safe to just not do anything because that type must be host only.
namespace __detect_ambiguous_access
{
struct __not_ambiguous
{};
template <class _Tp, size_t _Np>
_LIBCUDACXX_INLINE_VISIBILITY auto __begin(_Tp (&__array)[_Np]) -> decltype(begin(__array));
template <class _Cont>
_LIBCUDACXX_INLINE_VISIBILITY auto __begin(_Cont& __c) -> decltype(begin(__c));
template <class _Cont>
_LIBCUDACXX_INLINE_VISIBILITY auto __begin(const _Cont& __c) -> decltype(begin(__c));
_LIBCUDACXX_INLINE_VISIBILITY auto __begin(...) -> __not_ambiguous;
template <class _Cont>
struct __begin_not_ambiguous
: is_same<decltype(__detect_ambiguous_access::__begin(_CUDA_VSTD::declval<_Cont>())), __not_ambiguous>
{};
template <class _Tp>
struct __begin_not_ambiguous<initializer_list<_Tp>&> : true_type
{};
template <class _Tp>
struct __begin_not_ambiguous<const initializer_list<_Tp>&> : true_type
{};

template <class _Tp, size_t _Np>
_LIBCUDACXX_INLINE_VISIBILITY auto __end(_Tp (&__array)[_Np]) -> decltype(end(__array));
template <class _Cont>
_LIBCUDACXX_INLINE_VISIBILITY auto __end(_Cont& __c) -> decltype(end(__c));
template <class _Cont>
_LIBCUDACXX_INLINE_VISIBILITY auto __end(const _Cont& __c) -> decltype(end(__c));
_LIBCUDACXX_INLINE_VISIBILITY auto __end(...) -> __not_ambiguous;
template <class _Cont>
struct __end_not_ambiguous
: is_same<decltype(__detect_ambiguous_access::__end(_CUDA_VSTD::declval<_Cont>())), __not_ambiguous>
{};
template <class _Tp>
struct __end_not_ambiguous<initializer_list<_Tp>&> : true_type
{};
template <class _Tp>
struct __end_not_ambiguous<const initializer_list<_Tp>&> : true_type
{};
} // namespace __detect_ambiguous_access

template <class _Tp, size_t _Np>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 _Tp* begin(_Tp (&__array)[_Np])
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto begin(_Tp (&__array)[_Np])
-> __enable_if_t<__detect_ambiguous_access::__begin_not_ambiguous<_Tp (&)[_Np]>::value, _Tp*>
{
return __array;
}

template <class _Tp, size_t _Np>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 _Tp* end(_Tp (&__array)[_Np])
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto end(_Tp (&__array)[_Np])
-> __enable_if_t<__detect_ambiguous_access::__end_not_ambiguous<_Tp (&)[_Np]>::value, _Tp*>
{
return __array + _Np;
}

template <class _Cp>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto begin(_Cp& __c) -> decltype(__c.begin())
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto begin(_Cp& __c)
-> __enable_if_t<__detect_ambiguous_access::__begin_not_ambiguous<_Cp&>::value, decltype(__c.begin())>
{
return __c.begin();
}

template <class _Cp>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto begin(const _Cp& __c) -> decltype(__c.begin())
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto begin(const _Cp& __c)
-> __enable_if_t<__detect_ambiguous_access::__begin_not_ambiguous<const _Cp&>::value, decltype(__c.begin())>
{
return __c.begin();
}

template <class _Cp>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto end(_Cp& __c) -> decltype(__c.end())
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto end(_Cp& __c)
-> __enable_if_t<__detect_ambiguous_access::__end_not_ambiguous<_Cp&>::value, decltype(__c.end())>
{
return __c.end();
}

template <class _Cp>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto end(const _Cp& __c) -> decltype(__c.end())
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 auto end(const _Cp& __c)
-> __enable_if_t<__detect_ambiguous_access::__end_not_ambiguous<const _Cp&>::value, decltype(__c.end())>
{
return __c.end();
}
Expand Down
55 changes: 47 additions & 8 deletions libcudacxx/include/cuda/std/__iterator/reverse_access.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,51 @@

_LIBCUDACXX_BEGIN_NAMESPACE_STD

#if _CCCL_STD_VER > 2011
#if _CCCL_STD_VER >= 2014

// We need to detect whether there is already a free function begin / end that would end up being ambiguous.
// This can happen when a type pulls in both namespace std and namespace cuda::std via ADL.
// In that case we are always safe to just not do anything because that type must be host only.
// The exception is initializer list, which must be defined namespace std, so we need to special case that
namespace __detect_ambiguous_raccess
{
struct __not_ambiguous
{};
template <class _Tp, size_t _Np>
_LIBCUDACXX_INLINE_VISIBILITY auto __rbegin(_Tp (&__array)[_Np]) -> decltype(rbegin(__array));
template <class _Cont>
_LIBCUDACXX_INLINE_VISIBILITY auto __rbegin(_Cont& __c) -> decltype(rbegin(__c));
template <class _Cont>
_LIBCUDACXX_INLINE_VISIBILITY auto __rbegin(const _Cont& __c) -> decltype(rbegin(__c));
_LIBCUDACXX_INLINE_VISIBILITY auto __rbegin(...) -> __not_ambiguous;
template <class _Cont>
struct __rbegin_not_ambiguous
: is_same<decltype(__detect_ambiguous_raccess::__rbegin(_CUDA_VSTD::declval<_Cont>())), __not_ambiguous>
{};

template <class _Tp, size_t _Np>
_LIBCUDACXX_INLINE_VISIBILITY auto __rend(_Tp (&__array)[_Np]) -> decltype(rend(__array));
template <class _Cont>
_LIBCUDACXX_INLINE_VISIBILITY auto __rend(_Cont& __c) -> decltype(rend(__c));
template <class _Cont>
_LIBCUDACXX_INLINE_VISIBILITY auto __rend(const _Cont& __c) -> decltype(rend(__c));
_LIBCUDACXX_INLINE_VISIBILITY auto __rend(...) -> __not_ambiguous;
template <class _Cont>
struct __rend_not_ambiguous
: is_same<decltype(__detect_ambiguous_raccess::__rend(_CUDA_VSTD::declval<_Cont>())), __not_ambiguous>
{};
} // namespace __detect_ambiguous_raccess

template <class _Tp, size_t _Np>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 reverse_iterator<_Tp*> rbegin(_Tp (&__array)[_Np])
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rbegin(_Tp (&__array)[_Np])
-> __enable_if_t<__detect_ambiguous_raccess::__rbegin_not_ambiguous<_Tp (&)[_Np]>::value, reverse_iterator<_Tp*>>
{
return reverse_iterator<_Tp*>(__array + _Np);
}

template <class _Tp, size_t _Np>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 reverse_iterator<_Tp*> rend(_Tp (&__array)[_Np])
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rend(_Tp (&__array)[_Np])
-> __enable_if_t<__detect_ambiguous_raccess::__rend_not_ambiguous<_Tp (&)[_Np]>::value, reverse_iterator<_Tp*>>
{
return reverse_iterator<_Tp*>(__array);
}
Expand All @@ -54,25 +89,29 @@ _LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 reverse_iterator<const _Ep*>
}

template <class _Cp>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rbegin(_Cp& __c) -> decltype(__c.rbegin())
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rbegin(_Cp& __c)
-> __enable_if_t<__detect_ambiguous_raccess::__rbegin_not_ambiguous<_Cp&>::value, decltype(__c.rbegin())>
{
return __c.rbegin();
}

template <class _Cp>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rbegin(const _Cp& __c) -> decltype(__c.rbegin())
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rbegin(const _Cp& __c)
-> __enable_if_t<__detect_ambiguous_raccess::__rbegin_not_ambiguous<const _Cp&>::value, decltype(__c.rbegin())>
{
return __c.rbegin();
}

template <class _Cp>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rend(_Cp& __c) -> decltype(__c.rend())
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rend(_Cp& __c)
-> __enable_if_t<__detect_ambiguous_raccess::__rend_not_ambiguous<_Cp&>::value, decltype(__c.rend())>
{
return __c.rend();
}

template <class _Cp>
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rend(const _Cp& __c) -> decltype(__c.rend())
_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto rend(const _Cp& __c)
-> __enable_if_t<__detect_ambiguous_raccess::__rend_not_ambiguous<const _Cp&>::value, decltype(__c.rend())>
{
return __c.rend();
}
Expand All @@ -89,7 +128,7 @@ _LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX17 auto crend(const _Cp& __c) -
return _CUDA_VSTD::rend(__c);
}

#endif // _CCCL_STD_VER > 2011
#endif // _CCCL_STD_VER >= 2014

_LIBCUDACXX_END_NAMESPACE_STD

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
#endif
#include <cuda/std/initializer_list>

#if !defined(TEST_COMPILER_NVRTC)
# include <iterator>
# include <utility>
#endif // !TEST_COMPILER_NVRTC

#include "test_macros.h"

// cuda::std::array is explicitly allowed to be initialized with A a = { init-list };.
Expand All @@ -54,7 +59,7 @@ __host__ __device__ void test_const_container(const C& c, typename C::value_type
assert(*cuda::std::begin(c) == val);
assert(cuda::std::begin(c) != c.end());
assert(cuda::std::end(c) == c.end());
#if TEST_STD_VER > 2011
#if TEST_STD_VER >= 2014
assert(cuda::std::cbegin(c) == c.cbegin());
assert(cuda::std::cbegin(c) != c.cend());
assert(cuda::std::cend(c) == c.cend());
Expand All @@ -64,7 +69,7 @@ __host__ __device__ void test_const_container(const C& c, typename C::value_type
assert(cuda::std::crbegin(c) == c.crbegin());
assert(cuda::std::crbegin(c) != c.crend());
assert(cuda::std::crend(c) == c.crend());
#endif
#endif // TEST_STD_VER >= 2014
}

template <typename T>
Expand All @@ -74,19 +79,11 @@ __host__ __device__ void test_const_container(const cuda::std::initializer_list<
assert(*cuda::std::begin(c) == val);
assert(cuda::std::begin(c) != c.end());
assert(cuda::std::end(c) == c.end());
#if TEST_STD_VER > 2011
// initializer_list doesn't have cbegin/cend/rbegin/rend
// but cuda::std::cbegin(),etc work (b/c they're general fn templates)
// assert ( cuda::std::cbegin(c) == c.cbegin());
// assert ( cuda::std::cbegin(c) != c.cend());
// assert ( cuda::std::cend(c) == c.cend());
// assert ( cuda::std::rbegin(c) == c.rbegin());
// assert ( cuda::std::rbegin(c) != c.rend());
// assert ( cuda::std::rend(c) == c.rend());
// assert ( cuda::std::crbegin(c) == c.crbegin());
// assert ( cuda::std::crbegin(c) != c.crend());
// assert ( cuda::std::crend(c) == c.crend());
#endif
#if TEST_STD_VER >= 2014
assert(cuda::std::cbegin(c) != cuda::std::cend(c));
assert(cuda::std::rbegin(c) != cuda::std::rend(c));
assert(cuda::std::crbegin(c) != cuda::std::crend(c));
#endif // TEST_STD_VER >= 2014
}

template <typename C>
Expand All @@ -96,7 +93,7 @@ __host__ __device__ void test_container(C& c, typename C::value_type val)
assert(*cuda::std::begin(c) == val);
assert(cuda::std::begin(c) != c.end());
assert(cuda::std::end(c) == c.end());
#if TEST_STD_VER > 2011
#if TEST_STD_VER >= 2014
assert(cuda::std::cbegin(c) == c.cbegin());
assert(cuda::std::cbegin(c) != c.cend());
assert(cuda::std::cend(c) == c.cend());
Expand All @@ -106,7 +103,7 @@ __host__ __device__ void test_container(C& c, typename C::value_type val)
assert(cuda::std::crbegin(c) == c.crbegin());
assert(cuda::std::crbegin(c) != c.crend());
assert(cuda::std::crend(c) == c.crend());
#endif
#endif // TEST_STD_VER >= 2014
}

template <typename T>
Expand All @@ -116,18 +113,11 @@ __host__ __device__ void test_container(cuda::std::initializer_list<T>& c, T val
assert(*cuda::std::begin(c) == val);
assert(cuda::std::begin(c) != c.end());
assert(cuda::std::end(c) == c.end());
#if TEST_STD_VER > 2011
// initializer_list doesn't have cbegin/cend/rbegin/rend
// assert ( cuda::std::cbegin(c) == c.cbegin());
// assert ( cuda::std::cbegin(c) != c.cend());
// assert ( cuda::std::cend(c) == c.cend());
// assert ( cuda::std::rbegin(c) == c.rbegin());
// assert ( cuda::std::rbegin(c) != c.rend());
// assert ( cuda::std::rend(c) == c.rend());
// assert ( cuda::std::crbegin(c) == c.crbegin());
// assert ( cuda::std::crbegin(c) != c.crend());
// assert ( cuda::std::crend(c) == c.crend());
#endif
#if TEST_STD_VER >= 2014
assert(cuda::std::cbegin(c) != cuda::std::cend(c));
assert(cuda::std::rbegin(c) != cuda::std::rend(c));
assert(cuda::std::crbegin(c) != cuda::std::crend(c));
#endif // TEST_STD_VER >= 2014
}

template <typename T, size_t Sz>
Expand All @@ -137,12 +127,12 @@ __host__ __device__ void test_const_array(const T (&array)[Sz])
assert(*cuda::std::begin(array) == array[0]);
assert(cuda::std::begin(array) != cuda::std::end(array));
assert(cuda::std::end(array) == array + Sz);
#if TEST_STD_VER > 2011
#if TEST_STD_VER >= 2014
assert(cuda::std::cbegin(array) == array);
assert(*cuda::std::cbegin(array) == array[0]);
assert(cuda::std::cbegin(array) != cuda::std::cend(array));
assert(cuda::std::cend(array) == array + Sz);
#endif
#endif // TEST_STD_VER >= 2014
}

STATIC_TEST_GLOBAL_VAR TEST_CONSTEXPR_GLOBAL int global_array[]{1, 2, 3};
Expand All @@ -152,6 +142,52 @@ STATIC_TEST_GLOBAL_VAR TEST_CONSTEXPR_GLOBAL int global_const_array[] = {0, 1, 2
# endif // nvcc > 11.2
#endif // TEST_STD_VER > 2014

__host__ __device__ void test_ambiguous_std()
{
#if !defined(TEST_COMPILER_NVRTC)
// clang-format off
NV_IF_TARGET(NV_IS_HOST, (
{
cuda::std::array<::std::pair<int, int>, 10> c = {};
assert(begin(c) == c.begin());
assert(begin(c) != c.end());
assert(end(c) == c.end());
}

{
cuda::std::initializer_list<::std::pair<int, int>> init = {{1, 2}};
assert(begin(init) == init.begin());
assert(begin(init) != init.end());
assert(end(init) == init.end());
}
))
#if TEST_STD_VER >= 2014
NV_IF_TARGET(NV_IS_HOST, (
{
cuda::std::array<::std::pair<int, int>, 10> c = {};
assert(cbegin(c) == c.cbegin());
assert(cbegin(c) != c.cend());
assert(cend(c) == c.cend());
assert(rbegin(c) == c.rbegin());
assert(rbegin(c) != c.rend());
assert(rend(c) == c.rend());
assert(crbegin(c) == c.crbegin());
assert(crbegin(c) != c.crend());
assert(crend(c) == c.crend());
}

{
cuda::std::initializer_list<::std::pair<int, int>> init = {{1, 2}};
assert(cbegin(init) != cend(init));
assert(rbegin(init) != rend(init));
assert(crbegin(init) != crend(init));
}
))
#endif // TEST_STD_VER >= 2014
// clang-format on
#endif // !TEST_COMPILER_NVRTC
}

int main(int, char**)
{
#if defined(_LIBCUDACXX_HAS_VECTOR)
Expand Down Expand Up @@ -185,13 +221,13 @@ int main(int, char**)
test_const_container(il, 4);

test_const_array(global_array);
#if TEST_STD_VER > 2011
#if TEST_STD_VER >= 2014
constexpr const int* b = cuda::std::cbegin(global_array);
constexpr const int* e = cuda::std::cend(global_array);
static_assert(e - b == 3, "");
#endif
#endif // TEST_STD_VER >= 2014

#if TEST_STD_VER > 2014
#if TEST_STD_VER >= 2017
{
typedef cuda::std::array<int, 5> C;
constexpr const C local_const_array{0, 1, 2, 3, 4};
Expand Down Expand Up @@ -228,7 +264,9 @@ int main(int, char**)
static_assert(*cuda::std::crbegin(global_const_array) == 4, "");
}
# endif // nvcc > 11.2
#endif // TEST_STD_VER > 2014
#endif // TEST_STD_VER >= 2017

test_ambiguous_std();

return 0;
}

0 comments on commit 3206dce

Please sign in to comment.