From ef2d09b101bcdbabc994bcd526be58541c3e1242 Mon Sep 17 00:00:00 2001 From: Michael Schellenberger Costa Date: Fri, 15 Dec 2023 16:27:38 -0800 Subject: [PATCH] Add missing overloads for thrust::pow (#1222) Also add proper type checks for all of those overloads so that we can ensure that we are --- thrust/testing/complex.cu | 34 ++++++++++++++++++++++++++++++++-- thrust/thrust/complex.h | 21 ++++++++++++++++----- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/thrust/testing/complex.cu b/thrust/testing/complex.cu index 2dfee5500cf..7c43c57ab1b 100644 --- a/thrust/testing/complex.cu +++ b/thrust/testing/complex.cu @@ -449,17 +449,18 @@ struct TestComplexBasicArithmetic // Test the basic arithmetic functions against std ASSERT_ALMOST_EQUAL(thrust::abs(a), std::abs(b)); - ASSERT_ALMOST_EQUAL(thrust::arg(a), std::arg(b)); - ASSERT_ALMOST_EQUAL(thrust::norm(a), std::norm(b)); ASSERT_EQUAL(thrust::conj(a), std::conj(b)); + static_assert(cuda::std::is_same, decltype(thrust::conj(a))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::polar(data[0], data[1]), std::polar(data[0], data[1])); + static_assert(cuda::std::is_same, decltype(thrust::polar(data[0], data[1]))>::value, ""); // random_samples does not seem to produce infinities so proj(z) == z ASSERT_EQUAL(thrust::proj(a), a); + static_assert(cuda::std::is_same, decltype(thrust::proj(a))>::value, ""); } }; SimpleUnitTest TestComplexBasicArithmeticInstance; @@ -556,6 +557,9 @@ struct TestComplexExponentialFunctions ASSERT_ALMOST_EQUAL(thrust::exp(a), std::exp(b)); ASSERT_ALMOST_EQUAL(thrust::log(a), std::log(b)); ASSERT_ALMOST_EQUAL(thrust::log10(a), std::log10(b)); + static_assert(cuda::std::is_same, decltype(thrust::exp(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::log(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::log10(a))>::value, ""); } }; SimpleUnitTest @@ -575,16 +579,24 @@ struct TestComplexPowerFunctions const std::complex b_std(b_thrust); ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std)); + static_assert(cuda::std::is_same, decltype(thrust::pow(a_thrust, b_thrust))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real())); + static_assert(cuda::std::is_same, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std)); + static_assert(cuda::std::is_same, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, ""); + + ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, 4), std::pow(a_std, 4)); + static_assert(cuda::std::is_same, decltype(thrust::pow(a_thrust, 4))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::sqrt(a_thrust), std::sqrt(a_std)); + static_assert(cuda::std::is_same, decltype(thrust::sqrt(a_thrust))>::value, ""); } // Test power functions with promoted types. { using T0 = T; using T1 = other_floating_point_type_t; + using promoted = typename thrust::detail::promoted_numerical_type::type; thrust::host_vector data = unittest::random_samples(4); @@ -594,11 +606,17 @@ struct TestComplexPowerFunctions const std::complex b_std(data[2], data[3]); ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std)); + static_assert(cuda::std::is_same, decltype(thrust::pow(a_thrust, b_thrust))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust), std::pow(b_std, a_std)); + static_assert(cuda::std::is_same, decltype(thrust::pow(b_thrust, a_thrust))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real())); + static_assert(cuda::std::is_same, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust.real()), std::pow(b_std, a_std.real())); + static_assert(cuda::std::is_same, decltype(thrust::pow(b_thrust, a_thrust.real()))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std)); + static_assert(cuda::std::is_same, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust.real(), a_thrust), std::pow(b_std.real(), a_std)); + static_assert(cuda::std::is_same, decltype(thrust::pow(b_thrust.real(), a_thrust))>::value, ""); } } }; @@ -617,20 +635,32 @@ struct TestComplexTrigonometricFunctions ASSERT_ALMOST_EQUAL(thrust::cos(a), std::cos(c)); ASSERT_ALMOST_EQUAL(thrust::sin(a), std::sin(c)); ASSERT_ALMOST_EQUAL(thrust::tan(a), std::tan(c)); + static_assert(cuda::std::is_same, decltype(thrust::cos(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::sin(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::tan(a))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::cosh(a), std::cosh(c)); ASSERT_ALMOST_EQUAL(thrust::sinh(a), std::sinh(c)); ASSERT_ALMOST_EQUAL(thrust::tanh(a), std::tanh(c)); + static_assert(cuda::std::is_same, decltype(thrust::cosh(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::sinh(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::tanh(a))>::value, ""); #if _CCCL_STD_VER >= 2011 ASSERT_ALMOST_EQUAL(thrust::acos(a), std::acos(c)); ASSERT_ALMOST_EQUAL(thrust::asin(a), std::asin(c)); ASSERT_ALMOST_EQUAL(thrust::atan(a), std::atan(c)); + static_assert(cuda::std::is_same, decltype(thrust::acos(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::asin(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::atan(a))>::value, ""); ASSERT_ALMOST_EQUAL(thrust::acosh(a), std::acosh(c)); ASSERT_ALMOST_EQUAL(thrust::asinh(a), std::asinh(c)); ASSERT_ALMOST_EQUAL(thrust::atanh(a), std::atanh(c)); + static_assert(cuda::std::is_same, decltype(thrust::acosh(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::asinh(a))>::value, ""); + static_assert(cuda::std::is_same, decltype(thrust::atanh(a))>::value, ""); #endif } diff --git a/thrust/thrust/complex.h b/thrust/thrust/complex.h index a7e7f909742..08f3e74a1d8 100644 --- a/thrust/thrust/complex.h +++ b/thrust/thrust/complex.h @@ -474,7 +474,8 @@ using ::cuda::std::proj; using ::cuda::std::exp; using ::cuda::std::log; using ::cuda::std::log10; -using ::cuda::std::pow; +// pow always returns a complex. +// using ::cuda::std::pow; using ::cuda::std::sqrt; using ::cuda::std::acos; @@ -516,15 +517,25 @@ template _CCCL_HOST_DEVICE complex log10(const complex& c) { return static_cast>(::cuda::std::log10(c)); } -template -_CCCL_HOST_DEVICE complex pow(const complex& c) { - return static_cast>(::cuda::std::pow(c)); +template +_CCCL_HOST_DEVICE complex::type> +pow(const complex& x, const complex& y) { + return static_cast::type>>(::cuda::std::pow(x, y)); +} +template::value, int> = 0> +_CCCL_HOST_DEVICE complex::type> +pow(const complex& x, const T1& y) { + return static_cast::type>>(::cuda::std::pow(x, y)); +} +template::value, int> = 0> +_CCCL_HOST_DEVICE complex::type> +pow(const T0& x, const complex& y) { + return static_cast::type>>(::cuda::std::pow(x, y)); } template _CCCL_HOST_DEVICE complex sqrt(const complex& c) { return static_cast>(::cuda::std::sqrt(c)); } - template _CCCL_HOST_DEVICE complex acos(const complex& c) { return static_cast>(::cuda::std::acos(c));