Skip to content

Commit

Permalink
addressed PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Mar 28, 2024
1 parent aff966c commit c7aedc9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 44 deletions.
20 changes: 4 additions & 16 deletions src/operations/blas3/gemm_interleaved.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,22 +496,11 @@ class Gemm<input_t, output_t, /* DoubleBuffer = */ false, /* NbcA = */ false,
for (int j = 0; j < item_rows; ++j) {
#pragma unroll
for (int b = 0; b < item_batchs / VectorSize; ++b) {
if constexpr (std::is_same_v<value_t, element_t>
#ifdef __ADAPTIVECPP__
if constexpr (is_half<value_t>::value ||
!std::is_same_v<value_t, element_t>) {
#pragma unroll
for (int v = 0; v < VectorSize; ++v) {
(*reg_res)[v] = reg_a[j * (item_batchs / VectorSize) + b][v] *
reg_b[i * (item_batchs / VectorSize) + b][v] +
(*reg_res)[v];
}
} else {
*reg_res = cl::sycl::mad(reg_a[j * (item_batchs / VectorSize) + b],
reg_b[i * (item_batchs / VectorSize) + b],
*reg_res);
}
#else
if constexpr (std::is_same_v<value_t, element_t>) {
&& !is_half<value_t>::value
#endif // __ADAPTIVECPP__
) {
*reg_res = cl::sycl::mad(reg_a[j * (item_batchs / VectorSize) + b],
reg_b[i * (item_batchs / VectorSize) + b],
*reg_res);
Expand All @@ -523,7 +512,6 @@ class Gemm<input_t, output_t, /* DoubleBuffer = */ false, /* NbcA = */ false,
(*reg_res)[v];
}
}
#endif // __ADAPTIVECPP__
++reg_res;
}
}
Expand Down
3 changes: 1 addition & 2 deletions test/blas_test_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@
name_generator);
#endif // BLAS_ENABLE_COMPLEX

/** Registers test for all supported data
* types
/** Registers test for all supported data types
* @see BLAS_REGISTER_TEST_CUSTOM_NAME
*/
#define BLAS_REGISTER_TEST_ALL(class_name, combination_t, combination, \
Expand Down
58 changes: 32 additions & 26 deletions test/unittest/blas3/blas3_gemm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ inline std::vector<scalar_t> interleaved_to_strided(
* @brief verify gemm correctness against reference BLAS.
*
* @tparam scalar_in_t type of input matrices elements (A, B)
* @tparam scalar_t type of output matrix elements (C) and scalars
* @tparam scalar_out_t type of output matrix elements (C) and scalars
* (gemm_arguments)
*/
template <typename scalar_in_t, typename scalar_t, helper::AllocType mem_alloc>
inline void verify_gemm(const gemm_arguments_t<scalar_t> arguments) {
template <typename scalar_in_t, typename scalar_out_t,
helper::AllocType mem_alloc>
inline void verify_gemm(const gemm_arguments_t<scalar_out_t> arguments) {
std::string alloc;
index_t offset;
index_t batch;
Expand All @@ -103,8 +104,8 @@ inline void verify_gemm(const gemm_arguments_t<scalar_t> arguments) {
index_t k;
char transa;
char transb;
scalar_t alpha;
scalar_t beta;
scalar_out_t alpha;
scalar_out_t beta;
index_t lda_mul;
index_t ldb_mul;
index_t ldc_mul;
Expand Down Expand Up @@ -137,7 +138,7 @@ inline void verify_gemm(const gemm_arguments_t<scalar_t> arguments) {

std::vector<scalar_in_t> a_m(buffer_size_a);
std::vector<scalar_in_t> b_m(buffer_size_b);
std::vector<scalar_t> c_m_gpu(buffer_size_c);
std::vector<scalar_out_t> c_m_gpu(buffer_size_c);

fill_random(a_m);
fill_random(b_m);
Expand Down Expand Up @@ -179,7 +180,8 @@ inline void verify_gemm(const gemm_arguments_t<scalar_t> arguments) {
blas::helper::allocate<mem_alloc, scalar_in_t>(buffer_size_a, q);
auto m_b_gpu =
blas::helper::allocate<mem_alloc, scalar_in_t>(buffer_size_b, q);
auto m_c_gpu = blas::helper::allocate<mem_alloc, scalar_t>(buffer_size_c, q);
auto m_c_gpu =
blas::helper::allocate<mem_alloc, scalar_out_t>(buffer_size_c, q);

auto copy_a =
blas::helper::copy_to_device(q, a_m.data(), m_a_gpu, buffer_size_a);
Expand Down Expand Up @@ -227,8 +229,8 @@ inline void verify_gemm(const gemm_arguments_t<scalar_t> arguments) {
helper::deallocate<mem_alloc>(m_c_gpu, q);
}

template <typename scalar_in_t, typename scalar_t = scalar_in_t>
inline void verify_gemm(const gemm_arguments_t<scalar_t> arguments) {
template <typename scalar_in_t, typename scalar_out_t = scalar_in_t>
inline void verify_gemm(const gemm_arguments_t<scalar_out_t> arguments) {
std::string alloc;
index_t offset;
index_t batch;
Expand All @@ -237,8 +239,8 @@ inline void verify_gemm(const gemm_arguments_t<scalar_t> arguments) {
index_t k;
char transa;
char transb;
scalar_t alpha;
scalar_t beta;
scalar_out_t alpha;
scalar_out_t beta;
index_t lda_mul;
index_t ldb_mul;
index_t ldc_mul;
Expand All @@ -248,12 +250,13 @@ inline void verify_gemm(const gemm_arguments_t<scalar_t> arguments) {

if (alloc == "usm") {
#ifdef SB_ENABLE_USM
verify_gemm<scalar_in_t, scalar_t, helper::AllocType::usm>(arguments);
verify_gemm<scalar_in_t, scalar_out_t, helper::AllocType::usm>(arguments);
#else
GTEST_SKIP();
#endif
} else {
verify_gemm<scalar_in_t, scalar_t, helper::AllocType::buffer>(arguments);
verify_gemm<scalar_in_t, scalar_out_t, helper::AllocType::buffer>(
arguments);
}
}

Expand All @@ -279,12 +282,13 @@ static std::string generate_name(
* @brief verify gemm batched-strided correctness against reference BLAS.
*
* @tparam scalar_in_t type of input matrices elements (A, B)
* @tparam scalar_t type of output matrix elements (C) and scalars
* @tparam scalar_out_t type of output matrix elements (C) and scalars
* (gemm_batched_strided_arguments)
*/
template <typename scalar_in_t, typename scalar_t, helper::AllocType mem_alloc>
template <typename scalar_in_t, typename scalar_out_t,
helper::AllocType mem_alloc>
inline void verify_gemm(
const gemm_batched_strided_arguments_t<scalar_t> arguments) {
const gemm_batched_strided_arguments_t<scalar_out_t> arguments) {
std::string alloc;
index_t offset;
index_t batch;
Expand All @@ -293,8 +297,8 @@ inline void verify_gemm(
index_t k;
char transa;
char transb;
scalar_t alpha;
scalar_t beta;
scalar_out_t alpha;
scalar_out_t beta;
index_t lda_mul;
index_t ldb_mul;
index_t ldc_mul;
Expand Down Expand Up @@ -335,7 +339,7 @@ inline void verify_gemm(

std::vector<scalar_in_t> a_m(buffer_size_a);
std::vector<scalar_in_t> b_m(buffer_size_b);
std::vector<scalar_t> c_m_gpu(buffer_size_c);
std::vector<scalar_out_t> c_m_gpu(buffer_size_c);

fill_random(a_m);
fill_random(b_m);
Expand Down Expand Up @@ -369,7 +373,8 @@ inline void verify_gemm(
blas::helper::allocate<mem_alloc, scalar_in_t>(buffer_size_a, q);
auto m_b_gpu =
blas::helper::allocate<mem_alloc, scalar_in_t>(buffer_size_b, q);
auto m_c_gpu = blas::helper::allocate<mem_alloc, scalar_t>(buffer_size_c, q);
auto m_c_gpu =
blas::helper::allocate<mem_alloc, scalar_out_t>(buffer_size_c, q);

auto copy_a =
blas::helper::copy_to_device(q, a_m.data(), m_a_gpu, buffer_size_a);
Expand Down Expand Up @@ -403,9 +408,9 @@ inline void verify_gemm(
helper::deallocate<mem_alloc>(m_c_gpu, q);
}

template <typename scalar_in_t, typename scalar_t = scalar_in_t>
template <typename scalar_in_t, typename scalar_out_t = scalar_in_t>
inline void verify_gemm(
const gemm_batched_strided_arguments_t<scalar_t> arguments) {
const gemm_batched_strided_arguments_t<scalar_out_t> arguments) {
std::string alloc;
index_t offset;
index_t batch;
Expand All @@ -414,8 +419,8 @@ inline void verify_gemm(
index_t k;
char transa;
char transb;
scalar_t alpha;
scalar_t beta;
scalar_out_t alpha;
scalar_out_t beta;
index_t lda_mul;
index_t ldb_mul;
index_t ldc_mul;
Expand All @@ -428,10 +433,11 @@ inline void verify_gemm(

if (alloc == "usm") {
#ifdef SB_ENABLE_USM
verify_gemm<scalar_in_t, scalar_t, helper::AllocType::usm>(arguments);
verify_gemm<scalar_in_t, scalar_out_t, helper::AllocType::usm>(arguments);
#endif
} else {
verify_gemm<scalar_in_t, scalar_t, helper::AllocType::buffer>(arguments);
verify_gemm<scalar_in_t, scalar_out_t, helper::AllocType::buffer>(
arguments);
}
}

Expand Down

0 comments on commit c7aedc9

Please sign in to comment.