Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Separated half gemm config for default CPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Mar 5, 2024
1 parent 26970c0 commit a2f489f
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 4 deletions.
12 changes: 11 additions & 1 deletion cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,17 @@ else() # default cpu backend
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false")
endforeach()

if(BLAS_ENABLE_HALF)
add_gemm_configuration(
"half" 128 "false" "false" "false"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false")
add_gemm_configuration(
"half" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false")
endif()

endforeach()

if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
Expand Down
46 changes: 45 additions & 1 deletion src/interface/blas3/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_sycl_scalar<element_in_t>::value,
typename std::enable_if<is_sycl_scalar<element_in_t>::value &&
!is_half<element_in_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
Expand Down Expand Up @@ -120,6 +121,49 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
}
}

// Half Configurations
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_half<element_in_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
// Unused configuration cases
if constexpr (s_a || s_b) {
return _dependencies;
} else {
if (batch_type == gemm_batch_type_t::interleaved) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 4, 4, 1, 1, 1, 1, 4, 4>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 4,
static_cast<int>(gemm_batch_type_t::interleaved)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}

return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 128, false, false, false,
64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
}
}

// Complex Configurations
#ifdef BLAS_ENABLE_COMPLEX
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
Expand Down
2 changes: 1 addition & 1 deletion src/interface/blas3/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
// Unused configuration cases
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
if constexpr (s_a || s_b) {
return _dependencies;
} else {
if (batch_type == gemm_batch_type_t::interleaved) {
Expand Down
2 changes: 1 addition & 1 deletion src/interface/blas3/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
// Unused configuration cases
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
if constexpr (s_a || s_b) {
return _dependencies;
} else {
if (batch_type == gemm_batch_type_t::interleaved) {
Expand Down

0 comments on commit a2f489f

Please sign in to comment.