Skip to content

Commit

Permalink
[oneDPL][ranges][merge] + clang format
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeDvorskiy committed Jan 23, 2025
1 parent 7750787 commit 350387a
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 117 deletions.
150 changes: 74 additions & 76 deletions include/oneapi/dpl/pstl/algorithm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2949,44 +2949,44 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec,
// merge
//------------------------------------------------------------------------

template<typename It1, typename It2, typename ItOut, typename _Comp>
template <typename It1, typename It2, typename ItOut, typename _Comp>
std::pair<It1, It2>
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
/* __is_vector = */ std::false_type)
/* __is_vector = */ std::false_type)
{
while(__it_1 != __it_1_e && __it_2 != __it_2_e)
while (__it_1 != __it_1_e && __it_2 != __it_2_e)
{
if (__comp(*__it_1, *__it_2))
{
*__it_out = *__it_1;
++__it_out, ++__it_1;
}
else
{
*__it_out = *__it_2;
++__it_out, ++__it_2;
}
if(__it_out == __it_out_e)
if (__comp(*__it_1, *__it_2))
{
*__it_out = *__it_1;
++__it_out, ++__it_1;
}
else
{
*__it_out = *__it_2;
++__it_out, ++__it_2;
}
if (__it_out == __it_out_e)
return {__it_1, __it_2};
}

if(__it_1 == __it_1_e)
if (__it_1 == __it_1_e)
{
for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out)
for (; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out)
*__it_out = *__it_2;
}
else
{
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
for (; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
*__it_out = *__it_1;
}
return {__it_1, __it_2};
}

template<typename It1, typename It2, typename ItOut, typename _Comp>
template <typename It1, typename It2, typename ItOut, typename _Comp>
std::pair<It1, It2>
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
/* __is_vector = */ std::true_type)
/* __is_vector = */ std::true_type)
{
return __unseq_backend::__simd_merge(__it_1, __it_1_e, __it_2, __it_2_e, __it_out, __it_out_e, __comp);
}
Expand Down Expand Up @@ -3023,78 +3023,76 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt
typename _Tag::__is_vector{});
}

template<class _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
template <class _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2, typename _Index2,
typename _OutIt, typename _Index3, typename _Comp>
std::pair<_It1, _It2>
__pattern_merge_2(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
__pattern_merge_2(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2, _Index2 __n_2,
_OutIt __it_out, _Index3 __n_out, _Comp __comp)
{
return __brick_merge_2(__it_1, __it_1 + __n_1, __it_2, __it_2 + __n_2, __it_out, __it_out + __n_out, __comp,
typename _Tag::__is_vector{});
}

template<typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
template <typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
std::pair<_It1, _It2>
__pattern_merge_2(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
{
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;

_It1 __it_res_1;
_It2 __it_res_2;

__internal::__except_handler([&]() {
__par_backend::__parallel_for(__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
[=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j)
{
//a start merging point on the merge path; for each thread
_Index1 __r = 0; //row index
_Index2 __c = 0; //column index

if(__i > 0)
{
//calc merge path intersection:
const _Index3 __d_size =
std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;

auto __get_row = [__i, __n_1](auto __d)
{ return std::min<_Index1>(__i, __n_1) - __d - 1; };
auto __get_column = [__i, __n_1](auto __d)
{ return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0); };

oneapi::dpl::counting_iterator<_Index3> __it_d(0);

auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
[&](auto __d, auto __val) {
auto __r = __get_row(__d);
auto __c = __get_column(__d);

oneapi::dpl::__internal::__compare<_Comp, oneapi::dpl::identity>
__cmp{__comp, oneapi::dpl::identity{}};
const auto __res = __cmp(__it_1[__r], __it_2[__c]) ? 1 : 0;

return __res < __val;
}
);

//intersection point
__r = __get_row(__res_d);
__c = __get_column(__res_d);
++__r; //to get a merge matrix ceil, lying on the current diagonal
}

//serial merge n elements, starting from input x and y, to [i, j) output range
const auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1,
__it_2 + __c, __it_2 + __n_2,
__it_out + __i, __it_out + __j, __comp, _IsVector{});

if(__j == __n_out)
{
__it_res_1 = __res.first;
__it_res_2 = __res.second;
}
}, oneapi::dpl::__utils::__merge_algo_cut_off); //grainsize
__par_backend::__parallel_for(
__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
[=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j) {
//a start merging point on the merge path; for each thread
_Index1 __r = 0; //row index
_Index2 __c = 0; //column index

if (__i > 0)
{
//calc merge path intersection:
const _Index3 __d_size =
std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;

auto __get_row = [__i, __n_1](auto __d) { return std::min<_Index1>(__i, __n_1) - __d - 1; };
auto __get_column = [__i, __n_1](auto __d) {
return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0);
};

oneapi::dpl::counting_iterator<_Index3> __it_d(0);

auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1, [&](auto __d, auto __val) {
auto __r = __get_row(__d);
auto __c = __get_column(__d);

oneapi::dpl::__internal::__compare<_Comp, oneapi::dpl::identity> __cmp{__comp,
oneapi::dpl::identity{}};
const auto __res = __cmp(__it_1[__r], __it_2[__c]) ? 1 : 0;

return __res < __val;
});

//intersection point
__r = __get_row(__res_d);
__c = __get_column(__res_d);
++__r; //to get a merge matrix ceil, lying on the current diagonal
}

//serial merge n elements, starting from input x and y, to [i, j) output range
const auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1, __it_2 + __c, __it_2 + __n_2,
__it_out + __i, __it_out + __j, __comp, _IsVector{});

if (__j == __n_out)
{
__it_res_1 = __res.first;
__it_res_2 = __res.second;
}
},
oneapi::dpl::__utils::__merge_algo_cut_off); //grainsize
});

return {__it_res_1, __it_res_2};
Expand Down
12 changes: 7 additions & 5 deletions include/oneapi/dpl/pstl/algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,10 @@ auto
__pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
_Proj1 __proj1, _Proj2 __proj2)
{
using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;

using __return_type =
std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;

auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
std::forward<decltype(__val2)>(__val2)));};
Expand All @@ -467,10 +468,11 @@ __pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _
auto __it_2 = std::ranges::begin(__r2);
auto __it_out = std::ranges::begin(__out_r);

if(__n_out == 0)
if (__n_out == 0)
return __return_type{__it_1, __it_2, __it_out};

auto __res = __pattern_merge_2(__tag, std::forward<_ExecutionPolicy>(__exec), __it_2, __n_2, __it_1, __n_1, __it_out, __n_out, __comp_2);
auto __res = __pattern_merge_2(__tag, std::forward<_ExecutionPolicy>(__exec), __it_2, __n_2, __it_1, __n_1,
__it_out, __n_out, __comp_2);

return __return_type{__res.second, __res.first, __it_out + __n_out};
}
Expand Down
2 changes: 1 addition & 1 deletion include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ merge(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& _
oneapi::dpl::__internal::__ranges::__pattern_merge(
__dispatch_tag, ::std::forward<_ExecutionPolicy>(__exec), views::all_read(::std::forward<_Range1>(__rng1)),
views::all_read(::std::forward<_Range2>(__rng2)), __view_res, __comp);

return __view_res.size();
}

Expand Down
10 changes: 5 additions & 5 deletions include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ __pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Ran
::std::forward<_ExecutionPolicy>(__exec)),
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3));
return {0, __res};
return {0, __res};
}

if (__n2 == 0)
Expand All @@ -714,9 +714,9 @@ __pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Ran
return {__res, 0};
}

auto __res = __par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
::std::forward<_Range3>(__rng3), __comp);
auto __res = __par_backend_hetero::__parallel_merge(
_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__rng1),
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3), __comp);

auto __val = __res.get();
return {__val.first, __val.second};
Expand Down Expand Up @@ -748,7 +748,7 @@ __pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1&
using __return_t = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;

return __return_t{std::ranges::begin(__r1) + __res.first, std::ranges::begin(__r2) + __res.second,
return __return_t{std::ranges::begin(__r1) + __res.first, std::ranges::begin(__r2) + __res.second,
std::ranges::begin(__out_r) + __n_out};
}
#endif //_ONEDPL_CPP20_RANGES_PRESENT
Expand Down
35 changes: 20 additions & 15 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,24 +193,27 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N
using __val_t = _split_point_t<_IdType>;
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, __val_t>;
auto __p_res_storage = new __result_and_scratch_storage_t(__exec, 1, 0);

// Save the raw pointer into a shared_ptr to return it in __future and extend the lifetime of the storage.
std::shared_ptr<__result_and_scratch_storage_base> __p_result_base(__p_res_storage);

auto __event = __exec.queue().submit(
[&__rng1, &__rng2, &__rng3, __p_res_storage, __comp, __chunk, __steps, __n, __n1, __n2](sycl::handler& __cgh) {
auto __event = __exec.queue().submit([&__rng1, &__rng2, &__rng3, __p_res_storage, __comp, __chunk, __steps, __n,
__n1, __n2](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2, __rng3);
auto __result_acc = __p_res_storage->template __get_result_acc<sycl::access_mode::write>(__cgh, __dpl_sycl::__no_init{});
auto __result_acc =
__p_res_storage->template __get_result_acc<sycl::access_mode::write>(__cgh, __dpl_sycl::__no_init{});

__cgh.parallel_for<_Name...>(sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
auto __id = __item_id.get_linear_id();
const _IdType __i_elem = __id * __chunk;

const auto __n_merge = std::min<_IdType>(__chunk, __n - __i_elem);
const auto __start = __find_start_point(__rng1, _IdType{0}, __n1, __rng2, _IdType{0}, __n2, __i_elem, __comp);
auto __ends = __serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __n_merge, __n1, __n2, __comp, __n);
const auto __start =
__find_start_point(__rng1, _IdType{0}, __n1, __rng2, _IdType{0}, __n2, __i_elem, __comp);
auto __ends = __serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __n_merge,
__n1, __n2, __comp, __n);

if(__id == __steps - 1) //the last WI does additional work
if (__id == __steps - 1) //the last WI does additional work
{
auto __res_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__result_acc);
*__res_ptr = __ends;
Expand Down Expand Up @@ -243,7 +246,8 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,
// Calculate nd-range parameters
template <typename _ExecutionPolicy, typename _Range1, typename _Range2>
nd_range_params
eval_nd_range_params(_ExecutionPolicy&& __exec, const _Range1& __rng1, const _Range2& __rng2, const std::size_t __n) const
eval_nd_range_params(_ExecutionPolicy&& __exec, const _Range1& __rng1, const _Range2& __rng2,
const std::size_t __n) const
{
// Empirical number of values to process per work-item
const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4;
Expand All @@ -260,8 +264,8 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,
// Calculation of split points on each base diagonal
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Compare, typename _Storage>
sycl::event
eval_split_points_for_groups(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _IdType __n, _Compare __comp,
const nd_range_params& __nd_range_params,
eval_split_points_for_groups(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _IdType __n,
_Compare __comp, const nd_range_params& __nd_range_params,
_Storage& __base_diagonals_sp_global_storage) const
{
const _IdType __n1 = __rng1.size();
Expand Down Expand Up @@ -311,7 +315,8 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,
auto __base_diagonals_sp_global_acc =
__base_diagonals_sp_global_storage.template __get_scratch_acc<sycl::access_mode::read>(__cgh);

auto __result_acc = __base_diagonals_sp_global_storage.template __get_result_acc<sycl::access_mode::write>(__cgh, __dpl_sycl::__no_init{});
auto __result_acc = __base_diagonals_sp_global_storage.template __get_result_acc<sycl::access_mode::write>(
__cgh, __dpl_sycl::__no_init{});

__cgh.depends_on(__event);

Expand Down Expand Up @@ -339,8 +344,8 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,
}

const auto __ends = __serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem,
__nd_range_params.chunk, __n1, __n2, __comp, __n);
if(__global_idx == __nd_range_params.steps - 1)
__nd_range_params.chunk, __n1, __n2, __comp, __n);
if (__global_idx == __nd_range_params.steps - 1)
{
auto __res_ptr = _Storage::__get_usm_or_buffer_accessor_ptr(__result_acc);
*__res_ptr = __ends;
Expand All @@ -367,8 +372,8 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,

// Create storage to save split-points on each base diagonal + 1 (for the right base diagonal in the last work-group)
using __val_t = _split_point_t<_IdType>;
auto __p_base_diagonals_sp_global_storage = new __result_and_scratch_storage<_ExecutionPolicy, __val_t>(__exec,
1, __nd_range_params.base_diag_count + 1);
auto __p_base_diagonals_sp_global_storage = new __result_and_scratch_storage<_ExecutionPolicy, __val_t>(
__exec, 1, __nd_range_params.base_diag_count + 1);

// Save the raw pointer into a shared_ptr to return it in __future and extend the lifetime of the storage.
std::shared_ptr<__result_and_scratch_storage_base> __p_result_and_scratch_storage_base(
Expand Down
Loading

0 comments on commit 350387a

Please sign in to comment.