From 05d29f38d9b35300a29eac066c61b2a8e698ebb8 Mon Sep 17 00:00:00 2001 From: fineg74 <61437305+fineg74@users.noreply.github.com> Date: Wed, 29 May 2024 08:17:54 -0700 Subject: [PATCH] [ESIMD] Allow full autodeduction for scatter USM APIs accepting simd_view (#13941) --- sycl/include/sycl/ext/intel/esimd/memory.hpp | 229 +++++++++++++++++- sycl/test/esimd/memory_properties_scatter.cpp | 37 ++- 2 files changed, 256 insertions(+), 10 deletions(-) diff --git a/sycl/include/sycl/ext/intel/esimd/memory.hpp b/sycl/include/sycl/ext/intel/esimd/memory.hpp index ab1749e0eec4c..0cd77ee545135 100644 --- a/sycl/include/sycl/ext/intel/esimd/memory.hpp +++ b/sycl/include/sycl/ext/intel/esimd/memory.hpp @@ -911,10 +911,46 @@ scatter(T *p, simd byte_offsets, simd vals, } } -// template -// void scatter(T *p, simd byte_offsets, simd vals, -// PropertyListT props = {}); // (usm-sc-2) +/// template +/// void scatter(T *p, simd byte_offsets, ValuesSimdViewT vals, +/// simd_mask mask, PropertyListT props = {}); +/// +/// Variation of the API that allows to use \c simd_view without specifying \c T +/// and \c N template parameters. +/// Writes ("scatters") elements of the input vector to different memory +/// locations. Each memory location is base address plus an offset - a +/// value of the corresponding element in the input offset vector. Access to +/// any element's memory location can be disabled via the input mask. +/// @tparam VS Vector size. It can also be read as the number of writes per each +/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported +/// only on DG2 and PVC and only for 4- and 8-byte element vectors. +/// @param p The base address. +/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes. +/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned. +/// If the alignment property is not passed, then it is assumed that each +/// accessed address is aligned by element-size. +/// @param vals The vector to scatter. +/// @param mask The access mask. +/// @param props The optional compile-time properties. Only 'alignment' +/// and cache hint properties are used. +template < + int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename T, + int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(), + typename PropertyListT = ext::oneapi::experimental::empty_properties_t> +__ESIMD_API std::enable_if_t< + detail::is_simd_view_type_v && + ext::oneapi::experimental::is_property_list_v> +scatter(T *p, simd byte_offsets, ValuesSimdViewT vals, + simd_mask mask, PropertyListT props = {}) { + scatter(p, byte_offsets, vals.read(), mask, props); +} + +/// template +/// void scatter(T *p, simd byte_offsets, simd vals, +/// PropertyListT props = {}); // (usm-sc-2) /// /// Writes ("scatters") elements of the input vector to different memory /// locations. Each memory location is base address plus an offset - a @@ -943,10 +979,80 @@ scatter(T *p, simd byte_offsets, simd vals, scatter(p, byte_offsets, vals, Mask, props); } -// template -// void scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, -// simd_mask mask, PropertyListT props = {}); // (usm-sc-3) +/// template +/// void scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, +/// simd_mask mask, PropertyListT props = {}); +/// +/// Variation of the API that allows to use \c simd_view without specifying \c T +/// and \c N template parameters. +/// Writes ("scatters") elements of the input vector to different memory +/// locations. Each memory location is base address plus an offset - a +/// value of the corresponding element in the input offset vector. +/// @tparam VS Vector size. It can also be read as the number of writes per each +/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported +/// only on DG2 and PVC and only for 4- and 8-byte element vectors. +/// @param p The base address. +/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes. +/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned. +/// If the alignment property is not passed, then it is assumed that each +/// accessed address is aligned by element-size. +/// @param vals The vector to scatter. +/// @param mask The access mask. +/// @param props The optional compile-time properties. Only 'alignment' +/// and cache hint properties are used. +template < + int VS = 1, typename OffsetSimdViewT, typename ValuesSimdViewT, typename T, + int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(), + typename PropertyListT = ext::oneapi::experimental::empty_properties_t> +__ESIMD_API std::enable_if_t< + detail::is_simd_view_type_v && + detail::is_simd_view_type_v && + ext::oneapi::experimental::is_property_list_v> +scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, + simd_mask mask, PropertyListT props = {}) { + scatter(p, byte_offsets.read(), vals.read(), mask, props); +} + +/// template +/// void scatter(T *p, simd byte_offsets, ValuesSimdViewT vals, +/// PropertyListT props = {}); +/// +/// Variation of the API that allows to use \c simd_view without specifying \c T +/// and \c N template parameters. +/// Writes ("scatters") elements of the input vector to different memory +/// locations. Each memory location is base address plus an offset - a +/// value of the corresponding element in the input offset vector. +/// @tparam VS Vector size. It can also be read as the number of writes per each +/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported +/// only on DG2 and PVC and only for 4- and 8-byte element vectors. +/// @param p The base address. +/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes. +/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned. +/// If the alignment property is not passed, then it is assumed that each +/// accessed address is aligned by element-size. +/// @param vals The vector to scatter. +/// @param props The optional compile-time properties. Only 'alignment' +/// and cache hint properties are used. +template < + int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename T, + int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(), + typename PropertyListT = ext::oneapi::experimental::empty_properties_t> +__ESIMD_API std::enable_if_t< + detail::is_simd_view_type_v && + ext::oneapi::experimental::is_property_list_v> +scatter(T *p, simd byte_offsets, ValuesSimdViewT vals, + PropertyListT props = {}) { + scatter(p, byte_offsets, vals.read(), props); +} + +/// template +/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, +/// simd_mask mask, PropertyListT props = {}); // (usm-sc-3) /// /// Writes ("scatters") elements of the input vector to different memory /// locations. Each memory location is base address plus an offset - a @@ -978,6 +1084,75 @@ scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, scatter(p, byte_offsets.read(), vals, mask, props); } +/// template +/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, +/// simd_mask mask, PropertyListT props = {}); +/// +/// Variation of the API that allows to use \c simd_view without specifying \c T +/// and \c N template parameters. +/// Writes ("scatters") elements of the input vector to different memory +/// locations. Each memory location is base address plus an offset - a +/// value of the corresponding element in the input offset vector. Access to +/// any element's memory location can be disabled via the input mask. +/// @tparam VS Vector size. It can also be read as the number of writes per each +/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported +/// only on DG2 and PVC and only for 4- and 8-byte element vectors. +/// @param p The base address. +/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes +/// represented as a 'simd_view' object. +/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned. +/// If the alignment property is not passed, then it is assumed that each +/// accessed address is aligned by element-size. +/// @param vals The vector to scatter. +/// @param mask The access mask. +/// @param props The optional compile-time properties. Only 'alignment' +/// and cache hint properties are used. +template < + int VS, typename OffsetSimdViewT, typename T, int N, + typename PropertyListT = ext::oneapi::experimental::empty_properties_t> +__ESIMD_API std::enable_if_t< + detail::is_simd_view_type_v && + ext::oneapi::experimental::is_property_list_v> +scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, + simd_mask mask, PropertyListT props = {}) { + scatter(p, byte_offsets.read(), vals, mask, props); +} + +/// template +/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, +/// PropertyListT props = {}); +/// +/// Variation of the API that allows to use \c simd_view without specifying \c T +/// and \c N template parameters. +/// Writes ("scatters") elements of the input vector to different memory +/// locations. Each memory location is base address plus an offset - a +/// value of the corresponding element in the input offset vector. Access to +/// any element's memory location can be disabled via the input mask. +/// @tparam VS Vector size. It can also be read as the number of writes per each +/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported +/// only on DG2 and PVC and only for 4- and 8-byte element vectors. +/// @param p The base address. +/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes +/// represented as a 'simd_view' object. +/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned. +/// If the alignment property is not passed, then it is assumed that each +/// accessed address is aligned by element-size. +/// @param vals The vector to scatter. +/// @param props The optional compile-time properties. Only 'alignment' +/// and cache hint properties are used. +template < + int VS, typename OffsetSimdViewT, typename T, int N, + typename PropertyListT = ext::oneapi::experimental::empty_properties_t> +__ESIMD_API std::enable_if_t< + detail::is_simd_view_type_v && + ext::oneapi::experimental::is_property_list_v> +scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, + PropertyListT props = {}) { + scatter(p, byte_offsets.read(), vals, props); +} + /// template /// void scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, @@ -1012,6 +1187,44 @@ scatter(T *p, OffsetSimdViewT byte_offsets, simd vals, scatter(p, byte_offsets.read(), vals, Mask, props); } +/// template +/// void scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, +/// PropertyListT props = {}); +/// +/// Variation of the API that allows to use \c simd_view without specifying \c T +/// and \c N template parameters. +/// Writes ("scatters") elements of the input vector to different memory +/// locations. Each memory location is base address plus an offset - a +/// value of the corresponding element in the input offset vector. +/// @tparam VS Vector size. It can also be read as the number of writes per each +/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported +/// only on DG2 and PVC and only for 4- and 8-byte element vectors. +/// @param p The base address. +/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes +/// represented as a 'simd_view' object. +/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned. +/// If the alignment property is not passed, then it is assumed that each +/// accessed address is aligned by element-size. +/// @param vals The vector to scatter. +/// @param props The optional compile-time properties. Only 'alignment' +/// and cache hint properties are used. +template < + int VS = 1, typename OffsetSimdViewT, typename ValuesSimdViewT, typename T, + int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(), + typename PropertyListT = ext::oneapi::experimental::empty_properties_t> +__ESIMD_API std::enable_if_t< + detail::is_simd_view_type_v && + detail::is_simd_view_type_v && + ext::oneapi::experimental::is_property_list_v> +scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals, + PropertyListT props = {}) { + simd_mask Mask = 1; + scatter(p, byte_offsets.read(), vals.read(), Mask, props); +} + /// A variation of \c scatter API with \c offsets represented as scalar. /// /// @tparam Tx Element type, must be of size 4 or less. diff --git a/sycl/test/esimd/memory_properties_scatter.cpp b/sycl/test/esimd/memory_properties_scatter.cpp index 685cc52bbeead..85970c6ad8a7c 100644 --- a/sycl/test/esimd/memory_properties_scatter.cpp +++ b/sycl/test/esimd/memory_properties_scatter.cpp @@ -96,7 +96,7 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf, scatter(ptrf, ioffset_n32, usm, props_align4); - // CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v32i1.v32i64.v32i32(<32 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 1, i8 1, i8 0, <32 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0) + // CHECK-COUNT-22: call void @llvm.genx.lsc.store.stateless.v32i1.v32i64.v32i32(<32 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 1, i8 1, i8 0, <32 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0) scatter(ptrf, ioffset_n32, usm, mask_n32, props_cache_load); scatter(ptrf, ioffset_n32, usm, props_cache_load); @@ -110,6 +110,12 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf, props_cache_load); scatter(ptrf, ioffset_n32_view, usm_view, props_cache_load); + scatter(ptrf, ioffset_n32, usm_view, mask_n32, props_cache_load); + scatter(ptrf, ioffset_n32, usm_view, props_cache_load); + + scatter(ptrf, ioffset_n32_view, usm_view, mask_n32, props_cache_load); + scatter(ptrf, ioffset_n32_view, usm_view, props_cache_load); + scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm, mask_n32, props_cache_load); scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm, props_cache_load); @@ -123,9 +129,17 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf, usm_view.select<32, 1>(), mask_n32, props_cache_load); scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm_view.select<32, 1>(), props_cache_load); + scatter(ptrf, ioffset_n32, usm_view.select<32, 1>(), mask_n32, + props_cache_load); + scatter(ptrf, ioffset_n32, usm_view.select<32, 1>(), props_cache_load); + + scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm_view.select<32, 1>(), + mask_n32, props_cache_load); + scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm_view.select<32, 1>(), + props_cache_load); // VS > 1 - // CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0) + // CHECK-COUNT-24: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0) scatter(ptrf, ioffset_n16, usm, mask_n16, props_cache_load); scatter(ptrf, ioffset_n16, usm, props_cache_load); @@ -147,6 +161,16 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf, scatter(ptrf, ioffset_n16_view.select<16, 1>(), usm, props_cache_load); + scatter<2>(ptrf, ioffset_n16, usm_view, mask_n16, props_cache_load); + scatter<2>(ptrf, ioffset_n16, usm_view, props_cache_load); + + scatter<2>(ptrf, ioffset_n16_view, usm_view, mask_n16, props_cache_load); + scatter<2>(ptrf, ioffset_n16_view, usm_view, props_cache_load); + + scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm, mask_n16, + props_cache_load); + scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm, props_cache_load); + scatter(ptrf, ioffset_n16, usm_view.select<32, 1>(), mask_n16, props_cache_load); scatter(ptrf, ioffset_n16, usm_view.select<32, 1>(), @@ -157,6 +181,15 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf, scatter(ptrf, ioffset_n16_view.select<16, 1>(), usm_view.select<32, 1>(), props_cache_load); + scatter<2>(ptrf, ioffset_n16, usm_view.select<32, 1>(), mask_n16, + props_cache_load); + scatter<2>(ptrf, ioffset_n16, usm_view.select<32, 1>(), props_cache_load); + + scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm_view.select<32, 1>(), + mask_n16, props_cache_load); + scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm_view.select<32, 1>(), + props_cache_load); + // CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 0, i8 0, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0) scatter(ptrf, ioffset_n16, usm, mask_n16);