Skip to content

Commit

Permalink
Increase error margins for trsm tests when joint_matrix is used
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Feb 1, 2024
1 parent e7f1a43 commit ba3ca46
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
51 changes: 38 additions & 13 deletions common/include/common/float_comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,27 @@ scalar_t clamp_to_limits(scalar_t v) {
* Indicates the tolerated margin for relative differences
*/
template <typename scalar_t>
inline scalar_t getRelativeErrorMargin() {
inline scalar_t getRelativeErrorMargin(const bool is_trsm) {
/* Measured empirically with gemm. The dimensions of the matrices (even k)
* don't seem to have an impact on the observed relative differences
* In the cases where the relative error is relevant (non close to zero),
* relative differences of up to 0.002 were observed for float
*/
return static_cast<scalar_t>(0.005);
scalar_t margin = 0.005;
if (is_trsm) {
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
if (en_joint_matrix != NULL && std::is_same<scalar_t, float>::value &&
*en_joint_matrix == '1') {
// increase error margin for mixed precision calculation
// for trsm operator.
margin = 0.009f;
}
}
return margin;
}

template <>
inline double getRelativeErrorMargin<double>() {
inline double getRelativeErrorMargin<double>(const bool) {
/* Measured empirically with gemm. The dimensions of the matrices (even k)
* don't seem to have an impact on the observed relative differences
* In the cases where the relative error is relevant (non close to zero),
Expand All @@ -142,7 +152,7 @@ inline double getRelativeErrorMargin<double>() {
#ifdef BLAS_DATA_TYPE_HALF

template <>
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>(const bool) {
// Measured empirically with gemm
return 0.05f;
}
Expand All @@ -152,16 +162,27 @@ inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
* scalars are close to 0)
*/
template <typename scalar_t>
inline scalar_t getAbsoluteErrorMargin() {
inline scalar_t getAbsoluteErrorMargin(const bool is_trsm) {
/* Measured empirically with gemm.
* In the cases where the relative error is irrelevant (close to zero),
* absolute differences of up to 0.0006 were observed for float
*/
return 0.001f;
scalar_t margin = 0.001f;
if (is_trsm) {
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
if (en_joint_matrix != NULL && std::is_same<scalar_t, float>::value &&
*en_joint_matrix == '1') {
// increase error margin for mixed precision calculation
// for trsm operator.
margin = 0.009f;
}
}

return margin;
}

template <>
inline double getAbsoluteErrorMargin<double>() {
inline double getAbsoluteErrorMargin<double>(const bool) {
/* Measured empirically with gemm.
* In the cases where the relative error is irrelevant (close to zero),
* absolute differences of up to 10^-12 were observed for double
Expand All @@ -171,7 +192,7 @@ inline double getAbsoluteErrorMargin<double>() {
#ifdef BLAS_DATA_TYPE_HALF

template <>
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>(const bool) {
// Measured empirically with gemm.
return 1.0f;
}
Expand All @@ -181,7 +202,8 @@ inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
* Compare two scalars and returns false if the difference is not acceptable.
*/
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2,
const bool is_trsm = false) {
// Shortcut, also handles case where both are zero
if (scalar1 == scalar2) {
return true;
Expand All @@ -196,12 +218,13 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {

// Close to zero, the relative error doesn't work, use absolute error
if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} ||
absolute_diff < getAbsoluteErrorMargin<epsilon_t>()) {
return (absolute_diff < getAbsoluteErrorMargin<epsilon_t>());
absolute_diff < getAbsoluteErrorMargin<epsilon_t>(is_trsm)) {
return (absolute_diff < getAbsoluteErrorMargin<epsilon_t>(is_trsm));
}
// Use relative error
const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2);
return (absolute_diff / absolute_sum) < getRelativeErrorMargin<epsilon_t>();
return (absolute_diff / absolute_sum) <
getRelativeErrorMargin<epsilon_t>(is_trsm);
}

/**
Expand All @@ -214,6 +237,7 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors(std::vector<scalar_t> const& vec,
std::vector<scalar_t> const& ref,
const bool is_trsm = false,
std::ostream& err_stream = std::cerr,
std::string end_line = "\n") {
if (vec.size() != ref.size()) {
Expand All @@ -223,7 +247,7 @@ inline bool compare_vectors(std::vector<scalar_t> const& vec,
}

for (int i = 0; i < vec.size(); ++i) {
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i])) {
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i], is_trsm)) {
err_stream << "Value mismatch at index " << i << ": " << vec[i]
<< "; expected " << ref[i] << end_line;
return false;
Expand All @@ -243,6 +267,7 @@ inline bool compare_vectors(std::vector<scalar_t> const& vec,
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors(std::vector<std::complex<scalar_t>> const& vec,
std::vector<std::complex<scalar_t>> const& ref,
bool is_trsm = false,
std::ostream& err_stream = std::cerr,
std::string end_line = "\n") {
if (vec.size() != ref.size()) {
Expand Down
2 changes: 1 addition & 1 deletion test/unittest/blas3/blas3_trsm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void run_test(const combination_t<scalar_t> combi) {
blas::helper::copy_to_host<scalar_t>(q, b_gpu, B.data(), B.size());
sb_handle.wait(event);

bool isAlmostEqual = utils::compare_vectors(cpu_B, B);
bool isAlmostEqual = utils::compare_vectors(cpu_B, B, true);

ASSERT_TRUE(isAlmostEqual);

Expand Down

0 comments on commit ba3ca46

Please sign in to comment.