Skip to content

Commit

Permalink
use buffers inside blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
KornevNikita committed Feb 3, 2025
1 parent 7a76db8 commit 335d86e
Showing 1 changed file with 46 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ static void test_launch_reduce_impl() {
constexpr size_t N = get_N<Dims>(Size);
sycl::queue q;
int sumResult = 0;
sycl::buffer<int> sumBuf{&sumResult, 1};
int* sumPtr = sycl::malloc_shared<int>(1, q);
sumPtr[0] = 0;

Expand All @@ -110,13 +109,14 @@ static void test_launch_reduce_impl() {
sycl_cts::util::get_cts_object::range<Dims>::get(Size, Size, Size);
int expected_res = (N - 1) * N / 2;
{
sycl::buffer<int> sumBuf{&sumResult, 1};
q.submit([&](sycl::handler& h) {
sycl::khr::launch_reduce(h, r, task,
sycl::reduction(sumBuf, h, sycl::plus<>()));
});
q.wait();
CHECK(sumBuf.get_host_access()[0] == expected_res);
}
CHECK(sumResult == expected_res);

{
sycl::khr::launch_reduce(q, r, task,
sycl::reduction(sumPtr, sycl::plus<>()));
Expand Down Expand Up @@ -182,7 +182,6 @@ static void test_launch_grouped_reduce_impl() {
constexpr size_t N = get_N<Dims>(Size);
sycl::queue q;
int sumResult = 0;
sycl::buffer<int> sumBuf{&sumResult, 1};
int* sumPtr = sycl::malloc_shared<int>(1, q);
sumPtr[0] = 0;

Expand All @@ -196,14 +195,14 @@ static void test_launch_grouped_reduce_impl() {
Size / 2, Size / 2, Size / 2);
int expected_res = (N - 1) * N / 2;
{
sycl::buffer<int> sumBuf{&sumResult, 1};
q.submit([&](sycl::handler& h) {
sycl::khr::launch_grouped_reduce(
h, r_glob, r_loc, task, sycl::reduction(sumBuf, h, sycl::plus<>()));
});
q.wait();
CHECK(sumBuf.get_host_access()[0] == expected_res);
sumBuf.get_host_access()[0] = 0;
}
CHECK(sumResult == expected_res);

{
sycl::khr::launch_grouped_reduce(q, r_glob, r_loc, task,
sycl::reduction(sumPtr, sycl::plus<>()));
Expand Down Expand Up @@ -327,7 +326,6 @@ static void test_copy_accessors_host_to_device_impl() {
accT acc(buf);
sycl::khr::copy(q, src, acc);
}
q.wait();
}

for (size_t i = 0; i < N; ++i) CHECK(src[i] == dst[i]);
Expand Down Expand Up @@ -361,18 +359,19 @@ static void test_copy_accessors_device_to_host_impl() {
const auto test = [&](auto& dst, bool use_handler) {
T src[N] = {0};
std::iota(&src[0], &src[0] + N, 0);
sycl::buffer<T, 1> buf(src, sycl::range<1>(N));
{
sycl::buffer<T, 1> buf(src, sycl::range<1>(N));

if (use_handler) {
q.submit([&](sycl::handler& h) {
accT acc(buf, h, sycl::range<1>(N));
sycl::khr::copy(h, acc, dst);
});
} else {
accT acc(buf);
sycl::khr::copy(q, acc, dst);
if (use_handler) {
q.submit([&](sycl::handler& h) {
accT acc(buf, h, sycl::range<1>(N));
sycl::khr::copy(h, acc, dst);
});
} else {
accT acc(buf);
sycl::khr::copy(q, acc, dst);
}
}
q.wait();

for (size_t i = 0; i < N; ++i) CHECK(src[i] == dst[i]);
};
Expand Down Expand Up @@ -405,24 +404,25 @@ static void test_copy_accessors_device_to_device_impl() {
sycl::queue q;
auto test_copy = [&](bool use_handler) {
T dst[N] = {0};
sycl::buffer<T, 1> buf_src(src, sycl::range<1>(N));
sycl::buffer<T, 1> buf_dst(dst, sycl::range<1>(N));
{
sycl::buffer<T, 1> buf_src(src, sycl::range<1>(N));
sycl::buffer<T, 1> buf_dst(dst, sycl::range<1>(N));

if (use_handler) {
q.submit([&](sycl::handler& h) {
acc_src_T acc_src(buf_src, h, sycl::range<1>(N));
acc_dst_T acc_dst(buf_dst, h, sycl::range<1>(N));
sycl::khr::copy(h, acc_src, acc_dst);
});
} else {
acc_src_T acc_src(buf_src, sycl::range<1>(N));
acc_dst_T acc_dst(buf_dst, sycl::range<1>(N));
sycl::khr::copy(q, acc_src, acc_dst);
if (use_handler) {
q.submit([&](sycl::handler& h) {
acc_src_T acc_src(buf_src, h, sycl::range<1>(N));
acc_dst_T acc_dst(buf_dst, h, sycl::range<1>(N));
sycl::khr::copy(h, acc_src, acc_dst);
});
} else {
acc_src_T acc_src(buf_src, sycl::range<1>(N));
acc_dst_T acc_dst(buf_dst, sycl::range<1>(N));
sycl::khr::copy(q, acc_src, acc_dst);
}
}
q.wait();

for (size_t i = 0; i < N; ++i) {
// CHECK(src[i] == dst[i]);
CHECK(src[i] == dst[i]);
}
};

Expand Down Expand Up @@ -492,7 +492,6 @@ static void test_fill_impl() {
accT acc(buf, sycl::range<1>(N));
sycl::khr::fill(q, acc, val);
}
q.wait();
}

for (int i = 0; i < N; ++i) CHECK(dst[i] == val);
Expand All @@ -519,24 +518,25 @@ static void test_update_host_impl() {

auto test_buffer = [&](bool use_handler) {
T data[N] = {0};
sycl::buffer<T, 1> buf(data, sycl::range<1>(N));

q.submit([&](sycl::handler& h) {
accT acc(buf, h, sycl::range<1>(N));
h.parallel_for(sycl::range<1>{N},
[=](sycl::id<1> idx) { acc[idx] = idx; });
});
{
sycl::buffer<T, 1> buf(data, sycl::range<1>(N));

if (use_handler) {
q.submit([&](sycl::handler& h) {
accT acc(buf, h, sycl::range<1>(N));
sycl::khr::update_host(h, acc);
h.parallel_for(sycl::range<1>{N},
[=](sycl::id<1> idx) { acc[idx] = idx; });
});
} else {
accT acc(buf, sycl::range<1>(N));
sycl::khr::update_host(q, acc);

if (use_handler) {
q.submit([&](sycl::handler& h) {
accT acc(buf, h, sycl::range<1>(N));
sycl::khr::update_host(h, acc);
});
} else {
accT acc(buf, sycl::range<1>(N));
sycl::khr::update_host(q, acc);
}
}
q.wait();

for (size_t i = 0; i < N; ++i) CHECK(data[i] == i);
};
Expand Down

0 comments on commit 335d86e

Please sign in to comment.