Skip to content

Commit

Permalink
updated gemm template for joint matrix case
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Feb 29, 2024
1 parent ac95e7e commit be898a9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/operations/blas3/gemm_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
static constexpr bool symm_a = SymmA;
static constexpr bool symm_b = SymmB;

//! @brief Number of elements which fit within a cache line.
//! @brief Number of input elements which fit within a cache line.
static constexpr index_t cl_elems = ClSize / sizeof(element_in_t);
//! @brief Number of work items within a work group
static constexpr index_t wg_size = wg_rows * wg_cols;
Expand All @@ -128,7 +128,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,

static_assert(wg_size % cl_elems == 0,
"Work group size should be a multiple "
"of elements in a cache line\n"
"of input elements in a cache line\n"
" --- this is ensured iff:"
" cl_size | sizeof(element_in_t) * wg_rows * wg_cols");

Expand Down
7 changes: 4 additions & 3 deletions src/operations/blas3/gemm_local_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace blas {
* @tparam TransA iff true, matrix A will be transposed on the fly
* @tparam TransB iff true, matrix B will be transposed on the fly
* @tparam element_t type of matrix elements
* @tparam element_out_t type of output matrix elements (C) and scalars
* @tparam is_beta_zero True if beta == 0.
* @tparam VectorSize The packet size to be used for vectorization.
* @tparam batch_type the type of batch strideded /interleaved
Expand All @@ -70,10 +71,10 @@ namespace blas {
*/
template <typename input_t, typename output_t, bool DoubleBuffer, bool NbcA,
bool NbcB, int ClSize, typename TileType, bool TransA, bool TransB,
bool SymmA, bool SymmB, typename element_t, bool is_beta_zero,
int VectorSize>
bool SymmA, bool SymmB, typename element_t, typename element_out_t,
bool is_beta_zero, int VectorSize>
class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
TransA, TransB, SymmA, SymmB, element_t, is_beta_zero,
TransA, TransB, SymmA, SymmB, element_t, element_out_t, is_beta_zero,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::none), VectorSize,
Expand Down

0 comments on commit be898a9

Please sign in to comment.