Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Extended Gemm interface to support mixed precision operations #500

Merged
merged 56 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
9fb1f6d
Updated half data support approach & intial enablement for some blas1…
OuadiElfarouki Feb 5, 2024
b60ba9e
Enabled testing of half-data supporting operators & added relevant to…
OuadiElfarouki Feb 5, 2024
b60f791
disabled HALF precision when targetting default CPU
OuadiElfarouki Feb 5, 2024
9744e12
minor fixes
OuadiElfarouki Feb 5, 2024
da20e97
Merge branch 'master' into half_data_revisited
OuadiElfarouki Feb 6, 2024
c16db81
enabled benchmarks for fp16 supporting operators
OuadiElfarouki Feb 6, 2024
42fa2d0
minor updates to mul_add for half/complex data
OuadiElfarouki Feb 6, 2024
dfe2796
Enabled fp16 benchmarking in cuBLAS
OuadiElfarouki Feb 8, 2024
33ed262
added half gemm config on nvidia gpu for improved perf
OuadiElfarouki Feb 8, 2024
13dce4e
minor fixes
OuadiElfarouki Feb 8, 2024
ae2312e
enabled half data gemm benchmark on AMD GPU
OuadiElfarouki Feb 9, 2024
4663f32
Minor fixes to cublas & rocblas half-data benchmarks
OuadiElfarouki Feb 9, 2024
6935da0
minor fix
OuadiElfarouki Feb 9, 2024
97a3310
Removed unecessary custom cast function to half
OuadiElfarouki Feb 14, 2024
293d690
Moved half type casting within reference blas (Review adressed)
OuadiElfarouki Feb 14, 2024
a98091c
Removed uncessary half data branching in portblas benchmarks
OuadiElfarouki Feb 14, 2024
72934b6
removed unecessary half data branching/casting in cublas bench
OuadiElfarouki Feb 14, 2024
52b3e35
Cleaning half-related redundant changes
OuadiElfarouki Feb 14, 2024
23c8efd
Update test/blas_test.hpp
OuadiElfarouki Feb 15, 2024
feaf0fe
Update benchmark/cublas/blas3/gemm.cpp
OuadiElfarouki Feb 15, 2024
e1804e0
Remove extra bracket.
pgorlani Feb 15, 2024
e535773
removed unecessary half data branching/casting in rocBLAS bench
OuadiElfarouki Feb 15, 2024
a09aba5
further cleaning & simplifications
OuadiElfarouki Feb 15, 2024
d30d252
Update benchmark/cublas/utils.hpp
OuadiElfarouki Feb 16, 2024
919bc19
Update benchmark/rocblas/utils.hpp
OuadiElfarouki Feb 16, 2024
b8a5be3
Update common/include/common/common_utils.hpp
OuadiElfarouki Feb 16, 2024
af93a62
further simplifications
OuadiElfarouki Feb 16, 2024
d29372e
minor update to benchmark rand gen
OuadiElfarouki Feb 16, 2024
a8282b2
updated readme, disabled complex support by default and removed extra…
OuadiElfarouki Feb 19, 2024
25d3d23
minor cmake fix for header only use-case (oneMKL in particular)
OuadiElfarouki Feb 19, 2024
6eb6a73
Enabledd generation of mixed-precision gemm kernels
OuadiElfarouki Feb 21, 2024
005694a
extended gemm kernels mixed-type handling
OuadiElfarouki Feb 21, 2024
ad149a5
fixes to gemm backends instantiation templates
OuadiElfarouki Feb 21, 2024
3f8cb15
Extended teesting support to mixed precision
OuadiElfarouki Feb 23, 2024
bfc8c3a
disable sycl::mad vector calls for mixed precision calls
OuadiElfarouki Feb 28, 2024
7185756
Merge branch 'master' into mixed_precision_gemm
OuadiElfarouki Feb 28, 2024
ac95e7e
Merge branch 'master' into mixed_precision_gemm
OuadiElfarouki Feb 28, 2024
be898a9
updated gemm template for joint matrix case
OuadiElfarouki Feb 29, 2024
8a68d95
Minor fixes & improvements
OuadiElfarouki Feb 29, 2024
48545a0
Merge branch 'master' into mixed_precision_gemm
OuadiElfarouki Feb 29, 2024
155781d
Fixed issues in unit-tests
OuadiElfarouki Feb 29, 2024
38fd71e
Typo fixes
OuadiElfarouki Mar 4, 2024
db67f2b
Merge branch 'master' into mixed_precision_gemm
OuadiElfarouki Mar 4, 2024
852722e
minor fix for default cpu gemm config
OuadiElfarouki Mar 4, 2024
26970c0
re-enabled half support for default CPUs and fixed gemm non-local ker…
OuadiElfarouki Mar 4, 2024
a2f489f
Separated half gemm config for default CPUs
OuadiElfarouki Mar 5, 2024
3c150cf
cast half to float within mul_add
OuadiElfarouki Mar 15, 2024
620ea94
Apply typo suggestions from code review
OuadiElfarouki Mar 27, 2024
aff966c
removed non-necessary gemm template parameter element in/out
OuadiElfarouki Mar 28, 2024
c7aedc9
addressed PR comments
OuadiElfarouki Mar 28, 2024
d3a962a
Merge branch 'master' into mixed_precision_gemm
OuadiElfarouki Apr 11, 2024
95f77c6
removed const specifier causing errors when const data is enabled (gemm)
OuadiElfarouki Apr 18, 2024
8e2930b
Merge branch 'master' into mixed_precision_gemm
OuadiElfarouki Apr 18, 2024
d0a64b9
Merge branch 'master' into mixed_precision_gemm
OuadiElfarouki May 7, 2024
c7ce902
Merge remote-tracking branch 'upstream/master' into mixed_precision_gemm
s-Nick May 8, 2024
ae8a874
Merge branch 'master' of github.com:codeplaysoftware/portBLAS into mi…
s-Nick May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 126 additions & 107 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -152,32 +152,42 @@ function(generate_blas_objects blas_level func)
list(APPEND data_list_c "half")
endif()
endif()
foreach(data ${data_list_c})
cpp_type(cpp_data ${data})
foreach(index ${index_list})
foreach(increment ${index_list})
sanitize_file_name(file_name
"${func}_${data}_${index}_${data}_${increment}.cpp")
add_custom_command(OUTPUT "${LOCATION}/${file_name}"
COMMAND ${PYTHON_EXECUTABLE} ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_ops.py
${PROJECT_SOURCE_DIR}/external/
${PORTBLAS_SRC_GENERATOR}/gen
${blas_level}
${func}
${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
${cpp_data}
${index}
${increment}
${file_name}
MAIN_DEPENDENCY ${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
DEPENDS ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_ops.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
VERBATIM
)
list(APPEND FUNC_SRC "${LOCATION}/${file_name}")
endforeach(increment)
endforeach(index)
endforeach(data)
foreach(data_in ${data_list_c})
set(data_list_out ${data_in})
# When using half with Gemm target, generate a mixed-precision
# Gemm kernel (half-float) alongside the fully half based kernel.
if((data_in STREQUAL "half") AND (${func} STREQUAL "gemm"))
list(APPEND data_list_out "float")
endif()
cpp_type(cpp_data_in ${data_in})
foreach(data_out ${data_list_out})
cpp_type(cpp_data_out ${data_out})
foreach(index ${index_list})
foreach(increment ${index_list})
sanitize_file_name(file_name
"${func}_${data_in}_${index}_${data_out}_${increment}.cpp")
add_custom_command(OUTPUT "${LOCATION}/${file_name}"
COMMAND ${PYTHON_EXECUTABLE} ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_ops.py
${PROJECT_SOURCE_DIR}/external/
${PORTBLAS_SRC_GENERATOR}/gen
${blas_level}
${func}
${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
${cpp_data_in}
${cpp_data_out}
${index}
${increment}
${file_name}
MAIN_DEPENDENCY ${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
DEPENDS ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_ops.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
VERBATIM
)
list(APPEND FUNC_SRC "${LOCATION}/${file_name}")
endforeach(increment)
endforeach(index)
endforeach(data_out ${data_list_out})
endforeach(data_in)
add_library(${func} OBJECT ${FUNC_SRC})
set_target_compile_def(${func})
target_include_directories(${func} PRIVATE ${PORTBLAS_SRC} ${PORTBLAS_INCLUDE}
Expand Down Expand Up @@ -312,87 +322,97 @@ function(add_gemm_configuration
if(const_pos)
string(REPLACE "_const" "" actualfunc ${func})
endif()
# When using half data type, generate a mixed-precision Gemm
# configuration (half-float) alongside the fully half based one.
set(data_list_out ${data})
if(data STREQUAL "half")
list(APPEND data_list_out "float")
endif()
cpp_type(cpp_data ${data})
foreach(symm_a ${boolean_list})
foreach(symm_b ${boolean_list})
if ((${data} MATCHES "half") AND (symm_a OR symm_b))
continue()
endif()
if (symm_a AND symm_b)
continue()
endif()
foreach(trans_a ${boolean_list})
foreach(trans_b ${boolean_list})
if ((symm_a AND trans_b) OR (symm_b AND trans_a))
continue()
endif()
foreach(is_beta_zero ${boolean_list})
foreach(index ${index_list})
set(file_name "${func}_${double_buffer}_${conflict_a}_"
"${conflict_b}_${trans_a}_${trans_b}_"
"${is_beta_zero}_${gemm_memory_type}_"
"${gemm_shape_type}_${gemm_vectorize_type}_"
"${vector_size}_${batch_type}_${use_joint_matrix}_"
"${data}_${index}_${tir}_${tic}_${twr}_"
"${twc}_${tsr}_${tsc}_${tlr}_${tlc}_"
"${item_batch}_${wg_batch}_${symm_a}_${symm_b}_"
"${jm_m}_${jm_n}_${jm_k}_${jm_in_type}_${jm_out_type}_"
"${wg_size}_${cache_line_size}_${data}.cpp")
sanitize_file_name(file_name "${file_name}")
add_custom_command(OUTPUT "${LOCATION}/${file_name}"
COMMAND ${PYTHON_EXECUTABLE} ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_gemm_launcher.py
${PROJECT_SOURCE_DIR}/external/
${PORTBLAS_SRC_GENERATOR}/gen
${blas_level}
${func}
${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
${cpp_data}
${index}
${double_buffer}
${conflict_a}
${conflict_b}
${trans_a}
${trans_b}
${is_beta_zero}
${gemm_memory_type}
${gemm_shape_type}
${tir}
${tic}
${twr}
${twc}
${tsr}
${tsc}
${tlr}
${tlc}
${item_batch}
${wg_batch}
${jm_m}
${jm_n}
${jm_k}
${jm_in_type}
${jm_out_type}
${wg_size}
${cache_line_size}
${file_name}
${gemm_vectorize_type}
${vector_size}
${batch_type}
${use_joint_matrix}
${symm_a}
${symm_b}
MAIN_DEPENDENCY ${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
DEPENDS ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_gemm_launcher.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
VERBATIM
)
list(APPEND gemm_sources "${LOCATION}/${file_name}")
set(gemm_sources "${gemm_sources}" PARENT_SCOPE)
endforeach(index)
endforeach(is_beta_zero)
endforeach(trans_b)
endforeach(trans_a)
endforeach(symm_b)
endforeach(symm_a)
foreach(data_out ${data_list_out})
cpp_type(cpp_data_out ${data_out})
foreach(symm_a ${boolean_list})
foreach(symm_b ${boolean_list})
if ((${data} MATCHES "half") AND (symm_a OR symm_b))
continue()
endif()
if (symm_a AND symm_b)
continue()
endif()
foreach(trans_a ${boolean_list})
foreach(trans_b ${boolean_list})
if ((symm_a AND trans_b) OR (symm_b AND trans_a))
continue()
endif()
foreach(is_beta_zero ${boolean_list})
foreach(index ${index_list})
set(file_name "${func}_${double_buffer}_${conflict_a}_"
"${conflict_b}_${trans_a}_${trans_b}_"
"${is_beta_zero}_${gemm_memory_type}_"
"${gemm_shape_type}_${gemm_vectorize_type}_"
"${vector_size}_${batch_type}_${use_joint_matrix}_"
"${index}_${tir}_${tic}_${twr}_"
"${twc}_${tsr}_${tsc}_${tlr}_${tlc}_"
"${item_batch}_${wg_batch}_${symm_a}_${symm_b}_"
"${jm_m}_${jm_n}_${jm_k}_${jm_in_type}_${jm_out_type}_"
"${wg_size}_${cache_line_size}_${data}_${data_out}.cpp")
sanitize_file_name(file_name "${file_name}")
add_custom_command(OUTPUT "${LOCATION}/${file_name}"
COMMAND ${PYTHON_EXECUTABLE} ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_gemm_launcher.py
${PROJECT_SOURCE_DIR}/external/
${PORTBLAS_SRC_GENERATOR}/gen
${blas_level}
${func}
${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
${cpp_data}
${index}
${double_buffer}
${conflict_a}
${conflict_b}
${trans_a}
${trans_b}
${is_beta_zero}
${gemm_memory_type}
${gemm_shape_type}
${tir}
${tic}
${twr}
${twc}
${tsr}
${tsc}
${tlr}
${tlc}
${item_batch}
${wg_batch}
${jm_m}
${jm_n}
${jm_k}
${jm_in_type}
${jm_out_type}
${wg_size}
${cache_line_size}
${file_name}
${gemm_vectorize_type}
${vector_size}
${batch_type}
${use_joint_matrix}
${symm_a}
${symm_b}
${cpp_data_out}
MAIN_DEPENDENCY ${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
DEPENDS ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_gemm_launcher.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
VERBATIM
)
list(APPEND gemm_sources "${LOCATION}/${file_name}")
set(gemm_sources "${gemm_sources}" PARENT_SCOPE)
endforeach(index)
endforeach(is_beta_zero)
endforeach(trans_b)
endforeach(trans_a)
endforeach(symm_b)
endforeach(symm_a)
endforeach(data_out)
endfunction()
if(${TUNING_TARGET} STREQUAL "INTEL_GPU")
set(supported_types
Expand Down Expand Up @@ -702,7 +722,6 @@ else() # default cpu backend
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false")

if(BLAS_ENABLE_HALF)
add_gemm_configuration(
"half" 128 "false" "false" "false"
Expand Down
4 changes: 2 additions & 2 deletions include/operations/blas3_trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ struct Tile {
* @tparam TransB iff true, matrix B will be transposed on the fly
* @tparam SymmA whether the matrix A is a symmetric triangular matrix
* @tparam SymmB whether the matrix B is a symmetric triangular matrix
* @tparam element_t type of matrix elements
* @tparam element_t type of scalar alpha & beta
* @tparam UseJointMatrix boolean parameter to decide whether to use
* joint_matrix or not
* @param a_ the lhs_t matrix
Expand All @@ -195,7 +195,7 @@ template <typename input_t, typename output_t, bool DoubleBuffer, bool NbcA,
int VectorSize, int BatchType, bool UseJointMatrix = false>
class Gemm {
public:
using value_t = element_t;
using value_t = typename input_t::value_t;
using index_t = typename std::make_signed<typename input_t::index_t>::type;
static constexpr int wg_size = tile_type::wg_rows * tile_type::wg_cols;
static constexpr bool trans_a = TransA;
Expand Down
12 changes: 9 additions & 3 deletions python_generator/py_gen_blas_gemm_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
blas_level_name = sys.argv[3]
blas_function_name = sys.argv[4]
blas_template_impl = sys.argv[5]
data = sys.argv[6]
data_in = sys.argv[6]
index = sys.argv[7]
double_buffer = sys.argv[8]
conflict_a = sys.argv[9]
Expand Down Expand Up @@ -72,6 +72,7 @@
use_joint_matrix = sys.argv[37]
symm_a = sys.argv[38]
symm_b = sys.argv[39]
data_out = sys.argv[40] # Different from data_in for mixed-precision cases
source = 'generated_src/' + blas_level_name + '/' + blas_function_name + '/'
try:
os.makedirs(source)
Expand Down Expand Up @@ -208,8 +209,13 @@
itermode=Itermode.combinations,
iter_modifier=1),
Iterable(
key='DATA_TYPE',
vals=[data],
key='DATA_TYPE_IN',
vals=[data_in],
itermode=Itermode.combinations,
iter_modifier=1),
Iterable(
key='DATA_TYPE_OUT',
vals=[data_out],
itermode=Itermode.combinations,
iter_modifier=1),
Iterable(
Expand Down
26 changes: 21 additions & 5 deletions python_generator/py_gen_blas_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@
blas_level_name = sys.argv[3]
blas_function_name = sys.argv[4]
blas_template_impl = sys.argv[5]
data = sys.argv[6]
index = sys.argv[7]
increment = sys.argv[8]
file_name = sys.argv[9]
data_in = sys.argv[6]
data_out = sys.argv[7]
index = sys.argv[8]
increment = sys.argv[9]
file_name = sys.argv[10]
source = 'generated_src/' + blas_level_name + '/' + blas_function_name + '/'

try:
Expand All @@ -58,7 +59,7 @@
iterables = [
Iterable(
key='DATA_TYPE',
vals=[data],
vals=[data_in],
itermode=Itermode.combinations,
iter_modifier=1),
Iterable(
Expand All @@ -72,6 +73,21 @@
itermode=Itermode.combinations,
iter_modifier=1)
]

# Gemm supports mixed-precision inputs/outputs/arithmetics
is_gemm: bool = blas_function_name == "gemm"
if is_gemm:
iterables.append(Iterable(
key='DATA_TYPE_IN',
vals=[data_in],
itermode=Itermode.combinations,
iter_modifier=1))
iterables.append(Iterable(
key='DATA_TYPE_OUT',
vals=[data_out],
itermode=Itermode.combinations,
iter_modifier=1))

iter_groups = [IterGroup('@ip1@', template, iterables, combine_iters=True)]
generate_file(
input_template,
Expand Down
6 changes: 4 additions & 2 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
using element_in_t = typename ValueType<container_0_t>::type;
// Unused configuration cases
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
Expand All @@ -53,7 +54,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
const auto n_elem_access = (_M * _K + _K * _N + _M * _N);
const auto arith_ratio = n_fma / n_elem_access;
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
static constexpr int tileWgSize = ClSize / sizeof(element_in_t);
if (batch_type == gemm_batch_type_t::interleaved) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
Expand Down Expand Up @@ -242,8 +243,9 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
using element_in_t = typename ValueType<container_0_t>::type;
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
static constexpr int tileWgSize = ClSize / sizeof(element_in_t);
/* Tall & Skinny matrices. */
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8)) {
Expand Down
Loading
Loading