Skip to content

Commit

Permalink
[SYCL] Refactor sycl::vec's operators implementation
Browse files Browse the repository at this point in the history
* Don't use `sycl::vec::vector_t`, as it is planned to be removed from
  the SYCL 2020 (KhronosGroup/SYCL-Docs#676).
  Note that this implementation is NOT required to use it, so this PR
  can be merged before the specification change.
* Use `ext_vector_type`-based optimized implementation whenever it's
  available and not on device only.
  • Loading branch information
aelovikov-intel committed Jan 6, 2025
1 parent 62ce674 commit 0b224e9
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 183 deletions.
81 changes: 36 additions & 45 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,27 @@ struct VecOperators {
is_logical, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;

BinOp Op{};
if constexpr (is_host || N == 1 ||
std::is_same_v<element_type, ext::oneapi::bfloat16>) {
result_t res{};
for (size_t i = 0; i < N; ++i)
if constexpr (is_logical)
res[i] = Op(Args[i]...) ? -1 : 0;
else
res[i] = Op(Args[i]...);
return res;
} else {
using vector_t = typename Self::vector_t;

auto res = [&](auto... xs) {
#ifdef __has_extension
#if __has_extension(attribute_ext_vector_type)
// ext_vector_type's bool vectors are mapped onto <N x i1> and have
// different memory layout than sycl::vec<bool ,N> (which has 1 byte per
// element). As such we perform operation on int8_t and then need to
// create bit pattern that can be bit-casted back to the original
// sycl::vec<bool, N>. This is a hack actually, but we've been doing
// that for a long time using sycl::vec::vector_t type.
using vec_elem_ty =
typename detail::map_type<element_type, bool, /*->*/ std::uint8_t,
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
std::byte, /*->*/ std::int8_t,
#endif
#ifdef __SYCL_DEVICE_ONLY__
half, /*->*/ _Float16,
#endif
element_type, /*->*/ element_type>::type;
if constexpr (N != 1 && !check_type_in_v<vec_elem_ty, bool, half,
ext::oneapi::bfloat16>) {
using vec_t = vec_elem_ty __attribute__((ext_vector_type(N)));
auto tmp = [&](auto... xs) {
// Workaround for https://github.com/llvm/llvm-project/issues/119617.
if constexpr (sizeof...(Args) == 2) {
return [&](auto x, auto y) {
Expand All @@ -107,46 +115,29 @@ struct VecOperators {
} else {
return Op(xs...);
}
}(bit_cast<vector_t>(Args)...);

}(bit_cast<vec_t>(Args)...);
if constexpr (std::is_same_v<element_type, bool>) {
// vec(vector_t) ctor does a simple bit_cast and the way "bool" is
// stored is that only one bit matters. vector_t, however, is a char
// type and it can have non-zero value with lowest bit unset. E.g.,
// consider this:
//
// auto x = true + true; // int x = 2
// bool y = true + true; // bool y = true
//
// and the vec<bool, N> has to behave in a similar way. As such, current
// implementation needs to do some extra processing for operators that
// can result in this scenario.
//
// Some operations are known to produce the required bit patterns and
// the following post-processing isn't necessary for them:
if constexpr (!is_logical &&
!check_type_in_v<BinOp, std::multiplies<void>,
std::divides<void>, std::bit_or<void>,
std::bit_and<void>, std::bit_xor<void>,
ShiftRight, UnaryPlus>) {
// TODO: Not sure why the following doesn't work
// (test-e2e/Basic/vector/bool.cpp fails).
//
// res = (decltype(res))(res != 0);
ShiftRight, UnaryPlus>)
for (size_t i = 0; i < N; ++i)
res[i] = bit_cast<int8_t>(res[i]) != 0;
}
tmp[i] = (tmp[i] != 0);
}
// The following is true:
//
// using char2 = char __attribute__((ext_vector_type(2)));
// using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
// static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
// std::declval<uchar2>()),
// char2>);
//
// so we need some extra casts. Also, static_cast<uchar2>(char2{})
// isn't allowed either.
return result_t{(typename result_t::vector_t)res};
return bit_cast<result_t>(tmp);
}
#endif
#endif
result_t res{};
for (size_t i = 0; i < N; ++i)
if constexpr (is_logical)
res[i] = Op(Args[i]...) ? -1 : 0;
else
res[i] = Op(Args[i]...);
return res;
}
};

Expand Down
Loading

0 comments on commit 0b224e9

Please sign in to comment.