Skip to content

Commit

Permalink
Fix joint_matrix implementation to match latest api (#491)
Browse files Browse the repository at this point in the history
* Added tests for joint_matrix implementation as well

---------

Co-authored-by: pgorlani <[email protected]>
  • Loading branch information
muhammad-tanvir-1211 and pgorlani authored Feb 28, 2024
1 parent 2f149cb commit 861b310
Show file tree
Hide file tree
Showing 24 changed files with 2,051 additions and 306 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ export(EXPORT portblas

option(BLAS_ENABLE_TESTING "Whether to enable testing" ON)
option(ENABLE_EXPRESSION_TESTS "Whether to build expression tree fusion tests" OFF)
option(ENABLE_JOINTMATRIX_TESTS "Whether to build joint_matrix GEMM tests" OFF)
if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_TESTING)
message(STATUS "Tests are disabled when installing portBLAS in header only mode")
set(BLAS_ENABLE_TESTING OFF)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ Some of the supported options are:
| `CMAKE_INSTALL_PREFIX` | path | Specify the install location, used when invoking `ninja install` |
| `BUILD_SHARED_LIBS` | `ON`/`OFF` | Build as shared library (`ON` by default) |
| `ENABLE_EXPRESSION_TESTS` | `ON`/`OFF` | Build additional tests that use the header-only framework (e.g to test expression trees); `OFF` by default |
| `ENABLE_JOINTMATRIX_TESTS` | `ON`/`OFF` | Build additional tests that use joint_matrix extension; `OFF` by default |
| `BLAS_VERIFY_BENCHMARK` | `ON`/`OFF` | Verify the results of the benchmarks instead of only measuring the performance. See the documentation of the benchmarks for more details. `ON` by default |
| `BLAS_MEMPOOL_BENCHMARK` | `ON`/`OFF` | Determines whether to enable the scratchpad memory pool for benchmark execution. `OFF` by default |
| `BLAS_ENABLE_CONST_INPUT` | `ON`/`OFF` | Determines whether to enable kernel instantiation with const input buffer (`ON` by default) |
Expand Down
19 changes: 13 additions & 6 deletions benchmark/portblas/blas3/trsm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, char side,
}

std::ostringstream err_stream;
if (!utils::compare_vectors(b_temp, x_ref, err_stream, "")) {
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
if (!utils::compare_vectors(b_temp, x_ref, err_stream, "",
(en_joint_matrix != NULL) &&
(std::is_same<scalar_t, float>::value) &&
(*en_joint_matrix == '1')
? 2
: 1)) {
const std::string& err_str = err_stream.str();
state.SkipWithError(err_str.c_str());
*success = false;
Expand Down Expand Up @@ -181,8 +187,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, trans, diag, m, n,
mem_type).c_str(),
side, uplo, trans, diag, m, n, mem_type)
.c_str(),
BM_lambda, sb_handle_ptr, side, uplo, trans, diag, m, n, alpha, success)
->UseRealTime();
}
Expand All @@ -193,16 +199,17 @@ void register_benchmark(blas_benchmark::Args& args,
blas::SB_Handle* sb_handle_ptr, bool* success) {
auto trsm_params = blas_benchmark::utils::get_trsm_params<scalar_t>(args);
register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, trsm_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
trsm_params);
#ifdef SB_ENABLE_USM
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, trsm_params);
#endif
}

namespace blas_benchmark {
void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr,
bool* success) {
void create_benchmark(blas_benchmark::Args& args,
blas::SB_Handle* sb_handle_ptr, bool* success) {
BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success);
}
} // namespace blas_benchmark
38 changes: 24 additions & 14 deletions common/include/common/float_comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,20 @@ scalar_t clamp_to_limits(scalar_t v) {
* Indicates the tolerated margin for relative differences
*/
template <typename scalar_t>
inline scalar_t getRelativeErrorMargin() {
inline scalar_t getRelativeErrorMargin(const int32_t margin_multiplier = 1) {
/* Measured empirically with gemm. The dimensions of the matrices (even k)
* don't seem to have an impact on the observed relative differences
* In the cases where the relative error is relevant (non close to zero),
* relative differences of up to 0.002 were observed for float
*/
return static_cast<scalar_t>(0.005);
scalar_t margin = 0.005;
// increase error margin for mixed precision calculation
// for trsm operator.
return margin * margin_multiplier;
}

template <>
inline double getRelativeErrorMargin<double>() {
inline double getRelativeErrorMargin<double>(const int32_t) {
/* Measured empirically with gemm. The dimensions of the matrices (even k)
* don't seem to have an impact on the observed relative differences
* In the cases where the relative error is relevant (non close to zero),
Expand All @@ -135,7 +138,7 @@ inline double getRelativeErrorMargin<double>() {
}

template <>
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>(const int32_t) {
// Measured empirically with gemm
return 0.05f;
}
Expand All @@ -145,16 +148,19 @@ inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
* scalars are close to 0)
*/
template <typename scalar_t>
inline scalar_t getAbsoluteErrorMargin() {
inline scalar_t getAbsoluteErrorMargin(const int32_t margin_multiplier = 1) {
/* Measured empirically with gemm.
* In the cases where the relative error is irrelevant (close to zero),
* absolute differences of up to 0.0006 were observed for float
*/
return 0.001f;
scalar_t margin = 0.001f;
// increase error margin for mixed precision calculation
// for trsm operator.
return margin * margin_multiplier;
}

template <>
inline double getAbsoluteErrorMargin<double>() {
inline double getAbsoluteErrorMargin<double>(const int32_t) {
/* Measured empirically with gemm.
* In the cases where the relative error is irrelevant (close to zero),
* absolute differences of up to 10^-12 were observed for double
Expand All @@ -163,7 +169,7 @@ inline double getAbsoluteErrorMargin<double>() {
}

template <>
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>(const int32_t) {
// Measured empirically with gemm.
return 1.0f;
}
Expand All @@ -172,7 +178,8 @@ inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
* Compare two scalars and returns false if the difference is not acceptable.
*/
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2,
const int32_t margin_multiplier = 1) {
// Shortcut, also handles case where both are zero
if (scalar1 == scalar2) {
return true;
Expand All @@ -187,12 +194,14 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {

// Close to zero, the relative error doesn't work, use absolute error
if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} ||
absolute_diff < getAbsoluteErrorMargin<epsilon_t>()) {
return (absolute_diff < getAbsoluteErrorMargin<epsilon_t>());
absolute_diff < getAbsoluteErrorMargin<epsilon_t>(margin_multiplier)) {
return (absolute_diff <
getAbsoluteErrorMargin<epsilon_t>(margin_multiplier));
}
// Use relative error
const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2);
return (absolute_diff / absolute_sum) < getRelativeErrorMargin<epsilon_t>();
return (absolute_diff / absolute_sum) <
getRelativeErrorMargin<epsilon_t>(margin_multiplier);
}

/**
Expand All @@ -206,15 +215,16 @@ template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors(std::vector<scalar_t> const& vec,
std::vector<scalar_t> const& ref,
std::ostream& err_stream = std::cerr,
std::string end_line = "\n") {
std::string end_line = "\n",
const int32_t margin_multiplier = 1) {
if (vec.size() != ref.size()) {
err_stream << "Error: tried to compare vectors of different sizes"
<< std::endl;
return false;
}

for (int i = 0; i < vec.size(); ++i) {
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i])) {
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i], margin_multiplier)) {
err_stream << "Value mismatch at index " << i << ": " << vec[i]
<< "; expected " << ref[i] << end_line;
return false;
Expand Down
113 changes: 57 additions & 56 deletions src/operations/blas3/gemm_load_store_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,16 @@ struct PacketizeJointMatrix {

/*! @brief Performs a coalesced non-vectorized load when the current block is
* not internal.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam internal True if the current block is internal and no bounds
* checking is required.
* @tparam ld The leading dimension of the destination memory.
*/

template <bool trans, bool internal, int ld, typename SrcPointerType,
typename DestPointerType, typename EdgePredicate>
template <bool internal, typename SrcPointerType, typename DestPointerType,
typename EdgePredicate>
static PORTBLAS_INLINE typename std::enable_if<!internal>::type load(
const bool in_range, SrcPointerType src, DestPointerType dest,
EdgePredicate) {
value_t val = in_range ? *(src) : value_t{0};
value_t val = in_range ? *src : value_t{0};
using address_t = cl::sycl::access::address_space;
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
Expand All @@ -79,93 +77,96 @@ struct PacketizeJointMatrix {
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(val);
using namespace cl::sycl::ext::oneapi;
*dest = bfloat16(val);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = round_to_tf32(val);
}
}

/*! @brief Performs a vectorised load using sycl::vec::load when the current
* block is internal. In the case where k < the
* number of elements being loaded then edge loads will be element wise with
* additional bounds checking.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam internal True if the current block is internal and no bounds
* checking is required.
* @tparam ld The leading dimension of the destination memory. */
template <bool trans, bool internal, index_t ld, typename SrcPointerType,
typename DestPointerType, typename EdgePredicate>
*/
template <bool internal, typename SrcPointerType, typename DestPointerType,
typename EdgePredicate>
static PORTBLAS_INLINE typename std::enable_if<internal>::type load(
const bool in_range, SrcPointerType src, DestPointerType dest,
EdgePredicate edge_in_range) {
PacketType packet{};

using address_t = cl::sycl::access::address_space;
if (in_range) {
using address_t = cl::sycl::access::address_space;
packet.template load<address_t::global_space>(
0, cl::sycl::multi_ptr<const value_t, address_t::global_space>(src));
store(packet, dest);
} else {
// avoid writing to variable, instead directly write to
// shared local memory to avoid race condition experienced
// with release compiler.
#pragma unroll
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<value_t *>(&packet)[i] =
edge_in_range(i) ? *(src + i) : value_t{0};
}
}
store<trans, ld>(packet, dest);
}
/*! @brief Store a vector packet into local memory when the source is
* transposed. This will untranspose the elements individually when storing so
* the data in local memory is always consistent.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam ld The leading dimension of the destination memory.*/
template <bool trans, index_t ld, typename DestPointerType>
static PORTBLAS_INLINE typename std::enable_if<trans>::type store(
PacketType &packet, DestPointerType dest) {
using address_t = cl::sycl::access::address_space;
#pragma unroll
for (index_t i = 0; i < packet_size; i++) {
value_t val = reinterpret_cast<value_t *>(&packet)[i];
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*(dest + ld * i) = static_cast<dtype>(val);
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*(dest + ld * i) = static_cast<dtype>(val);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*(dest + ld * i) = round_to_tf32(val);
for (index_t i = 0; i < packet_size; i++, dest++, src++) {
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*dest = static_cast<dtype>(edge_in_range(i) ? *src : 0);
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using namespace cl::sycl::ext::oneapi;
*dest = bfloat16(edge_in_range(i) ? *src : 0.f);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f;
}
}
}
}

/*! @brief Store a vector packet into local memory when the source is not
* transposed. This will use sycl::vec::store function.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam ld The leading dimension of the destination memory.*/
template <bool trans, int ld, typename DestPointerType>
static PORTBLAS_INLINE typename std::enable_if<!trans>::type store(
PacketType &packet, DestPointerType dest) {
/*! @brief Store a vector packet into local memory. This will use
* sycl::vec::store function.
*/
template <typename DestPointerType>
static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) {
using address_t = cl::sycl::access::address_space;
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*dest = static_cast<dtype>(packet[0]);
cl::sycl::vec<dtype, vector_size> new_vec{};
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
static_cast<dtype>(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(packet[0]);
// sycl::vec doesn't accept bfloat16 as a valid input type
// so we need to write the packet elements individually to
// the shared memory.
using namespace cl::sycl::ext::oneapi;
for (index_t i = 0; i < packet_size; i++, dest++) {
*dest = bfloat16(reinterpret_cast<value_t *>(&packet)[i]);
}
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = round_to_tf32(packet[0]);
using dtype = float;
cl::sycl::vec<dtype, vector_size> new_vec;
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
round_to_tf32(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
}
}
};
Expand Down
Loading

0 comments on commit 861b310

Please sign in to comment.