diff --git a/src/operations/blas3/gemm_load_store_joint_matrix.hpp b/src/operations/blas3/gemm_load_store_joint_matrix.hpp index c8e28f864..876817158 100644 --- a/src/operations/blas3/gemm_load_store_joint_matrix.hpp +++ b/src/operations/blas3/gemm_load_store_joint_matrix.hpp @@ -77,8 +77,8 @@ struct PacketizeJointMatrix { cl::sycl::ext::oneapi::bfloat16, address_t::local_space>, DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *dest = static_cast(val); + using namespace cl::sycl::ext::oneapi; + *dest = bfloat16(val); } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; *dest = round_to_tf32(val); @@ -119,8 +119,8 @@ struct PacketizeJointMatrix { cl::sycl::ext::oneapi::bfloat16, address_t::local_space>, DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *dest = static_cast(edge_in_range(i) ? *src : 0); + using namespace cl::sycl::ext::oneapi; + *dest = bfloat16(edge_in_range(i) ? *src : 0.f); } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; *dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f; @@ -150,14 +150,13 @@ struct PacketizeJointMatrix { cl::sycl::ext::oneapi::bfloat16, address_t::local_space>, DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - cl::sycl::vec new_vec; - for (index_t i = 0; i < packet_size; i++) { - reinterpret_cast(&new_vec)[i] = - static_cast(reinterpret_cast(&packet)[i]); + // sycl::vec doesn't accept bfloat16 as a valid input type + // so we need to write the packet elements individually to + // the shared memory. + using namespace cl::sycl::ext::oneapi; + for (index_t i = 0; i < packet_size; i++, dest++) { + *dest = bfloat16(reinterpret_cast(&packet)[i]); } - new_vec.template store( - 0, cl::sycl::multi_ptr(dest)); } else { using namespace cl::sycl::ext::oneapi::experimental::matrix; using dtype = float; diff --git a/test/blas_test.hpp b/test/blas_test.hpp index fef0d60cb..94c3689bf 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -229,7 +229,8 @@ static inline void fill_trsm_matrix(std::vector &A, size_t k, * @param val input/output float value. * @param nbits number of last bit set to zero. It is set by default to 13 since * this is the difference of the number of bits of the mantissa between floats - * (23) and FP16 / NVIDIA TF32 (10). + * (23) and FP16 / NVIDIA TF32 (10). For bfloat16, this value needs to be set to + * 16 to get correct result. */ static inline void set_to_zero_last_nbits(float &val, int32_t nbits = 13) { int32_t *int_pntr = reinterpret_cast(&val);