diff --git a/CMakeLists.txt b/CMakeLists.txt index 21340430c..b3fd172d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,6 +203,7 @@ export(EXPORT portblas option(BLAS_ENABLE_TESTING "Whether to enable testing" ON) option(ENABLE_EXPRESSION_TESTS "Whether to build expression tree fusion tests" OFF) +option(ENABLE_JOINTMATRIX_TESTS "Whether to build joint_matrix GEMM tests" OFF) if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_TESTING) message(STATUS "Tests are disabled when installing portBLAS in header only mode") set(BLAS_ENABLE_TESTING OFF) diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 3298203f2..440229fef 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -158,6 +158,11 @@ class Gemm 1 && jm_row_frags == num_sub_groups) || + (frags_per_sg == 1 && num_jm_frags == num_sub_groups), + "Joint Matrix Row Fragments needs to map 1:1 with total sub_groups."); + //! @brief leading dimension of block of A in local static constexpr index_t ldsa = (trans_a ? cl_elems : block_rows) + diff --git a/test/unittest/CMakeLists.txt b/test/unittest/CMakeLists.txt index 4cdcf7197..14e852af8 100644 --- a/test/unittest/CMakeLists.txt +++ b/test/unittest/CMakeLists.txt @@ -143,3 +143,21 @@ foreach(blas_test ${SYCL_UNITTEST_SRCS}) COMPONENT tests ) endforeach() + +if(${ENABLE_JOINTMATRIX_TESTS}) + if (${DPCPP_SYCL_TARGET} STREQUAL "nvptx64-nvidia-cuda") + string(FIND ${DPCPP_SYCL_ARCH} "_" start_idx) + if(start_idx) + MATH(EXPR start_idx "${start_idx} + 1") + string(SUBSTRING ${DPCPP_SYCL_ARCH} ${start_idx} "2" sm_val) + endif() + + if (${start_idx} AND ${sm_val} GREATER_EQUAL "80") + add_subdirectory(joint_matrix) + else() + message(FATAL_ERROR "Joint Matrix Tests only supported for NVIDIA GPUs with sm_80 arch and above.") + endif() + else() + message(FATAL_ERROR "Joint Matrix Tests only supported for NVIDIA GPUs.") + endif() +endif() diff --git a/test/unittest/joint_matrix/CMakeLists.txt b/test/unittest/joint_matrix/CMakeLists.txt new file mode 100644 index 000000000..efba7944c --- /dev/null +++ b/test/unittest/joint_matrix/CMakeLists.txt @@ -0,0 +1,74 @@ +#/*************************************************************************** +# * +# * @license +# * Copyright (C) Codeplay Software Limited +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * For your convenience, a copy of the License has been included in this +# * repository. +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# * +# * portBLAS: BLAS implementation using SYCL +# * +# * @filename CMakeLists.txt +# * +# **************************************************************************/ + +set(PORTBLAS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../include) +set(PORTBLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../src) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/Modules) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +include(ConfigurePORTBLAS) +include(SYCL) +find_package(PORTBLAS REQUIRED) + +set(PORTBLAS_JOINTMATRIX_TEST ${CMAKE_CURRENT_SOURCE_DIR}) + +include_directories(${PORTBLAS_TEST} ${BLAS_INCLUDE_DIRS}) + +# compiling tests +set(SYCL_UNITTEST_SRCS + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_half_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/half_float_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_16_16_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_32_8_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_8_32_16.cpp + ${PORTBLAS_JOINTMATRIX_TEST}/tf32_float_16_16_8.cpp +) + +foreach(blas_test ${SYCL_UNITTEST_SRCS}) + get_filename_component(test_exec ${blas_test} NAME_WE) + add_executable(joint_matrix_${test_exec}_test ../main.cpp ${blas_test}) + target_compile_definitions(joint_matrix_${test_exec}_test PRIVATE -DBLAS_INDEX_T=${BLAS_TEST_INDEX_TYPE}) + target_link_libraries(joint_matrix_${test_exec}_test PRIVATE gtest_main Clara::Clara blas::blas PORTBLAS::PORTBLAS) + target_include_directories(joint_matrix_${test_exec}_test PRIVATE ${SYCL_INCLUDE_DIRS}) + target_include_directories(joint_matrix_${test_exec}_test PRIVATE ${CBLAS_INCLUDE} ${PORTBLAS_COMMON_INCLUDE_DIR}) + target_compile_options(joint_matrix_${test_exec}_test PRIVATE ${DPCPP_FLAGS}) + target_link_options(joint_matrix_${test_exec}_test PRIVATE ${DPCPP_FLAGS}) + + if(TEST_DEVICE) + add_test(NAME joint_matrix_${test_exec}_test COMMAND ${CMAKE_CURRENT_BINARY_DIR}/joint_matrix_${test_exec}_test --device ${TEST_DEVICE} --gtest_output=xml:output/) + else() + add_test(NAME joint_matrix_${test_exec}_test COMMAND ${CMAKE_CURRENT_BINARY_DIR}/joint_matrix_${test_exec}_test --gtest_output=xml:output/) + endif() + message(STATUS "Created google test joint_matrix_${test_exec}_test") + install(TARGETS joint_matrix_${test_exec}_test + RUNTIME + DESTINATION ${CMAKE_INSTALL_BINDIR} + COMPONENT tests + ) +endforeach() diff --git a/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp b/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp new file mode 100644 index 000000000..af6af66f0 --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm16n16k16); + +template +const auto MediumMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm16n16k16); + +template +const auto LargeMatricesBfloat16Floatm16n16k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm16n16k16); diff --git a/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp b/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp new file mode 100644 index 000000000..d53941431 --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm32n8k16); + +template +const auto MediumMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm32n8k16); + +template +const auto LargeMatricesBfloat16Floatm32n8k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm32n8k16); diff --git a/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp b/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp new file mode 100644 index 000000000..8e086f858 --- /dev/null +++ b/test/unittest/joint_matrix/bfloat16_float_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename bfloat16_float_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm8n32k16); + +template +const auto MediumMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm8n32k16); + +template +const auto LargeMatricesBfloat16Floatm8n32k16 = ::testing::Combine( + ::testing::Values("bfloat16"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm8n32k16); diff --git a/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake b/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake new file mode 100644 index 000000000..f102b0247 --- /dev/null +++ b/test/unittest/joint_matrix/cmake/FindPORTBLAS.cmake @@ -0,0 +1,65 @@ +#/*************************************************************************** +# * +# * @license +# * Copyright (C) Codeplay Software Limited +# * Licensed under the Apache License, Version 2.0 (the "License"); +# * you may not use this file except in compliance with the License. +# * You may obtain a copy of the License at +# * +# * http://www.apache.org/licenses/LICENSE-2.0 +# * +# * For your convenience, a copy of the License has been included in this +# * repository. +# * +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, +# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# * See the License for the specific language governing permissions and +# * limitations under the License. +# * +# * portBLAS: BLAS implementation using SYCL +# * +# * @filename FindPORTBLAS.cmake +# * +# **************************************************************************/ + +find_path(PORTBLAS_INCLUDE_DIR + NAMES portblas.h + PATH_SUFFIXES include + HINTS ${PORTBLAS_DIR} + DOC "The PORTBLAS include directory" +) + +find_path(PORTBLAS_SRC_DIR + NAMES portblas.hpp + PATH_SUFFIXES src + HINTS ${PORTBLAS_DIR} + DOC "The PORTBLAS source directory" +) + + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(PORTBLAS + FOUND_VAR PORTBLAS_FOUND + REQUIRED_VARS PORTBLAS_INCLUDE_DIR + PORTBLAS_SRC_DIR +) + +mark_as_advanced(PORTBLAS_FOUND + PORTBLAS_SRC_DIR + PORTBLAS_INCLUDE_DIR +) + +if(PORTBLAS_FOUND) + set(PORTBLAS_INCLUDE_DIRS + ${PORTBLAS_INCLUDE_DIR} + ${PORTBLAS_SRC_DIR} + ) +endif() + +if(PORTBLAS_FOUND AND NOT TARGET PORTBLAS::PORTBLAS) + add_library(PORTBLAS::PORTBLAS INTERFACE IMPORTED) + set_target_properties(PORTBLAS::PORTBLAS PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PORTBLAS_INCLUDE_DIRS}" + ) +endif() diff --git a/test/unittest/joint_matrix/half_float_16_16_16.cpp b/test/unittest/joint_matrix/half_float_16_16_16.cpp new file mode 100644 index 000000000..d45fcfbaf --- /dev/null +++ b/test/unittest/joint_matrix/half_float_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloat161616); + +template +const auto MediumMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloat161616); + +template +const auto LargeMatricesHalfFloat161616 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloat161616); diff --git a/test/unittest/joint_matrix/half_float_32_8_16.cpp b/test/unittest/joint_matrix/half_float_32_8_16.cpp new file mode 100644 index 000000000..db39802c5 --- /dev/null +++ b/test/unittest/joint_matrix/half_float_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloatm32n8k16); + +template +const auto MediumMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloatm32n8k16); + +template +const auto LargeMatricesHalfFloatm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloatm32n8k16); diff --git a/test/unittest/joint_matrix/half_float_8_32_16.cpp b/test/unittest/joint_matrix/half_float_8_32_16.cpp new file mode 100644 index 000000000..61048be18 --- /dev/null +++ b/test/unittest/joint_matrix/half_float_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_float_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfFloatm8n32k16); + +template +const auto MediumMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfFloatm8n32k16); + +template +const auto LargeMatricesHalfFloatm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("float"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfFloatm8n32k16); diff --git a/test/unittest/joint_matrix/half_half_16_16_16.cpp b/test/unittest/joint_matrix/half_half_16_16_16.cpp new file mode 100644 index 000000000..91de6a527 --- /dev/null +++ b/test/unittest/joint_matrix/half_half_16_16_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_16_16_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm16n16k16); + +template +const auto MediumMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm16n16k16); + +template +const auto LargeMatricesHalfHalfm16n16k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm16n16k16); diff --git a/test/unittest/joint_matrix/half_half_32_8_16.cpp b/test/unittest/joint_matrix/half_half_32_8_16.cpp new file mode 100644 index 000000000..ad719d8e9 --- /dev/null +++ b/test/unittest/joint_matrix/half_half_32_8_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_32_8_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm32n8k16); + +template +const auto MediumMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm32n8k16); + +template +const auto LargeMatricesHalfHalfm32n8k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(32), // jm_m + ::testing::Values(8), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm32n8k16); diff --git a/test/unittest/joint_matrix/half_half_8_32_16.cpp b/test/unittest/joint_matrix/half_half_8_32_16.cpp new file mode 100644 index 000000000..090eab1ee --- /dev/null +++ b/test/unittest/joint_matrix/half_half_8_32_16.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename half_half_8_32_16.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesHalfHalfm8n32k16); + +template +const auto MediumMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesHalfHalfm8n32k16); + +template +const auto LargeMatricesHalfHalfm8n32k16 = ::testing::Combine( + ::testing::Values("half"), // input type + ::testing::Values("half"), // output type + ::testing::Values(8), // jm_m + ::testing::Values(32), // jm_n + ::testing::Values(16), // jm_n + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesHalfHalfm8n32k16); diff --git a/test/unittest/joint_matrix/joint_matrix_common.hpp b/test/unittest/joint_matrix/joint_matrix_common.hpp new file mode 100644 index 000000000..7998f9dc0 --- /dev/null +++ b/test/unittest/joint_matrix/joint_matrix_common.hpp @@ -0,0 +1,260 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename joint_matrix_common.hpp + * + **************************************************************************/ + +#include "launch_gemm.hpp" + +template +using joint_matrix_arguments_t = + std::tuple; + +template +inline void verify_gemm(const joint_matrix_arguments_t arguments) { + std::string jm_inType, jm_outType; + index_t jm_m, jm_n, jm_k; + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + scalar_t alpha; + scalar_t beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(jm_inType, jm_outType, jm_m, jm_n, jm_k, alloc, offset, batch, m, n, + k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, + batch_type) = arguments; + + assert(batch_type == gemm_batch_type_t::strided); + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t buffer_size_a = batch * size_a + offset; + const index_t buffer_size_b = batch * size_b + offset; + const index_t buffer_size_c = batch * size_c + offset; + + std::vector a_m(buffer_size_a); + std::vector b_m(buffer_size_b); + std::vector c_m_gpu(buffer_size_c); + + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + + index_t nbits = 13; + if (jm_inType == "bfloat16") { + nbits = 16; + } + set_to_zero_last_nbits(a_m, nbits); + set_to_zero_last_nbits(b_m, nbits); + set_to_zero_last_nbits(c_m_gpu, nbits); + set_to_zero_last_nbits(alpha, nbits); + set_to_zero_last_nbits(beta, nbits); + + std::vector c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::gemm(ta_str, tb_str, m, n, k, alpha, + a_m.data() + i * size_a + offset, lda, + b_m.data() + i * size_b + offset, ldb, beta, + c_m_cpu.data() + i * size_c + offset, ldc); + } + + auto m_a_gpu = blas::helper::allocate(buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate(buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate(buffer_size_c, q); + + auto copy_a = + blas::helper::copy_to_device(q, a_m.data(), m_a_gpu, buffer_size_a); + auto copy_b = + blas::helper::copy_to_device(q, b_m.data(), m_b_gpu, buffer_size_b); + auto copy_c = + blas::helper::copy_to_device(q, c_m_gpu.data(), m_c_gpu, buffer_size_c); + + // portBLAS GEMM implementation + typename blas::SB_Handle::event_t gemm_event; + if (jm_inType == "half" && jm_outType == "float") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = launch_gemm_with_beta<16, 16, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = launch_gemm_with_beta<32, 8, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = launch_gemm_with_beta<8, 32, 16, cl::sycl::half, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "half" && jm_outType == "half") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = + launch_gemm_with_beta<16, 16, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = + launch_gemm_with_beta<32, 8, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = + launch_gemm_with_beta<8, 32, 16, cl::sycl::half, cl::sycl::half>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "bfloat16" && jm_outType == "float") { + if (jm_m == 16 && jm_n == 16) { + gemm_event = + launch_gemm_with_beta<16, 16, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_m == 32 && jm_n == 8) { + gemm_event = + launch_gemm_with_beta<32, 8, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } else if (jm_n == 32 && jm_m == 8) { + gemm_event = + launch_gemm_with_beta<8, 32, 16, sycl::ext::oneapi::bfloat16, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, + ldc, size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + } else if (jm_inType == "tf32" && jm_outType == "float") { + using namespace sycl::ext::oneapi::experimental::matrix; + gemm_event = launch_gemm_with_beta<16, 16, 8, precision::tf32, float>( + sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, lda, + size_a, m_b_gpu + offset, ldb, size_b, beta, m_c_gpu + offset, ldc, + size_c, batch, batch_type, {copy_a, copy_b, copy_c}); + } + sb_handle.wait(gemm_event); + + auto event = + blas::helper::copy_to_host(q, m_c_gpu, c_m_gpu.data(), buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = utils::compare_vectors( + c_m_gpu, c_m_cpu, std::cerr, "", + jm_inType == "half" && jm_outType == "half" ? 3 : 1); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm(const joint_matrix_arguments_t arguments) { + std::string jm_inType, jm_OutType; + index_t jm_m, jm_n, jm_k; + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + scalar_t alpha; + scalar_t beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(jm_inType, jm_OutType, jm_m, jm_n, jm_k, alloc, offset, batch, m, n, + k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, + batch_type) = arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#else + GTEST_SKIP(); +#endif + } else { + verify_gemm(arguments); + } +} + +template <> +inline void dump_arg(std::ostream& ss, + gemm_batch_type_t batch_type) { + ss << (int)batch_type; +} + +template +static std::string generate_name( + const ::testing::TestParamInfo>& info) { + std::string jm_inType, jm_OutType; + int jm_m, jm_n, jm_k; + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul; + char transa, transb; + T alpha, beta; + gemm_batch_type_t batchType; + BLAS_GENERATE_NAME(info.param, jm_inType, jm_OutType, jm_m, jm_n, jm_k, alloc, + offset, batch, m, n, k, transa, transb, alpha, beta, + ldaMul, ldbMul, ldcMul, batchType); +} + +/** Registers Joint Matrix test for all supported data types (only float for + * now) + * @param test_suite Name of the test suite + * @param combination Combinations object + * @see BLAS_REGISTER_TEST_CUSTOM_NAME + */ +#define GENERATE_JOINTMATRIX_TEST(test_suite, combination) \ + BLAS_REGISTER_TEST_FLOAT_CUSTOM_NAME(test_suite, test_suite##combination, \ + verify_gemm, joint_matrix_arguments_t, \ + combination, generate_name); diff --git a/test/unittest/joint_matrix/launch_gemm.hpp b/test/unittest/joint_matrix/launch_gemm.hpp new file mode 100644 index 000000000..afab6a0b3 --- /dev/null +++ b/test/unittest/joint_matrix/launch_gemm.hpp @@ -0,0 +1,247 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename launch_gemm.hpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "interface/gemm_launcher.hpp" +#include "portblas.hpp" +#include + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, + Tile<8, 8, 16, 16, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M > 64 && _N > 64) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<4, 8, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<8, 16, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else if (_M > 64 && _N > 64) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, true, true, 128, + Tile<8, 8, 8, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE + typename std::enable_if::type + launch_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, + index_t _strideb, element_t _beta, container_2_t _c, + index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M > 1024 && _N > 1024) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, + Tile<4, 4, 16, 16, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, + Tile<2, 4, 16, 8, 16, 2, 1, 1, 1, 1, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>, + _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + true>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); + } +} + +template +PORTBLAS_ALWAYS_INLINE typename sb_handle_t::event_t launch_gemm_with_transpose( + sb_handle_t& sb_handle, char _trans_a, char _trans_b, index_t _M, + index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, + element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, + index_t batch_size, gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + typename sb_handle_t::event_t gemm_event; + if (_trans_a == 't' && _trans_b == 't') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 'n' && _trans_b == 'n') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 't' && _trans_b == 'n') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } else if (_trans_a == 'n' && _trans_b == 't') { + gemm_event = launch_gemm( + sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, + _beta, _c, _ldc, _stridec, batch_size, batch_type, _dependencies); + } + return gemm_event; +} + +template +PORTBLAS_ALWAYS_INLINE typename sb_handle_t::event_t launch_gemm_with_beta( + sb_handle_t& sb_handle, char _trans_a, char _trans_b, index_t _M, + index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, + index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, + element_t _beta, container_2_t _c, index_t _ldc, index_t _stridec, + index_t batch_size, gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + typename sb_handle_t::event_t gemm_event; + if (_beta == (element_t)0) { + gemm_event = + launch_gemm_with_transpose( + sb_handle, _trans_a, _trans_b, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, + batch_type, _dependencies); + } else { + gemm_event = + launch_gemm_with_transpose( + sb_handle, _trans_a, _trans_b, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, + batch_type, _dependencies); + } + return gemm_event; +} diff --git a/test/unittest/joint_matrix/tf32_float_16_16_8.cpp b/test/unittest/joint_matrix/tf32_float_16_16_8.cpp new file mode 100644 index 000000000..249c88c34 --- /dev/null +++ b/test/unittest/joint_matrix/tf32_float_16_16_8.cpp @@ -0,0 +1,99 @@ +/*************************************************************************** + * + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename tf32_float_16_16_8.cpp + * + **************************************************************************/ + +#include "blas_test.hpp" +#include "joint_matrix_common.hpp" + +template +const auto SmallMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32, 63), // m + ::testing::Values(11, 16, 32, 63), // n + ::testing::Values(17, 33, 64), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesTF32Floatm16n16k8); + +template +const auto MediumMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(65, 127, 234, 511, 768), // m + ::testing::Values(65, 127, 234, 511, 768), // n + ::testing::Values(65, 127), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesTF32Floatm16n16k8); + +template +const auto LargeMatricesTF32Floatm16n16k8 = ::testing::Combine( + ::testing::Values("tf32"), // input type + ::testing::Values("float"), // output type + ::testing::Values(16), // jm_m + ::testing::Values(16), // jm_n + ::testing::Values(8), // jm_k + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 33), // offset + ::testing::Values(1), // batch + ::testing::Values(1024, 1535, 2024), // m + ::testing::Values(1024, 1535, 2024), // n + ::testing::Values(1536, 2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values(1.5), // alpha + ::testing::Values(0, 1.5), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesTF32Floatm16n16k8);