diff --git a/sycl/doc/syclcompat/README.md b/sycl/doc/syclcompat/README.md index c1c03e05dfd0d..bb55cd96fc875 100644 --- a/sycl/doc/syclcompat/README.md +++ b/sycl/doc/syclcompat/README.md @@ -53,6 +53,7 @@ Specifically, this library depends on the following SYCL extensions: If available, the following extensions extend SYCLcompat functionality: * [sycl_ext_intel_device_info](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/supported/sycl_ext_intel_device_info.md) \[Optional\] +* [sycl_ext_oneapi_bfloat16_math_functions](../extensions/experimental/sycl_ext_oneapi_bfloat16_math_functions.asciidoc) \[Optional\] ## Usage @@ -1275,6 +1276,10 @@ static kernel_function_info get_kernel_function_info(const void *function); length. `syclcompat::length` provides a templated version that wraps over `sycl::length`. +`compare`, `unordered_compare`, `compare_both`, `unordered_compare_both`, +`compare_mask`, and `unordered_compare_mask`, handle both ordered and unordered +comparisons. + `vectorized_max` and `vectorized_min` are binary operations returning the max/min of two arguments, where each argument is treated as a `sycl::vec` type. `vectorized_isgreater` performs elementwise `isgreater`, treating each argument @@ -1292,6 +1297,45 @@ inline float fast_length(const float *a, int len); template inline ValueT length(const ValueT *a, const int len); +// The following definition is enabled when BinaryOperation(ValueT, ValueT) returns bool +// std::enable_if_t, bool>, bool> +template +inline bool +compare(const ValueT a, const ValueT b, const BinaryOperation binary_op); +template +inline std::enable_if_t +compare(const ValueT a, const ValueT b, const BinaryOperation binary_op); + +// The following definition is enabled when BinaryOperation(ValueT, ValueT) returns bool +// std::enable_if_t, bool>, bool> +template +inline bool +unordered_compare(const ValueT a, const ValueT b, + const BinaryOperation binary_op); +template +inline std::enable_if_t +unordered_compare(const ValueT a, const ValueT b, + const BinaryOperation binary_op); + +template +inline std::enable_if_t +compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op); +template + +inline std::enable_if_t +unordered_compare_both(const ValueT a, const ValueT b, + const BinaryOperation binary_op); + +template +inline unsigned compare_mask(const sycl::vec a, + const sycl::vec b, + const BinaryOperation binary_op); + +template +inline unsigned unordered_compare_mask(const sycl::vec a, + const sycl::vec b, + const BinaryOperation binary_op); + template inline T vectorized_max(T a, T b); template inline T vectorized_min(T a, T b); diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index 2e317d185a744..21cddf580eb2c 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -118,6 +118,15 @@ inline constexpr RetT extend_binary(AT a, BT b, CT c, return second_op(extend_temp, extend_c); } +template inline bool isnan(const ValueT a) { + return sycl::isnan(a); +} +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +inline bool isnan(const sycl::ext::oneapi::bfloat16 a) { + return sycl::ext::oneapi::experimental::isnan(a); +} +#endif + } // namespace detail /// Compute fast_length for variable-length array @@ -167,6 +176,121 @@ inline ValueT length(const ValueT *a, const int len) { } } +/// Performs comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t< + std::is_same_v, bool>, + bool> +compare(const ValueT a, const ValueT b, const BinaryOperation binary_op) { + return binary_op(a, b); +} +template +inline std::enable_if_t< + std::is_same_v, ValueT, ValueT>, + bool>, + bool> +compare(const ValueT a, const ValueT b, const std::not_equal_to<> binary_op) { + return !detail::isnan(a) && !detail::isnan(b) && binary_op(a, b); +} + +/// Performs 2 element comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +compare(const ValueT a, const ValueT b, const BinaryOperation binary_op) { + return {compare(a[0], b[0], binary_op), compare(a[1], b[1], binary_op)}; +} + +/// Performs unordered comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t< + std::is_same_v, bool>, + bool> +unordered_compare(const ValueT a, const ValueT b, + const BinaryOperation binary_op) { + return detail::isnan(a) || detail::isnan(b) || binary_op(a, b); +} + +/// Performs 2 element unordered comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +unordered_compare(const ValueT a, const ValueT b, + const BinaryOperation binary_op) { + return {unordered_compare(a[0], b[0], binary_op), + unordered_compare(a[1], b[1], binary_op)}; +} + +/// Performs 2 element comparison and return true if both results are true. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op) { + return compare(a[0], b[0], binary_op) && compare(a[1], b[1], binary_op); +} + +/// Performs 2 element unordered comparison and return true if both results are +/// true. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +unordered_compare_both(const ValueT a, const ValueT b, + const BinaryOperation binary_op) { + return unordered_compare(a[0], b[0], binary_op) && + unordered_compare(a[1], b[1], binary_op); +} + +/// Performs 2 elements comparison, compare result of each element is 0 (false) +/// or 0xffff (true), returns an unsigned int by composing compare result of two +/// elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline unsigned compare_mask(const sycl::vec a, + const sycl::vec b, + const BinaryOperation binary_op) { + // Since compare returns 0 or 1, -compare will be 0x00000000 or 0xFFFFFFFF + return ((-compare(a[0], b[0], binary_op)) << 16) | + ((-compare(a[1], b[1], binary_op)) & 0xFFFF); +} + +/// Performs 2 elements unordered comparison, compare result of each element is +/// 0 (false) or 0xffff (true), returns an unsigned int by composing compare +/// result of two elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline unsigned unordered_compare_mask(const sycl::vec a, + const sycl::vec b, + const BinaryOperation binary_op) { + return ((-unordered_compare(a[0], b[0], binary_op)) << 16) | + ((-unordered_compare(a[1], b[1], binary_op)) & 0xFFFF); +} + /// Compute vectorized max for two values, with each value treated as a vector /// type \p S /// \param [in] S The type of the vector diff --git a/sycl/test-e2e/syclcompat/common.hpp b/sycl/test-e2e/syclcompat/common.hpp index aace097740ce2..231ebd4655358 100644 --- a/sycl/test-e2e/syclcompat/common.hpp +++ b/sycl/test-e2e/syclcompat/common.hpp @@ -43,3 +43,5 @@ void instantiate_all_types(Functor &&f) { using value_type_list = std::tuple; + +using fp_type_list = std::tuple; diff --git a/sycl/test-e2e/syclcompat/math/math_compare.cpp b/sycl/test-e2e/syclcompat/math/math_compare.cpp new file mode 100644 index 0000000000000..34b6bcc4d75ff --- /dev/null +++ b/sycl/test-e2e/syclcompat/math/math_compare.cpp @@ -0,0 +1,371 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCLcompat API + * + * math_compare.cpp + * + * Description: + * math helpers tests + **************************************************************************/ + +// The original source was under the license below: +// ===------------------- math.cpp ---------- -*- C++ -* ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// +// ===---------------------------------------------------------------------===// + +// REQUIRES: aspect-fp16 + +// RUN: %clangxx -std=c++17 -fsycl -fsycl-targets=%{sycl_triple} %s -o %t.out +// RUN: %{run} %t.out + +#include +#include + +#include "../common.hpp" +#include "math_fixt.hpp" + +template +void compare_equal_kernel(ValueT *a, ValueT *b, bool *r) { + *r = syclcompat::compare(*a, *b, std::equal_to<>()); +} + +template +void compare_not_equal_kernel(ValueT *a, ValueT *b, bool *r) { + *r = syclcompat::compare(*a, *b, std::not_equal_to<>()); +} + +template void test_compare() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + constexpr ValueT op1 = static_cast(1.0); + ValueT op2 = sycl::nan(static_cast(0)); + + // 1.0 == 1.0 -> True + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op1, true); + // NaN == 1.0 -> False + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op2, op1, false); + // 1.0 == NaN -> False + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op2, false); + // NaN == NaN -> False + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op2, op2, false); + + // 1.0 != 1.0 -> False + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op1, false); + // NaN != 1.0 -> False + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op2, op1, false); + // 1.0 != NaN -> False + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op2, false); + // NaN != NaN -> False + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op2, op2, false); +} + +template +void compare_equal_vec_kernel(Container *a, Container *b, Container *r) { + *r = syclcompat::compare(*a, *b, std::equal_to<>()); +} + +template +void compare_not_equal_vec_kernel(Container *a, Container *b, Container *r) { + *r = syclcompat::compare(*a, *b, std::not_equal_to<>()); +} + +template void test_compare_vec() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + using Container = sycl::vec; + + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + constexpr Container op1 = {static_cast(1.0), + static_cast(2.0)}; + Container op2 = {static_cast(1.0), + sycl::nan(static_cast(0))}; + + // bool2 does not exist, 1.0 and 0.0 floats are used for true + // and false instead. + // 1.0 == 1.0, 2.0 == NaN -> {true, false} + constexpr Container res1 = {1.0, 0.0}; + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op2, + res1); + // 1.0 != 1.0, 2.0 != NaN -> {false, false} + constexpr Container res2 = {0.0, 0.0}; + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op2, + res2); +} + +template +void unordered_compare_equal_kernel(ValueT *a, ValueT *b, bool *r) { + *r = syclcompat::unordered_compare(*a, *b, std::equal_to<>()); +} + +template +void unordered_compare_not_equal_kernel(ValueT *a, ValueT *b, bool *r) { + *r = syclcompat::unordered_compare(*a, *b, std::not_equal_to<>()); +} + +template +void test_unordered_compare() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + constexpr ValueT op1 = static_cast(1.0); + ValueT op2 = sycl::nan(static_cast(0)); + + // Unordered comparison checks if either operand is NaN, or the binaryop holds + // true + // 1.0 == 1.0 -> True + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op1, + true); + // NaN == 1.0 -> True + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op2, op1, + true); + // 1.0 == NaN -> True + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op2, + true); + // NaN == NaN -> True + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op2, op2, + true); + // 1.0 != 1.0 -> False + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op1, false); + // No need to check again if either operand is NaN +} + +template +void unordered_compare_equal_vec_kernel(Container *a, Container *b, + Container *r) { + *r = syclcompat::unordered_compare(*a, *b, std::equal_to<>()); +} + +template +void unordered_compare_not_equal_vec_kernel(Container *a, Container *b, + Container *r) { + *r = syclcompat::unordered_compare(*a, *b, std::not_equal_to<>()); +} + +template void test_unordered_compare_vec() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + using Container = sycl::vec; + + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + constexpr Container op1 = {static_cast(1.0), + static_cast(2.0)}; + Container op2 = {static_cast(1.0), + sycl::nan(static_cast(0))}; + + // bool2 does not exist, 1.0 and 0.0 floats are used for true + // and false instead. + // 1.0 == 1.0, 2.0 == NaN -> {true, true} + constexpr Container res1 = {1.0, 1.0}; + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op2, res1); + // 1.0 != 1.0, 2.0 != NaN -> {false, true} + constexpr Container res2 = {0.0, 1.0}; + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op2, res2); +} + +template +void compare_both_kernel(Container *a, Container *b, bool *r) { + *r = syclcompat::compare_both(*a, *b, std::equal_to<>()); +} + +template void test_compare_both() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + using Container = sycl::vec; + + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + constexpr Container op1 = {static_cast(1.0), + static_cast(2.0)}; + Container op2 = {static_cast(1.0), + sycl::nan(static_cast(0))}; + + // 1.0 == 1.0, 2.0 == NaN -> {true, false} -> false + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op2, false); + + // 1.0 == 1.0, 2.0 == 2.0 -> {true, true} -> true + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op1, true); + + // 1.0 == 1.0, NaN == NaN -> {true, false} -> false + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op2, op2, false); +} + +template +void unordered_compare_both_kernel(Container *a, Container *b, bool *r) { + *r = syclcompat::unordered_compare_both(*a, *b, std::equal_to<>()); +} + +template void test_unordered_compare_both() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + using Container = sycl::vec; + + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + constexpr Container op1 = {static_cast(1.0), + static_cast(2.0)}; + Container op2 = {static_cast(1.0), + sycl::nan(static_cast(0))}; + + // 1.0 == 1.0, 2.0 == NaN -> {true, true} -> true + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op2, + true); + // 1.0 == 1.0, 2.0 == 2.0 -> {true, true} -> true + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op1, + true); + // 1.0 == 1.0, NaN == NaN -> {true, true} -> true + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op2, op2, + true); +} + +template +void compare_mask_kernel(Container *a, Container *b, unsigned *r) { + *r = syclcompat::compare_mask(*a, *b, std::equal_to<>()); +} + +template void test_compare_mask() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + using Container = sycl::vec; + + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + constexpr Container op1 = {static_cast(1.0), + static_cast(2.0)}; + constexpr Container op2 = {static_cast(2.0), + static_cast(1.0)}; + constexpr Container op3 = {static_cast(1.0), + static_cast(3.0)}; + constexpr Container op4 = {static_cast(3.0), + static_cast(2.0)}; + Container op5 = {sycl::nan(static_cast(0)), + sycl::nan(static_cast(0))}; + + // 1.0 == 1.0, 2.0 == 2.0 -> 0xffffffff + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op1, + 0xffffffff); + + // 1.0 == 2.0, 2.0 == 1.0 -> 0x00000000 + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op2, + 0x00000000); + + // 1.0 == 1.0, 2.0 == 3.0 -> 0xffff0000 + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op3, + 0xffff0000); + + // 1.0 == 3.0, 2.0 == 2.0 -> 0x0000ffff + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op4, + 0x0000ffff); + + // 1.0 == NaN, 2.0 == NaN -> 0x00000000 + BinaryOpTestLauncher(grid, threads) + .template launch_test>(op1, op5, + 0x00000000); +} + +template +void unordered_compare_mask_kernel(Container *a, Container *b, unsigned *r) { + *r = syclcompat::unordered_compare_mask(*a, *b, std::equal_to<>()); +} + +template void test_unordered_compare_mask() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + using Container = sycl::vec; + + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + constexpr Container op1 = {static_cast(1.0), + static_cast(2.0)}; + constexpr Container op2 = {static_cast(2.0), + static_cast(1.0)}; + constexpr Container op3 = {static_cast(1.0), + static_cast(3.0)}; + constexpr Container op4 = {static_cast(3.0), + static_cast(2.0)}; + Container op5 = {sycl::nan(static_cast(0)), + sycl::nan(static_cast(0))}; + + // 1.0 == 1.0, 2.0 == 2.0 -> 0xffffffff + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op1, 0xffffffff); + + // 1.0 == 2.0, 2.0 == 1.0 -> 0x00000000 + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op2, 0x00000000); + + // 1.0 == 1.0, 2.0 == 3.0 -> 0xffff0000 + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op3, 0xffff0000); + + // 1.0 == 3.0, 2.0 == 2.0 -> 0x0000ffff + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op4, 0x0000ffff); + + // 1.0 == NaN, 2.0 == NaN -> 0xffffffff + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op5, 0xffffffff); +} + +int main() { + INSTANTIATE_ALL_TYPES(fp_type_list, test_compare); + INSTANTIATE_ALL_TYPES(fp_type_list, test_unordered_compare); + INSTANTIATE_ALL_TYPES(fp_type_list, test_compare_vec); + INSTANTIATE_ALL_TYPES(fp_type_list, test_unordered_compare_vec); + INSTANTIATE_ALL_TYPES(fp_type_list, test_compare_both); + INSTANTIATE_ALL_TYPES(fp_type_list, test_unordered_compare_both); + INSTANTIATE_ALL_TYPES(fp_type_list, test_compare_mask); + INSTANTIATE_ALL_TYPES(fp_type_list, test_unordered_compare_mask); + + return 0; +}