diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 9bb2b659e69b9..55aa1ad1e262a 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -76,30 +76,6 @@ template void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) { #endif } -template void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) - uint16_t *dst_i16 = sycl::bit_cast(dst); - if constexpr (N == 1) - __devicelib_ConvertFToBF16INTELVec1(src, dst_i16); - else if constexpr (N == 2) - __devicelib_ConvertFToBF16INTELVec2(src, dst_i16); - else if constexpr (N == 3) - __devicelib_ConvertFToBF16INTELVec3(src, dst_i16); - else if constexpr (N == 4) - __devicelib_ConvertFToBF16INTELVec4(src, dst_i16); - else if constexpr (N == 8) - __devicelib_ConvertFToBF16INTELVec8(src, dst_i16); - else if constexpr (N == 16) - __devicelib_ConvertFToBF16INTELVec16(src, dst_i16); -#else - for (int i = 0; i < N; ++i) { - // No need to cast as bfloat16 has a assignment op overload that takes - // a float. - dst[i] = src[i]; - } -#endif -} - // sycl::vec support namespace bf16 { #ifndef __INTEL_PREVIEW_BREAKING_CHANGES @@ -309,6 +285,30 @@ class bfloat16 { namespace detail { +template void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) + uint16_t *dst_i16 = sycl::bit_cast(dst); + if constexpr (N == 1) + __devicelib_ConvertFToBF16INTELVec1(src, dst_i16); + else if constexpr (N == 2) + __devicelib_ConvertFToBF16INTELVec2(src, dst_i16); + else if constexpr (N == 3) + __devicelib_ConvertFToBF16INTELVec3(src, dst_i16); + else if constexpr (N == 4) + __devicelib_ConvertFToBF16INTELVec4(src, dst_i16); + else if constexpr (N == 8) + __devicelib_ConvertFToBF16INTELVec8(src, dst_i16); + else if constexpr (N == 16) + __devicelib_ConvertFToBF16INTELVec16(src, dst_i16); +#else + for (int i = 0; i < N; ++i) { + // No need to cast as bfloat16 has a assignment op overload that takes + // a float. + dst[i] = src[i]; + } +#endif +} + // Helper function for getting the internal representation of a bfloat16. inline Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value) { return Value.value;