Skip to content

Commit

Permalink
Add tuple protocol to cuda::std::complex from C++26 (#2882)
Browse files Browse the repository at this point in the history
  • Loading branch information
davebayer authored Nov 22, 2024
1 parent 4ae70bb commit 83d180f
Show file tree
Hide file tree
Showing 12 changed files with 500 additions and 0 deletions.
32 changes: 32 additions & 0 deletions libcudacxx/include/cuda/std/__complex/nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ _CCCL_DIAG_POP

# include <cuda/std/__complex/vector_support.h>
# include <cuda/std/__cuda/cmath_nvbf16.h>
# include <cuda/std/__fwd/get.h>
# include <cuda/std/__type_traits/enable_if.h>
# include <cuda/std/__type_traits/integral_constant.h>
# include <cuda/std/__type_traits/is_constructible.h>
Expand Down Expand Up @@ -112,6 +113,9 @@ class _CCCL_TYPE_VISIBILITY_DEFAULT _CCCL_ALIGNAS(alignof(__nv_bfloat162)) compl
template <class _Up>
friend class complex;

template <class _Up>
friend struct __get_complex_impl;

public:
using value_type = __nv_bfloat16;

Expand Down Expand Up @@ -295,6 +299,34 @@ _LIBCUDACXX_HIDE_FROM_ABI complex<__nv_bfloat16> acos(const complex<__nv_bfloat1
return complex<__nv_bfloat16>{_CUDA_VSTD::acos(complex<float>{__x})};
}

template <>
struct __get_complex_impl<__nv_bfloat16>
{
template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16& get(complex<__nv_bfloat16>& __z) noexcept
{
return (_Index == 0) ? __z.__repr_.x : __z.__repr_.y;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16&& get(complex<__nv_bfloat16>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__repr_.x : __z.__repr_.y);
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const __nv_bfloat16& get(const complex<__nv_bfloat16>& __z) noexcept
{
return (_Index == 0) ? __z.__repr_.x : __z.__repr_.y;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const __nv_bfloat16&& get(const complex<__nv_bfloat16>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__repr_.x : __z.__repr_.y);
}
};

# if !_CCCL_COMPILER(NVRTC)
template <class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>&
Expand Down
32 changes: 32 additions & 0 deletions libcudacxx/include/cuda/std/__complex/nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

# include <cuda/std/__complex/vector_support.h>
# include <cuda/std/__cuda/cmath_nvfp16.h>
# include <cuda/std/__fwd/get.h>
# include <cuda/std/__type_traits/enable_if.h>
# include <cuda/std/__type_traits/integral_constant.h>
# include <cuda/std/__type_traits/is_constructible.h>
Expand Down Expand Up @@ -109,6 +110,9 @@ class _CCCL_TYPE_VISIBILITY_DEFAULT _CCCL_ALIGNAS(alignof(__half2)) complex<__ha
template <class _Up>
friend class complex;

template <class _Up>
friend struct __get_complex_impl;

public:
using value_type = __half;

Expand Down Expand Up @@ -292,6 +296,34 @@ _LIBCUDACXX_HIDE_FROM_ABI complex<__half> acos(const complex<__half>& __x)
return complex<__half>{_CUDA_VSTD::acos(complex<float>{__x})};
}

template <>
struct __get_complex_impl<__half>
{
template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half& get(complex<__half>& __z) noexcept
{
return (_Index == 0) ? __z.__repr_.x : __z.__repr_.y;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half&& get(complex<__half>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__repr_.x : __z.__repr_.y);
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const __half& get(const complex<__half>& __z) noexcept
{
return (_Index == 0) ? __z.__repr_.x : __z.__repr_.y;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const __half&& get(const complex<__half>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__repr_.x : __z.__repr_.y);
}
};

# if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION) && !_CCCL_COMPILER(NVRTC)
template <class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>& operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<__half>& __x)
Expand Down
30 changes: 30 additions & 0 deletions libcudacxx/include/cuda/std/__fwd/complex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023-24 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___FWD_COMPLEX_H
#define _LIBCUDACXX___FWD_COMPLEX_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
class _CCCL_TYPE_VISIBILITY_DEFAULT complex;

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___FWD_COMPLEX_H
13 changes: 13 additions & 0 deletions libcudacxx/include/cuda/std/__fwd/get.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <cuda/std/__concepts/copyable.h>
#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/complex.h>
#include <cuda/std/__fwd/pair.h>
#include <cuda/std/__fwd/subrange.h>
#include <cuda/std/__fwd/tuple.h>
Expand Down Expand Up @@ -70,6 +71,18 @@ _LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14 _Tp&& get(array<_Tp, _Size>&&) n
template <size_t _Ip, class _Tp, size_t _Size>
_LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14 const _Tp&& get(const array<_Tp, _Size>&&) noexcept;

template <size_t _Ip, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp& get(complex<_Tp>&) noexcept;

template <size_t _Ip, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp&& get(complex<_Tp>&&) noexcept;

template <size_t _Ip, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp& get(const complex<_Tp>&) noexcept;

template <size_t _Ip, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp&& get(const complex<_Tp>&&) noexcept;

_LIBCUDACXX_END_NAMESPACE_STD

#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017)
Expand Down
9 changes: 9 additions & 0 deletions libcudacxx/include/cuda/std/__tuple_dir/structured_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wmismatched-tags")
#endif // !_CCCL_COMPILER(NVRTC)

#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/complex.h>
#include <cuda/std/__fwd/pair.h>
#include <cuda/std/__fwd/subrange.h>
#include <cuda/std/__fwd/tuple.h>
Expand Down Expand Up @@ -87,6 +88,14 @@ struct tuple_element<_Ip, const volatile _CUDA_VSTD::array<_Tp, _Size>>
: _CUDA_VSTD::tuple_element<_Ip, const volatile _CUDA_VSTD::array<_Tp, _Size>>
{};

template <class _Tp>
struct tuple_size<_CUDA_VSTD::complex<_Tp>> : _CUDA_VSTD::tuple_size<_CUDA_VSTD::complex<_Tp>>
{};

template <size_t _Ip, class _Tp>
struct tuple_element<_Ip, _CUDA_VSTD::complex<_Tp>> : _CUDA_VSTD::tuple_element<_Ip, _CUDA_VSTD::complex<_Tp>>
{};

template <class _Tp, class _Up>
struct tuple_size<_CUDA_VSTD::pair<_Tp, _Up>> : _CUDA_VSTD::tuple_size<_CUDA_VSTD::pair<_Tp, _Up>>
{};
Expand Down
5 changes: 5 additions & 0 deletions libcudacxx/include/cuda/std/__tuple_dir/tuple_like.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#endif // no system header

#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/complex.h>
#include <cuda/std/__fwd/pair.h>
#include <cuda/std/__fwd/subrange.h>
#include <cuda/std/__fwd/tuple.h>
Expand Down Expand Up @@ -56,6 +57,10 @@ template <class _Tp, size_t _Size>
struct __tuple_like<array<_Tp, _Size>> : true_type
{};

template <class _Tp>
struct __tuple_like<complex<_Tp>> : true_type
{};

#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017)
template <class _Ip, class _Sp, _CUDA_VRANGES::subrange_kind _Kp>
struct __tuple_like<_CUDA_VRANGES::subrange<_Ip, _Sp, _Kp>> : true_type
Expand Down
5 changes: 5 additions & 0 deletions libcudacxx/include/cuda/std/__tuple_dir/tuple_like_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#endif // no system header

#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/complex.h>
#include <cuda/std/__fwd/pair.h>
#include <cuda/std/__fwd/tuple.h>
#include <cuda/std/__tuple_dir/tuple_types.h>
Expand Down Expand Up @@ -55,6 +56,10 @@ template <class _Tp, size_t _Size>
struct __tuple_like_ext<array<_Tp, _Size>> : true_type
{};

template <class _Tp>
struct __tuple_like_ext<complex<_Tp>> : true_type
{};

template <class... _Tp>
struct __tuple_like_ext<__tuple_types<_Tp...>> : true_type
{};
Expand Down
74 changes: 74 additions & 0 deletions libcudacxx/include/cuda/std/detail/libcxx/include/complex
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ template<class T> complex<T> tanh (const complex<T>&);
#endif // no system header

#include <cuda/std/__complex/vector_support.h>
#include <cuda/std/__fwd/get.h>
#include <cuda/std/__tuple_dir/tuple_element.h>
#include <cuda/std/__tuple_dir/tuple_size.h>
#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_constructible.h>
#include <cuda/std/__type_traits/is_floating_point.h>
Expand Down Expand Up @@ -286,6 +289,9 @@ class _CCCL_TYPE_VISIBILITY_DEFAULT _LIBCUDACXX_COMPLEX_ALIGNAS complex
template <class _Up>
friend class complex;

template <class _Up>
friend struct __get_complex_impl;

public:
using value_type = _Tp;

Expand Down Expand Up @@ -1418,6 +1424,74 @@ _LIBCUDACXX_HIDE_FROM_ABI complex<_Tp> tan(const complex<_Tp>& __x)
return complex<_Tp>(__z.imag(), -__z.real());
}

template <class _Tp>
struct tuple_size<complex<_Tp>> : _CUDA_VSTD::integral_constant<size_t, 2>
{};

template <size_t _Index, class _Tp>
struct tuple_element<_Index, complex<_Tp>> : _CUDA_VSTD::enable_if < _Index<2, _Tp>
{};

template <class _Tp>
struct __get_complex_impl
{
template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp& get(complex<_Tp>& __z) noexcept
{
return (_Index == 0) ? __z.__re_ : __z.__im_;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp&& get(complex<_Tp>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__re_ : __z.__im_);
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp& get(const complex<_Tp>& __z) noexcept
{
return (_Index == 0) ? __z.__re_ : __z.__im_;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp&& get(const complex<_Tp>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__re_ : __z.__im_);
}
};

template <size_t _Index, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp& get(complex<_Tp>& __z) noexcept
{
static_assert(_Index < 2, "Index value is out of range");

return __get_complex_impl<_Tp>::template get<_Index>(__z);
}

template <size_t _Index, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp&& get(complex<_Tp>&& __z) noexcept
{
static_assert(_Index < 2, "Index value is out of range");

return __get_complex_impl<_Tp>::template get<_Index>(_CUDA_VSTD::move(__z));
}

template <size_t _Index, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp& get(const complex<_Tp>& __z) noexcept
{
static_assert(_Index < 2, "Index value is out of range");

return __get_complex_impl<_Tp>::template get<_Index>(__z);
}

template <size_t _Index, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp&& get(const complex<_Tp>&& __z) noexcept
{
static_assert(_Index < 2, "Index value is out of range");

return __get_complex_impl<_Tp>::template get<_Index>(_CUDA_VSTD::move(__z));
}

#if !_CCCL_COMPILER(NVRTC)
template <class _Tp, class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>& operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<_Tp>& __x)
Expand Down
1 change: 1 addition & 0 deletions libcudacxx/include/cuda/std/version
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#endif // !_CCCL_COMPILER(NVRTC)

#define __cccl_lib_to_underlying 202102L
// #define __cpp_lib_tuple_like 202311L // P2819R2 is implemented, but P2165R4 is not yet

#if _CCCL_STD_VER >= 2014
# define __cccl_lib_bit_cast 201806L
Expand Down
Loading

0 comments on commit 83d180f

Please sign in to comment.