diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 236b88710..136c66492 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -16,7 +16,7 @@ jobs: # cancel outdated builds on pull requests. skip-check: continue-on-error: true - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 # Map a step output to a job output outputs: should_skip: ${{ steps.skip_check.outputs.should_skip }} @@ -34,7 +34,7 @@ jobs: Build-and-Test: needs: skip-check if: ${{ needs.pre_job.outputs.should_skip != 'true' }} - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: diff --git a/.scripts/build_dpcpp.sh b/.scripts/build_dpcpp.sh index 8839e1a7b..b137bbb63 100644 --- a/.scripts/build_dpcpp.sh +++ b/.scripts/build_dpcpp.sh @@ -5,7 +5,7 @@ set -ev ########################### # Get DPCPP ########################### -wget --no-verbose https://github.com/intel/llvm/releases/download/sycl-nightly/20230727/dpcpp-compiler.tar.gz -O dpcpp-compiler.tar.gz +wget --no-verbose https://github.com/intel/llvm/releases/download/nightly-2023-12-06/sycl_linux.tar.gz -O dpcpp-compiler.tar.gz rm -rf /tmp/dpcpp && mkdir /tmp/dpcpp/ tar -xzf dpcpp-compiler.tar.gz -C /tmp/dpcpp --strip-components 1 ls -R /tmp/dpcpp/ diff --git a/Dockerfile b/Dockerfile index f60cba986..4e26a9db7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:focal +FROM ubuntu:jammy # Default values for the build ARG command diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index f2d3244dc..c7804d0aa 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -616,19 +616,24 @@ else() # default cpu backend 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "naive" "none" 1 "strided" "false" "false") else() add_gemm_configuration( - "${data}" 64 "false" "false" "false" - 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 2 "strided" "false" "false") + "${data}" 128 "false" "false" "false" + 64 2 2 2 2 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 2 "strided" "false" "false") add_gemm_configuration( - "${data}" 64 "false" "false" "false" - 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") + "${data}" 128 "false" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") + add_gemm_configuration( + "${data}" 128 "false" "false" "false" + 64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 2 "strided" "false" "false") + endif() add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false") + endforeach() if(BLAS_ENABLE_COMPLEX) # Extract list of complex for each data in supported_types diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 4fbf341e2..e0b519e61 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -69,10 +69,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); #else - if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { + if (_M <= 128 && _N <= 128 && _K <= 256 && !s_a && !s_b) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + container_0_t, container_1_t, container_2_t, 128, false, false, false, + 64, Tile<2, 2, 2, 2>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 2, @@ -80,10 +80,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (!s_a && !s_b) { + } else if ((_M * _N) >= 524288 && !s_a && !s_b) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + container_0_t, container_1_t, container_2_t, 128, false, false, false, + 64, Tile<4, 4, 4, 4>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, @@ -91,6 +91,17 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); + } else if (!s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, false, false, + 64, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false,