Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Refactor sycl::vec's operators implementation #16529

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,24 @@ template <typename T> constexpr bool is_vec_v = is_vec<T>::value;

template <typename T, typename = void>
struct is_ext_vector : std::false_type {};
template <typename T, typename = void>
struct is_valid_type_for_ext_vector : std::false_type {};
#if defined(__has_extension)
#if __has_extension(attribute_ext_vector_type)
template <typename T, int N>
struct is_ext_vector<T __attribute__((ext_vector_type(N)))> : std::true_type {};
using ext_vector = T __attribute__((ext_vector_type(N)));
template <typename T, int N>
struct is_ext_vector<ext_vector<T, N>> : std::true_type {};
template <typename T>
struct is_valid_type_for_ext_vector<T, std::void_t<ext_vector<T, 2>>>
: std::true_type {};
#endif
#endif
template <typename T>
inline constexpr bool is_ext_vector_v = is_ext_vector<T>::value;
template <typename T>
inline constexpr bool is_valid_type_for_ext_vector_v =
is_valid_type_for_ext_vector<T>::value;

template <typename> struct is_swizzle : std::false_type {};
template <typename VecT, typename OperationLeftT, typename OperationRightT,
Expand Down
135 changes: 67 additions & 68 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@

#pragma once

#include <sycl/aliases.hpp> // for half, cl_char, cl_int
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
#include <sycl/detail/type_traits.hpp> // for is_floating_point

#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#include <cstddef>
#include <type_traits> // for enable_if_t, is_same
#include <sycl/aliases.hpp>
#include <sycl/detail/generic_type_traits.hpp>
#include <sycl/detail/type_traits.hpp>
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>

namespace sycl {
inline namespace _V1 {
Expand Down Expand Up @@ -50,13 +47,7 @@ struct UnaryPlus {
};

struct VecOperators {
#ifdef __SYCL_DEVICE_ONLY__
static constexpr bool is_host = false;
#else
static constexpr bool is_host = true;
#endif

template <typename BinOp, typename... ArgTys>
template <typename OpTy, typename... ArgTys>
static constexpr auto apply(const ArgTys &...Args) {
using Self = nth_type_t<0, ArgTys...>;
static_assert(is_vec_v<Self>);
Expand All @@ -65,88 +56,96 @@ struct VecOperators {
using element_type = typename Self::element_type;
constexpr int N = Self::size();
constexpr bool is_logical = check_type_in_v<
BinOp, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
OpTy, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
std::greater<void>, std::less_equal<void>, std::greater_equal<void>,
std::logical_and<void>, std::logical_or<void>, std::logical_not<void>>;

using result_t = std::conditional_t<
is_logical, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;

BinOp Op{};
if constexpr (is_host || N == 1 ||
std::is_same_v<element_type, ext::oneapi::bfloat16>) {
result_t res{};
for (size_t i = 0; i < N; ++i)
if constexpr (is_logical)
res[i] = Op(Args[i]...) ? -1 : 0;
else
res[i] = Op(Args[i]...);
return res;
} else {
using vector_t = typename Self::vector_t;

auto res = [&](auto... xs) {
OpTy Op{};
#ifdef __has_extension
#if __has_extension(attribute_ext_vector_type)
// ext_vector_type's bool vectors are mapped onto <N x i1> and have
// different memory layout than sycl::vec<bool ,N> (which has 1 byte per
// element). As such we perform operation on int8_t and then need to
// create bit pattern that can be bit-casted back to the original
// sycl::vec<bool, N>. This is a hack actually, but we've been doing
// that for a long time using sycl::vec::vector_t type.
using vec_elem_ty =
typename detail::map_type<element_type, //
bool, /*->*/ std::int8_t,
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
std::byte, /*->*/ std::uint8_t,
#endif
#ifdef __SYCL_DEVICE_ONLY__
half, /*->*/ _Float16,
#endif
element_type, /*->*/ element_type>::type;
if constexpr (N != 1 &&
detail::is_valid_type_for_ext_vector_v<vec_elem_ty>) {
using vec_t = ext_vector<vec_elem_ty, N>;
auto tmp = [&](auto... xs) {
// Workaround for https://github.com/llvm/llvm-project/issues/119617.
if constexpr (sizeof...(Args) == 2) {
aelovikov-intel marked this conversation as resolved.
Show resolved Hide resolved
return [&](auto x, auto y) {
steffenlarsen marked this conversation as resolved.
Show resolved Hide resolved
if constexpr (std::is_same_v<BinOp, std::equal_to<void>>)
if constexpr (std::is_same_v<OpTy, std::equal_to<void>>)
return x == y;
else if constexpr (std::is_same_v<BinOp, std::not_equal_to<void>>)
else if constexpr (std::is_same_v<OpTy, std::not_equal_to<void>>)
return x != y;
else if constexpr (std::is_same_v<BinOp, std::less<void>>)
else if constexpr (std::is_same_v<OpTy, std::less<void>>)
return x < y;
else if constexpr (std::is_same_v<BinOp, std::less_equal<void>>)
else if constexpr (std::is_same_v<OpTy, std::less_equal<void>>)
return x <= y;
else if constexpr (std::is_same_v<BinOp, std::greater<void>>)
else if constexpr (std::is_same_v<OpTy, std::greater<void>>)
return x > y;
else if constexpr (std::is_same_v<BinOp, std::greater_equal<void>>)
else if constexpr (std::is_same_v<OpTy, std::greater_equal<void>>)
return x >= y;
else
return Op(x, y);
}(xs...);
} else {
return Op(xs...);
}
}(bit_cast<vector_t>(Args)...);

}(bit_cast<vec_t>(Args)...);
if constexpr (std::is_same_v<element_type, bool>) {
// vec(vector_t) ctor does a simple bit_cast and the way "bool" is
// stored is that only one bit matters. vector_t, however, is a char
// type and it can have non-zero value with lowest bit unset. E.g.,
// consider this:
//
// auto x = true + true; // int x = 2
// bool y = true + true; // bool y = true
//
// and the vec<bool, N> has to behave in a similar way. As such, current
// implementation needs to do some extra processing for operators that
// can result in this scenario.
//
// Some operations are known to produce the required bit patterns and
// the following post-processing isn't necessary for them:
if constexpr (!is_logical &&
!check_type_in_v<BinOp, std::multiplies<void>,
!check_type_in_v<OpTy, std::multiplies<void>,
std::divides<void>, std::bit_or<void>,
std::bit_and<void>, std::bit_xor<void>,
ShiftRight, UnaryPlus>) {
// TODO: Not sure why the following doesn't work
// (test-e2e/Basic/vector/bool.cpp fails).
//
// res = (decltype(res))(res != 0);
for (size_t i = 0; i < N; ++i)
res[i] = bit_cast<int8_t>(res[i]) != 0;
// Extra cast is needed because:
static_assert(std::is_same_v<int8_t, signed char>);
static_assert(!std::is_same_v<
decltype(std::declval<ext_vector<int8_t, 2>>() != 0),
ext_vector<int8_t, 2>>);
static_assert(std::is_same_v<
decltype(std::declval<ext_vector<int8_t, 2>>() != 0),
ext_vector<char, 2>>);

// `... * -1` is needed because ext_vector_type's comparison follows
// OpenCL binary representation for "true" (-1).
// `std::array<bool, N>` is different and LLVM annotates its
// elements with [0, 2) range metadata when loaded, so we need to
// ensure we generate 0/1 only (and not 2/-1/etc.).
static_assert((ext_vector<int8_t, 2>{1, 0} == 0)[1] == -1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to be a constant expression on Windows and MacOS.


tmp = reinterpret_cast<decltype(tmp)>((tmp != 0) * -1);
}
}
// The following is true:
//
// using char2 = char __attribute__((ext_vector_type(2)));
// using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
// static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
// std::declval<uchar2>()),
// char2>);
//
// so we need some extra casts. Also, static_cast<uchar2>(char2{})
// isn't allowed either.
return result_t{(typename result_t::vector_t)res};
return bit_cast<result_t>(tmp);
}
#endif
#endif
result_t res{};
for (size_t i = 0; i < N; ++i)
if constexpr (is_logical)
res[i] = Op(Args[i]...) ? -1 : 0;
else
res[i] = Op(Args[i]...);
return res;
}
};

Expand Down
Loading
Loading