Skip to content

Commit

Permalink
[ESIMD] Allow full autodeduction for scatter USM APIs accepting simd_…
Browse files Browse the repository at this point in the history
…view (intel#13941)
  • Loading branch information
fineg74 authored May 29, 2024
1 parent dec1146 commit 05d29f3
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 10 deletions.
229 changes: 221 additions & 8 deletions sycl/include/sycl/ext/intel/esimd/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,10 +911,46 @@ scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
}
}

// template <typename T, int N, int VS = 1, typename OffsetT,
// typename PropertyListT = empty_properties_t>
// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
// PropertyListT props = {}); // (usm-sc-2)
/// template <int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename
/// T, int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
/// typename PropertyListT = empty_properties_t>
/// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, ValuesSimdViewT vals,
/// simd_mask<N / VS> mask, PropertyListT props = {});
///
/// Variation of the API that allows to use \c simd_view without specifying \c T
/// and \c N template parameters.
/// Writes ("scatters") elements of the input vector to different memory
/// locations. Each memory location is base address plus an offset - a
/// value of the corresponding element in the input offset vector. Access to
/// any element's memory location can be disabled via the input mask.
/// @tparam VS Vector size. It can also be read as the number of writes per each
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
/// @param p The base address.
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes.
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
/// If the alignment property is not passed, then it is assumed that each
/// accessed address is aligned by element-size.
/// @param vals The vector to scatter.
/// @param mask The access mask.
/// @param props The optional compile-time properties. Only 'alignment'
/// and cache hint properties are used.
template <
int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename T,
int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
__ESIMD_API std::enable_if_t<
detail::is_simd_view_type_v<ValuesSimdViewT> &&
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
scatter(T *p, simd<OffsetT, N / VS> byte_offsets, ValuesSimdViewT vals,
simd_mask<N / VS> mask, PropertyListT props = {}) {
scatter<T, N, VS>(p, byte_offsets, vals.read(), mask, props);
}

/// template <typename T, int N, int VS = 1, typename OffsetT,
/// typename PropertyListT = empty_properties_t>
/// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
/// PropertyListT props = {}); // (usm-sc-2)
///
/// Writes ("scatters") elements of the input vector to different memory
/// locations. Each memory location is base address plus an offset - a
Expand Down Expand Up @@ -943,10 +979,80 @@ scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
scatter<T, N, VS>(p, byte_offsets, vals, Mask, props);
}

// template <typename T, int N, int VS = 1, typename OffsetSimdViewT,
// typename PropertyListT = empty_properties_t>
// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
// simd_mask<N / VS> mask, PropertyListT props = {}); // (usm-sc-3)
/// template <int VS = 1, typename OffsetSimdViewT, typename ValuesSimdViewT,
/// typename T, int N = ValuesSimdViewT::getSizeX() *
/// ValuesSimdViewT::getSizeY(), typename PropertyListT = empty_properties_t>
/// void scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
/// simd_mask<N / VS> mask, PropertyListT props = {});
///
/// Variation of the API that allows to use \c simd_view without specifying \c T
/// and \c N template parameters.
/// Writes ("scatters") elements of the input vector to different memory
/// locations. Each memory location is base address plus an offset - a
/// value of the corresponding element in the input offset vector.
/// @tparam VS Vector size. It can also be read as the number of writes per each
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
/// @param p The base address.
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes.
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
/// If the alignment property is not passed, then it is assumed that each
/// accessed address is aligned by element-size.
/// @param vals The vector to scatter.
/// @param mask The access mask.
/// @param props The optional compile-time properties. Only 'alignment'
/// and cache hint properties are used.
template <
int VS = 1, typename OffsetSimdViewT, typename ValuesSimdViewT, typename T,
int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
__ESIMD_API std::enable_if_t<
detail::is_simd_view_type_v<ValuesSimdViewT> &&
detail::is_simd_view_type_v<OffsetSimdViewT> &&
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
simd_mask<N / VS> mask, PropertyListT props = {}) {
scatter<T, N, VS>(p, byte_offsets.read(), vals.read(), mask, props);
}

/// template <int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename
/// T, int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
/// typename PropertyListT = empty_properties_t>
/// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, ValuesSimdViewT vals,
/// PropertyListT props = {});
///
/// Variation of the API that allows to use \c simd_view without specifying \c T
/// and \c N template parameters.
/// Writes ("scatters") elements of the input vector to different memory
/// locations. Each memory location is base address plus an offset - a
/// value of the corresponding element in the input offset vector.
/// @tparam VS Vector size. It can also be read as the number of writes per each
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
/// @param p The base address.
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes.
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
/// If the alignment property is not passed, then it is assumed that each
/// accessed address is aligned by element-size.
/// @param vals The vector to scatter.
/// @param props The optional compile-time properties. Only 'alignment'
/// and cache hint properties are used.
template <
int VS = 1, typename OffsetT, typename ValuesSimdViewT, typename T,
int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
__ESIMD_API std::enable_if_t<
detail::is_simd_view_type_v<ValuesSimdViewT> &&
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
scatter(T *p, simd<OffsetT, N / VS> byte_offsets, ValuesSimdViewT vals,
PropertyListT props = {}) {
scatter<T, N, VS>(p, byte_offsets, vals.read(), props);
}

/// template <typename T, int N, int VS = 1, typename OffsetSimdViewT,
/// typename PropertyListT = empty_properties_t>
/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
/// simd_mask<N / VS> mask, PropertyListT props = {}); // (usm-sc-3)
///
/// Writes ("scatters") elements of the input vector to different memory
/// locations. Each memory location is base address plus an offset - a
Expand Down Expand Up @@ -978,6 +1084,75 @@ scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
scatter<T, N, VS>(p, byte_offsets.read(), vals, mask, props);
}

/// template <int VS, typename OffsetSimdViewT, typename T, int N, typename
/// PropertyListT = empty_properties_t>
/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T,N> vals,
/// simd_mask<N / VS> mask, PropertyListT props = {});
///
/// Variation of the API that allows to use \c simd_view without specifying \c T
/// and \c N template parameters.
/// Writes ("scatters") elements of the input vector to different memory
/// locations. Each memory location is base address plus an offset - a
/// value of the corresponding element in the input offset vector. Access to
/// any element's memory location can be disabled via the input mask.
/// @tparam VS Vector size. It can also be read as the number of writes per each
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
/// @param p The base address.
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes
/// represented as a 'simd_view' object.
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
/// If the alignment property is not passed, then it is assumed that each
/// accessed address is aligned by element-size.
/// @param vals The vector to scatter.
/// @param mask The access mask.
/// @param props The optional compile-time properties. Only 'alignment'
/// and cache hint properties are used.
template <
int VS, typename OffsetSimdViewT, typename T, int N,
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
__ESIMD_API std::enable_if_t<
detail::is_simd_view_type_v<OffsetSimdViewT> &&
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
simd_mask<N / VS> mask, PropertyListT props = {}) {
scatter<T, N, VS>(p, byte_offsets.read(), vals, mask, props);
}

/// template <int VS, typename OffsetSimdViewT, typename T, int N, typename
/// PropertyListT = empty_properties_t>
/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T,N> vals,
/// PropertyListT props = {});
///
/// Variation of the API that allows to use \c simd_view without specifying \c T
/// and \c N template parameters.
/// Writes ("scatters") elements of the input vector to different memory
/// locations. Each memory location is base address plus an offset - a
/// value of the corresponding element in the input offset vector. Access to
/// any element's memory location can be disabled via the input mask.
/// @tparam VS Vector size. It can also be read as the number of writes per each
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
/// @param p The base address.
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes
/// represented as a 'simd_view' object.
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
/// If the alignment property is not passed, then it is assumed that each
/// accessed address is aligned by element-size.
/// @param vals The vector to scatter.
/// @param props The optional compile-time properties. Only 'alignment'
/// and cache hint properties are used.
template <
int VS, typename OffsetSimdViewT, typename T, int N,
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
__ESIMD_API std::enable_if_t<
detail::is_simd_view_type_v<OffsetSimdViewT> &&
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
PropertyListT props = {}) {
scatter<T, N, VS>(p, byte_offsets.read(), vals, props);
}

/// template <typename T, int N, int VS = 1, typename OffsetSimdViewT,
/// typename PropertyListT = empty_properties_t>
/// void scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
Expand Down Expand Up @@ -1012,6 +1187,44 @@ scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
scatter<T, N, VS>(p, byte_offsets.read(), vals, Mask, props);
}

/// template <int VS = 1, typename OffsetSimdViewT, typename
/// ValuesSimdViewT, typename T, int N = ValuesSimdViewT::getSizeX() *
/// ValuesSimdViewT::getSizeY(), typename PropertyListT =
/// empty_properties_t>
/// void scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
/// PropertyListT props = {});
///
/// Variation of the API that allows to use \c simd_view without specifying \c T
/// and \c N template parameters.
/// Writes ("scatters") elements of the input vector to different memory
/// locations. Each memory location is base address plus an offset - a
/// value of the corresponding element in the input offset vector.
/// @tparam VS Vector size. It can also be read as the number of writes per each
/// address. The parameter 'N' must be divisible by 'VS'. (VS > 1) is supported
/// only on DG2 and PVC and only for 4- and 8-byte element vectors.
/// @param p The base address.
/// @param byte_offsets the vector of 32-bit or 64-bit offsets in bytes
/// represented as a 'simd_view' object.
/// For each i, ((byte*)p + byte_offsets[i]) must be element size aligned.
/// If the alignment property is not passed, then it is assumed that each
/// accessed address is aligned by element-size.
/// @param vals The vector to scatter.
/// @param props The optional compile-time properties. Only 'alignment'
/// and cache hint properties are used.
template <
int VS = 1, typename OffsetSimdViewT, typename ValuesSimdViewT, typename T,
int N = ValuesSimdViewT::getSizeX() * ValuesSimdViewT::getSizeY(),
typename PropertyListT = ext::oneapi::experimental::empty_properties_t>
__ESIMD_API std::enable_if_t<
detail::is_simd_view_type_v<OffsetSimdViewT> &&
detail::is_simd_view_type_v<ValuesSimdViewT> &&
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
PropertyListT props = {}) {
simd_mask<N / VS> Mask = 1;
scatter<T, N, VS>(p, byte_offsets.read(), vals.read(), Mask, props);
}

/// A variation of \c scatter API with \c offsets represented as scalar.
///
/// @tparam Tx Element type, must be of size 4 or less.
Expand Down
37 changes: 35 additions & 2 deletions sycl/test/esimd/memory_properties_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,

scatter(ptrf, ioffset_n32, usm, props_align4);

// CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v32i1.v32i64.v32i32(<32 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 1, i8 1, i8 0, <32 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
// CHECK-COUNT-22: call void @llvm.genx.lsc.store.stateless.v32i1.v32i64.v32i32(<32 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 1, i8 1, i8 0, <32 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
scatter(ptrf, ioffset_n32, usm, mask_n32, props_cache_load);
scatter(ptrf, ioffset_n32, usm, props_cache_load);

Expand All @@ -110,6 +110,12 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
props_cache_load);
scatter<float, 32>(ptrf, ioffset_n32_view, usm_view, props_cache_load);

scatter(ptrf, ioffset_n32, usm_view, mask_n32, props_cache_load);
scatter(ptrf, ioffset_n32, usm_view, props_cache_load);

scatter(ptrf, ioffset_n32_view, usm_view, mask_n32, props_cache_load);
scatter(ptrf, ioffset_n32_view, usm_view, props_cache_load);

scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm, mask_n32,
props_cache_load);
scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm, props_cache_load);
Expand All @@ -123,9 +129,17 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
usm_view.select<32, 1>(), mask_n32, props_cache_load);
scatter<float, 32>(ptrf, ioffset_n32_view.select<32, 1>(),
usm_view.select<32, 1>(), props_cache_load);
scatter(ptrf, ioffset_n32, usm_view.select<32, 1>(), mask_n32,
props_cache_load);
scatter(ptrf, ioffset_n32, usm_view.select<32, 1>(), props_cache_load);

scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm_view.select<32, 1>(),
mask_n32, props_cache_load);
scatter(ptrf, ioffset_n32_view.select<32, 1>(), usm_view.select<32, 1>(),
props_cache_load);

// VS > 1
// CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
// CHECK-COUNT-24: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 1, i8 1, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
scatter<float, 32, 2>(ptrf, ioffset_n16, usm, mask_n16, props_cache_load);

scatter<float, 32, 2>(ptrf, ioffset_n16, usm, props_cache_load);
Expand All @@ -147,6 +161,16 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
scatter<float, 32, 2>(ptrf, ioffset_n16_view.select<16, 1>(), usm,
props_cache_load);

scatter<2>(ptrf, ioffset_n16, usm_view, mask_n16, props_cache_load);
scatter<2>(ptrf, ioffset_n16, usm_view, props_cache_load);

scatter<2>(ptrf, ioffset_n16_view, usm_view, mask_n16, props_cache_load);
scatter<2>(ptrf, ioffset_n16_view, usm_view, props_cache_load);

scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm, mask_n16,
props_cache_load);
scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm, props_cache_load);

scatter<float, 32, 2>(ptrf, ioffset_n16, usm_view.select<32, 1>(), mask_n16,
props_cache_load);
scatter<float, 32, 2>(ptrf, ioffset_n16, usm_view.select<32, 1>(),
Expand All @@ -157,6 +181,15 @@ test_scatter(AccType &acc, LocalAccType &local_acc, float *ptrf,
scatter<float, 32, 2>(ptrf, ioffset_n16_view.select<16, 1>(),
usm_view.select<32, 1>(), props_cache_load);

scatter<2>(ptrf, ioffset_n16, usm_view.select<32, 1>(), mask_n16,
props_cache_load);
scatter<2>(ptrf, ioffset_n16, usm_view.select<32, 1>(), props_cache_load);

scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm_view.select<32, 1>(),
mask_n16, props_cache_load);
scatter<2>(ptrf, ioffset_n16_view.select<16, 1>(), usm_view.select<32, 1>(),
props_cache_load);

// CHECK-COUNT-14: call void @llvm.genx.lsc.store.stateless.v16i1.v16i64.v32i32(<16 x i1> {{[^)]+}}, i8 4, i8 0, i8 0, i16 1, i32 0, i8 3, i8 2, i8 1, i8 0, <16 x i64> {{[^)]+}}, <32 x i32> {{[^)]+}}, i32 0)
scatter<float, 32, 2>(ptrf, ioffset_n16, usm, mask_n16);

Expand Down

0 comments on commit 05d29f3

Please sign in to comment.