diff --git a/src/interface/extension/backend/amd_gpu.hpp b/src/interface/extension/backend/amd_gpu.hpp index 63ae8725d..f969f77a1 100644 --- a/src/interface/extension/backend/amd_gpu.hpp +++ b/src/interface/extension/backend/amd_gpu.hpp @@ -153,7 +153,7 @@ typename sb_handle_t::event_t _axpy_batch( const auto nWG = (_N + local_size - 1) / local_size; // the limit for _N*batch_size is taken empirically from test on AMDW6800 const index_t global_size = - (_N * _batch_size >= 327680) + (_N * _batch_size >= 163840) ? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG : local_size * nWG * _batch_size; return blas::internal::_axpy_batch_impl<256, 32>( diff --git a/src/interface/extension/backend/default_cpu.hpp b/src/interface/extension/backend/default_cpu.hpp index 3e11f4587..d8a2f6c24 100644 --- a/src/interface/extension/backend/default_cpu.hpp +++ b/src/interface/extension/backend/default_cpu.hpp @@ -124,11 +124,10 @@ typename sb_handle_t::event_t _axpy_batch( // local_size taken empirically constexpr index_t local_size = static_cast(256); const auto nWG = (_N + local_size - 1) / local_size; - // the limit for _N*batch_size is taken empirically from test on AMDW6800 - const index_t global_size = - (_N * _batch_size >= 327680) - ? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG - : local_size * nWG * _batch_size; + // the limit for _N*batch_size is taken empirically from test on i9 CPU + const index_t global_size = (_N * _batch_size >= 163840) + ? local_size * nWG + : local_size * nWG * _batch_size; return blas::internal::_axpy_batch_impl<256, 32>( sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, _batch_size, _dependencies, global_size); diff --git a/src/interface/extension/backend/intel_gpu.hpp b/src/interface/extension/backend/intel_gpu.hpp index b2d9a7b8e..90ec53746 100644 --- a/src/interface/extension/backend/intel_gpu.hpp +++ b/src/interface/extension/backend/intel_gpu.hpp @@ -146,11 +146,10 @@ typename sb_handle_t::event_t _axpy_batch( constexpr index_t local_size = static_cast(256); const auto nWG = (_N + local_size - 1) / local_size; // the limit for _N*batch_size is taken empirically from test on intelGPU - const index_t global_size = - (_N * _batch_size >= 327680) - ? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG - : local_size * nWG * _batch_size; - return blas::internal::_axpy_batch_impl<256, 32>( + const index_t global_size = (_N * _batch_size > 327680) + ? local_size * nWG + : local_size * nWG * _batch_size; + return blas::internal::_axpy_batch_impl<256, 128>( sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, _batch_size, _dependencies, global_size); } diff --git a/src/interface/extension/backend/nvidia_gpu.hpp b/src/interface/extension/backend/nvidia_gpu.hpp index 2cc56941d..d378abccf 100644 --- a/src/interface/extension/backend/nvidia_gpu.hpp +++ b/src/interface/extension/backend/nvidia_gpu.hpp @@ -151,14 +151,23 @@ typename sb_handle_t::event_t _axpy_batch( // local_size taken empirically constexpr index_t local_size = static_cast(256); const auto nWG = (_N + local_size - 1) / local_size; - // the limit for _N*batch_size is taken empirically from test on AMDW6800 - const index_t global_size = - (_N * _batch_size >= 327680) - ? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG - : local_size * nWG * _batch_size; - return blas::internal::_axpy_batch_impl<256, 32>( - sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, - _batch_size, _dependencies, global_size); + // the limit for _N*batch_size is taken empirically from test on A100 + if (_N * _batch_size <= 81920) { + const index_t global_size = local_size * nWG * _batch_size; + return blas::internal::_axpy_batch_impl<256, 32>( + sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, + _batch_size, _dependencies, global_size); + } else if (_N <= (1 << 19)) { + const index_t global_size = local_size * nWG; + return blas::internal::_axpy_batch_impl<256, 64>( + sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, + _batch_size, _dependencies, global_size); + } else { + const index_t global_size = (local_size * nWG); + return blas::internal::_axpy_batch_impl<256, 128>( + sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y, + _batch_size, _dependencies, global_size); + } } } // namespace backend } // namespace axpy_batch