Skip to content

Commit

Permalink
further cleaning & simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Feb 15, 2024
1 parent e535773 commit a09aba5
Show file tree
Hide file tree
Showing 8 changed files with 5 additions and 35 deletions.
2 changes: 0 additions & 2 deletions benchmark/cublas/blas3/gemm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1,
cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;

constexpr const bool is_half = std::is_same_v<scalar_t, cl::sycl::half>;

cuda_scalar_t alpha_cuda = *reinterpret_cast<cuda_scalar_t*>(&alpha);
cuda_scalar_t beta_cuda = *reinterpret_cast<cuda_scalar_t*>(&beta);

Expand Down
2 changes: 0 additions & 2 deletions benchmark/cublas/blas3/gemm_batched_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
cublasOperation_t c_t_a = trA ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t c_t_b = trB ? CUBLAS_OP_N : CUBLAS_OP_T;

constexpr const bool is_half = std::is_same_v<scalar_t, cl::sycl::half>;

cuda_scalar_t alpha_cuda = *reinterpret_cast<cuda_scalar_t*>(&alpha);
cuda_scalar_t beta_cuda = *reinterpret_cast<cuda_scalar_t*>(&beta);

Expand Down
3 changes: 1 addition & 2 deletions include/blas_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,10 @@ struct is_sycl_scalar<float *> : std::false_type {};

template <>
struct is_sycl_scalar<double *> : std::false_type {};
#ifdef BLAS_ENABLE_HALF

template <class type>
struct is_half
: std::integral_constant<bool, std::is_same_v<type, cl::sycl::half>> {};
#endif

#ifdef BLAS_ENABLE_COMPLEX
// SYCL Complex type alias
Expand Down
5 changes: 0 additions & 5 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,9 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
#ifndef BLAS_ENABLE_HALF
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename sb_handle_t::event_t>::type
#else
typename std::enable_if<is_half<element_t>::value ||
is_sycl_scalar<element_t>::value,
typename sb_handle_t::event_t>::type
#endif
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
Expand Down
2 changes: 0 additions & 2 deletions src/interface/blas3/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
}

// Half Configurations
#ifdef BLAS_ENABLE_HALF
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
Expand Down Expand Up @@ -269,7 +268,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
}
}
}
#endif

// Complex Configurations
#ifdef BLAS_ENABLE_COMPLEX
Expand Down
2 changes: 0 additions & 2 deletions src/interface/blas3/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
}

// Half Configurations
#ifdef BLAS_ENABLE_HALF
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
Expand Down Expand Up @@ -232,7 +231,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
}
}
}
#endif // BLAS_ENABLE_HALF

// Complex Configurations
#ifdef BLAS_ENABLE_COMPLEX
Expand Down
15 changes: 4 additions & 11 deletions src/interface/gemm_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,14 @@ namespace blas {
*/
namespace internal {

// Check whether value is zero (complex & float/double)
// Check whether value is zero (complex & half/float/double)
template <typename T>
inline typename std::enable_if<is_sycl_scalar<T>::value, bool>::type isZero(
const T& value) {
inline typename std::enable_if<is_sycl_scalar<T>::value || is_half<T>::value,
bool>::type
isZero(const T& value) {
return (value == static_cast<T>(0));
}

#ifdef BLAS_ENABLE_HALF
template <typename T>
inline typename std::enable_if<is_half<T>::value, bool>::type isZero(
const T& value) {
return (value == static_cast<T>(0));
}
#endif

#ifdef BLAS_ENABLE_COMPLEX
template <typename T>
inline typename std::enable_if<is_complex_sycl<T>::value, bool>::type isZero(
Expand Down
9 changes: 0 additions & 9 deletions src/operations/blas3/gemm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,13 @@ mul_add(T a, T b, T c,
}
#endif

#ifndef BLAS_ENABLE_HALF
template <typename T>
static PORTBLAS_INLINE T
mul_add(T a, T b, T c,
typename std::enable_if<is_sycl_scalar<T>::value>::type * = 0) {
return (cl::sycl::mad(a, b, c));
}
#else
template <typename T>
static PORTBLAS_INLINE T mul_add(
T a, T b, T c,
typename std::enable_if<is_half<T>::value || is_sycl_scalar<T>::value>::type
* = 0) {
return (cl::sycl::mad(a, b, c));
}
#endif

template <typename T>
struct type_string {
Expand Down

0 comments on commit a09aba5

Please sign in to comment.