Skip to content

Commit

Permalink
Revert "[SYCL] Refactor sycl::vec's operators implementation (#16529)"
Browse files Browse the repository at this point in the history
This reverts commit 16c2c21.
  • Loading branch information
steffenlarsen authored Jan 8, 2025
1 parent 16c2c21 commit 063e01a
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 214 deletions.
12 changes: 1 addition & 11 deletions sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,14 @@ template <typename T> constexpr bool is_vec_v = is_vec<T>::value;

template <typename T, typename = void>
struct is_ext_vector : std::false_type {};
template <typename T, typename = void>
struct is_valid_type_for_ext_vector : std::false_type {};
#if defined(__has_extension)
#if __has_extension(attribute_ext_vector_type)
template <typename T, int N>
using ext_vector = T __attribute__((ext_vector_type(N)));
template <typename T, int N>
struct is_ext_vector<ext_vector<T, N>> : std::true_type {};
template <typename T>
struct is_valid_type_for_ext_vector<T, std::void_t<ext_vector<T, 2>>>
: std::true_type {};
struct is_ext_vector<T __attribute__((ext_vector_type(N)))> : std::true_type {};
#endif
#endif
template <typename T>
inline constexpr bool is_ext_vector_v = is_ext_vector<T>::value;
template <typename T>
inline constexpr bool is_valid_type_for_ext_vector_v =
is_valid_type_for_ext_vector<T>::value;

template <typename> struct is_swizzle : std::false_type {};
template <typename VecT, typename OperationLeftT, typename OperationRightT,
Expand Down
135 changes: 68 additions & 67 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@

#pragma once

#include <sycl/aliases.hpp>
#include <sycl/detail/generic_type_traits.hpp>
#include <sycl/detail/type_traits.hpp>
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
#include <sycl/detail/type_traits.hpp> // for is_floating_point

#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#include <cstddef>
#include <type_traits> // for enable_if_t, is_same

namespace sycl {
inline namespace _V1 {
Expand Down Expand Up @@ -47,7 +50,13 @@ struct UnaryPlus {
};

struct VecOperators {
template <typename OpTy, typename... ArgTys>
#ifdef __SYCL_DEVICE_ONLY__
static constexpr bool is_host = false;
#else
static constexpr bool is_host = true;
#endif

template <typename BinOp, typename... ArgTys>
static constexpr auto apply(const ArgTys &...Args) {
using Self = nth_type_t<0, ArgTys...>;
static_assert(is_vec_v<Self>);
Expand All @@ -56,96 +65,88 @@ struct VecOperators {
using element_type = typename Self::element_type;
constexpr int N = Self::size();
constexpr bool is_logical = check_type_in_v<
OpTy, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
BinOp, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
std::greater<void>, std::less_equal<void>, std::greater_equal<void>,
std::logical_and<void>, std::logical_or<void>, std::logical_not<void>>;

using result_t = std::conditional_t<
is_logical, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;

OpTy Op{};
#ifdef __has_extension
#if __has_extension(attribute_ext_vector_type)
// ext_vector_type's bool vectors are mapped onto <N x i1> and have
// different memory layout than sycl::vec<bool ,N> (which has 1 byte per
// element). As such we perform operation on int8_t and then need to
// create bit pattern that can be bit-casted back to the original
// sycl::vec<bool, N>. This is a hack actually, but we've been doing
// that for a long time using sycl::vec::vector_t type.
using vec_elem_ty =
typename detail::map_type<element_type, //
bool, /*->*/ std::int8_t,
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
std::byte, /*->*/ std::uint8_t,
#endif
#ifdef __SYCL_DEVICE_ONLY__
half, /*->*/ _Float16,
#endif
element_type, /*->*/ element_type>::type;
if constexpr (N != 1 &&
detail::is_valid_type_for_ext_vector_v<vec_elem_ty>) {
using vec_t = ext_vector<vec_elem_ty, N>;
auto tmp = [&](auto... xs) {
BinOp Op{};
if constexpr (is_host || N == 1 ||
std::is_same_v<element_type, ext::oneapi::bfloat16>) {
result_t res{};
for (size_t i = 0; i < N; ++i)
if constexpr (is_logical)
res[i] = Op(Args[i]...) ? -1 : 0;
else
res[i] = Op(Args[i]...);
return res;
} else {
using vector_t = typename Self::vector_t;

auto res = [&](auto... xs) {
// Workaround for https://github.com/llvm/llvm-project/issues/119617.
if constexpr (sizeof...(Args) == 2) {
return [&](auto x, auto y) {
if constexpr (std::is_same_v<OpTy, std::equal_to<void>>)
if constexpr (std::is_same_v<BinOp, std::equal_to<void>>)
return x == y;
else if constexpr (std::is_same_v<OpTy, std::not_equal_to<void>>)
else if constexpr (std::is_same_v<BinOp, std::not_equal_to<void>>)
return x != y;
else if constexpr (std::is_same_v<OpTy, std::less<void>>)
else if constexpr (std::is_same_v<BinOp, std::less<void>>)
return x < y;
else if constexpr (std::is_same_v<OpTy, std::less_equal<void>>)
else if constexpr (std::is_same_v<BinOp, std::less_equal<void>>)
return x <= y;
else if constexpr (std::is_same_v<OpTy, std::greater<void>>)
else if constexpr (std::is_same_v<BinOp, std::greater<void>>)
return x > y;
else if constexpr (std::is_same_v<OpTy, std::greater_equal<void>>)
else if constexpr (std::is_same_v<BinOp, std::greater_equal<void>>)
return x >= y;
else
return Op(x, y);
}(xs...);
} else {
return Op(xs...);
}
}(bit_cast<vec_t>(Args)...);
}(bit_cast<vector_t>(Args)...);

if constexpr (std::is_same_v<element_type, bool>) {
// Some operations are known to produce the required bit patterns and
// the following post-processing isn't necessary for them:
// vec(vector_t) ctor does a simple bit_cast and the way "bool" is
// stored is that only one bit matters. vector_t, however, is a char
// type and it can have non-zero value with lowest bit unset. E.g.,
// consider this:
//
// auto x = true + true; // int x = 2
// bool y = true + true; // bool y = true
//
// and the vec<bool, N> has to behave in a similar way. As such, current
// implementation needs to do some extra processing for operators that
// can result in this scenario.
//
if constexpr (!is_logical &&
!check_type_in_v<OpTy, std::multiplies<void>,
!check_type_in_v<BinOp, std::multiplies<void>,
std::divides<void>, std::bit_or<void>,
std::bit_and<void>, std::bit_xor<void>,
ShiftRight, UnaryPlus>) {
// Extra cast is needed because:
static_assert(std::is_same_v<int8_t, signed char>);
static_assert(!std::is_same_v<
decltype(std::declval<ext_vector<int8_t, 2>>() != 0),
ext_vector<int8_t, 2>>);
static_assert(std::is_same_v<
decltype(std::declval<ext_vector<int8_t, 2>>() != 0),
ext_vector<char, 2>>);

// `... * -1` is needed because ext_vector_type's comparison follows
// OpenCL binary representation for "true" (-1).
// `std::array<bool, N>` is different and LLVM annotates its
// elements with [0, 2) range metadata when loaded, so we need to
// ensure we generate 0/1 only (and not 2/-1/etc.).
static_assert((ext_vector<int8_t, 2>{1, 0} == 0)[1] == -1);

tmp = reinterpret_cast<decltype(tmp)>((tmp != 0) * -1);
// TODO: Not sure why the following doesn't work
// (test-e2e/Basic/vector/bool.cpp fails).
//
// res = (decltype(res))(res != 0);
for (size_t i = 0; i < N; ++i)
res[i] = bit_cast<int8_t>(res[i]) != 0;
}
}
return bit_cast<result_t>(tmp);
// The following is true:
//
// using char2 = char __attribute__((ext_vector_type(2)));
// using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
// static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
// std::declval<uchar2>()),
// char2>);
//
// so we need some extra casts. Also, static_cast<uchar2>(char2{})
// isn't allowed either.
return result_t{(typename result_t::vector_t)res};
}
#endif
#endif
result_t res{};
for (size_t i = 0; i < N; ++i)
if constexpr (is_logical)
res[i] = Op(Args[i]...) ? -1 : 0;
else
res[i] = Op(Args[i]...);
return res;
}
};

Expand Down
Loading

0 comments on commit 063e01a

Please sign in to comment.