Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rocm jaxlib v0.4.28 qa gemm rtol #26

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_llvm_enable_invariant_load_metadata(true);
opts.set_xla_llvm_disable_expensive_passes(false);
opts.set_xla_backend_optimization_level(3);
opts.set_xla_gpu_autotune_level(4);
opts.set_xla_gpu_autotune_level(5);
opts.set_xla_gpu_autotune_max_solutions(0);
opts.set_xla_cpu_multi_thread_eigen(true);
opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,7 @@ xla_test(
":backend_configs_cc",
":gemm_algorithm_picker",
":gemm_rewriter",
":variant_visitor",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
Expand Down
114 changes: 64 additions & 50 deletions xla/service/gpu/buffer_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,35 @@ using ComparisonKernelT =
se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
float, uint64_t, se::DeviceMemory<uint64_t>>;

struct ComparisonParams {
double relative_tol = 0.1f;
bool verbose = true;
const Shape *shape = nullptr;
se::Stream* stream = nullptr;
se::DeviceMemoryBase current{};
se::DeviceMemoryBase expected{};
};

// Compares two buffers on the GPU.
//
// Returns `true` if two buffers are equal, `false` otherwise.
template <typename ElementT>
absl::StatusOr<bool> BufferComparator::DeviceCompare(
std::string_view kernel_name, void* kernel_symbol) {
se::StreamExecutor* executor = stream_->parent();
static absl::StatusOr<bool> DeviceCompare(
std::string_view kernel_name, void* kernel_symbol,
const ComparisonParams& params) {
se::StreamExecutor* executor = params.stream->parent();

se::DeviceMemoryHandle out_param(executor,
se::DeviceMemoryHandle out(executor,
executor->AllocateScalar<uint64_t>());

TF_RETURN_IF_ERROR(
stream_->MemZero(out_param.memory_ptr(), sizeof(uint64_t)));
if (current_.size() != expected_.size()) {
TF_RETURN_IF_ERROR(params.stream->MemZero(out.memory_ptr(), sizeof(uint64_t)));
if (params.current.size() != params.expected.size()) {
return Internal("Mismatched buffer size: %d bytes vs. %d bytes",
current_.size(), expected_.size());
params.current.size(), params.expected.size());
}

se::DeviceMemory<ElementT> current_typed(current_);
se::DeviceMemory<ElementT> expected_typed(expected_);
se::DeviceMemory<ElementT> current_typed(params.current);
se::DeviceMemory<ElementT> expected_typed(params.expected);
uint64_t buffer_size = current_typed.ElementCount();

TF_ASSIGN_OR_RETURN(
Expand All @@ -78,18 +87,20 @@ absl::StatusOr<bool> BufferComparator::DeviceCompare(
const se::DeviceDescription& gpu_device_info =
executor->GetDeviceDescription();

LaunchDimensions dim = CalculateLaunchDimensions(shape_, gpu_device_info);
LaunchDimensions dim =
CalculateLaunchDimensions(*params.shape, gpu_device_info);

se::DeviceMemory<uint64_t> as_uint64(out_param.memory());
TF_RETURN_IF_ERROR(stream_->ThenLaunch(
se::DeviceMemory<uint64_t> as_uint64(out.memory());
TF_RETURN_IF_ERROR(params.stream->ThenLaunch(
dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel,
current_typed, expected_typed, rtol_tolerance_, buffer_size, as_uint64));
current_typed, expected_typed, static_cast<float>(params.relative_tol),
buffer_size, as_uint64));

uint64_t result = -1;
CHECK_EQ(out_param.memory().size(), sizeof(result));
CHECK_EQ(out.memory().size(), sizeof(result));
TF_RETURN_IF_ERROR(
stream_->Memcpy(&result, out_param.memory(), sizeof(result)));
TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone());
params.stream->Memcpy(&result, out.memory(), sizeof(result)));
TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone());
return result == 0;
}

Expand All @@ -98,14 +109,16 @@ absl::StatusOr<bool> BufferComparator::DeviceCompare(
//
// Returns true if no differences were seen, false otherwise.
template <typename ElementType, typename ComparisonType>
absl::StatusOr<bool> BufferComparator::HostCompare() {
int64_t n = current_.size() / sizeof(ElementType);
static absl::StatusOr<bool> HostCompare(const ComparisonParams& params) {
int64_t n = params.current.size() / sizeof(ElementType);
std::vector<ElementType> host_current(n), host_expected(n);
TF_RETURN_IF_ERROR(
stream_->Memcpy(host_current.data(), current_, current_.size()));
params.stream->Memcpy(host_current.data(), params.current,
params.current.size()));
TF_RETURN_IF_ERROR(
stream_->Memcpy(host_expected.data(), expected_, expected_.size()));
TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone());
params.stream->Memcpy(host_expected.data(), params.expected,
params.expected.size()));
TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone());

const auto canonicalize = [](ComparisonType a) -> ComparisonType {
if (std::is_same<ElementType, Eigen::half>::value && a) {
Expand Down Expand Up @@ -138,88 +151,89 @@ absl::StatusOr<bool> BufferComparator::HostCompare() {
!(std::abs(current_value_canonical - expected_value_canonical) /
(std::max(std::abs(current_value_canonical),
std::abs(expected_value_canonical)) +
1) <
(double)rtol_tolerance_)) {
if (!verbose_) return false; // Return immediately if not verbose.
1) < params.relative_tol)) {
if (!params.verbose) return false; // Return immediately if not verbose.
++differences_seen;
LOG(ERROR) << "Difference at " << i << ": " << current_value
<< ", expected " << expected_value;
<< ", expected " << expected_value;
}
}
return differences_seen == 0;
}

template <typename ElementT, typename ComparisonT>
absl::StatusOr<bool> BufferComparator::CompareEqualParameterized(
std::string_view kernel_name, void* kernel_symbol) {
static absl::StatusOr<bool> CompareEqualParameterized(
std::string_view kernel_name, void* kernel_symbol,
const ComparisonParams& params) {
XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
TF_ASSIGN_OR_RETURN(bool result,
DeviceCompare<ElementT>(kernel_name, kernel_symbol));
TF_ASSIGN_OR_RETURN(
bool result, DeviceCompare<ElementT>(kernel_name, kernel_symbol, params));

if (result) {
return true;
}

TF_ASSIGN_OR_RETURN(bool host_return, (HostCompare<ElementT, ComparisonT>()));
TF_ASSIGN_OR_RETURN(bool host_return,
(HostCompare<ElementT, ComparisonT>(params)));
CHECK_EQ(host_return, result)
<< "Host comparison succeeded even though GPU comparison failed.";
return false;
}

absl::StatusOr<bool> BufferComparator::CompareEqual(
se::Stream* stream, se::DeviceMemoryBase current,
se::DeviceMemoryBase expected) {
stream_ = stream;
current_ = current;
expected_ = expected;
se::DeviceMemoryBase expected) const {

ComparisonParams params{
relative_tol_, verbose_, &shape_, stream, current, expected};

switch (shape_.element_type()) {
#if GOOGLE_CUDA
case xla::F8E4M3FN:
return CompareEqualParameterized<tsl::float8_e4m3fn, float>(
"fp8_e4m3fn_comparison", buffer_comparator::fp8_e4m3fn_comparison());
"fp8_e4m3fn_comparison", buffer_comparator::fp8_e4m3fn_comparison(),
params);
case xla::F8E5M2:
return CompareEqualParameterized<tsl::float8_e5m2, float>(
"fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison());
"fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison(),
params);
#endif // GOOGLE_CUDA
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
case xla::F8E4M3FNUZ:
return CompareEqualParameterized<tsl::float8_e4m3fnuz, float>(
"fp8_e4m3fnuz_comparison",
buffer_comparator::fp8_e4m3fnuz_comparison());
buffer_comparator::fp8_e4m3fnuz_comparison(), params);
case xla::F8E5M2FNUZ:
return CompareEqualParameterized<tsl::float8_e5m2fnuz, float>(
"fp8_e5m2fnuz_comparison",
buffer_comparator::fp8_e5m2fnuz_comparison());
buffer_comparator::fp8_e5m2fnuz_comparison(), params);
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
case xla::F16:
return CompareEqualParameterized<Eigen::half, float>(
"fp16_comparison", buffer_comparator::fp16_comparison());
"fp16_comparison", buffer_comparator::fp16_comparison(), params);
case xla::BF16:
return CompareEqualParameterized<Eigen::bfloat16, float>(
"bf16_comparison", buffer_comparator::bf16_comparison());
"bf16_comparison", buffer_comparator::bf16_comparison(), params);
case xla::F32:
return CompareEqualParameterized<float, float>(
"fp32_comparison", buffer_comparator::fp32_comparison());
"fp32_comparison", buffer_comparator::fp32_comparison(), params);
case xla::F64:
return CompareEqualParameterized<double, double>(
"fp64_comparison", buffer_comparator::fp64_comparison());
"fp64_comparison", buffer_comparator::fp64_comparison(), params);
case xla::S8:
return CompareEqualParameterized<int8_t, float>(
"int8_comparison", buffer_comparator::int8_comparison());
"int8_comparison", buffer_comparator::int8_comparison(), params);
case xla::S32:
return CompareEqualParameterized<int32_t, float>(
"int32_comparison", buffer_comparator::int32_comparison());
"int32_comparison", buffer_comparator::int32_comparison(), params);
default:
return Unimplemented("Unimplemented element type");
}
}

BufferComparator::BufferComparator(const Shape& shape,
const HloModuleConfig& config, bool verbose)
: shape_(shape),
rtol_tolerance_(config.debug_options().xla_gpu_autotune_gemm_rtol()),
verbose_(verbose) {
BufferComparator::BufferComparator(const Shape& shape, double tolerance,
bool verbose) :
shape_(shape), relative_tol_(tolerance), verbose_(verbose) {
// Normalize complex shapes: since we treat the passed array as a contiguous
// storage it does not matter which dimension are we doubling.
auto double_dim_size = [&]() {
Expand Down
27 changes: 3 additions & 24 deletions xla/service/gpu/buffer_comparator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class BufferComparator {
BufferComparator(const BufferComparator&) = delete;
BufferComparator(BufferComparator&&) = default;

BufferComparator(const Shape& shape, const HloModuleConfig& config,
explicit BufferComparator(const Shape& shape, double tolerance = 0.1,
bool verbose = true);

// Returns true if the two buffers compare equal. The definition of "equal"
Expand All @@ -48,32 +48,11 @@ class BufferComparator {
// See the implementation for the tolerance value.
absl::StatusOr<bool> CompareEqual(se::Stream* stream,
se::DeviceMemoryBase current,
se::DeviceMemoryBase expected);

private:
// Returns `true` if two buffers are equal, `false` otherwise.
template <typename ElementT>
absl::StatusOr<bool> DeviceCompare(std::string_view kernel_name,
void* kernel_symbol);

// Host side comparison code that does the same thing, but reports some of the
// differences as well. It only print logs for debugging.
//
// Returns true if no differences were seen, false otherwise.
template <typename ElementType, typename ComparisonType>
absl::StatusOr<bool> HostCompare();

template <typename ElementT, typename ComparisonT>
absl::StatusOr<bool> CompareEqualParameterized(std::string_view kernel_name,
void* kernel_symbol);

se::DeviceMemoryBase expected) const;
private:
Shape shape_;
float rtol_tolerance_; // relative tolerance for comparison
double relative_tol_; // relative tolerance for comparison
bool verbose_; // whether to print out error message on mismatch
se::Stream* stream_ = nullptr;
se::DeviceMemoryBase current_;
se::DeviceMemoryBase expected_;
};

namespace buffer_comparator {
Expand Down
29 changes: 22 additions & 7 deletions xla/service/gpu/buffer_comparator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ namespace xla {
namespace gpu {
namespace {

constexpr double kDefaultTolerance = 0.1;


class BufferComparatorTest : public testing::Test {
protected:
BufferComparatorTest()
Expand All @@ -53,7 +56,8 @@ class BufferComparatorTest : public testing::Test {
// Take floats only for convenience. Still uses ElementType internally.
template <typename ElementType>
bool CompareEqualBuffers(const std::vector<ElementType>& current,
const std::vector<ElementType>& expected) {
const std::vector<ElementType>& expected,
double tolerance) {
auto stream = stream_exec_->CreateStream().value();

se::DeviceMemoryHandle current_buffer(
Expand All @@ -72,7 +76,7 @@ class BufferComparatorTest : public testing::Test {
ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<ElementType>(),
{static_cast<int64_t>(current.size())}),
HloModuleConfig());
tolerance);
return comparator
.CompareEqual(stream.get(), current_buffer.memory(),
expected_buffer.memory())
Expand All @@ -82,16 +86,18 @@ class BufferComparatorTest : public testing::Test {
// Take floats only for convenience. Still uses ElementType internally.
template <typename ElementType>
bool CompareEqualFloatBuffers(const std::vector<float>& lhs_float,
const std::vector<float>& rhs_float) {
const std::vector<float>& rhs_float,
double tolerance = kDefaultTolerance) {
std::vector<ElementType> lhs(lhs_float.begin(), lhs_float.end());
std::vector<ElementType> rhs(rhs_float.begin(), rhs_float.end());
return CompareEqualBuffers(lhs, rhs);
return CompareEqualBuffers(lhs, rhs, tolerance);
}

template <typename ElementType>
bool CompareEqualComplex(const std::vector<std::complex<ElementType>>& lhs,
const std::vector<std::complex<ElementType>>& rhs) {
return CompareEqualBuffers<std::complex<ElementType>>(lhs, rhs);
return CompareEqualBuffers<std::complex<ElementType>>(lhs, rhs,
kDefaultTolerance);
}

se::Platform* platform_;
Expand Down Expand Up @@ -238,6 +244,16 @@ TEST_F(BufferComparatorTest, TestNumbers) {
EXPECT_TRUE(CompareEqualFloatBuffers<tsl::float8_e5m2>({11}, {12}));
EXPECT_TRUE(CompareEqualFloatBuffers<tsl::float8_e5m2>({12}, {11}));
#endif // GOOGLE_CUDA

// Rerunning tests with increased relative tolerance
const double tol = 0.001;
EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({0.9}, {1}, tol));
EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({0.9}, {0.901}, tol));
EXPECT_FALSE(CompareEqualFloatBuffers<float>({10}, {10.1}, tol));
EXPECT_TRUE(CompareEqualFloatBuffers<float>({10}, {10.01}, tol));
EXPECT_FALSE(CompareEqualFloatBuffers<int8_t>({100}, {101}, tol));
EXPECT_FALSE(CompareEqualFloatBuffers<double>({20}, {20.1}, tol));
EXPECT_TRUE(CompareEqualFloatBuffers<double>({20}, {20.01}, tol));
}

TEST_F(BufferComparatorTest, TestMultiple) {
Expand Down Expand Up @@ -361,8 +377,7 @@ TEST_F(BufferComparatorTest, BF16) {
stream_exec_->AllocateArray<Eigen::bfloat16>(element_count));
InitializeBuffer(stream.get(), BF16, &rng_state, rhs.memory());

BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count}),
HloModuleConfig());
BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count}));
EXPECT_FALSE(comparator.CompareEqual(stream.get(), lhs.memory(), rhs.memory())
.value());
}
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/conv_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,11 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(

if (reference_result->has_value()) {
XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);

const DebugOptions& debug_options =
runtime_arguments.hlo_module_config.debug_options();
BufferComparator comparator(runtime_arguments.rz_buffers.output_shape(),
runtime_arguments.hlo_module_config);
debug_options.xla_gpu_autotune_gemm_rtol());
for (int i = 0; i < result_buffers.size(); ++i) {
absl::StatusOr<bool> compare_result = comparator.CompareEqual(
stream, (*reference_result)->buffers[i], result_buffers[i]);
Expand All @@ -696,8 +699,6 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
// Possibly OOM. Propagate the error.
return compare_result.status();
}
const DebugOptions& debug_options =
runtime_arguments.hlo_module_config.debug_options();
CHECK(!debug_options.xla_gpu_crash_on_verification_failures());
} else if (!compare_result.value()) {
LOG(ERROR)
Expand Down
Loading
Loading