From 13e283854101e7159c0fa9c9c0652a4d4f509dc2 Mon Sep 17 00:00:00 2001 From: Hugh Bird Date: Tue, 18 Jun 2024 20:05:33 +0100 Subject: [PATCH] Support 2024.1 and later only; Fix exception handling --- src/dft/backends/mklgpu/commit.cpp | 52 ++++++++++++++++-------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/src/dft/backends/mklgpu/commit.cpp b/src/dft/backends/mklgpu/commit.cpp index 942a0debd..10379a98a 100644 --- a/src/dft/backends/mklgpu/commit.cpp +++ b/src/dft/backends/mklgpu/commit.cpp @@ -39,9 +39,9 @@ // MKL 2024.1 deprecates input/output strides. #include "mkl_version.h" -namespace oneapi::mkl::dft::mklgpu::detail { -constexpr bool mklgpu_use_forward_backward_strides_api = INTEL_MKL_VERSION >= 20240001; -} +#if INTEL_MKL_VERSION < 20240001 +#error MKLGPU requires oneMKL 2024.1 or later +#endif /** Note that in this file, the Intel oneMKL closed-source library's interface mirrors the interface @@ -89,22 +89,30 @@ class mklgpu_commit final : public dft::detail::commit_impl { oneapi::mkl::dft::detail::external_workspace_helper( config_values.workspace_placement == oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL); + + // A separate descriptor for each direction may not be required. + bool one_descriptor = config_values.input_strides == config_values.output_strides; + bool forward_good = true; + // Make sure that second is always pointing to something new if this is a recommit. + handle.second = handle.first; + // Generate forward DFT descriptor. set_value(*handle.first, config_values, true); try { handle.first->commit(this->get_queue()); } catch (const std::exception& mkl_exception) { - // Catching the real Intel oneMKL exception causes headaches with naming. - throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what()); + // Catching the real Intel oneMKL exception causes headaches with naming + forward_good = false; + if (one_descriptor) { + throw mkl::exception("dft/backends/mklgpu" + "commit", + mkl_exception.what()); + } } // Generate backward DFT descriptor only if required. - if (config_values.input_strides == config_values.output_strides) { - // Required if second != first before a recommit. - handle.second = handle.first; - } - else { + if (!one_descriptor) { handle.second = std::make_shared(config_values.dimensions); set_value(*handle.second, config_values, false); try { @@ -112,7 +120,11 @@ class mklgpu_commit final : public dft::detail::commit_impl { } catch (const std::exception& mkl_exception) { // Catching the real Intel oneMKL exception causes headaches with naming. - throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what()); + if (!forward_good) { + throw mkl::exception("dft/backends/mklgpu" + "commit", + mkl_exception.what()); + } } } } @@ -173,21 +185,13 @@ class mklgpu_commit final : public dft::detail::commit_impl { throw mkl::unimplemented("dft/backends/mklgpu", "commit", "MKLGPU does not support nonzero offsets."); } - if constexpr (mklgpu_use_forward_backward_strides_api) { - // Support for Intel oneMKL 2024.1 or newer using FWD/BWD stride API. - if (assume_fwd_dft) { - desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data()); - desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data()); - } - else { - desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data()); - desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data()); - } + if (assume_fwd_dft) { + desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data()); + desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data()); } else { - // Support for Intel oneMKL older than 2024.1 - desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data()); - desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data()); + desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data()); + desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data()); } desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist); desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);