Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Update gemm config for CPUs & Update complex header #484

Merged
merged 3 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
# cancel outdated builds on pull requests.
skip-check:
continue-on-error: true
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
# Map a step output to a job output
outputs:
should_skip: ${{ steps.skip_check.outputs.should_skip }}
Expand All @@ -34,7 +34,7 @@ jobs:
Build-and-Test:
needs: skip-check
if: ${{ needs.pre_job.outputs.should_skip != 'true' }}
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
strategy:
fail-fast: false
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .scripts/build_dpcpp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -ev
###########################
# Get DPCPP
###########################
wget --no-verbose https://github.com/intel/llvm/releases/download/sycl-nightly/20230727/dpcpp-compiler.tar.gz -O dpcpp-compiler.tar.gz
wget --no-verbose https://github.com/intel/llvm/releases/download/nightly-2023-12-06/sycl_linux.tar.gz -O dpcpp-compiler.tar.gz
rm -rf /tmp/dpcpp && mkdir /tmp/dpcpp/
tar -xzf dpcpp-compiler.tar.gz -C /tmp/dpcpp --strip-components 1
ls -R /tmp/dpcpp/
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM ubuntu:focal
FROM ubuntu:jammy

# Default values for the build
ARG command
Expand Down
13 changes: 9 additions & 4 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -616,19 +616,24 @@ else() # default cpu backend
64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "naive" "none" 1 "strided" "false" "false")
else()
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 2 "strided" "false" "false")
"${data}" 128 "false" "false" "false"
64 2 2 2 2 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 2 "strided" "false" "false")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false")
"${data}" 128 "false" "false" "false"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false")
add_gemm_configuration(
"${data}" 128 "false" "false" "false"
64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 2 "strided" "false" "false")

endif()

add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false")

endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
Expand Down
23 changes: 17 additions & 6 deletions src/interface/blas3/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,39 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
#else
if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) {
if (_M <= 128 && _N <= 128 && _K <= 256 && !s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b,
container_0_t, container_1_t, container_2_t, 128, false, false, false,
64, Tile<2, 2, 2, 2>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 2,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (!s_a && !s_b) {
} else if ((_M * _N) >= 524288 && !s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b,
container_0_t, container_1_t, container_2_t, 128, false, false, false,
64, Tile<4, 4, 4, 4>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (!s_a && !s_b) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 128, false, false, false,
64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
Expand Down