Skip to content

Commit

Permalink
Increase error margins for trsm tests when joint_matrix is used
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Feb 1, 2024
1 parent e7f1a43 commit 7985cbd
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
13 changes: 7 additions & 6 deletions benchmark/portblas/blas3/trsm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, char side,
}

std::ostringstream err_stream;
if (!utils::compare_vectors(b_temp, x_ref, err_stream, "")) {
if (!utils::compare_vectors(b_temp, x_ref, err_stream, "", true)) {
const std::string& err_str = err_stream.str();
state.SkipWithError(err_str.c_str());
*success = false;
Expand Down Expand Up @@ -181,8 +181,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, trans, diag, m, n,
mem_type).c_str(),
side, uplo, trans, diag, m, n, mem_type)
.c_str(),
BM_lambda, sb_handle_ptr, side, uplo, trans, diag, m, n, alpha, success)
->UseRealTime();
}
Expand All @@ -193,16 +193,17 @@ void register_benchmark(blas_benchmark::Args& args,
blas::SB_Handle* sb_handle_ptr, bool* success) {
auto trsm_params = blas_benchmark::utils::get_trsm_params<scalar_t>(args);
register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, trsm_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
trsm_params);
#ifdef SB_ENABLE_USM
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, trsm_params);
#endif
}

namespace blas_benchmark {
void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr,
bool* success) {
void create_benchmark(blas_benchmark::Args& args,
blas::SB_Handle* sb_handle_ptr, bool* success) {
BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success);
}
} // namespace blas_benchmark
54 changes: 39 additions & 15 deletions common/include/common/float_comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,27 @@ scalar_t clamp_to_limits(scalar_t v) {
* Indicates the tolerated margin for relative differences
*/
template <typename scalar_t>
inline scalar_t getRelativeErrorMargin() {
inline scalar_t getRelativeErrorMargin(const bool is_trsm) {
/* Measured empirically with gemm. The dimensions of the matrices (even k)
* don't seem to have an impact on the observed relative differences
* In the cases where the relative error is relevant (non close to zero),
* relative differences of up to 0.002 were observed for float
*/
return static_cast<scalar_t>(0.005);
scalar_t margin = 0.005;
if (is_trsm) {
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
if (en_joint_matrix != NULL && std::is_same<scalar_t, float>::value &&
*en_joint_matrix == '1') {
// increase error margin for mixed precision calculation
// for trsm operator.
margin = 0.009f;
}
}
return margin;
}

template <>
inline double getRelativeErrorMargin<double>() {
inline double getRelativeErrorMargin<double>(const bool) {
/* Measured empirically with gemm. The dimensions of the matrices (even k)
* don't seem to have an impact on the observed relative differences
* In the cases where the relative error is relevant (non close to zero),
Expand All @@ -142,7 +152,7 @@ inline double getRelativeErrorMargin<double>() {
#ifdef BLAS_DATA_TYPE_HALF

template <>
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>(const bool) {
// Measured empirically with gemm
return 0.05f;
}
Expand All @@ -152,16 +162,27 @@ inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
* scalars are close to 0)
*/
template <typename scalar_t>
inline scalar_t getAbsoluteErrorMargin() {
inline scalar_t getAbsoluteErrorMargin(const bool is_trsm) {
/* Measured empirically with gemm.
* In the cases where the relative error is irrelevant (close to zero),
* absolute differences of up to 0.0006 were observed for float
*/
return 0.001f;
scalar_t margin = 0.001f;
if (is_trsm) {
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
if (en_joint_matrix != NULL && std::is_same<scalar_t, float>::value &&
*en_joint_matrix == '1') {
// increase error margin for mixed precision calculation
// for trsm operator.
margin = 0.009f;
}
}

return margin;
}

template <>
inline double getAbsoluteErrorMargin<double>() {
inline double getAbsoluteErrorMargin<double>(const bool) {
/* Measured empirically with gemm.
* In the cases where the relative error is irrelevant (close to zero),
* absolute differences of up to 10^-12 were observed for double
Expand All @@ -171,7 +192,7 @@ inline double getAbsoluteErrorMargin<double>() {
#ifdef BLAS_DATA_TYPE_HALF

template <>
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>(const bool) {
// Measured empirically with gemm.
return 1.0f;
}
Expand All @@ -181,7 +202,8 @@ inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
* Compare two scalars and returns false if the difference is not acceptable.
*/
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2,
const bool is_trsm = false) {
// Shortcut, also handles case where both are zero
if (scalar1 == scalar2) {
return true;
Expand All @@ -196,12 +218,13 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {

// Close to zero, the relative error doesn't work, use absolute error
if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} ||
absolute_diff < getAbsoluteErrorMargin<epsilon_t>()) {
return (absolute_diff < getAbsoluteErrorMargin<epsilon_t>());
absolute_diff < getAbsoluteErrorMargin<epsilon_t>(is_trsm)) {
return (absolute_diff < getAbsoluteErrorMargin<epsilon_t>(is_trsm));
}
// Use relative error
const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2);
return (absolute_diff / absolute_sum) < getRelativeErrorMargin<epsilon_t>();
return (absolute_diff / absolute_sum) <
getRelativeErrorMargin<epsilon_t>(is_trsm);
}

/**
Expand All @@ -215,15 +238,16 @@ template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors(std::vector<scalar_t> const& vec,
std::vector<scalar_t> const& ref,
std::ostream& err_stream = std::cerr,
std::string end_line = "\n") {
std::string end_line = "\n",
const bool is_trsm = false) {
if (vec.size() != ref.size()) {
err_stream << "Error: tried to compare vectors of different sizes"
<< std::endl;
return false;
}

for (int i = 0; i < vec.size(); ++i) {
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i])) {
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i], is_trsm)) {
err_stream << "Value mismatch at index " << i << ": " << vec[i]
<< "; expected " << ref[i] << end_line;
return false;
Expand All @@ -244,7 +268,7 @@ template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors(std::vector<std::complex<scalar_t>> const& vec,
std::vector<std::complex<scalar_t>> const& ref,
std::ostream& err_stream = std::cerr,
std::string end_line = "\n") {
std::string end_line = "\n", bool is_trsm = false) {
if (vec.size() != ref.size()) {
err_stream << "Error: tried to compare vectors of different sizes"
<< std::endl;
Expand Down
2 changes: 1 addition & 1 deletion test/unittest/blas3/blas3_trsm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void run_test(const combination_t<scalar_t> combi) {
blas::helper::copy_to_host<scalar_t>(q, b_gpu, B.data(), B.size());
sb_handle.wait(event);

bool isAlmostEqual = utils::compare_vectors(cpu_B, B);
bool isAlmostEqual = utils::compare_vectors(cpu_B, B, std::cerr, "", true);

ASSERT_TRUE(isAlmostEqual);

Expand Down

0 comments on commit 7985cbd

Please sign in to comment.