Skip to content

Commit

Permalink
removed non-necessary gemm template parameter element in/out
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Mar 28, 2024
1 parent 620ea94 commit aff966c
Show file tree
Hide file tree
Showing 19 changed files with 407 additions and 475 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON)
option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for GEMM" OFF)
option(BLAS_ENABLE_HALF "Whether to enable sycl::half data type for supported operators" OFF)

if(((NOT INSTALL_HEADER_ONLY) AND (TUNING_TARGET STREQUAL "DEFAULT_CPU"))
OR (INSTALL_HEADER_ONLY AND (NOT TUNING_TARGET)))
set(BLAS_ENABLE_HALF OFF)
message(STATUS "FP16 operations are not supported for CPU targets. BLAS_ENABLE_HALF is disabled")
endif()

if (SYCL_COMPILER MATCHES "adaptivecpp")
if(BLAS_ENABLE_COMPLEX)
Expand Down
35 changes: 16 additions & 19 deletions include/operations/blas3_trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ struct Tile {
* @tparam TransB iff true, matrix B will be transposed on the fly
* @tparam SymmA whether the matrix A is a symmetric triangular matrix
* @tparam SymmB whether the matrix B is a symmetric triangular matrix
* @tparam element_in_t type of input matrix elements (A & B)
* @tparam element_t type of output matrix elements (C) and scaling parameters
* @tparam element_t type of scalar alpha & beta
* @tparam UseJointMatrix boolean parameter to decide whether to use
* joint_matrix or not
* @param a_ the lhs_t matrix
Expand All @@ -191,13 +190,12 @@ struct Tile {
*/
template <typename input_t, typename output_t, bool DoubleBuffer, bool NbcA,
bool NbcB, int ClSize, typename tile_type, bool TransA, bool TransB,
bool SymmA, bool SymmB, typename element_in_t, typename element_t,
bool is_beta_zero, int GemmMemoryType, int GemmAlgorithm,
int GemmVectorization, int VectorSize, int BatchType,
bool UseJointMatrix = false>
bool SymmA, bool SymmB, typename element_t, bool is_beta_zero,
int GemmMemoryType, int GemmAlgorithm, int GemmVectorization,
int VectorSize, int BatchType, bool UseJointMatrix = false>
class Gemm {
public:
using value_t = element_t;
using value_t = typename input_t::value_t;
using index_t = typename std::make_signed<typename input_t::index_t>::type;
static constexpr int wg_size = tile_type::wg_rows * tile_type::wg_cols;
static constexpr bool trans_a = TransA;
Expand Down Expand Up @@ -251,8 +249,7 @@ class Gemm {
*/
template <typename input_t, typename output_t, bool DoubleBuffer, bool NbcA,
bool NbcB, int ClSize, typename TileType, bool TransA, bool TransB,
bool IsFinal, bool IsBetaZero, typename element_in_t,
typename element_t, int GemmMemoryType>
bool IsFinal, bool IsBetaZero, typename element_t, int GemmMemoryType>
class GemmPartial {};

/*
Expand All @@ -263,21 +260,21 @@ template <bool DoubleBuffer, bool ConflictA, bool ConflictB, int ClSize,
typename TileType, bool TransA, bool TransB, bool SymmA, bool SymmB,
int GemmMemoryType, int GemmAlgorithm, int GemmVectorization,
bool is_beta_zero, int VectorSize, int BatchType, bool UseJointMatrix,
typename element_in_t, typename element_t, typename input_t,
typename output_t, typename index_t>
typename input_t, typename output_t, typename element_t,
typename index_t>
inline Gemm<input_t, output_t, DoubleBuffer, ConflictA, ConflictB, ClSize,
TileType, TransA, TransB, SymmA, SymmB, element_in_t, element_t,
is_beta_zero, GemmMemoryType, GemmAlgorithm, GemmVectorization,
VectorSize, BatchType, UseJointMatrix>
TileType, TransA, TransB, SymmA, SymmB, element_t, is_beta_zero,
GemmMemoryType, GemmAlgorithm, GemmVectorization, VectorSize,
BatchType, UseJointMatrix>
make_gemm(input_t buffer_a, input_t buffer_b, output_t buffer_c,
element_t alpha, element_t beta, index_t batch_size, index_t _stridea,
index_t _strideb, index_t _stridec) {
return Gemm<input_t, output_t, DoubleBuffer, ConflictA, ConflictB, ClSize,
TileType, TransA, TransB, SymmA, SymmB, element_in_t, element_t,
is_beta_zero, GemmMemoryType, GemmAlgorithm, GemmVectorization,
VectorSize, BatchType, UseJointMatrix>(
buffer_a, buffer_b, buffer_c, alpha, beta, batch_size, _stridea, _strideb,
_stridec);
TileType, TransA, TransB, SymmA, SymmB, element_t, is_beta_zero,
GemmMemoryType, GemmAlgorithm, GemmVectorization, VectorSize,
BatchType, UseJointMatrix>(buffer_a, buffer_b, buffer_c, alpha,
beta, batch_size, _stridea, _strideb,
_stridec);
}

/**
Expand Down
55 changes: 26 additions & 29 deletions include/sb_handle/portblas_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,43 +123,40 @@ class SB_Handle {

template <typename input_t, typename output_t, bool DoubleBuffer, bool NbcA,
bool NbcB, int ClSize, typename tile_type, bool TransA, bool TransB,
bool SymmA, bool SymmB, typename element_in_t,
typename element_out_t, bool is_beta_zero, int GemmMemoryType,
int GemmAlgorithm, int GemmVectorization, int VectorSize,
int BatchType, bool UseJointMatrix>
event_t execute(
Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
TransA, TransB, SymmA, SymmB, element_in_t, element_out_t,
is_beta_zero, GemmMemoryType, GemmAlgorithm, GemmVectorization,
VectorSize, BatchType, UseJointMatrix>
gemm_tree,
const event_t& dependencies = {});
bool SymmA, bool SymmB, typename element_t, bool is_beta_zero,
int GemmMemoryType, int GemmAlgorithm, int GemmVectorization,
int VectorSize, int BatchType, bool UseJointMatrix>
event_t execute(Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize,
tile_type, TransA, TransB, SymmA, SymmB, element_t,
is_beta_zero, GemmMemoryType, GemmAlgorithm,
GemmVectorization, VectorSize, BatchType, UseJointMatrix>
gemm_tree,
const event_t& dependencies = {});

// Tall and skinny Gemm specialization
template <typename input_t, typename output_t, bool DoubleBuffer, bool NbcA,
bool NbcB, int ClSize, typename tile_type, bool TransA, bool TransB,
bool SymmA, bool SymmB, typename element_in_t,
typename element_out_t, bool is_beta_zero, int GemmMemoryType,
int GemmVectorization, int VectorSize, int BatchType>
event_t execute(Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize,
tile_type, TransA, TransB, SymmA, SymmB, element_in_t,
element_out_t, is_beta_zero, GemmMemoryType,
static_cast<int>(gemm_algorithm_t::tall_skinny),
GemmVectorization, VectorSize, BatchType>
gemm_wrapper,
const event_t& dependencies = {});
bool SymmA, bool SymmB, typename element_t, bool is_beta_zero,
int GemmMemoryType, int GemmVectorization, int VectorSize,
int BatchType>
event_t execute(
Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
TransA, TransB, SymmA, SymmB, element_t, is_beta_zero,
GemmMemoryType, static_cast<int>(gemm_algorithm_t::tall_skinny),
GemmVectorization, VectorSize, BatchType>
gemm_wrapper,
const event_t& dependencies = {});

// GemmPartial specialization
template <typename input_t, typename output_t, bool DoubleBuffer, bool NbcA,
bool NbcB, int ClSize, typename tile_type, bool TransA, bool TransB,
bool IsFinal, bool IsBetaZero, typename element_in_t,
typename element_out_t, int GemmMemoryType>
event_t execute(
GemmPartial<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize,
tile_type, TransA, TransB, IsFinal, IsBetaZero, element_in_t,
element_out_t, GemmMemoryType>
gemm_partial,
const event_t& dependencies = {});
bool IsFinal, bool IsBetaZero, typename element_t,
int GemmMemoryType>
event_t execute(GemmPartial<input_t, output_t, DoubleBuffer, NbcA, NbcB,
ClSize, tile_type, TransA, TransB, IsFinal,
IsBetaZero, element_t, GemmMemoryType>
gemm_partial,
const event_t& dependencies = {});

// Reduction specialization (inner or outer dimension)
template <typename operator_t, typename params_t, typename input_t,
Expand Down
24 changes: 12 additions & 12 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ namespace gemm {

namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_sycl_scalar<element_in_t>::value,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
using element_in_t = typename ValueType<container_0_t>::type;
// Unused configuration cases
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
Expand Down Expand Up @@ -233,17 +233,17 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
// Complex Configurations
#ifdef BLAS_ENABLE_COMPLEX
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_complex_sycl<element_in_t>::value,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_complex_sycl<element_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
using element_in_t = typename ValueType<container_0_t>::type;
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_in_t);
/* Tall & Skinny matrices. */
Expand Down
38 changes: 18 additions & 20 deletions src/interface/blas3/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ namespace gemm {
namespace backend {

template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_sycl_scalar<element_in_t>::value &&
!is_half<element_in_t>::value,
typename sb_handle_t::event_t>::type
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<
is_sycl_scalar<element_t>::value &&
!is_half<typename ValueType<container_0_t>::type>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
Expand Down Expand Up @@ -123,14 +123,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,

// Half Configurations
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_half<element_in_t>::value,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_half<typename ValueType<container_0_t>::type>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
Expand Down Expand Up @@ -167,14 +166,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
// Complex Configurations
#ifdef BLAS_ENABLE_COMPLEX
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_complex_sycl<element_in_t>::value,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_complex_sycl<element_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
Expand Down
41 changes: 20 additions & 21 deletions src/interface/blas3/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,19 @@ namespace blas {
namespace gemm {
namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_sycl_scalar<element_in_t>::value &&
!is_half<element_in_t>::value,
typename sb_handle_t::event_t>::type
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<
is_sycl_scalar<element_t>::value &&
!is_half<typename ValueType<container_0_t>::type>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
using element_in_t = typename ValueType<container_0_t>::type;
// Unused configuration cases
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
Expand Down Expand Up @@ -213,14 +214,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,

// Half Configurations
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_half<element_in_t>::value,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_half<typename ValueType<container_0_t>::type>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
Expand Down Expand Up @@ -275,20 +275,19 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
// Complex Configurations
#ifdef BLAS_ENABLE_COMPLEX
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename element_in_t, typename element_out_t, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t>
typename std::enable_if<is_complex_sycl<element_in_t>::value,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_complex_sycl<element_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_out_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_out_t _beta,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1) {
constexpr int wg_size = sizeof(element_in_t) == 16 ? 4 : 8;
constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8;
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, true, true, true, 64,
Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, false, false,
Expand Down
Loading

0 comments on commit aff966c

Please sign in to comment.