-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This consolidates implementations and provides access to constraints for the device-enabled SPMM. DeviceTensor and device_gemm() have been outlined into a header file. Signed-off-by: Joseph Schuchart <[email protected]>
- Loading branch information
Showing
5 changed files
with
409 additions
and
1,992 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
|
||
#if defined(TTG_ENABLE_LEVEL_ZERO) | ||
#include <oneapi/mkl.hpp> | ||
#include <sys/time.h> | ||
#endif | ||
|
||
#include "../devblas_helper.h" | ||
|
||
|
||
template <typename Blk> | ||
inline void device_gemm(Blk &C, const Blk &A, const Blk &B) { | ||
using blk_t = Blk; | ||
using T = typename blk_t::value_type; | ||
static_assert(std::is_same_v<T,double> || std::is_same_v<T,float>); | ||
static const T alpha = 1.0; | ||
static const T beta = 1.0; | ||
// make sure all memory is on the device | ||
// TODO: A and B are read-only so the owner device will be 0. How to fix? | ||
//assert(A.b.get_current_device() != 0); | ||
//assert(B.b.get_current_device() != 0); | ||
auto device = ttg::device::current_device(); | ||
assert(device.is_device()); | ||
#if defined(TTG_ENABLE_CUDA) | ||
if constexpr (std::is_same_v<T,double>) { | ||
cublasDgemm(cublas_handle(), CUBLAS_OP_N, CUBLAS_OP_N, C.extent(0), C.extent(1), A.extent(1), | ||
&alpha, A.b.current_device_ptr(), A.extent(0), B.b.current_device_ptr(), B.extent(0), &beta, | ||
C.b.current_device_ptr(), C.extent(0)); | ||
} | ||
else if constexpr (std::is_same_v<T,float>) { | ||
cublasSgemm(cublas_handle(), CUBLAS_OP_N, CUBLAS_OP_N, C.extent(0), C.extent(1), A.extent(1), | ||
&alpha, A.b.current_device_ptr(), A.extent(0), B.b.current_device_ptr(), B.extent(0), &beta, | ||
C.b.current_device_ptr(), C.extent(0)); | ||
} | ||
#elif defined(TTG_ENABLE_HIP) | ||
if constexpr (std::is_same_v<T,double>) { | ||
hipblasDgemm(hipblas_handle(), | ||
HIPBLAS_OP_N, HIPBLAS_OP_N, | ||
C.extent(0), C.extent(1), A.extent(1), &alpha, | ||
A.b.current_device_ptr(), A.extent(0), | ||
B.b.current_device_ptr(), B.extent(0), &beta, | ||
C.b.current_device_ptr(), C.extent(0)); | ||
} else if constexpr (std::is_same_v<T,float>) { | ||
hipblasSgemm(hipblas_handle(), | ||
HIPBLAS_OP_N, HIPBLAS_OP_N, | ||
C.extent(0), C.extent(1), A.extent(1), &alpha, | ||
A.b.current_device_ptr(), A.extent(0), | ||
B.b.current_device_ptr(), B.extent(0), &beta, | ||
C.b.current_device_ptr(), C.extent(0)); | ||
} | ||
#elif defined(TTG_ENABLE_LEVEL_ZERO) | ||
|
||
#if defined(DEBUG_SYNCHRONOUS) | ||
try { | ||
#endif /* DEBUG_SYNCHRONOUS */ | ||
cl::sycl::event gemm_event; | ||
gemm_event = oneapi::mkl::blas::gemm(ttg::device::current_stream(), | ||
oneapi::mkl::transpose::N, oneapi::mkl::transpose::N, | ||
C.extent(0), C.extent(1), A.extent(1), | ||
alpha, A.b.current_device_ptr(), A.extent(0), | ||
B.b.current_device_ptr(), B.extent(0), | ||
beta, C.b.current_device_ptr(), C.extent(0)); | ||
#if defined(DEBUG_SYNCHRONOUS) | ||
gemm_event.wait(); | ||
} catch (const oneapi::mkl::invalid_argument &e) { | ||
std::cerr << "OneAPI MKL BLAS GEMM throws invalid argument exception" << std::endl; | ||
} catch (const oneapi::mkl::unsupported_device &e) { | ||
std::cerr << "OneAPI MKL BLAS GEMM throws unsuported device exception" << std::endl; | ||
} catch (const oneapi::mkl::host_bad_alloc &e) { | ||
std::cerr << "OneAPI MKL BLAS GEMM throws host bad allocation exception" << std::endl; | ||
} catch (const oneapi::mkl::device_bad_alloc &e) { | ||
std::cerr << "OneAPI MKL BLAS GEMM throws device bad allocation exception" << std::endl; | ||
} catch (const oneapi::mkl::unimplemented &e) { | ||
std::cerr << "OneAPI MKL BLAS GEMM throws unimplemented exception" << std::endl; | ||
} catch (const std::exception& e) { | ||
std::cerr << "OneAPI MKL BLAS GEMM throws unexpected exception" << std::endl; | ||
} catch (...) { | ||
std::cerr << "OneAPI MKL BLAS GEMM throws unexpected exception that is also badly formatted..." << std::endl; | ||
} | ||
#endif /* DEBUG_SYNCHRONOUS */ | ||
#endif | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
#ifndef HAVE_DEVICETENSOR_H | ||
#define HAVE_DEVICETENSOR_H | ||
|
||
#include <ttg.h> | ||
|
||
#if __has_include(<btas/features.h>) | ||
#pragma message("C Preprocessor got here!") | ||
#include <btas/features.h> | ||
#ifdef BTAS_IS_USABLE | ||
#include <btas/btas.h> | ||
#include <btas/optimize/contract.h> | ||
#include <btas/util/mohndle.h> | ||
#include <TiledArray/device/allocators.h> | ||
#include "../devblas_helper.h" | ||
#include <madness/world/parsec.h> // need to initialize MADNESS purely for the purposes of TA allocators | ||
#else | ||
#warning "found btas/features.h but Boost.Iterators is missing, hence BTAS is unusable ... add -I/path/to/boost" | ||
#endif | ||
#endif | ||
|
||
#if defined(BTAS_IS_USABLE) | ||
|
||
/** | ||
* Derives from btas::Tensor and wraps a ttg::Buffer | ||
* to enable device support in SPMM. The ttg::Buffer | ||
* does not own the host memory but mananages the device | ||
* memory. | ||
*/ | ||
template <typename _T, class _Range, class _Storage> | ||
struct DeviceTensor : public ttg::TTValue<DeviceTensor<_T, _Range, _Storage>> | ||
, public btas::Tensor<_T, _Range, _Storage> { | ||
using tensor_type = typename btas::Tensor<_T, _Range, _Storage>; | ||
using ttvalue_type = typename ttg::TTValue<DeviceTensor<_T, _Range, _Storage>>; | ||
ttg::Buffer<_T> b; // does not own the host buffer | ||
|
||
using value_type = typename tensor_type::value_type; | ||
using size_type = typename tensor_type::size_type; | ||
using storage_type = typename tensor_type::storage_type; | ||
using range_type = typename tensor_type::range_type; | ||
|
||
|
||
public: | ||
DeviceTensor() = default; | ||
~DeviceTensor() = default; | ||
|
||
/// constructor with index extent | ||
template <typename... _args> | ||
explicit DeviceTensor(const size_type& first, const _args&... rest) | ||
: ttvalue_type() | ||
, tensor_type(first, rest...) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// construct from \c range, allocate data, but not initialized | ||
template <typename Range> | ||
explicit DeviceTensor(const Range& range, typename std::enable_if<btas::is_boxrange<Range>::value>::type* = 0) | ||
: ttvalue_type() | ||
, tensor_type(range) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// construct from \c range object, set all elements to \c v | ||
template <typename Range> | ||
DeviceTensor(const Range& range, value_type v, typename std::enable_if<btas::is_boxrange<Range>::value>::type* = 0) | ||
: ttvalue_type() | ||
, tensor_type(range) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// construct from \c range object, copy elements from \c vec | ||
template <typename Range, typename U> | ||
DeviceTensor(const Range& range, U* vec, typename std::enable_if<btas::is_boxrange<Range>::value>::type* = 0) | ||
: ttvalue_type() | ||
, tensor_type(range, vec) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// construct from \c range and \c storage | ||
template <typename Range, typename Storage> | ||
DeviceTensor(const Range& range, const Storage& storage, | ||
typename std::enable_if<btas::is_boxrange<Range>::value & not std::is_same<Range, range_type>::value & | ||
not std::is_same<Storage, storage_type>::value>::type* = 0) | ||
: ttvalue_type() | ||
, tensor_type(range, storage) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// copy-copy-construct from \c range and \c storage | ||
DeviceTensor(const range_type& range, const storage_type& storage) | ||
: ttvalue_type() | ||
, tensor_type(range, storage) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// copy-move-construct from \c range and \c storage | ||
DeviceTensor(const range_type& range, storage_type&& storage) | ||
: ttvalue_type() | ||
, tensor_type(range, std::forward<storage_type>(storage)) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// move-construct from \c range and \c storage | ||
DeviceTensor(range_type&& range, storage_type&& storage) | ||
: ttvalue_type() | ||
, tensor_type(std::forward<range_type>(range), std::forward<storage_type>(storage)) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// Construct an evaluated tensor | ||
|
||
/// This constructor will allocate memory for \c range.area() elements. Each element | ||
/// will be initialized as: | ||
/// \code | ||
/// for(auto&& idx: range) | ||
/// (*this)[idx] = op(*(it++)); | ||
/// \endcode | ||
/// \tparam Range An input Range type. | ||
/// \tparam InIter An input iterator type. | ||
/// \tparam Op A unary operation type | ||
/// \param range the input range type | ||
/// \param first An input iterator for the argument | ||
/// \param op The unary operation to be applied to the argument data | ||
template <typename Range, typename InIter, typename Op> | ||
DeviceTensor(const Range& range, InIter it, const Op& op, | ||
typename std::enable_if<btas::is_boxrange<Range>::value>::type* = 0) | ||
: ttvalue_type() | ||
, tensor_type(range, it, op) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ } | ||
|
||
/// copy constructor | ||
/// It will accept Tensors and TensorViews | ||
template <class _Tensor, class = typename std::enable_if<btas::is_boxtensor<_Tensor>::value>::type> | ||
DeviceTensor(const _Tensor& x) noexcept | ||
: ttvalue_type() | ||
, tensor_type(x.clone()) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ | ||
//std::cout << "DeviceTensor tensor_type copy ctor" << std::endl; | ||
} | ||
|
||
/// copy constructor: devicebuf cannot be copied, so deleted | ||
DeviceTensor(const DeviceTensor& x) noexcept | ||
: ttvalue_type(x) | ||
, tensor_type(x.clone()) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ | ||
//std::cout << "DeviceTensor copy ctor" << std::endl; | ||
} | ||
|
||
/// move constructor | ||
DeviceTensor(tensor_type&& x) noexcept | ||
: ttvalue_type() | ||
, tensor_type(std::move(x)) | ||
, b(this->size() ? this->data() : nullptr, this->size()) | ||
{ | ||
//std::cout << "DeviceTensor tensor_type move ctor" << std::endl; | ||
} | ||
|
||
DeviceTensor(DeviceTensor&& x) noexcept | ||
: ttvalue_type(std::move(x)) | ||
, tensor_type(static_cast<tensor_type&&>(x)) | ||
, b(std::move(x.b)) | ||
{ | ||
assert(this->data() == b.host_ptr()); | ||
//std::cout << "DeviceTensor move ctor" << std::endl; | ||
} | ||
|
||
/// copy assignment operator | ||
template <class _Tensor, class = typename std::enable_if< | ||
btas::is_boxtensor<_Tensor>::value && | ||
not std::is_same<typename _Tensor::storage_type, storage_type>::value>::type> | ||
DeviceTensor& operator=(const _Tensor& x) noexcept { | ||
tensor_type::operator=(x.clone()); | ||
b.reset(this->size() ? this->data() : nullptr, this->size()); | ||
//std::cout << "DeviceTensor tensor_type copy operator" << std::endl; | ||
return *this; | ||
} | ||
|
||
/// copy assignment operator | ||
template <class _Tensor, class = typename std::enable_if<btas::is_boxtensor<_Tensor>::value>::type, | ||
class = typename std::enable_if< | ||
std::is_same<typename _Tensor::storage_type, storage_type>::value>::type> | ||
DeviceTensor& operator=(const _Tensor& x) noexcept { | ||
tensor_type::operator=(x.clone()); | ||
b.reset(this->size() ? this->data() : nullptr, this->size()); | ||
//std::cout << "DeviceTensor tensor_type copy operator" << std::endl; | ||
return *this; | ||
} | ||
|
||
/// copy assignment: devicebuf cannot be copied, deleted | ||
DeviceTensor& operator=(const DeviceTensor& x) noexcept { | ||
ttvalue_type::operator=(x); | ||
tensor_type::operator=(x.clone()); | ||
b.reset(this->size() ? this->data() : nullptr, this->size()); | ||
//std::cout << "DeviceTensor copy operator" << std::endl; | ||
return *this; | ||
} | ||
|
||
/// move assignment operator | ||
DeviceTensor& operator=(DeviceTensor&& x) noexcept { | ||
ttvalue_type::operator=(std::move(x)); | ||
tensor_type::operator=(static_cast<tensor_type&&>(x)); | ||
b = std::move(x.b); | ||
//std::cout << "DeviceTensor move ctor" << std::endl; | ||
return *this; | ||
} | ||
|
||
using tensor_type::begin; | ||
using tensor_type::cbegin; | ||
using tensor_type::end; | ||
using tensor_type::cend; | ||
|
||
}; | ||
|
||
#endif // defined(BTAS_IS_USABLE) | ||
|
||
#endif // HAVE_DEVICETENSOR_H |
Oops, something went wrong.