Skip to content

Commit

Permalink
adopted changes
Browse files Browse the repository at this point in the history
  • Loading branch information
i-chaochen committed Jul 15, 2024
1 parent f753b6f commit 51dc9a5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
10 changes: 5 additions & 5 deletions xla/service/gpu/buffer_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ using ComparisonKernelT =
float, uint64_t, se::DeviceMemory<uint64_t>>;

struct ComparisonParams {
float relative_tol = 0.1f;
double relative_tol = 0.1f;
bool verbose = true;
const Shape *shape = nullptr;
se::Stream* stream = nullptr;
Expand Down Expand Up @@ -93,7 +93,8 @@ static absl::StatusOr<bool> DeviceCompare(
se::DeviceMemory<uint64_t> as_uint64(out.memory());
TF_RETURN_IF_ERROR(params.stream->ThenLaunch(
dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel,
current_typed, expected_typed, params.relative_tol, buffer_size, as_uint64));
current_typed, expected_typed, static_cast<float>(params.relative_tol),
buffer_size, as_uint64));

uint64_t result = -1;
CHECK_EQ(out.memory().size(), sizeof(result));
Expand Down Expand Up @@ -150,9 +151,8 @@ static absl::StatusOr<bool> HostCompare(const ComparisonParams& params) {
!(std::abs(current_value_canonical - expected_value_canonical) /
(std::max(std::abs(current_value_canonical),
std::abs(expected_value_canonical)) +
1) <
static_cast<double>(params.relative_tol))) {
if(!params.verbose) return false; // Return immediately if not verbose.
1) < params.relative_tol)) {
if (!params.verbose) return false; // Return immediately if not verbose.
++differences_seen;
LOG(ERROR) << "Difference at " << i << ": " << current_value
<< ", expected " << expected_value;
Expand Down
14 changes: 10 additions & 4 deletions xla/service/gpu/buffer_comparator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ namespace xla {
namespace gpu {
namespace {

constexpr double kDefaultTolerance = 0.1;


class BufferComparatorTest : public testing::Test {
protected:
BufferComparatorTest()
Expand All @@ -53,7 +56,8 @@ class BufferComparatorTest : public testing::Test {
// Take floats only for convenience. Still uses ElementType internally.
template <typename ElementType>
bool CompareEqualBuffers(const std::vector<ElementType>& current,
const std::vector<ElementType>& expected) {
const std::vector<ElementType>& expected,
double tolerance) {
auto stream = stream_exec_->CreateStream().value();

se::DeviceMemoryHandle current_buffer(
Expand Down Expand Up @@ -82,16 +86,18 @@ class BufferComparatorTest : public testing::Test {
// Take floats only for convenience. Still uses ElementType internally.
template <typename ElementType>
bool CompareEqualFloatBuffers(const std::vector<float>& lhs_float,
const std::vector<float>& rhs_float) {
const std::vector<float>& rhs_float,
double tolerance = kDefaultTolerance) {
std::vector<ElementType> lhs(lhs_float.begin(), lhs_float.end());
std::vector<ElementType> rhs(rhs_float.begin(), rhs_float.end());
return CompareEqualBuffers(lhs, rhs);
return CompareEqualBuffers(lhs, rhs, tolerance);
}

template <typename ElementType>
bool CompareEqualComplex(const std::vector<std::complex<ElementType>>& lhs,
const std::vector<std::complex<ElementType>>& rhs) {
return CompareEqualBuffers<std::complex<ElementType>>(lhs, rhs);
return CompareEqualBuffers<std::complex<ElementType>>(lhs, rhs,
kDefaultTolerance);
}

se::Platform* platform_;
Expand Down

0 comments on commit 51dc9a5

Please sign in to comment.