diff --git a/xla/service/gpu/buffer_comparator.cc b/xla/service/gpu/buffer_comparator.cc index 396826f2f9106..272cfa6831187 100644 --- a/xla/service/gpu/buffer_comparator.cc +++ b/xla/service/gpu/buffer_comparator.cc @@ -46,7 +46,7 @@ using ComparisonKernelT = float, uint64_t, se::DeviceMemory>; struct ComparisonParams { - float relative_tol = 0.1f; + double relative_tol = 0.1f; bool verbose = true; const Shape *shape = nullptr; se::Stream* stream = nullptr; @@ -93,7 +93,8 @@ static absl::StatusOr DeviceCompare( se::DeviceMemory as_uint64(out.memory()); TF_RETURN_IF_ERROR(params.stream->ThenLaunch( dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel, - current_typed, expected_typed, params.relative_tol, buffer_size, as_uint64)); + current_typed, expected_typed, static_cast(params.relative_tol), + buffer_size, as_uint64)); uint64_t result = -1; CHECK_EQ(out.memory().size(), sizeof(result)); @@ -150,9 +151,8 @@ static absl::StatusOr HostCompare(const ComparisonParams& params) { !(std::abs(current_value_canonical - expected_value_canonical) / (std::max(std::abs(current_value_canonical), std::abs(expected_value_canonical)) + - 1) < - static_cast(params.relative_tol))) { - if(!params.verbose) return false; // Return immediately if not verbose. + 1) < params.relative_tol)) { + if (!params.verbose) return false; // Return immediately if not verbose. ++differences_seen; LOG(ERROR) << "Difference at " << i << ": " << current_value << ", expected " << expected_value; diff --git a/xla/service/gpu/buffer_comparator_test.cc b/xla/service/gpu/buffer_comparator_test.cc index 20aee05740c05..9481473efbe26 100644 --- a/xla/service/gpu/buffer_comparator_test.cc +++ b/xla/service/gpu/buffer_comparator_test.cc @@ -39,6 +39,9 @@ namespace xla { namespace gpu { namespace { +constexpr double kDefaultTolerance = 0.1; + + class BufferComparatorTest : public testing::Test { protected: BufferComparatorTest() @@ -53,7 +56,8 @@ class BufferComparatorTest : public testing::Test { // Take floats only for convenience. Still uses ElementType internally. template bool CompareEqualBuffers(const std::vector& current, - const std::vector& expected) { + const std::vector& expected, + double tolerance) { auto stream = stream_exec_->CreateStream().value(); se::DeviceMemoryHandle current_buffer( @@ -82,16 +86,18 @@ class BufferComparatorTest : public testing::Test { // Take floats only for convenience. Still uses ElementType internally. template bool CompareEqualFloatBuffers(const std::vector& lhs_float, - const std::vector& rhs_float) { + const std::vector& rhs_float, + double tolerance = kDefaultTolerance) { std::vector lhs(lhs_float.begin(), lhs_float.end()); std::vector rhs(rhs_float.begin(), rhs_float.end()); - return CompareEqualBuffers(lhs, rhs); + return CompareEqualBuffers(lhs, rhs, tolerance); } template bool CompareEqualComplex(const std::vector>& lhs, const std::vector>& rhs) { - return CompareEqualBuffers>(lhs, rhs); + return CompareEqualBuffers>(lhs, rhs, + kDefaultTolerance); } se::Platform* platform_;