Skip to content

Commit

Permalink
removed unecessary half data branching/casting in rocBLAS bench
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Feb 15, 2024
1 parent e1804e0 commit e535773
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 201 deletions.
2 changes: 0 additions & 2 deletions benchmark/cublas/blas3/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;

constexpr const bool is_half = std::is_same_v<scalar_t, cl::sycl::half>;

cuda_scalar_t alpha_cuda = *reinterpret_cast<cuda_scalar_t*>(&alpha);
cuda_scalar_t beta_cuda = *reinterpret_cast<cuda_scalar_t*>(&beta);

Expand Down
5 changes: 1 addition & 4 deletions benchmark/cublas/blas3/gemm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,9 @@ static inline void cublas_routine(args_t&&... args) {
CUBLAS_CHECK(cublasSgemmBatched(std::forward<args_t>(args)...));
} else if constexpr (std::is_same_v<scalar_t, double>) {
CUBLAS_CHECK(cublasDgemmBatched(std::forward<args_t>(args)...));
}
#ifdef BLAS_ENABLE_HALF
else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
} else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
CUBLAS_CHECK(cublasHgemmBatched(std::forward<args_t>(args)...));
}
#endif
return;
}

Expand Down
5 changes: 1 addition & 4 deletions benchmark/cublas/blas3/gemm_batched_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,9 @@ static inline void cublas_routine(args_t&&... args) {
CUBLAS_CHECK(cublasSgemmStridedBatched(std::forward<args_t>(args)...));
} else if constexpr (std::is_same_v<scalar_t, double>) {
CUBLAS_CHECK(cublasDgemmStridedBatched(std::forward<args_t>(args)...));
}
#ifdef BLAS_ENABLE_HALF
else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
} else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
CUBLAS_CHECK(cublasHgemmStridedBatched(std::forward<args_t>(args)...));
}
#endif
return;
}

Expand Down
61 changes: 10 additions & 51 deletions benchmark/rocblas/blas3/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,9 @@ static inline void rocblas_gemm_f(args_t&&... args) {
CHECK_ROCBLAS_STATUS(rocblas_sgemm(std::forward<args_t>(args)...));
} else if constexpr (std::is_same_v<scalar_t, double>) {
CHECK_ROCBLAS_STATUS(rocblas_dgemm(std::forward<args_t>(args)...));
}
#ifdef BLAS_ENABLE_HALF
else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
} else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
CHECK_ROCBLAS_STATUS(rocblas_hgemm(std::forward<args_t>(args)...));
}
#endif
return;
}

Expand All @@ -59,9 +56,6 @@ template <typename scalar_t>
void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i,
int t_b_i, index_t m, index_t k, index_t n, scalar_t alpha,
scalar_t beta, bool* success) {
// scalar_t if scalar_t!=sycl::half, float otherwise
using ref_scalar_t =
typename blas_benchmark::utils::ReferenceType<scalar_t>::type;
// scalar_t if scalar_t!=sycl::half, rocblas_half otherwise
using rocm_scalar_t =
typename blas_benchmark::utils::RocblasType<scalar_t>::type;
Expand Down Expand Up @@ -111,55 +105,20 @@ void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i,
blas_benchmark::utils::HIPVector<rocm_scalar_t> c_gpu(
c_size, reinterpret_cast<rocm_scalar_t*>(c.data()));

constexpr const bool is_half = std::is_same_v<scalar_t, cl::sycl::half>;

rocm_scalar_t alpha_rocm, beta_rocm;

if constexpr (is_half) {
#ifdef BLAS_ENABLE_HALF
// sycl::half to rocblas__half
alpha_rocm = *reinterpret_cast<rocm_scalar_t*>(&alpha);
beta_rocm = *reinterpret_cast<rocm_scalar_t*>(&beta);
} else {
#endif
alpha_rocm = alpha;
beta_rocm = beta;
}
rocm_scalar_t alpha_rocm = *reinterpret_cast<rocm_scalar_t*>(&alpha);
rocm_scalar_t beta_rocm = *reinterpret_cast<rocm_scalar_t*>(&beta);

#ifdef BLAS_VERIFY_BENCHMARK
// Reference gemm
std::vector<ref_scalar_t> c_ref(n * m, 0);
std::vector<scalar_t> c_temp(n * m, 0);

if constexpr (is_half) {
// Float-type variables for reference ops
ref_scalar_t alpha_f = alpha;
ref_scalar_t beta_f = beta;
std::vector<ref_scalar_t> a_f(m * k);
std::vector<ref_scalar_t> b_f(k * n);

// sycl::half to float reference type
std::transform(a.begin(), a.end(), a_f.begin(),
[](scalar_t x) { return (static_cast<ref_scalar_t>(x)); });
std::transform(b.begin(), b.end(), b_f.begin(),
[](scalar_t x) { return (static_cast<ref_scalar_t>(x)); });

reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha_f, a_f.data(), lda,
b_f.data(), ldb, beta_f, c_ref.data(), ldc);
std::vector<scalar_t> c_ref = c;
reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha, a.data(), lda,
b.data(), ldb, beta, c_ref.data(), ldc);

// Rocblas verification gemm
std::vector<scalar_t> c_temp = c;
{
blas_benchmark::utils::HIPVector<rocm_scalar_t, true> c_temp_gpu(
m * n, reinterpret_cast<rocm_scalar_t*>(c_temp.data()));

rocblas_gemm_f<scalar_t>(rb_handle, trans_a_rb, trans_b_rb, m, n, k,
&alpha_rocm, a_gpu, lda, b_gpu, ldb, &beta_rocm,
c_temp_gpu, ldc);

} else {
reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha, a.data(), lda,
b.data(), ldb, beta, c_ref.data(), ldc);

blas_benchmark::utils::HIPVector<scalar_t, true> c_temp_gpu(m * n,
c_temp.data());
c_size, reinterpret_cast<rocm_scalar_t*>(c_temp.data()));
rocblas_gemm_f<scalar_t>(rb_handle, trans_a_rb, trans_b_rb, m, n, k,
&alpha_rocm, a_gpu, lda, b_gpu, ldb, &beta_rocm,
c_temp_gpu, ldc);
Expand Down
77 changes: 15 additions & 62 deletions benchmark/rocblas/blas3/gemm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,9 @@ static inline void rocblas_gemm_batched_f(args_t&&... args) {
CHECK_ROCBLAS_STATUS(rocblas_sgemm_batched(std::forward<args_t>(args)...));
} else if constexpr (std::is_same_v<scalar_t, double>) {
CHECK_ROCBLAS_STATUS(rocblas_dgemm_batched(std::forward<args_t>(args)...));
}
#ifdef BLAS_ENABLE_HALF
else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
} else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
CHECK_ROCBLAS_STATUS(rocblas_hgemm_batched(std::forward<args_t>(args)...));
}
#endif
return;
}

Expand All @@ -59,9 +56,6 @@ template <typename scalar_t>
void run(benchmark::State& state, rocblas_handle& rb_handle, index_t t_a_i,
index_t t_b_i, index_t m, index_t k, index_t n, scalar_t alpha,
scalar_t beta, index_t batch_size, int batch_type_i, bool* success) {
// scalar_t if scalar_t!=sycl::half, float otherwise
using ref_scalar_t =
typename blas_benchmark::utils::ReferenceType<scalar_t>::type;
// scalar_t if scalar_t!=sycl::half, rocblas_half otherwise
using rocm_scalar_t =
typename blas_benchmark::utils::RocblasType<scalar_t>::type;
Expand Down Expand Up @@ -116,68 +110,27 @@ void run(benchmark::State& state, rocblas_handle& rb_handle, index_t t_a_i,
blas_benchmark::utils::HIPVectorBatched<rocm_scalar_t> c_batched_gpu(
c_size, batch_size);

constexpr const bool is_half = std::is_same_v<scalar_t, cl::sycl::half>;

rocm_scalar_t alpha_rocm, beta_rocm;

if constexpr (is_half) {
#ifdef BLAS_ENABLE_HALF
// sycl::half to rocblas__half
alpha_rocm = *reinterpret_cast<rocm_scalar_t*>(&alpha);
beta_rocm = *reinterpret_cast<rocm_scalar_t*>(&beta);
} else {
#endif
alpha_rocm = alpha;
beta_rocm = beta;
}

rocm_scalar_t alpha_rocm = *reinterpret_cast<rocm_scalar_t*>(&alpha);
rocm_scalar_t beta_rocm = *reinterpret_cast<rocm_scalar_t*>(&beta);
#ifdef BLAS_VERIFY_BENCHMARK
std::vector<ref_scalar_t> c_ref(c_size * batch_size, 0);
std::vector<scalar_t> c_temp(c_size * batch_size, 0);

if constexpr (is_half) {
// Float-type variables for reference ops
ref_scalar_t alpha_f = alpha;
ref_scalar_t beta_f = beta;
std::vector<ref_scalar_t> a_f(a_size * batch_size);
std::vector<ref_scalar_t> b_f(b_size * batch_size);

// sycl::half to float reference type
std::transform(a.begin(), a.end(), a_f.begin(),
[](scalar_t x) { return (static_cast<ref_scalar_t>(x)); });
std::transform(b.begin(), b.end(), b_f.begin(),
[](scalar_t x) { return (static_cast<ref_scalar_t>(x)); });

// Reference batched gemm
for (int batch = 0; batch < batch_size; batch++) {
reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha_f,
a_f.data() + batch * a_size, lda,
b_f.data() + batch * b_size, ldb, beta_f,
c_ref.data() + batch * c_size, ldc);
}
// Reference batched gemm
std::vector<scalar_t> c_ref = c;
for (int batch = 0; batch < batch_size; batch++) {
reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha,
a.data() + batch * a_size, lda,
b.data() + batch * b_size, ldb, beta,
c_ref.data() + batch * c_size, ldc);
}

// Rocblas verification gemm_batched
// Rocblas verification
// gemm_batched
std::vector<scalar_t> c_temp = c;
{
blas_benchmark::utils::HIPVectorBatched<rocm_scalar_t, true> c_temp_gpu(
c_size, batch_size, reinterpret_cast<rocm_scalar_t*>(c_temp.data()));
rocblas_gemm_batched_f<scalar_t>(
rb_handle, trans_a_rb, trans_b_rb, m, n, k, &alpha_rocm, a_batched_gpu,
lda, b_batched_gpu, ldb, &beta_rocm, c_temp_gpu, ldc, batch_size);

} else {
// Reference batched gemm
for (int batch = 0; batch < batch_size; batch++) {
reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha,
a.data() + batch * a_size, lda,
b.data() + batch * b_size, ldb, beta,
c_ref.data() + batch * c_size, ldc);
}

// Rocblas verification gemm_batched
blas_benchmark::utils::HIPVectorBatched<scalar_t, true> c_temp_gpu(
c_size, batch_size, c_temp.data());
rocblas_gemm_batched_f<scalar_t>(
rb_handle, trans_a_rb, trans_b_rb, m, n, k, &alpha_rocm, a_batched_gpu,
lda, b_batched_gpu, ldb, &beta_rocm, c_temp_gpu, ldc, batch_size);
}

std::ostringstream err_stream;
Expand Down
77 changes: 14 additions & 63 deletions benchmark/rocblas/blas3/gemm_batched_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,10 @@ static inline void rocblas_gemm_strided_batched(args_t&&... args) {
} else if constexpr (std::is_same_v<scalar_t, double>) {
CHECK_ROCBLAS_STATUS(
rocblas_dgemm_strided_batched(std::forward<args_t>(args)...));
}
#ifdef BLAS_ENABLE_HALF
else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
} else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
CHECK_ROCBLAS_STATUS(
rocblas_hgemm_strided_batched(std::forward<args_t>(args)...));
}
#endif
return;
}

Expand All @@ -65,9 +62,6 @@ void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i,
int t_b_i, index_t m, index_t k, index_t n, scalar_t alpha,
scalar_t beta, index_t batch_size, index_t stride_a_mul,
index_t stride_b_mul, index_t stride_c_mul, bool* success) {
// scalar_t if scalar_t!=sycl::half, float otherwise
using ref_scalar_t =
typename blas_benchmark::utils::ReferenceType<scalar_t>::type;
// scalar_t if scalar_t!=sycl::half, rocblas_half otherwise
using rocm_scalar_t =
typename blas_benchmark::utils::RocblasType<scalar_t>::type;
Expand Down Expand Up @@ -131,71 +125,28 @@ void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i,
blas_benchmark::utils::HIPVectorBatchedStrided<rocm_scalar_t> c_batched_gpu(
c_size, batch_size, stride_c, reinterpret_cast<rocm_scalar_t*>(c.data()));

constexpr const bool is_half = std::is_same_v<scalar_t, cl::sycl::half>;

rocm_scalar_t alpha_rocm, beta_rocm;

if constexpr (is_half) {
#ifdef BLAS_ENABLE_HALF
// sycl::half to rocblas__half
alpha_rocm = *reinterpret_cast<rocm_scalar_t*>(&alpha);
beta_rocm = *reinterpret_cast<rocm_scalar_t*>(&beta);
} else {
#endif
alpha_rocm = alpha;
beta_rocm = beta;
}

rocm_scalar_t alpha_rocm = *reinterpret_cast<rocm_scalar_t*>(&alpha);
rocm_scalar_t beta_rocm = *reinterpret_cast<rocm_scalar_t*>(&beta);
#ifdef BLAS_VERIFY_BENCHMARK
std::vector<ref_scalar_t> c_ref(size_c_batch, 0);
std::vector<scalar_t> c_temp(size_c_batch, 0);

if constexpr (is_half) {
// Float-type variables for reference ops
ref_scalar_t alpha_f = alpha;
ref_scalar_t beta_f = beta;
std::vector<ref_scalar_t> a_f(size_a_batch);
std::vector<ref_scalar_t> b_f(size_b_batch);

// sycl::half to float reference type
std::transform(a.begin(), a.end(), a_f.begin(),
[](scalar_t x) { return (static_cast<ref_scalar_t>(x)); });
std::transform(b.begin(), b.end(), b_f.begin(),
[](scalar_t x) { return (static_cast<ref_scalar_t>(x)); });

// Reference batched gemm
for (int batch = 0; batch < batch_size; batch++) {
reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha_f,
a_f.data() + batch * stride_a, lda,
b_f.data() + batch * stride_b, ldb, beta_f,
c_ref.data() + batch * stride_c, ldc);
}
// Reference gemm batched strided (strided loop of gemm)
std::vector<scalar_t> c_ref = c;
for (int batch = 0; batch < batch_size; batch++) {
reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha,
a.data() + batch * stride_a, lda,
b.data() + batch * stride_b, ldb, beta,
c_ref.data() + batch * stride_c, ldc);
}

// Rocblas verification gemm_batched_strided
// Rocblas verification gemm_batched_strided
std::vector<scalar_t> c_temp = c;
{
blas_benchmark::utils::HIPVectorBatchedStrided<rocm_scalar_t, true>
c_temp_gpu(c_size, batch_size, stride_c,
reinterpret_cast<rocm_scalar_t*>(c_temp.data()));
rocblas_gemm_strided_batched<scalar_t>(
rb_handle, trans_a_rb, trans_b_rb, m, n, k, &alpha_rocm, a_batched_gpu,
lda, stride_a, b_batched_gpu, ldb, stride_b, &beta_rocm, c_temp_gpu,
ldc, stride_c, batch_size);

} else {
// Reference batched gemm
for (int batch = 0; batch < batch_size; batch++) {
reference_blas::gemm(t_a_str, t_b_str, m, n, k, alpha,
a.data() + batch * stride_a, lda,
b.data() + batch * stride_b, ldb, beta,
c_ref.data() + batch * stride_c, ldc);
}

// Rocblas verification gemm_batched_strided
blas_benchmark::utils::HIPVectorBatchedStrided<scalar_t, true> c_temp_gpu(
c_size, batch_size, stride_c, c_temp.data());
rocblas_gemm_strided_batched<scalar_t>(
rb_handle, trans_a_rb, trans_b_rb, m, n, k, &alpha_rocm, a_batched_gpu,
lda, stride_a, b_batched_gpu, ldb, stride_b, &beta_rocm, c_temp_gpu,
ldc, stride_c, batch_size);
}

std::ostringstream err_stream;
Expand Down
17 changes: 3 additions & 14 deletions benchmark/rocblas/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,31 +375,20 @@ static inline std::tuple<double, double> timef_hip(function_t func,
}

/**
* Reference type of the underlying tests data aimed to match the reference
* library in tests/benchmarks and random number generator APIs.
* Reference type of the underlying benchmark data aimed to match the
* rocm/rocBLAS scalar types.
*/
template <typename T, typename Enable = void>
struct RocblasType {
using type = T;
};

template <typename T, typename Enable = void>
struct ReferenceType {
using type = T;
};

#ifdef BLAS_ENABLE_HALF
// When T is sycl::half, use float as type for reference BLAS implementations.
// When T is sycl::half, use rocBLAS's rocblas_half as type.
template <typename T>
struct RocblasType<T, std::enable_if_t<std::is_same_v<T, cl::sycl::half>>> {
using type = rocblas_half;
};

template <typename T>
struct ReferenceType<T, std::enable_if_t<std::is_same_v<T, cl::sycl::half>>> {
using type = float;
};

#endif
} // namespace utils
} // namespace blas_benchmark
Expand Down
3 changes: 2 additions & 1 deletion test/blas_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ static inline scalar_t random_scalar(scalar_t rangeMin, scalar_t rangeMax) {
static std::random_device rd;
static std::default_random_engine gen(rd());
using random_scalar_t =
std::conditional_t<std::is_same_v<scalar_t, cl::sycl::half>, float, scalar_t>;
std::conditional_t<std::is_same_v<scalar_t, cl::sycl::half>, float,
scalar_t>;
std::uniform_real_distribution<random_scalar_t> dis(rangeMin, rangeMax);
return dis(gen);
}
Expand Down

0 comments on commit e535773

Please sign in to comment.