From 058ee955e985d5c9d4941dc265b6a4c0435dd21b Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 14 Oct 2024 09:47:28 +0100 Subject: [PATCH] [rocfft][cufft] DFT update host task to use native command (#578) Signed-off-by: JackAKirk Co-authored-by: Hugh Bird Co-authored-by: Rafal Bielski Co-authored-by: Romain Biessy --- src/dft/backends/cufft/backward.cpp | 9 ++-- src/dft/backends/cufft/execute_helper.hpp | 12 +++-- src/dft/backends/cufft/forward.cpp | 9 ++-- src/dft/backends/rocfft/backward.cpp | 42 ++++++++--------- src/dft/backends/rocfft/execute_helper.hpp | 24 ++++++---- src/dft/backends/rocfft/forward.cpp | 41 +++++++---------- src/dft/execute_helper_generic.hpp | 53 ++++++++++++++++++++++ 7 files changed, 121 insertions(+), 69 deletions(-) create mode 100644 src/dft/execute_helper_generic.hpp diff --git a/src/dft/backends/cufft/backward.cpp b/src/dft/backends/cufft/backward.cpp index aea9f232f..80e475991 100644 --- a/src/dft/backends/cufft/backward.cpp +++ b/src/dft/backends/cufft/backward.cpp @@ -30,6 +30,7 @@ #include "oneapi/mkl/dft/types.hpp" #include "execute_helper.hpp" +#include "../../execute_helper_generic.hpp" #include @@ -71,7 +72,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto inout_acc = inout.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); auto inout_native = reinterpret_cast *>( @@ -117,7 +118,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto out_acc = out.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); auto in_native = reinterpret_cast( @@ -171,7 +172,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); detail::cufft_execute>( @@ -217,7 +218,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwddepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); detail::cufft_execute>( diff --git a/src/dft/backends/cufft/execute_helper.hpp b/src/dft/backends/cufft/execute_helper.hpp index 776f0f254..7b7d946db 100644 --- a/src/dft/backends/cufft/execute_helper.hpp +++ b/src/dft/backends/cufft/execute_helper.hpp @@ -17,8 +17,8 @@ * SPDX-License-Identifier: Apache-2.0 *******************************************************************************/ -#ifndef _ONEMKL_DFT_SRC_CUFFT_EXECUTE_HPP_ -#define _ONEMKL_DFT_SRC_CUFFT_EXECUTE_HPP_ +#ifndef _ONEMKL_DFT_SRC_EXECUTE_HELPER_CUFFT_HPP_ +#define _ONEMKL_DFT_SRC_EXECUTE_HELPER_CUFFT_HPP_ #if __has_include() #include @@ -125,12 +125,16 @@ void cufft_execute(const std::string &func, CUstream stream, cufftHandle plan, v } } } - +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + // If not using the enqueue native extension, the host task must wait on the + // asynchronous operation to complete. Otherwise it report the operation + // as complete early. auto result = cuStreamSynchronize(stream); if (result != CUDA_SUCCESS) { throw oneapi::mkl::exception("dft/backends/cufft", func, "cuStreamSynchronize returned " + std::to_string(result)); } +#endif } inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, cufftHandle plan) { @@ -145,4 +149,4 @@ inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, c } // namespace oneapi::mkl::dft::cufft::detail -#endif +#endif // _ONEMKL_DFT_SRC_EXECUTE_HELPER_CUFFT_HPP_ diff --git a/src/dft/backends/cufft/forward.cpp b/src/dft/backends/cufft/forward.cpp index fb323c085..7cf73976d 100644 --- a/src/dft/backends/cufft/forward.cpp +++ b/src/dft/backends/cufft/forward.cpp @@ -31,6 +31,7 @@ #include "oneapi/mkl/dft/types.hpp" #include "execute_helper.hpp" +#include "../../execute_helper_generic.hpp" #include @@ -74,7 +75,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, auto inout_acc = inout.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); auto inout_native = reinterpret_cast *>( @@ -119,7 +120,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); auto in_native = reinterpret_cast( @@ -173,7 +174,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); detail::cufft_execute>( @@ -219,7 +220,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, plan); detail::cufft_execute>( diff --git a/src/dft/backends/rocfft/backward.cpp b/src/dft/backends/rocfft/backward.cpp index 5ff0e2a1f..e76437ee2 100644 --- a/src/dft/backends/rocfft/backward.cpp +++ b/src/dft/backends/rocfft/backward.cpp @@ -29,6 +29,7 @@ #include "oneapi/mkl/dft/descriptor.hpp" #include "execute_helper.hpp" +#include "../../execute_helper_generic.hpp" #include "rocfft_handle.hpp" #include @@ -78,14 +79,13 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto inout_acc = inout.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); auto inout_native = reinterpret_cast( reinterpret_cast *>(detail::native_mem(ih, inout_acc)) + offsets[0]); - detail::execute_checked(func_name, plan, &inout_native, nullptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, &inout_native, nullptr, info); }); }); } @@ -113,7 +113,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto inout_im_acc = inout_im.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); std::array inout_native{ @@ -124,8 +124,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, detail::native_mem(ih, inout_im_acc)) + offsets[0]) }; - detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info); }); }); } @@ -148,7 +147,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto out_acc = out.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_backward(desc, in, out)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -158,8 +157,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto out_native = reinterpret_cast( reinterpret_cast *>(detail::native_mem(ih, out_acc)) + offsets[1]); - detail::execute_checked(func_name, plan, &in_native, &out_native, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, &in_native, &out_native, info); }); }); } @@ -184,7 +182,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, auto out_im_acc = out_im.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -204,8 +202,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, detail::native_mem(ih, out_im_acc)) + offsets[1]) }; - detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info); }); }); } @@ -239,12 +236,11 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); void *inout_ptr = inout; - detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, &inout_ptr, nullptr, info); }); }); commit->set_last_usm_workspace_event_if_rqd(sycl_event); @@ -273,12 +269,12 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalardepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); std::array inout_native{ inout_re + offsets[0], inout_im + offsets[0] }; - detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info); + }); }); commit->set_last_usm_workspace_event_if_rqd(sycl_event); @@ -305,14 +301,13 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwddepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_backward(desc, in, out, deps)"; auto stream = detail::setup_stream(func_name, ih, info); void *in_ptr = in; void *out_ptr = out; - detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, &in_ptr, &out_ptr, info); }); }); commit->set_last_usm_workspace_event_if_rqd(sycl_event); @@ -336,15 +331,14 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalardepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im, deps)"; auto stream = detail::setup_stream(func_name, ih, info); std::array in_native{ in_re + offsets[0], in_im + offsets[0] }; std::array out_native{ out_re + offsets[1], out_im + offsets[1] }; - detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info); }); }); commit->set_last_usm_workspace_event_if_rqd(sycl_event); diff --git a/src/dft/backends/rocfft/execute_helper.hpp b/src/dft/backends/rocfft/execute_helper.hpp index 4dff6831d..78663a090 100644 --- a/src/dft/backends/rocfft/execute_helper.hpp +++ b/src/dft/backends/rocfft/execute_helper.hpp @@ -17,8 +17,8 @@ * SPDX-License-Identifier: Apache-2.0 *******************************************************************************/ -#ifndef _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_ -#define _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_ +#ifndef _ONEMKL_DFT_SRC_EXECUTE_HELPER_ROCFFT_HPP_ +#define _ONEMKL_DFT_SRC_EXECUTE_HELPER_ROCFFT_HPP_ #if __has_include() #include @@ -76,22 +76,28 @@ inline hipStream_t setup_stream(const std::string &func, sycl::interop_handle &i } inline void sync_checked(const std::string &func, hipStream_t stream) { - auto result = hipStreamSynchronize(stream); - if (result != hipSuccess) { - throw oneapi::mkl::exception("dft/backends/rocfft", func, - "hipStreamSynchronize returned " + std::to_string(result)); - } + auto result = hipStreamSynchronize(stream); + if (result != hipSuccess) { + throw oneapi::mkl::exception("dft/backends/rocfft", func, + "hipStreamSynchronize returned " + std::to_string(result)); + } } -inline void execute_checked(const std::string &func, const rocfft_plan plan, void *in_buffer[], +inline void execute_checked(const std::string &func, hipStream_t stream, const rocfft_plan plan, void *in_buffer[], void *out_buffer[], rocfft_execution_info info) { auto result = rocfft_execute(plan, in_buffer, out_buffer, info); if (result != rocfft_status_success) { throw oneapi::mkl::exception("dft/backends/rocfft", func, "rocfft_execute returned " + std::to_string(result)); } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + // If not using equeue native extension, the host task must wait on the + // asynchronous operation to complete. Otherwise it report the operation + // as complete early. + sync_checked(func, stream); +#endif } } // namespace oneapi::mkl::dft::rocfft::detail -#endif +#endif // _ONEMKL_DFT_SRC_EXECUTE_HELPER_ROCFFT_HPP_ diff --git a/src/dft/backends/rocfft/forward.cpp b/src/dft/backends/rocfft/forward.cpp index 70d3d0f97..d9a576720 100644 --- a/src/dft/backends/rocfft/forward.cpp +++ b/src/dft/backends/rocfft/forward.cpp @@ -30,6 +30,7 @@ #include "oneapi/mkl/dft/descriptor.hpp" #include "execute_helper.hpp" +#include "../../execute_helper_generic.hpp" #include "rocfft_handle.hpp" #include @@ -81,14 +82,13 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, auto inout_acc = inout.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); auto inout_native = reinterpret_cast( reinterpret_cast *>(detail::native_mem(ih, inout_acc)) + offsets[0]); - detail::execute_checked(func_name, plan, &inout_native, nullptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, &inout_native, nullptr, info); }); }); } @@ -116,7 +116,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, auto inout_im_acc = inout_im.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); std::array inout_native{ @@ -127,8 +127,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, detail::native_mem(ih, inout_im_acc)) + offsets[0]) }; - detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info); }); }); } @@ -150,7 +149,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_forward(desc, in, out)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -160,8 +159,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer( reinterpret_cast *>(detail::native_mem(ih, out_acc)) + offsets[1]); - detail::execute_checked(func_name, plan, &in_native, &out_native, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, &in_native, &out_native, info); }); }); } @@ -186,7 +184,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, auto out_im_acc = out_im.template get_access(cgh); commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_forward(desc, in_re, in_im, out_re, out_im)"; auto stream = detail::setup_stream(func_name, ih, info); @@ -206,8 +204,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, detail::native_mem(ih, out_im_acc)) + offsets[1]) }; - detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info); }); }); } @@ -241,12 +238,11 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); void *inout_ptr = inout; - detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, &inout_ptr, nullptr, info); }); }); commit->set_last_usm_workspace_event_if_rqd(sycl_event); @@ -274,12 +270,11 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalardepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { auto stream = detail::setup_stream(func_name, ih, info); std::array inout_native{ inout_re + offsets[0], inout_im + offsets[0] }; - detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, inout_native.data(), nullptr, info); }); }); commit->set_last_usm_workspace_event_if_rqd(sycl_event); @@ -306,14 +301,13 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwddepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_forward(desc, in, out, deps)"; auto stream = detail::setup_stream(func_name, ih, info); void *in_ptr = in; void *out_ptr = out; - detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, &in_ptr, &out_ptr, info); }); }); commit->set_last_usm_workspace_event_if_rqd(sycl_event); @@ -337,15 +331,14 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalardepend_on_last_usm_workspace_event_if_rqd(cgh); - cgh.host_task([=](sycl::interop_handle ih) { + dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) { const std::string func_name = "compute_forward(desc, in_re, in_im, out_re, out_im, deps)"; auto stream = detail::setup_stream(func_name, ih, info); std::array in_native{ in_re + offsets[0], in_im + offsets[0] }; std::array out_native{ out_re + offsets[1], out_im + offsets[1] }; - detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); - detail::sync_checked(func_name, stream); + detail::execute_checked(func_name, stream, plan, in_native.data(), out_native.data(), info); }); }); commit->set_last_usm_workspace_event_if_rqd(sycl_event); diff --git a/src/dft/execute_helper_generic.hpp b/src/dft/execute_helper_generic.hpp new file mode 100644 index 000000000..22fe0cb33 --- /dev/null +++ b/src/dft/execute_helper_generic.hpp @@ -0,0 +1,53 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_EXECUTE_HELPER_GENERIC_HPP_ +#define _ONEMKL_DFT_SRC_EXECUTE_HELPER_GENERIC_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +namespace oneapi::mkl::dft::detail { + +/** Wrap interop API to launch interop host task. + * + * @tparam HandlerT The command group handler type + * @tparam FnT The body of the enqueued task + * + * Either uses host task interop API, or enqueue native command extension. + * This extension avoids host synchronization after + * the native call is complete. + */ +template +static inline void fft_enqueue_task(HandlerT&& cgh, FnT&& f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){ +#else + cgh.host_task([=](sycl::interop_handle ih){ +#endif + f(std::move(ih)); + }); +} + +} // namespace oneapi::mkl::dft::detail + +#endif // _ONEMKL_DFT_SRC_EXECUTE_HELPER_GENERIC_HPP_