Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix joint_matrix implementation to match latest api #491

Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cdc3486
Fix joint_matrix_mad api call
muhammad-tanvir-1211 Dec 20, 2023
548321c
Fixed output for matrices with no corner cases
muhammad-tanvir-1211 Dec 26, 2023
65d720b
Added support for corner cases
muhammad-tanvir-1211 Dec 26, 2023
1f1b41c
Fixed the transpose implementation
muhammad-tanvir-1211 Dec 27, 2023
95b11bd
Fixed corner case output
muhammad-tanvir-1211 Dec 28, 2023
e85a2fa
Code working for gemm and gemm_batched
muhammad-tanvir-1211 Jan 1, 2024
6d659af
Unrolled the Global Memory write loops
muhammad-tanvir-1211 Jan 5, 2024
277d787
Reorder computation loops
muhammad-tanvir-1211 Jan 5, 2024
eca4634
Updated comments and removed redundant code
muhammad-tanvir-1211 Jan 8, 2024
37d26e1
Fixed build with release compiler
muhammad-tanvir-1211 Jan 10, 2024
9d2f540
Fix loop bounds
muhammad-tanvir-1211 Jan 16, 2024
5e8d8c6
Add functions for checking lower precision
pgorlani Jan 18, 2024
60fa20a
Fixed synchronization for double buffering
muhammad-tanvir-1211 Jan 25, 2024
ba4cd87
Restriced VectorSize to 1 for joint_matrix
muhammad-tanvir-1211 Jan 31, 2024
78af1da
Fixed race condition
muhammad-tanvir-1211 Jan 31, 2024
9767329
Fixed compilation error with bfloat16 type
muhammad-tanvir-1211 Jan 31, 2024
d1f348b
Increase error margins for trsm tests when joint_matrix is used
muhammad-tanvir-1211 Feb 1, 2024
0c0adaa
Address feedback
muhammad-tanvir-1211 Feb 1, 2024
23dff65
Added joint_matrix tests
muhammad-tanvir-1211 Feb 6, 2024
fe55302
Address test commit feedback
muhammad-tanvir-1211 Feb 22, 2024
50c150b
Reduced the initializer range
muhammad-tanvir-1211 Feb 22, 2024
6b87355
Merge branch 'master' into joint_matrix_fix
muhammad-tanvir-1211 Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ export(EXPORT portblas

option(BLAS_ENABLE_TESTING "Whether to enable testing" ON)
option(ENABLE_EXPRESSION_TESTS "Whether to build expression tree fusion tests" OFF)
option(ENABLE_JOINTMATRIX_TESTS "Whether to build joint_matrix GEMM tests" OFF)
muhammad-tanvir-1211 marked this conversation as resolved.
Show resolved Hide resolved
if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_TESTING)
message(STATUS "Tests are disabled when installing portBLAS in header only mode")
set(BLAS_ENABLE_TESTING OFF)
Expand Down
19 changes: 13 additions & 6 deletions benchmark/portblas/blas3/trsm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, char side,
}

std::ostringstream err_stream;
if (!utils::compare_vectors(b_temp, x_ref, err_stream, "")) {
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
if (!utils::compare_vectors(b_temp, x_ref, err_stream, "",
(en_joint_matrix != NULL) &&
(std::is_same<scalar_t, float>::value) &&
(*en_joint_matrix == '1')
? 2
: 1)) {
const std::string& err_str = err_stream.str();
state.SkipWithError(err_str.c_str());
*success = false;
Expand Down Expand Up @@ -181,8 +187,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, trans, diag, m, n,
mem_type).c_str(),
side, uplo, trans, diag, m, n, mem_type)
.c_str(),
BM_lambda, sb_handle_ptr, side, uplo, trans, diag, m, n, alpha, success)
->UseRealTime();
}
Expand All @@ -193,16 +199,17 @@ void register_benchmark(blas_benchmark::Args& args,
blas::SB_Handle* sb_handle_ptr, bool* success) {
auto trsm_params = blas_benchmark::utils::get_trsm_params<scalar_t>(args);
register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, trsm_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
trsm_params);
#ifdef SB_ENABLE_USM
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, trsm_params);
#endif
}

namespace blas_benchmark {
void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr,
bool* success) {
void create_benchmark(blas_benchmark::Args& args,
blas::SB_Handle* sb_handle_ptr, bool* success) {
BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success);
}
} // namespace blas_benchmark
38 changes: 24 additions & 14 deletions common/include/common/float_comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,20 @@ scalar_t clamp_to_limits(scalar_t v) {
* Indicates the tolerated margin for relative differences
*/
template <typename scalar_t>
inline scalar_t getRelativeErrorMargin() {
inline scalar_t getRelativeErrorMargin(const int32_t margin_multiplier = 1) {
/* Measured empirically with gemm. The dimensions of the matrices (even k)
* don't seem to have an impact on the observed relative differences
* In the cases where the relative error is relevant (non close to zero),
* relative differences of up to 0.002 were observed for float
*/
return static_cast<scalar_t>(0.005);
scalar_t margin = 0.005;
// increase error margin for mixed precision calculation
// for trsm operator.
return margin * margin_multiplier;
}

template <>
inline double getRelativeErrorMargin<double>() {
inline double getRelativeErrorMargin<double>(const int32_t) {
/* Measured empirically with gemm. The dimensions of the matrices (even k)
* don't seem to have an impact on the observed relative differences
* In the cases where the relative error is relevant (non close to zero),
Expand All @@ -142,7 +145,7 @@ inline double getRelativeErrorMargin<double>() {
#ifdef BLAS_DATA_TYPE_HALF

template <>
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>(const int32_t) {
// Measured empirically with gemm
return 0.05f;
}
Expand All @@ -152,16 +155,19 @@ inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
* scalars are close to 0)
*/
template <typename scalar_t>
inline scalar_t getAbsoluteErrorMargin() {
inline scalar_t getAbsoluteErrorMargin(const int32_t margin_multiplier = 1) {
/* Measured empirically with gemm.
* In the cases where the relative error is irrelevant (close to zero),
* absolute differences of up to 0.0006 were observed for float
*/
return 0.001f;
scalar_t margin = 0.001f;
// increase error margin for mixed precision calculation
// for trsm operator.
return margin * margin_multiplier;
}

template <>
inline double getAbsoluteErrorMargin<double>() {
inline double getAbsoluteErrorMargin<double>(const int32_t) {
/* Measured empirically with gemm.
* In the cases where the relative error is irrelevant (close to zero),
* absolute differences of up to 10^-12 were observed for double
Expand All @@ -171,7 +177,7 @@ inline double getAbsoluteErrorMargin<double>() {
#ifdef BLAS_DATA_TYPE_HALF

template <>
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>(const int32_t) {
// Measured empirically with gemm.
return 1.0f;
}
Expand All @@ -181,7 +187,8 @@ inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
* Compare two scalars and returns false if the difference is not acceptable.
*/
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2,
const int32_t margin_multiplier = 1) {
// Shortcut, also handles case where both are zero
if (scalar1 == scalar2) {
return true;
Expand All @@ -196,12 +203,14 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {

// Close to zero, the relative error doesn't work, use absolute error
if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} ||
absolute_diff < getAbsoluteErrorMargin<epsilon_t>()) {
return (absolute_diff < getAbsoluteErrorMargin<epsilon_t>());
absolute_diff < getAbsoluteErrorMargin<epsilon_t>(margin_multiplier)) {
return (absolute_diff <
getAbsoluteErrorMargin<epsilon_t>(margin_multiplier));
}
// Use relative error
const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2);
return (absolute_diff / absolute_sum) < getRelativeErrorMargin<epsilon_t>();
return (absolute_diff / absolute_sum) <
getRelativeErrorMargin<epsilon_t>(margin_multiplier);
}

/**
Expand All @@ -215,15 +224,16 @@ template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors(std::vector<scalar_t> const& vec,
std::vector<scalar_t> const& ref,
std::ostream& err_stream = std::cerr,
std::string end_line = "\n") {
std::string end_line = "\n",
const int32_t margin_multiplier = 1) {
if (vec.size() != ref.size()) {
err_stream << "Error: tried to compare vectors of different sizes"
<< std::endl;
return false;
}

for (int i = 0; i < vec.size(); ++i) {
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i])) {
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i], margin_multiplier)) {
err_stream << "Value mismatch at index " << i << ": " << vec[i]
<< "; expected " << ref[i] << end_line;
return false;
Expand Down
113 changes: 57 additions & 56 deletions src/operations/blas3/gemm_load_store_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,16 @@ struct PacketizeJointMatrix {

/*! @brief Performs a coalesced non-vectorized load when the current block is
* not internal.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam internal True if the current block is internal and no bounds
* checking is required.
* @tparam ld The leading dimension of the destination memory.
*/

template <bool trans, bool internal, int ld, typename SrcPointerType,
typename DestPointerType, typename EdgePredicate>
template <bool internal, typename SrcPointerType, typename DestPointerType,
typename EdgePredicate>
static PORTBLAS_INLINE typename std::enable_if<!internal>::type load(
const bool in_range, SrcPointerType src, DestPointerType dest,
EdgePredicate) {
value_t val = in_range ? *(src) : value_t{0};
value_t val = in_range ? *src : value_t{0};
using address_t = cl::sycl::access::address_space;
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
Expand All @@ -79,93 +77,96 @@ struct PacketizeJointMatrix {
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(val);
using namespace cl::sycl::ext::oneapi;
*dest = bfloat16(val);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = round_to_tf32(val);
}
}

/*! @brief Performs a vectorised load using sycl::vec::load when the current
* block is internal. In the case where k < the
* number of elements being loaded then edge loads will be element wise with
* additional bounds checking.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam internal True if the current block is internal and no bounds
* checking is required.
* @tparam ld The leading dimension of the destination memory. */
template <bool trans, bool internal, index_t ld, typename SrcPointerType,
typename DestPointerType, typename EdgePredicate>
*/
template <bool internal, typename SrcPointerType, typename DestPointerType,
typename EdgePredicate>
static PORTBLAS_INLINE typename std::enable_if<internal>::type load(
const bool in_range, SrcPointerType src, DestPointerType dest,
EdgePredicate edge_in_range) {
PacketType packet{};

using address_t = cl::sycl::access::address_space;
if (in_range) {
using address_t = cl::sycl::access::address_space;
packet.template load<address_t::global_space>(
0, cl::sycl::multi_ptr<const value_t, address_t::global_space>(src));
store(packet, dest);
} else {
// avoid writing to variable, instead directly write to
// shared local memory to avoid race condition experienced
// with release compiler.
#pragma unroll
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<value_t *>(&packet)[i] =
edge_in_range(i) ? *(src + i) : value_t{0};
}
}
store<trans, ld>(packet, dest);
}
/*! @brief Store a vector packet into local memory when the source is
* transposed. This will untranspose the elements individually when storing so
* the data in local memory is always consistent.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam ld The leading dimension of the destination memory.*/
template <bool trans, index_t ld, typename DestPointerType>
static PORTBLAS_INLINE typename std::enable_if<trans>::type store(
PacketType &packet, DestPointerType dest) {
using address_t = cl::sycl::access::address_space;
#pragma unroll
for (index_t i = 0; i < packet_size; i++) {
value_t val = reinterpret_cast<value_t *>(&packet)[i];
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*(dest + ld * i) = static_cast<dtype>(val);
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*(dest + ld * i) = static_cast<dtype>(val);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*(dest + ld * i) = round_to_tf32(val);
for (index_t i = 0; i < packet_size; i++, dest++, src++) {
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*dest = static_cast<dtype>(edge_in_range(i) ? *src : 0);
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using namespace cl::sycl::ext::oneapi;
*dest = bfloat16(edge_in_range(i) ? *src : 0.f);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f;
}
}
}
}

/*! @brief Store a vector packet into local memory when the source is not
* transposed. This will use sycl::vec::store function.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam ld The leading dimension of the destination memory.*/
template <bool trans, int ld, typename DestPointerType>
static PORTBLAS_INLINE typename std::enable_if<!trans>::type store(
PacketType &packet, DestPointerType dest) {
/*! @brief Store a vector packet into local memory. This will use
* sycl::vec::store function.
*/
template <typename DestPointerType>
static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) {
using address_t = cl::sycl::access::address_space;
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*dest = static_cast<dtype>(packet[0]);
cl::sycl::vec<dtype, vector_size> new_vec{};
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
static_cast<dtype>(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(packet[0]);
// sycl::vec doesn't accept bfloat16 as a valid input type
// so we need to write the packet elements individually to
// the shared memory.
using namespace cl::sycl::ext::oneapi;
for (index_t i = 0; i < packet_size; i++, dest++) {
*dest = bfloat16(reinterpret_cast<value_t *>(&packet)[i]);
}
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = round_to_tf32(packet[0]);
using dtype = float;
cl::sycl::vec<dtype, vector_size> new_vec;
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
round_to_tf32(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
}
}
};
Expand Down
Loading
Loading