Skip to content

Commit

Permalink
Support 2024.1 and later only; Fix exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
hjabird committed Jun 18, 2024
1 parent 9ae3a61 commit 13e2838
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions src/dft/backends/mklgpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@

// MKL 2024.1 deprecates input/output strides.
#include "mkl_version.h"
namespace oneapi::mkl::dft::mklgpu::detail {
constexpr bool mklgpu_use_forward_backward_strides_api = INTEL_MKL_VERSION >= 20240001;
}
#if INTEL_MKL_VERSION < 20240001
#error MKLGPU requires oneMKL 2024.1 or later
#endif

/**
Note that in this file, the Intel oneMKL closed-source library's interface mirrors the interface
Expand Down Expand Up @@ -89,30 +89,42 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
oneapi::mkl::dft::detail::external_workspace_helper<prec, dom>(
config_values.workspace_placement ==
oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL);

// A separate descriptor for each direction may not be required.
bool one_descriptor = config_values.input_strides == config_values.output_strides;
bool forward_good = true;
// Make sure that second is always pointing to something new if this is a recommit.
handle.second = handle.first;

// Generate forward DFT descriptor.
set_value(*handle.first, config_values, true);
try {
handle.first->commit(this->get_queue());
}
catch (const std::exception& mkl_exception) {
// Catching the real Intel oneMKL exception causes headaches with naming.
throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what());
// Catching the real Intel oneMKL exception causes headaches with naming
forward_good = false;
if (one_descriptor) {
throw mkl::exception("dft/backends/mklgpu"
"commit",
mkl_exception.what());
}
}

// Generate backward DFT descriptor only if required.
if (config_values.input_strides == config_values.output_strides) {
// Required if second != first before a recommit.
handle.second = handle.first;
}
else {
if (!one_descriptor) {
handle.second = std::make_shared<mklgpu_descriptor_t>(config_values.dimensions);
set_value(*handle.second, config_values, false);
try {
handle.second->commit(this->get_queue());
}
catch (const std::exception& mkl_exception) {
// Catching the real Intel oneMKL exception causes headaches with naming.
throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what());
if (!forward_good) {
throw mkl::exception("dft/backends/mklgpu"
"commit",
mkl_exception.what());
}
}
}
}
Expand Down Expand Up @@ -173,21 +185,13 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
if constexpr (mklgpu_use_forward_backward_strides_api) {
// Support for Intel oneMKL 2024.1 or newer using FWD/BWD stride API.
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
}
else {
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
}
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
}
else {
// Support for Intel oneMKL older than 2024.1
desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data());
desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data());
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
}
desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist);
desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);
Expand Down

0 comments on commit 13e2838

Please sign in to comment.