Skip to content

Commit

Permalink
Added joint_matrix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Feb 6, 2024
1 parent ec2884c commit 9ddcee5
Show file tree
Hide file tree
Showing 17 changed files with 1,660 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ export(EXPORT portblas

option(BLAS_ENABLE_TESTING "Whether to enable testing" ON)
option(ENABLE_EXPRESSION_TESTS "Whether to build expression tree fusion tests" OFF)
option(ENABLE_JOINTMATRIX_TESTS "Whether to build joint_matrix GEMM tests" OFF)
if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_TESTING)
message(STATUS "Tests are disabled when installing portBLAS in header only mode")
set(BLAS_ENABLE_TESTING OFF)
Expand Down
5 changes: 5 additions & 0 deletions src/operations/blas3/gemm_local_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
static_assert(VectorSize == 1,
"Vectorization not supported for joint_matrix.");

static_assert(
(frags_per_sg > 1 && jm_row_frags == num_sub_groups) ||
(frags_per_sg == 1 && num_jm_frags == num_sub_groups),
"Joint Matrix Row Fragments needs to map 1:1 with total sub_groups.");

//! @brief leading dimension of block of A in local
static constexpr index_t ldsa =
(trans_a ? cl_elems : block_rows) +
Expand Down
18 changes: 18 additions & 0 deletions test/unittest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,21 @@ foreach(blas_test ${SYCL_UNITTEST_SRCS})
COMPONENT tests
)
endforeach()

if(${ENABLE_JOINTMATRIX_TESTS})
if (${DPCPP_SYCL_TARGET} STREQUAL "nvptx64-nvidia-cuda")
string(FIND ${DPCPP_SYCL_ARCH} "_" start_idx)
if(start_idx)
MATH(EXPR start_idx "${start_idx} + 1")
string(SUBSTRING ${DPCPP_SYCL_ARCH} ${start_idx} "2" sm_val)
endif()

if (${start_idx} AND ${sm_val} GREATER_EQUAL "80")
add_subdirectory(joint_matrix)
else()
message(FATAL_ERROR "Joint Matrix Tests only supported for NVIDIA GPUs with sm_80 arch and above.")
endif()
else()
message(FATAL_ERROR "Joint Matrix Tests only supported for NVIDIA GPUs.")
endif()
endif()
74 changes: 74 additions & 0 deletions test/unittest/joint_matrix/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#/***************************************************************************
# *
# * @license
# * Copyright (C) Codeplay Software Limited
# * Licensed under the Apache License, Version 2.0 (the "License");
# * you may not use this file except in compliance with the License.
# * You may obtain a copy of the License at
# *
# * http://www.apache.org/licenses/LICENSE-2.0
# *
# * For your convenience, a copy of the License has been included in this
# * repository.
# *
# * Unless required by applicable law or agreed to in writing, software
# * distributed under the License is distributed on an "AS IS" BASIS,
# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * See the License for the specific language governing permissions and
# * limitations under the License.
# *
# * portBLAS: BLAS implementation using SYCL
# *
# * @filename CMakeLists.txt
# *
# **************************************************************************/

set(PORTBLAS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../include)
set(PORTBLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../src)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/Modules)
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
include(ConfigurePORTBLAS)
include(SYCL)
find_package(PORTBLAS REQUIRED)

set(PORTBLAS_JOINTMATRIX_TEST ${CMAKE_CURRENT_SOURCE_DIR})

include_directories(${PORTBLAS_TEST} ${BLAS_INCLUDE_DIRS})

# compiling tests
set(SYCL_UNITTEST_SRCS
${PORTBLAS_JOINTMATRIX_TEST}/half_half_16_16_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/half_half_32_8_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/half_half_8_32_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/half_float_16_16_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/half_float_32_8_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/half_float_8_32_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_16_16_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_32_8_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/bfloat16_float_8_32_16.cpp
${PORTBLAS_JOINTMATRIX_TEST}/tf32_float_16_16_8.cpp
)

foreach(blas_test ${SYCL_UNITTEST_SRCS})
get_filename_component(test_exec ${blas_test} NAME_WE)
add_executable(joint_matrix_${test_exec}_test ../main.cpp ${blas_test})
target_compile_definitions(joint_matrix_${test_exec}_test PRIVATE -DBLAS_INDEX_T=${BLAS_TEST_INDEX_TYPE})
target_link_libraries(joint_matrix_${test_exec}_test PRIVATE gtest_main Clara::Clara blas::blas PORTBLAS::PORTBLAS)
target_include_directories(joint_matrix_${test_exec}_test PRIVATE ${SYCL_INCLUDE_DIRS})
target_include_directories(joint_matrix_${test_exec}_test PRIVATE ${CBLAS_INCLUDE} ${PORTBLAS_COMMON_INCLUDE_DIR})
target_compile_options(joint_matrix_${test_exec}_test PRIVATE ${DPCPP_FLAGS})
target_link_options(joint_matrix_${test_exec}_test PRIVATE ${DPCPP_FLAGS})

if(TEST_DEVICE)
add_test(NAME joint_matrix_${test_exec}_test COMMAND ${CMAKE_CURRENT_BINARY_DIR}/joint_matrix_${test_exec}_test --device ${TEST_DEVICE} --gtest_output=xml:output/)
else()
add_test(NAME joint_matrix_${test_exec}_test COMMAND ${CMAKE_CURRENT_BINARY_DIR}/joint_matrix_${test_exec}_test --gtest_output=xml:output/)
endif()
message(STATUS "Created google test joint_matrix_${test_exec}_test")
install(TARGETS joint_matrix_${test_exec}_test
RUNTIME
DESTINATION ${CMAKE_INSTALL_BINDIR}
COMPONENT tests
)
endforeach()
99 changes: 99 additions & 0 deletions test/unittest/joint_matrix/bfloat16_float_16_16_16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/***************************************************************************
*
* @license
* Copyright (C) Codeplay Software Limited
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* For your convenience, a copy of the License has been included in this
* repository.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* portBLAS: BLAS implementation using SYCL
*
* @filename bfloat16_float_16_16_16.cpp
*
**************************************************************************/

#include "blas_test.hpp"
#include "joint_matrix_common.hpp"

template <typename scalar_t>
const auto SmallMatricesBfloat16Floatm16n16k16 = ::testing::Combine(
::testing::Values("bfloat16"), // input type
::testing::Values("float"), // output type
::testing::Values(16), // jm_m
::testing::Values(16), // jm_n
::testing::Values(16), // jm_k
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0, 33), // offset
::testing::Values(1), // batch
::testing::Values(11, 16, 32, 63), // m
::testing::Values(11, 16, 32, 63), // n
::testing::Values(17, 33, 64), // k
::testing::Values('n', 't'), // transa
::testing::Values('n', 't'), // transb
::testing::Values<scalar_t>(1.5), // alpha
::testing::Values<scalar_t>(0, 1.5), // beta
::testing::Values(1), // lda_mul
::testing::Values(1), // ldb_mul
::testing::Values(1), // ldc_mul
::testing::Values(gemm_batch_type_t::strided) // batch_type
);
GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm16n16k16);

template <typename scalar_t>
const auto MediumMatricesBfloat16Floatm16n16k16 = ::testing::Combine(
::testing::Values("bfloat16"), // input type
::testing::Values("float"), // output type
::testing::Values(16), // jm_m
::testing::Values(16), // jm_n
::testing::Values(16), // jm_k
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0, 33), // offset
::testing::Values(1), // batch
::testing::Values(65, 127, 234, 511, 768), // m
::testing::Values(65, 127, 234, 511, 768), // n
::testing::Values(65, 127), // k
::testing::Values('n', 't'), // transa
::testing::Values('n', 't'), // transb
::testing::Values<scalar_t>(1.5), // alpha
::testing::Values<scalar_t>(0, 1.5), // beta
::testing::Values(1), // lda_mul
::testing::Values(1), // ldb_mul
::testing::Values(1), // ldc_mul
::testing::Values(gemm_batch_type_t::strided) // batch_type
);
GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm16n16k16);

template <typename scalar_t>
const auto LargeMatricesBfloat16Floatm16n16k16 = ::testing::Combine(
::testing::Values("bfloat16"), // input type
::testing::Values("float"), // output type
::testing::Values(16), // jm_m
::testing::Values(16), // jm_n
::testing::Values(16), // jm_k
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0, 33), // offset
::testing::Values(1), // batch
::testing::Values(1024, 1535, 2024), // m
::testing::Values(1024, 1535, 2024), // n
::testing::Values(1536, 2049), // k
::testing::Values('n', 't'), // transa
::testing::Values('n', 't'), // transb
::testing::Values<scalar_t>(1.5), // alpha
::testing::Values<scalar_t>(0, 1.5), // beta
::testing::Values(1), // lda_mul
::testing::Values(1), // ldb_mul
::testing::Values(1), // ldc_mul
::testing::Values(gemm_batch_type_t::strided) // batch_type
);
GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm16n16k16);
99 changes: 99 additions & 0 deletions test/unittest/joint_matrix/bfloat16_float_32_8_16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/***************************************************************************
*
* @license
* Copyright (C) Codeplay Software Limited
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* For your convenience, a copy of the License has been included in this
* repository.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* portBLAS: BLAS implementation using SYCL
*
* @filename bfloat16_float_32_8_16.cpp
*
**************************************************************************/

#include "blas_test.hpp"
#include "joint_matrix_common.hpp"

template <typename scalar_t>
const auto SmallMatricesBfloat16Floatm32n8k16 = ::testing::Combine(
::testing::Values("bfloat16"), // input type
::testing::Values("float"), // output type
::testing::Values(32), // jm_m
::testing::Values(8), // jm_n
::testing::Values(16), // jm_k
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0, 33), // offset
::testing::Values(1), // batch
::testing::Values(11, 16, 32, 63), // m
::testing::Values(11, 16, 32, 63), // n
::testing::Values(17, 33, 64), // k
::testing::Values('n', 't'), // transa
::testing::Values('n', 't'), // transb
::testing::Values<scalar_t>(1.5), // alpha
::testing::Values<scalar_t>(0, 1.5), // beta
::testing::Values(1), // lda_mul
::testing::Values(1), // ldb_mul
::testing::Values(1), // ldc_mul
::testing::Values(gemm_batch_type_t::strided) // batch_type
);
GENERATE_JOINTMATRIX_TEST(JointMatrix, SmallMatricesBfloat16Floatm32n8k16);

template <typename scalar_t>
const auto MediumMatricesBfloat16Floatm32n8k16 = ::testing::Combine(
::testing::Values("bfloat16"), // input type
::testing::Values("float"), // output type
::testing::Values(32), // jm_m
::testing::Values(8), // jm_n
::testing::Values(16), // jm_k
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0, 33), // offset
::testing::Values(1), // batch
::testing::Values(65, 127, 234, 511, 768), // m
::testing::Values(65, 127, 234, 511, 768), // n
::testing::Values(65, 127), // k
::testing::Values('n', 't'), // transa
::testing::Values('n', 't'), // transb
::testing::Values<scalar_t>(1.5), // alpha
::testing::Values<scalar_t>(0, 1.5), // beta
::testing::Values(1), // lda_mul
::testing::Values(1), // ldb_mul
::testing::Values(1), // ldc_mul
::testing::Values(gemm_batch_type_t::strided) // batch_type
);
GENERATE_JOINTMATRIX_TEST(JointMatrix, MediumMatricesBfloat16Floatm32n8k16);

template <typename scalar_t>
const auto LargeMatricesBfloat16Floatm32n8k16 = ::testing::Combine(
::testing::Values("bfloat16"), // input type
::testing::Values("float"), // output type
::testing::Values(32), // jm_m
::testing::Values(8), // jm_n
::testing::Values(16), // jm_k
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0, 33), // offset
::testing::Values(1), // batch
::testing::Values(1024, 1535, 2024), // m
::testing::Values(1024, 1535, 2024), // n
::testing::Values(1536, 2049), // k
::testing::Values('n', 't'), // transa
::testing::Values('n', 't'), // transb
::testing::Values<scalar_t>(1.5), // alpha
::testing::Values<scalar_t>(0, 1.5), // beta
::testing::Values(1), // lda_mul
::testing::Values(1), // ldb_mul
::testing::Values(1), // ldc_mul
::testing::Values(gemm_batch_type_t::strided) // batch_type
);
GENERATE_JOINTMATRIX_TEST(JointMatrix, LargeMatricesBfloat16Floatm32n8k16);
Loading

0 comments on commit 9ddcee5

Please sign in to comment.