Skip to content

Commit

Permalink
Fix joint group functions for when input and output types differ
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Jan 10, 2024
1 parent 07acbc5 commit 35a3fd5
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 61 deletions.
56 changes: 19 additions & 37 deletions include/simsycl/detail/group_operation_impl.hh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ enum class group_operation_id {
none_of,
shift_left,
shift_right,
permute,
permute_by_xor,
select,
joint_reduce,
reduce,
Expand Down Expand Up @@ -84,12 +84,12 @@ struct group_broadcast_data : group_per_operation_data {
struct group_barrier_data : group_per_operation_data {
sycl::memory_scope fence_scope;
};
template<typename T>
template<typename Ptr>
struct group_joint_bool_op_data : group_per_operation_data {
T *first;
T *last;
Ptr first;
Ptr last;
bool result;
group_joint_bool_op_data(T *first, T *last, bool result) : first(first), last(last), result(result) {}
group_joint_bool_op_data(Ptr first, Ptr last, bool result) : first(first), last(last), result(result) {}
};
struct group_bool_data : group_per_operation_data {
std::vector<bool> values;
Expand All @@ -112,13 +112,13 @@ struct group_select_data : group_per_operation_data {
std::vector<T> values;
group_select_data(size_t num_work_items) : values(num_work_items) {}
};
template<typename T>
template<typename Ptr, typename T>
struct group_joint_reduce_data : group_per_operation_data {
T *first;
T *last;
Ptr first;
Ptr last;
std::optional<T> init;
T result;
group_joint_reduce_data(T *first, T *last, std::optional<T> init, T result)
group_joint_reduce_data(Ptr first, Ptr last, std::optional<T> init, T result)
: first(first), last(last), init(init), result(result) {}
};
template<typename T>
Expand All @@ -127,13 +127,13 @@ struct group_reduce_data : group_per_operation_data {
std::vector<T> values;
group_reduce_data(size_t num_work_items, std::optional<T> init) : init(init), values(num_work_items) {}
};
template<typename T>
template<typename Ptr, typename T>
struct group_joint_scan_data : group_per_operation_data {
T *first;
T *last;
Ptr first;
Ptr last;
std::optional<T> init;
std::vector<T> results;
group_joint_scan_data(T *first, T *last, std::optional<T> init, const std::vector<T> &results)
group_joint_scan_data(Ptr first, Ptr last, std::optional<T> init, const std::vector<T> &results)
: first(first), last(last), init(init), results(results) {}
};
template<typename T>
Expand Down Expand Up @@ -280,15 +280,9 @@ template<sycl::Group G, sycl::Pointer Ptr, sycl::Fundamental T>
void joint_reduce_impl(G g, Ptr first, Ptr last, std::optional<T> init, T result) {
perform_group_operation(g, group_operation_id::joint_reduce,
group_operation_spec{//
.init =
[&] {
// TODO there was a strange error when using make_unique here - investigate if necessary after
// group_operations CTS test is fixed
return std::unique_ptr<group_joint_reduce_data<T>>(
new group_joint_reduce_data<T>(first, last, init, result));
},
.init = [&] { return std::make_unique<group_joint_reduce_data<Ptr, T>>(first, last, init, result); },
.reached =
[&](group_joint_reduce_data<T> &per_op) {
[&](group_joint_reduce_data<Ptr, T> &per_op) {
SIMSYCL_CHECK(per_op.first == first);
SIMSYCL_CHECK(per_op.last == last);
SIMSYCL_CHECK(per_op.init == init);
Expand All @@ -302,10 +296,7 @@ T group_reduce_impl(G g, T x, std::optional<T> init, Op op) {
group_operation_spec{//
.init =
[&] {
// TODO there was a strange error when using make_unique here - investigate if necessary after
// group_operations CTS test is fixed
auto per_op = std::unique_ptr<group_reduce_data<T>>(
new group_reduce_data<T>(g.get_local_range().size(), init));
auto per_op = std::make_unique<group_reduce_data<T>>(g.get_local_range().size(), init);
per_op->values[g.get_local_linear_id()] = x;
return per_op;
},
Expand Down Expand Up @@ -333,15 +324,9 @@ void joint_scan_impl(
G g, group_operation_id op_id, Ptr first, Ptr last, std::optional<T> init, const std::vector<T> &results) {
perform_group_operation(g, op_id,
group_operation_spec{//
.init =
[&] {
// TODO there was a strange error when using make_unique here - investigate if necessary after
// group_operations CTS test is fixed
return std::unique_ptr<group_joint_scan_data<T>>(
new group_joint_scan_data<T>(first, last, init, results));
},
.init = [&] { return std::make_unique<group_joint_scan_data<Ptr, T>>(first, last, init, results); },
.reached =
[&](group_joint_scan_data<T> &per_op) {
[&](group_joint_scan_data<Ptr, T> &per_op) {
SIMSYCL_CHECK(per_op.first == first);
SIMSYCL_CHECK(per_op.last == last);
SIMSYCL_CHECK(per_op.init == init);
Expand All @@ -355,10 +340,7 @@ T group_scan_impl(G g, group_operation_id op_id, T x, std::optional<T> init, Op
group_operation_spec{//
.init =
[&] {
// TODO there was a strange error when using make_unique here - investigate if necessary after
// group_operations CTS test is fixed
auto per_op
= std::unique_ptr<group_scan_data<T>>(new group_scan_data<T>(g.get_local_range().size(), init));
auto per_op = std::make_unique<group_scan_data<T>>(g.get_local_range().size(), init);
per_op->values[g.get_local_linear_id()] = x;
return per_op;
},
Expand Down
32 changes: 14 additions & 18 deletions include/simsycl/sycl/group_algorithms.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace simsycl::sycl {

// any_of

template<Group G, Pointer Ptr, typename Predicate, typename T = std::remove_pointer_t<Ptr>>
requires std::predicate<Predicate, T>
template<Group G, Pointer Ptr, typename Predicate>
requires std::predicate<Predicate, std::remove_pointer_t<Ptr>>
bool joint_any_of(G g, Ptr first, Ptr last, Predicate pred) {
// approach: perform the operation sequentially on all work items, confirm that they compute the same result
// (this is the closest we can easily get to verifying the standard requirement that
Expand All @@ -20,9 +20,9 @@ bool joint_any_of(G g, Ptr first, Ptr last, Predicate pred) {

detail::perform_group_operation(g, detail::group_operation_id::joint_any_of,
detail::group_operation_spec{//
.init = [&]() { return std::make_unique<detail::group_joint_bool_op_data<T>>(first, last, result); },
.init = [&]() { return std::make_unique<detail::group_joint_bool_op_data<Ptr>>(first, last, result); },
.reached =
[&](detail::group_joint_bool_op_data<T> &per_op) {
[&](detail::group_joint_bool_op_data<Ptr> &per_op) {
SIMSYCL_CHECK(per_op.first == first);
SIMSYCL_CHECK(per_op.last == last);
SIMSYCL_CHECK(per_op.result == result);
Expand Down Expand Up @@ -56,16 +56,16 @@ bool any_of_group(G g, bool pred) {

// all_of

template<Group G, Pointer Ptr, typename Predicate, typename T = std::remove_pointer_t<Ptr>>
requires std::predicate<Predicate, T>
template<Group G, Pointer Ptr, typename Predicate>
requires std::predicate<Predicate, std::remove_pointer_t<Ptr>>
bool joint_all_of(G g, Ptr first, Ptr last, Predicate pred) {
bool result = true;
for(auto start = first; result && start != last; ++start) { result = pred(*start); }
detail::perform_group_operation(g, detail::group_operation_id::joint_all_of,
detail::group_operation_spec{//
.init = [&]() { return std::make_unique<detail::group_joint_bool_op_data<T>>(first, last, result); },
.init = [&]() { return std::make_unique<detail::group_joint_bool_op_data<Ptr>>(first, last, result); },
.reached =
[&](detail::group_joint_bool_op_data<T> &per_op) {
[&](detail::group_joint_bool_op_data<Ptr> &per_op) {
SIMSYCL_CHECK(per_op.first == first);
SIMSYCL_CHECK(per_op.last == last);
SIMSYCL_CHECK(per_op.result == result);
Expand Down Expand Up @@ -98,16 +98,16 @@ bool all_of_group(G g, bool pred) {

// none_of

template<Group G, Pointer Ptr, typename Predicate, typename T = std::remove_pointer_t<Ptr>>
requires std::predicate<Predicate, T>
template<Group G, Pointer Ptr, typename Predicate>
requires std::predicate<Predicate, std::remove_pointer_t<Ptr>>
bool joint_none_of(G g, Ptr first, Ptr last, Predicate pred) {
bool result = true;
for(auto start = first; result && start != last; ++start) { result = !pred(*start); }
detail::perform_group_operation(g, detail::group_operation_id::joint_none_of,
detail::group_operation_spec{//
.init = [&]() { return std::make_unique<detail::group_joint_bool_op_data<T>>(first, last, result); },
.init = [&]() { return std::make_unique<detail::group_joint_bool_op_data<Ptr>>(first, last, result); },
.reached =
[&](detail::group_joint_bool_op_data<T> &per_op) {
[&](detail::group_joint_bool_op_data<Ptr> &per_op) {
SIMSYCL_CHECK(per_op.first == first);
SIMSYCL_CHECK(per_op.last == last);
SIMSYCL_CHECK(per_op.result == result);
Expand Down Expand Up @@ -194,8 +194,8 @@ T shift_group_right(G g, T x, typename G::linear_id_type delta = 1) {
// permute

template<SubGroup G, TriviallyCopyable T>
T permute_group(G g, T x, typename G::linear_id_type mask) {
return detail::perform_group_operation(g, detail::group_operation_id::permute,
T permute_group_by_xor(G g, T x, typename G::linear_id_type mask) {
return detail::perform_group_operation(g, detail::group_operation_id::permute_by_xor,
detail::group_operation_spec{//
.init =
[&]() {
Expand All @@ -220,10 +220,6 @@ T permute_group(G g, T x, typename G::linear_id_type mask) {
}});
}

template<typename Group, typename T>
T permute_group_by_xor(Group g, T x, typename Group::linear_id_type mask); // TODO


// select

template<SubGroup G, TriviallyCopyable T>
Expand Down
2 changes: 1 addition & 1 deletion include/simsycl/sycl/group_functions.hh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ T group_broadcast(G g, T x, typename G::linear_id_type local_linear_id) {
template<Group G, TriviallyCopyable T>
T group_broadcast(G g, T x, typename G::id_type local_id) {
SIMSYCL_CHECK(all_true(local_id < id(g.get_local_range())));
group_broadcast(g, x, detail::get_linear_index(g.get_local_range(), local_id));
return group_broadcast(g, x, detail::get_linear_index(g.get_local_range(), local_id));
}

template<Group G>
Expand Down
2 changes: 1 addition & 1 deletion src/simsycl/group_operation_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const char *group_operation_id_to_string(group_operation_id id) {
case group_operation_id::none_of: return "none_of";
case group_operation_id::shift_left: return "shift_left";
case group_operation_id::shift_right: return "shift_right";
case group_operation_id::permute: return "permute";
case group_operation_id::permute_by_xor: return "permute";
case group_operation_id::select: return "select";
case group_operation_id::joint_reduce: return "joint_reduce";
case group_operation_id::reduce: return "reduce";
Expand Down
8 changes: 4 additions & 4 deletions test/group_op_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ TEST_CASE("Group shift operation behave as expected", "[group_op][shift]") {
}
}

TEST_CASE("Group permute behaves as expected", "[group_op][permute]") {
TEST_CASE("Group permute by XOR behaves as expected", "[group_op][permute]") {
REPEAT_FOR_ALL_SCHEDULES

int inputs[4] = {1, 2, 3, 4};
Expand All @@ -408,14 +408,14 @@ TEST_CASE("Group permute behaves as expected", "[group_op][permute]") {
sycl::queue().submit([&inputs](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<1>{8, 8}, [&inputs](sycl::nd_item<1> it) {
auto id = it.get_sub_group().get_local_linear_id();
auto val = sycl::permute_group(it.get_sub_group(), inputs[id], 0b0101u);
auto val = sycl::permute_group_by_xor(it.get_sub_group(), inputs[id], 0b0101u);
auto target = id ^ 0b0101u;
if(target < 4) {
CHECK(val == inputs[target]);
} else {
CHECK(val == detail::unspecified<int>());
}
check_group_op_sequence(it.get_sub_group(), {detail::group_operation_id::permute});
check_group_op_sequence(it.get_sub_group(), {detail::group_operation_id::permute_by_xor});
});
});
}
Expand Down Expand Up @@ -670,7 +670,7 @@ TEST_CASE("Mismatched parameters for group ops are reported", "[check][group_op]
Catch::Matchers::ContainsSubstring("group shift delta mismatch"));
REQUIRE_THROWS_WITH(sycl::queue{}.submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<1>{2, 2},
[](sycl::nd_item<1> it) { sycl::permute_group(it.get_sub_group(), 0, it.get_local_linear_id()); });
[](sycl::nd_item<1> it) { sycl::permute_group_by_xor(it.get_sub_group(), 0, it.get_local_linear_id()); });
}),
Catch::Matchers::ContainsSubstring("group permute mask mismatch"));
}

0 comments on commit 35a3fd5

Please sign in to comment.