Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Update tuning parameters to get better performance
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick committed Nov 10, 2023
1 parent daca9e8 commit 2ddf10a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/interface/extension/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ typename sb_handle_t::event_t _axpy_batch(
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on AMDW6800
const index_t global_size =
(_N * _batch_size >= 327680)
(_N * _batch_size >= 163840)
? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
Expand Down
9 changes: 4 additions & 5 deletions src/interface/extension/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,10 @@ typename sb_handle_t::event_t _axpy_batch(
// local_size taken empirically
constexpr index_t local_size = static_cast<index_t>(256);
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on AMDW6800
const index_t global_size =
(_N * _batch_size >= 327680)
? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG
: local_size * nWG * _batch_size;
// the limit for _N*batch_size is taken empirically from test on i9 CPU
const index_t global_size = (_N * _batch_size >= 163840)
? local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
Expand Down
9 changes: 4 additions & 5 deletions src/interface/extension/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,10 @@ typename sb_handle_t::event_t _axpy_batch(
constexpr index_t local_size = static_cast<index_t>(256);
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on intelGPU
const index_t global_size =
(_N * _batch_size >= 327680)
? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
const index_t global_size = (_N * _batch_size > 327680)
? local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 128>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
}
Expand Down
25 changes: 17 additions & 8 deletions src/interface/extension/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,23 @@ typename sb_handle_t::event_t _axpy_batch(
// local_size taken empirically
constexpr index_t local_size = static_cast<index_t>(256);
const auto nWG = (_N + local_size - 1) / local_size;
// the limit for _N*batch_size is taken empirically from test on AMDW6800
const index_t global_size =
(_N * _batch_size >= 327680)
? (_N > (1 << 19)) ? (local_size * nWG) / 4 : local_size * nWG
: local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
// the limit for _N*batch_size is taken empirically from test on A100
if (_N * _batch_size <= 81920) {
const index_t global_size = local_size * nWG * _batch_size;
return blas::internal::_axpy_batch_impl<256, 32>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
} else if (_N <= (1 << 19)) {
const index_t global_size = local_size * nWG;
return blas::internal::_axpy_batch_impl<256, 64>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
} else {
const index_t global_size = (local_size * nWG);
return blas::internal::_axpy_batch_impl<256, 128>(
sb_handle, _N, _alpha, _vx, _incx, _stride_x, _vy, _incy, _stride_y,
_batch_size, _dependencies, global_size);
}
}
} // namespace backend
} // namespace axpy_batch
Expand Down

0 comments on commit 2ddf10a

Please sign in to comment.