From 247702e706f6019b204f2ea0b50c8843a7dc4cdd Mon Sep 17 00:00:00 2001 From: Gabriel Mitterrutzner Date: Mon, 27 Jan 2025 18:49:11 +0100 Subject: [PATCH] Change return type of 1D-element named swizzles --- include/simsycl/sycl/vec.hh | 12 ++++++------ test/vec_tests.cc | 15 +++++++++++++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/include/simsycl/sycl/vec.hh b/include/simsycl/sycl/vec.hh index fcd2b6c..6ee8853 100644 --- a/include/simsycl/sycl/vec.hh +++ b/include/simsycl/sycl/vec.hh @@ -302,10 +302,10 @@ class swizzled_vec { } #define SIMSYCL_DETAIL_DEFINE_1D_SWIZZLE(req, comp) \ - auto comp() const \ + ReferenceDataT &comp() const \ requires(req && num_elements > sycl::elem::comp) \ { \ - return detail::swizzled_vec(m_elems); \ + return m_elems[indices[sycl::elem::comp]]; \ } #define SIMSYCL_DETAIL_DEFINE_2D_SWIZZLE(req, comp1, comp2) \ @@ -612,15 +612,15 @@ class alignas(detail::vec_alignment_v) vec { } #define SIMSYCL_DETAIL_DEFINE_1D_SWIZZLE(req, comp) \ - auto comp() \ + DataT &comp() \ requires(req && num_elements > elem::comp) \ { \ - return detail::swizzled_vec(m_elems); \ + return m_elems[elem::comp]; \ } \ - auto comp() const \ + const DataT &comp() const \ requires(req && num_elements > elem::comp) \ { \ - return detail::swizzled_vec(m_elems); \ + return m_elems[elem::comp]; \ } #define SIMSYCL_DETAIL_DEFINE_2D_SWIZZLE(req, comp1, comp2) \ diff --git a/test/vec_tests.cc b/test/vec_tests.cc index 1a2f281..a2b4bfb 100644 --- a/test/vec_tests.cc +++ b/test/vec_tests.cc @@ -158,6 +158,17 @@ TEST_CASE("Vector swizzled access is available", "[vec][swizzle]") { vi.even() = {11, 12}; CHECK(check_bool_vec(vi == sycl::vec{11, 9, 12, 10})); } + + SECTION("1D swizzled vec return type") { + sycl::vec v4{255}; + CHECK(std::is_same_v); + CHECK(std::is_same_v); + CHECK(std::is_same_v); + CHECK(std::is_same_v); + + int i = (v4.x() + 1); + CHECK(i == 256); + } } TEST_CASE("Operations on swizzled vectors work as expected", "[vec][swizzle]") { @@ -208,7 +219,7 @@ TEST_CASE("Operations on swizzled vectors work as expected", "[vec][swizzle]") { CHECK(check_bool_vec(sycl::vec{0, 0} < vi.xy())); CHECK(check_bool_vec(vi.zw() < 10)); - CHECK(check_bool_vec(vi.z() == 3)); - CHECK(check_bool_vec(4 == vi.a())); + CHECK((vi.z() == 3)); + CHECK((4 == vi.a())); } }