Skip to content

Commit

Permalink
Added joint_matrix for Intel PVC and ARC
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Mar 28, 2024
1 parent eff2458 commit 81b98a3
Show file tree
Hide file tree
Showing 7 changed files with 1,085 additions and 5 deletions.
29 changes: 25 additions & 4 deletions benchmark/portblas/blas3/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
**************************************************************************/

#include "../utils.hpp"
#include "../../../test/blas_test.hpp"

using namespace cl::sycl::ext::oneapi;

constexpr blas_benchmark::utils::Level3Op benchmark_op =
blas_benchmark::utils::Level3Op::gemm;
Expand Down Expand Up @@ -61,19 +64,37 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int t1,
}

// Matrices
std::vector<scalar_t> a = blas_benchmark::utils::random_data<scalar_t>(m * k);
std::vector<scalar_t> b = blas_benchmark::utils::random_data<scalar_t>(k * n);
std::vector<scalar_t> a(m * k);
std::vector<scalar_t> b(n * k);
std::vector<scalar_t> c =
blas_benchmark::utils::const_data<scalar_t>(m * n, 0);

fill_random_with_range(a, scalar_t{1}, scalar_t{2});
fill_random_with_range(b, scalar_t{1}, scalar_t{2});

set_to_zero_last_nbits(a, 16);
set_to_zero_last_nbits(b, 16);
set_to_zero_last_nbits(c, 16);

std::vector<bfloat16> a_bf16(m * k * sizeof(scalar_t) / sizeof(bfloat16));
std::vector<bfloat16> b_bf16(n * k * sizeof(scalar_t) / sizeof(bfloat16));

for (int i = 0; i < b.size(); i++) {
b_bf16[i] = static_cast<bfloat16>(b[i]);
}

for (int i = 0; i < a.size(); i++) {
a_bf16[i] = static_cast<bfloat16>(a[i]);
}

auto a_gpu = blas::helper::allocate<mem_alloc, scalar_t>(m * k, q);
auto b_gpu = blas::helper::allocate<mem_alloc, scalar_t>(k * n, q);
auto c_gpu = blas::helper::allocate<mem_alloc, scalar_t>(m * n, q);

auto copy_a =
blas::helper::copy_to_device<scalar_t>(q, a.data(), a_gpu, m * k);
blas::helper::copy_to_device<scalar_t>(q, reinterpret_cast<scalar_t*>(a_bf16.data()), a_gpu, m * k);
auto copy_b =
blas::helper::copy_to_device<scalar_t>(q, b.data(), b_gpu, n * k);
blas::helper::copy_to_device<scalar_t>(q, reinterpret_cast<scalar_t*>(b_bf16.data()), b_gpu, n * k);
auto copy_c =
blas::helper::copy_to_device<scalar_t>(q, c.data(), c_gpu, m * n);

Expand Down
25 changes: 25 additions & 0 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,31 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU")
"float"
"double"
)

# Joint Matrix specific GEMM configurations
if(${DPCPP_SYCL_TARGET} STREQUAL "intel_gpu_pvc")
add_gemm_configuration(
"float" 64 "false" "false" "false"
64 2 2 8 8 16 1 1 1 1 1 8 16 16 cl::sycl::ext::oneapi::bfloat16 float "no_local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"float" 64 "false" "false" "false"
64 4 2 8 8 16 1 1 1 1 1 8 16 16 cl::sycl::ext::oneapi::bfloat16 float "no_local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"float" 64 "false" "false" "false"
64 8 2 4 16 16 1 1 1 1 1 8 16 16 cl::sycl::ext::oneapi::bfloat16 float "no_local" "standard" "none" 1 "strided" "true")
endif()

if(${DPCPP_SYCL_TARGET} STREQUAL "intel_gpu_dg2_g12")
add_gemm_configuration(
"float" 64 "false" "false" "false"
64 2 2 8 8 8 1 1 1 1 1 8 8 16 cl::sycl::ext::oneapi::bfloat16 float "no_local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"float" 64 "false" "false" "false"
64 4 2 8 8 8 1 1 1 1 1 8 8 16 cl::sycl::ext::oneapi::bfloat16 float "no_local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"float" 64 "false" "false" "false"
64 8 2 4 16 8 1 1 1 1 1 8 8 16 cl::sycl::ext::oneapi::bfloat16 float "no_local" "standard" "none" 1 "strided" "true")
endif()
foreach(data ${supported_types})
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
Expand Down
7 changes: 7 additions & 0 deletions cmake/Modules/FindDPCPP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ if (${DPCPP_SYCL_TARGET} STREQUAL "nvptx64-nvidia-cuda")
list(APPEND DPCPP_FLAGS "-DSYCL_EXT_ONEAPI_MATRIX_VERSION=4")
list(APPEND DPCPP_FLAGS "-DSB_ENABLE_JOINT_MATRIX=1")
endif()
elseif(${DPCPP_SYCL_TARGET} STREQUAL "intel_gpu_pvc" OR ${DPCPP_SYCL_TARGET} STREQUAL "intel_gpu_dg2_g12")
list(APPEND DPCPP_FLAGS "-DSYCL_EXT_INTEL_MATRIX=1")
if(${DPCPP_SYCL_TARGET} STREQUAL "intel_gpu_pvc")
list(APPEND DPCPP_FLAGS "-DSB_ENABLE_JOINT_MATRIX_PVC=1")
else()
list(APPEND DPCPP_FLAGS "-DSB_ENABLE_JOINT_MATRIX_ARC=1")
endif()
endif()

# add compiler directive to enable USM code
Expand Down
86 changes: 86 additions & 0 deletions src/interface/blas3/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,92 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}
#if defined SB_ENABLE_JOINT_MATRIX_PVC || defined SB_ENABLE_JOINT_MATRIX_ARC
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
// TODO: change this if condition to enable the code for bfloat16 input type
if (en_joint_matrix != NULL && *en_joint_matrix == '1' && !s_a && !s_b &&
std::is_same<typename ValueType<container_0_t>::type, float>::value &&
std::is_same<typename ValueType<container_1_t>::type, float>::value &&
std::is_same<typename ValueType<container_2_t>::type, float>::value) {
#ifdef SB_ENABLE_JOINT_MATRIX_PVC
if (_M > 1024 && _N > 1024) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false,
false, 64,
Tile<8, 2, 4, 16, 16, 1, 1, 1, 1, 1, 8, 16, 16,
cl::sycl::ext::oneapi::bfloat16, float>,
_t_b, _t_a, s_b, s_a, static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
true>::template _select_gemm(sb_handle, _N, _M, _K, _alpha, _b,
_ldb, _strideb, _a, _lda, _stridea,
_beta, _c, _ldc, _stridec, batch_size,
_dependencies);
} else if (_M > 64 && _N > 64) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false,
false, 64,
Tile<4, 2, 8, 8, 16, 1, 1, 1, 1, 1, 8, 16, 16,
cl::sycl::ext::oneapi::bfloat16, float>,
_t_b, _t_a, s_b, s_a, static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
true>::template _select_gemm(sb_handle, _N, _M, _K, _alpha, _b,
_ldb, _strideb, _a, _lda, _stridea,
_beta, _c, _ldc, _stridec, batch_size,
_dependencies);

} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false,
false, 64,
Tile<2, 2, 8, 8, 16, 1, 1, 1, 1, 1, 8, 16, 16,
cl::sycl::ext::oneapi::bfloat16, float>,
_t_b, _t_a, s_b, s_a, static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
true>::template _select_gemm(sb_handle, _N, _M, _K, _alpha, _b,
_ldb, _strideb, _a, _lda, _stridea,
_beta, _c, _ldc, _stridec, batch_size,
_dependencies);
}
#else // SB_ENABLE_JOINT_MATRIX_ARC
if (_M > 64 && _N > 64) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false,
false, 64,
Tile<4, 2, 8, 8, 8, 1, 1, 1, 1, 1, 8, 8, 16,
cl::sycl::ext::oneapi::bfloat16, float>,
_t_b, _t_a, s_b, s_a, static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
true>::template _select_gemm(sb_handle, _N, _M, _K, _alpha, _b,
_ldb, _strideb, _a, _lda, _stridea,
_beta, _c, _ldc, _stridec, batch_size,
_dependencies);

} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false,
false, 64,
Tile<2, 2, 8, 8, 8, 1, 1, 1, 1, 1, 8, 8, 16,
cl::sycl::ext::oneapi::bfloat16, float>,
_t_b, _t_a, s_b, s_a, static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
true>::template _select_gemm(sb_handle, _N, _M, _K, _alpha, _b,
_ldb, _strideb, _a, _lda, _stridea,
_beta, _c, _ldc, _stridec, batch_size,
_dependencies);
}
#endif
}
#endif // SB_ENABLE_JOINT_MATRIX_PVC || defined SB_ENABLE_JOINT_MATRIX_ARC
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (!s_a && !s_b) {
/* Tall & Skinny matrices. */
Expand Down
Loading

0 comments on commit 81b98a3

Please sign in to comment.