From a09aba5107ecee6d4e77397153b1354c89bdd34f Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Thu, 15 Feb 2024 18:28:22 +0000 Subject: [PATCH] further cleaning & simplifications --- benchmark/cublas/blas3/gemm_batched.cpp | 2 -- benchmark/cublas/blas3/gemm_batched_strided.cpp | 2 -- include/blas_meta.h | 3 +-- src/interface/blas3/backend/amd_gpu.hpp | 5 ----- src/interface/blas3/backend/intel_gpu.hpp | 2 -- src/interface/blas3/backend/nvidia_gpu.hpp | 2 -- src/interface/gemm_interface.hpp | 15 ++++----------- src/operations/blas3/gemm_common.hpp | 9 --------- 8 files changed, 5 insertions(+), 35 deletions(-) diff --git a/benchmark/cublas/blas3/gemm_batched.cpp b/benchmark/cublas/blas3/gemm_batched.cpp index e84f1956b..d1a4e3ae2 100644 --- a/benchmark/cublas/blas3/gemm_batched.cpp +++ b/benchmark/cublas/blas3/gemm_batched.cpp @@ -100,8 +100,6 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1, cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; - constexpr const bool is_half = std::is_same_v; - cuda_scalar_t alpha_cuda = *reinterpret_cast(&alpha); cuda_scalar_t beta_cuda = *reinterpret_cast(&beta); diff --git a/benchmark/cublas/blas3/gemm_batched_strided.cpp b/benchmark/cublas/blas3/gemm_batched_strided.cpp index 484f10e72..846fd7806 100644 --- a/benchmark/cublas/blas3/gemm_batched_strided.cpp +++ b/benchmark/cublas/blas3/gemm_batched_strided.cpp @@ -119,8 +119,6 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, cublasOperation_t c_t_a = trA ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t c_t_b = trB ? CUBLAS_OP_N : CUBLAS_OP_T; - constexpr const bool is_half = std::is_same_v; - cuda_scalar_t alpha_cuda = *reinterpret_cast(&alpha); cuda_scalar_t beta_cuda = *reinterpret_cast(&beta); diff --git a/include/blas_meta.h b/include/blas_meta.h index 31339d36f..2b95e4c6a 100644 --- a/include/blas_meta.h +++ b/include/blas_meta.h @@ -198,11 +198,10 @@ struct is_sycl_scalar : std::false_type {}; template <> struct is_sycl_scalar : std::false_type {}; -#ifdef BLAS_ENABLE_HALF + template struct is_half : std::integral_constant> {}; -#endif #ifdef BLAS_ENABLE_COMPLEX // SYCL Complex type alias diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index f2dd4dfba..7b6b802f3 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -33,14 +33,9 @@ namespace backend { template -#ifndef BLAS_ENABLE_HALF -typename std::enable_if::value, - typename sb_handle_t::event_t>::type -#else typename std::enable_if::value || is_sycl_scalar::value, typename sb_handle_t::event_t>::type -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index 4cb06d807..40a3f8e9d 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -210,7 +210,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } // Half Configurations -#ifdef BLAS_ENABLE_HALF template @@ -269,7 +268,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } } } -#endif // Complex Configurations #ifdef BLAS_ENABLE_COMPLEX diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index 9ae246831..1a66c3df5 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -176,7 +176,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } // Half Configurations -#ifdef BLAS_ENABLE_HALF template @@ -232,7 +231,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } } } -#endif // BLAS_ENABLE_HALF // Complex Configurations #ifdef BLAS_ENABLE_COMPLEX diff --git a/src/interface/gemm_interface.hpp b/src/interface/gemm_interface.hpp index 90b22d320..4d9b54c0a 100644 --- a/src/interface/gemm_interface.hpp +++ b/src/interface/gemm_interface.hpp @@ -48,21 +48,14 @@ namespace blas { */ namespace internal { -// Check whether value is zero (complex & float/double) +// Check whether value is zero (complex & half/float/double) template -inline typename std::enable_if::value, bool>::type isZero( - const T& value) { +inline typename std::enable_if::value || is_half::value, + bool>::type +isZero(const T& value) { return (value == static_cast(0)); } -#ifdef BLAS_ENABLE_HALF -template -inline typename std::enable_if::value, bool>::type isZero( - const T& value) { - return (value == static_cast(0)); -} -#endif - #ifdef BLAS_ENABLE_COMPLEX template inline typename std::enable_if::value, bool>::type isZero( diff --git a/src/operations/blas3/gemm_common.hpp b/src/operations/blas3/gemm_common.hpp index 0c0842e5c..a94b2d054 100644 --- a/src/operations/blas3/gemm_common.hpp +++ b/src/operations/blas3/gemm_common.hpp @@ -42,14 +42,6 @@ mul_add(T a, T b, T c, } #endif -#ifndef BLAS_ENABLE_HALF -template -static PORTBLAS_INLINE T -mul_add(T a, T b, T c, - typename std::enable_if::value>::type * = 0) { - return (cl::sycl::mad(a, b, c)); -} -#else template static PORTBLAS_INLINE T mul_add( T a, T b, T c, @@ -57,7 +49,6 @@ static PORTBLAS_INLINE T mul_add( * = 0) { return (cl::sycl::mad(a, b, c)); } -#endif template struct type_string {