Skip to content

Commit

Permalink
further simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Feb 16, 2024
1 parent b8a5be3 commit af93a62
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ function(add_gemm_configuration
cpp_type(cpp_data ${data})
foreach(symm_a ${boolean_list})
foreach(symm_b ${boolean_list})
if ((${data} MATCHES "complex") AND (symm_a OR symm_b))
if ((${data} MATCHES "half") AND (symm_a OR symm_b))
continue()
endif()
if (symm_a AND symm_b)
Expand Down
2 changes: 1 addition & 1 deletion include/blas_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ struct is_sycl_scalar
std::false_type>::type {};

template <>
struct is_sycl_scalar<cl::sycl::half> : std::false_type {};
struct is_sycl_scalar<cl::sycl::half> : std::true_type {};

template <>
struct is_sycl_scalar<float *> : std::false_type {};
Expand Down
3 changes: 1 addition & 2 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_half<element_t>::value ||
is_sycl_scalar<element_t>::value,
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
Expand Down
3 changes: 2 additions & 1 deletion src/interface/blas3/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename std::enable_if<is_sycl_scalar<element_t>::value &&
!is_half<element_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
Expand Down
3 changes: 2 additions & 1 deletion src/interface/blas3/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename std::enable_if<is_sycl_scalar<element_t>::value &&
!is_half<element_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
Expand Down
3 changes: 2 additions & 1 deletion src/interface/blas3/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename std::enable_if<is_sycl_scalar<element_t>::value &&
!is_half<element_t>::value,
typename sb_handle_t::event_t>::type
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
Expand Down
5 changes: 2 additions & 3 deletions src/interface/gemm_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ namespace internal {

// Check whether value is zero (complex & half/float/double)
template <typename T>
inline typename std::enable_if<is_sycl_scalar<T>::value || is_half<T>::value,
bool>::type
isZero(const T& value) {
inline typename std::enable_if<is_sycl_scalar<T>::value, bool>::type isZero(
const T& value) {
return (value == static_cast<T>(0));
}

Expand Down
7 changes: 3 additions & 4 deletions src/operations/blas3/gemm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ mul_add(T a, T b, T c,
#endif

template <typename T>
static PORTBLAS_INLINE T mul_add(
T a, T b, T c,
typename std::enable_if<is_half<T>::value || is_sycl_scalar<T>::value>::type
* = 0) {
static PORTBLAS_INLINE T
mul_add(T a, T b, T c,
typename std::enable_if<is_sycl_scalar<T>::value>::type * = 0) {
return (cl::sycl::mad(a, b, c));
}

Expand Down

0 comments on commit af93a62

Please sign in to comment.