From 7ba611d343c5806c21c22cb93efc9977b5dfa786 Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Wed, 5 Jun 2024 12:55:06 +0000 Subject: [PATCH 1/3] added precision settings for autotuner and buffer_comparator small refactoring, added verbose flag --- xla/service/gpu/BUILD | 1 + xla/service/gpu/buffer_comparator.cc | 117 ++++++------ xla/service/gpu/buffer_comparator.h | 27 +-- xla/service/gpu/buffer_comparator_test.cc | 15 +- xla/service/gpu/conv_algorithm_picker.cc | 7 +- xla/service/gpu/gemm_algorithm_picker.cc | 88 ++++++---- xla/service/gpu/gemm_algorithm_picker.h | 7 + xla/service/gpu/gemm_algorithm_picker_test.cc | 166 ++++++++++++------ xla/service/gpu/gemm_fusion_autotuner.cc | 2 +- .../gpu/triton_fusion_numerics_verifier.cc | 3 +- 10 files changed, 254 insertions(+), 179 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 49969912ed7c0..b347ba3557124 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1694,6 +1694,7 @@ xla_test( ":backend_configs_cc", ":gemm_algorithm_picker", ":gemm_rewriter", + ":variant_visitor", "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", diff --git a/xla/service/gpu/buffer_comparator.cc b/xla/service/gpu/buffer_comparator.cc index 488c2abd242b3..290e7a7df5ab7 100644 --- a/xla/service/gpu/buffer_comparator.cc +++ b/xla/service/gpu/buffer_comparator.cc @@ -45,26 +45,35 @@ using ComparisonKernelT = se::TypedKernel, se::DeviceMemory, float, uint64_t, se::DeviceMemory>; +struct ComparisonParams { + float relative_tol = 0.1f; + bool verbose = true; + const Shape *shape = nullptr; + se::Stream* stream = nullptr; + se::DeviceMemoryBase current{}; + se::DeviceMemoryBase expected{}; +}; + // Compares two buffers on the GPU. // // Returns `true` if two buffers are equal, `false` otherwise. template -absl::StatusOr BufferComparator::DeviceCompare( - std::string_view kernel_name, void* kernel_symbol) { - se::StreamExecutor* executor = stream_->parent(); +static absl::StatusOr DeviceCompare( + std::string_view kernel_name, void* kernel_symbol, + const ComparisonParams& params) { + se::StreamExecutor* executor = params.stream->parent(); - se::DeviceMemoryHandle out_param(executor, + se::DeviceMemoryHandle out(executor, executor->AllocateScalar()); - TF_RETURN_IF_ERROR( - stream_->MemZero(out_param.memory_ptr(), sizeof(uint64_t))); - if (current_.size() != expected_.size()) { + TF_RETURN_IF_ERROR(params.stream->MemZero(out.memory_ptr(), sizeof(uint64_t))); + if (params.current.size() != params.expected.size()) { return Internal("Mismatched buffer size: %d bytes vs. %d bytes", - current_.size(), expected_.size()); + params.current.size(), params.expected.size()); } - se::DeviceMemory current_typed(current_); - se::DeviceMemory expected_typed(expected_); + se::DeviceMemory current_typed(params.current); + se::DeviceMemory expected_typed(params.expected); uint64_t buffer_size = current_typed.ElementCount(); TF_ASSIGN_OR_RETURN( @@ -78,34 +87,32 @@ absl::StatusOr BufferComparator::DeviceCompare( const se::DeviceDescription& gpu_device_info = executor->GetDeviceDescription(); - LaunchDimensions dim = CalculateLaunchDimensions(shape_, gpu_device_info); + LaunchDimensions dim = + CalculateLaunchDimensions(*params.shape, gpu_device_info); - se::DeviceMemory as_uint64(out_param.memory()); - TF_RETURN_IF_ERROR(stream_->ThenLaunch( + se::DeviceMemory as_uint64(out.memory()); + TF_RETURN_IF_ERROR(params.stream->ThenLaunch( dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel, - current_typed, expected_typed, rtol_tolerance_, buffer_size, as_uint64)); + current_typed, expected_typed, params.relative_tol, buffer_size, as_uint64)); uint64_t result = -1; - CHECK_EQ(out_param.memory().size(), sizeof(result)); + CHECK_EQ(out.memory().size(), sizeof(result)); TF_RETURN_IF_ERROR( - stream_->Memcpy(&result, out_param.memory(), sizeof(result))); - TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); - return result == 0; -} - -// Host side comparison code that does the same thing, but reports some of the + params.stream->Memcpy(&result, out.memory(), sizeof(result))); + TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); // differences as well. It only print logs for debugging. // -// Returns true if no differences were seen, false otherwise. template -absl::StatusOr BufferComparator::HostCompare() { - int64_t n = current_.size() / sizeof(ElementType); +static absl::StatusOr HostCompare(const ComparisonParams& params) { + int64_t n = params.current.size() / sizeof(ElementType); std::vector host_current(n), host_expected(n); TF_RETURN_IF_ERROR( - stream_->Memcpy(host_current.data(), current_, current_.size())); + params.stream->Memcpy(host_current.data(), params.current, + params.current.size())); TF_RETURN_IF_ERROR( - stream_->Memcpy(host_expected.data(), expected_, expected_.size())); - TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + params.stream->Memcpy(host_expected.data(), params.expected, + params.expected.size())); + TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); const auto canonicalize = [](ComparisonType a) -> ComparisonType { if (std::is_same::value && a) { @@ -139,28 +146,30 @@ absl::StatusOr BufferComparator::HostCompare() { (std::max(std::abs(current_value_canonical), std::abs(expected_value_canonical)) + 1) < - (double)rtol_tolerance_)) { - if (!verbose_) return false; // Return immediately if not verbose. + static_cast(params.relative_tol))) { + if(!params.verbose) return false; // Return immediately if not verbose. ++differences_seen; LOG(ERROR) << "Difference at " << i << ": " << current_value - << ", expected " << expected_value; + << ", expected " << expected_value; } } return differences_seen == 0; } template -absl::StatusOr BufferComparator::CompareEqualParameterized( - std::string_view kernel_name, void* kernel_symbol) { +static absl::StatusOr CompareEqualParameterized( + std::string_view kernel_name, void* kernel_symbol, + const ComparisonParams& params) { XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual"); - TF_ASSIGN_OR_RETURN(bool result, - DeviceCompare(kernel_name, kernel_symbol)); + TF_ASSIGN_OR_RETURN( + bool result, DeviceCompare(kernel_name, kernel_symbol, params)); if (result) { return true; } - TF_ASSIGN_OR_RETURN(bool host_return, (HostCompare())); + TF_ASSIGN_OR_RETURN(bool host_return, + (HostCompare(params))); CHECK_EQ(host_return, result) << "Host comparison succeeded even though GPU comparison failed."; return false; @@ -168,58 +177,58 @@ absl::StatusOr BufferComparator::CompareEqualParameterized( absl::StatusOr BufferComparator::CompareEqual( se::Stream* stream, se::DeviceMemoryBase current, - se::DeviceMemoryBase expected) { - stream_ = stream; - current_ = current; - expected_ = expected; + se::DeviceMemoryBase expected) const { + + ComparisonParams params{ + relative_tol_, verbose_, &shape_, stream, current, expected}; switch (shape_.element_type()) { #if GOOGLE_CUDA case xla::F8E4M3FN: return CompareEqualParameterized( - "fp8_e4m3fn_comparison", buffer_comparator::fp8_e4m3fn_comparison()); + "fp8_e4m3fn_comparison", buffer_comparator::fp8_e4m3fn_comparison(), + params); case xla::F8E5M2: return CompareEqualParameterized( - "fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison()); + "fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison(), + params); #endif // GOOGLE_CUDA #if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 case xla::F8E4M3FNUZ: return CompareEqualParameterized( "fp8_e4m3fnuz_comparison", - buffer_comparator::fp8_e4m3fnuz_comparison()); + buffer_comparator::fp8_e4m3fnuz_comparison(), params); case xla::F8E5M2FNUZ: return CompareEqualParameterized( "fp8_e5m2fnuz_comparison", - buffer_comparator::fp8_e5m2fnuz_comparison()); + buffer_comparator::fp8_e5m2fnuz_comparison(), params); #endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 case xla::F16: return CompareEqualParameterized( - "fp16_comparison", buffer_comparator::fp16_comparison()); + "fp16_comparison", buffer_comparator::fp16_comparison(), params); case xla::BF16: return CompareEqualParameterized( - "bf16_comparison", buffer_comparator::bf16_comparison()); + "bf16_comparison", buffer_comparator::bf16_comparison(), params); case xla::F32: return CompareEqualParameterized( - "fp32_comparison", buffer_comparator::fp32_comparison()); + "fp32_comparison", buffer_comparator::fp32_comparison(), params); case xla::F64: return CompareEqualParameterized( - "fp64_comparison", buffer_comparator::fp64_comparison()); + "fp64_comparison", buffer_comparator::fp64_comparison(), params); case xla::S8: return CompareEqualParameterized( - "int8_comparison", buffer_comparator::int8_comparison()); + "int8_comparison", buffer_comparator::int8_comparison(), params); case xla::S32: return CompareEqualParameterized( - "int32_comparison", buffer_comparator::int32_comparison()); + "int32_comparison", buffer_comparator::int32_comparison(), params); default: return Unimplemented("Unimplemented element type"); } } -BufferComparator::BufferComparator(const Shape& shape, - const HloModuleConfig& config, bool verbose) - : shape_(shape), - rtol_tolerance_(config.debug_options().xla_gpu_autotune_gemm_rtol()), - verbose_(verbose) { +BufferComparator::BufferComparator(const Shape& shape, double tolerance, + bool verbose) : + shape_(shape), relative_tol_(tolerance), verbose_(verbose) { // Normalize complex shapes: since we treat the passed array as a contiguous // storage it does not matter which dimension are we doubling. auto double_dim_size = [&]() { diff --git a/xla/service/gpu/buffer_comparator.h b/xla/service/gpu/buffer_comparator.h index e702467a28e30..b56fc21073e56 100644 --- a/xla/service/gpu/buffer_comparator.h +++ b/xla/service/gpu/buffer_comparator.h @@ -34,7 +34,7 @@ class BufferComparator { BufferComparator(const BufferComparator&) = delete; BufferComparator(BufferComparator&&) = default; - BufferComparator(const Shape& shape, const HloModuleConfig& config, + explicit BufferComparator(const Shape& shape, double tolerance = 0.1, bool verbose = true); // Returns true if the two buffers compare equal. The definition of "equal" @@ -48,32 +48,11 @@ class BufferComparator { // See the implementation for the tolerance value. absl::StatusOr CompareEqual(se::Stream* stream, se::DeviceMemoryBase current, - se::DeviceMemoryBase expected); - - private: - // Returns `true` if two buffers are equal, `false` otherwise. - template - absl::StatusOr DeviceCompare(std::string_view kernel_name, - void* kernel_symbol); - - // Host side comparison code that does the same thing, but reports some of the - // differences as well. It only print logs for debugging. - // - // Returns true if no differences were seen, false otherwise. - template - absl::StatusOr HostCompare(); - - template - absl::StatusOr CompareEqualParameterized(std::string_view kernel_name, - void* kernel_symbol); - + se::DeviceMemoryBase expected) const; private: Shape shape_; - float rtol_tolerance_; // relative tolerance for comparison + float relative_tol_; // relative tolerance for comparison bool verbose_; // whether to print out error message on mismatch - se::Stream* stream_ = nullptr; - se::DeviceMemoryBase current_; - se::DeviceMemoryBase expected_; }; namespace buffer_comparator { diff --git a/xla/service/gpu/buffer_comparator_test.cc b/xla/service/gpu/buffer_comparator_test.cc index 9faa3156a8496..20aee05740c05 100644 --- a/xla/service/gpu/buffer_comparator_test.cc +++ b/xla/service/gpu/buffer_comparator_test.cc @@ -72,7 +72,7 @@ class BufferComparatorTest : public testing::Test { ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {static_cast(current.size())}), - HloModuleConfig()); + tolerance); return comparator .CompareEqual(stream.get(), current_buffer.memory(), expected_buffer.memory()) @@ -238,6 +238,16 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({11}, {12})); EXPECT_TRUE(CompareEqualFloatBuffers({12}, {11})); #endif // GOOGLE_CUDA + + // Rerunning tests with increased relative tolerance + const double tol = 0.001; + EXPECT_FALSE(CompareEqualFloatBuffers({0.9}, {1}, tol)); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {0.901}, tol)); + EXPECT_FALSE(CompareEqualFloatBuffers({10}, {10.1}, tol)); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {10.01}, tol)); + EXPECT_FALSE(CompareEqualFloatBuffers({100}, {101}, tol)); + EXPECT_FALSE(CompareEqualFloatBuffers({20}, {20.1}, tol)); + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.01}, tol)); } TEST_F(BufferComparatorTest, TestMultiple) { @@ -361,8 +371,7 @@ TEST_F(BufferComparatorTest, BF16) { stream_exec_->AllocateArray(element_count)); InitializeBuffer(stream.get(), BF16, &rng_state, rhs.memory()); - BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count}), - HloModuleConfig()); + BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count})); EXPECT_FALSE(comparator.CompareEqual(stream.get(), lhs.memory(), rhs.memory()) .value()); } diff --git a/xla/service/gpu/conv_algorithm_picker.cc b/xla/service/gpu/conv_algorithm_picker.cc index ea4a961e753b0..ea0e830e0a374 100644 --- a/xla/service/gpu/conv_algorithm_picker.cc +++ b/xla/service/gpu/conv_algorithm_picker.cc @@ -681,8 +681,11 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( if (reference_result->has_value()) { XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2); + + const DebugOptions& debug_options = + runtime_arguments.hlo_module_config.debug_options(); BufferComparator comparator(runtime_arguments.rz_buffers.output_shape(), - runtime_arguments.hlo_module_config); + debug_options.xla_gpu_autotune_gemm_rtol()); for (int i = 0; i < result_buffers.size(); ++i) { absl::StatusOr compare_result = comparator.CompareEqual( stream, (*reference_result)->buffers[i], result_buffers[i]); @@ -696,8 +699,6 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( // Possibly OOM. Propagate the error. return compare_result.status(); } - const DebugOptions& debug_options = - runtime_arguments.hlo_module_config.debug_options(); CHECK(!debug_options.xla_gpu_crash_on_verification_failures()); } else if (!compare_result.value()) { LOG(ERROR) diff --git a/xla/service/gpu/gemm_algorithm_picker.cc b/xla/service/gpu/gemm_algorithm_picker.cc index 710acf4240533..4911b91af5569 100644 --- a/xla/service/gpu/gemm_algorithm_picker.cc +++ b/xla/service/gpu/gemm_algorithm_picker.cc @@ -97,13 +97,18 @@ class GemmAutotuner { se::Stream* stream_ = nullptr; bool deterministic_ops_ = false; size_t solutions_limit_ = 0; + size_t num_algorithms_left_ = 0; public: explicit GemmAutotuner(const AutotuneConfig& autotune_config) : autotune_config_(autotune_config) {} + size_t num_algorithms_left() const { return num_algorithms_left_; } + absl::StatusOr operator()(const HloInstruction* gemm, const AutotuneCacheKey& key) { + + num_algorithms_left_ = 0; if (autotune_config_.IsDeviceless()) { // Return empty result, will tune at runtime. return AutotuneResult{}; @@ -201,7 +206,7 @@ class GemmAutotuner { }; return GetBestAlgorithm( - gemm, algorithms, gemm_config.beta, tuned_func); + gemm, algorithms, gemm_config.beta, true, tuned_func); } absl::StatusOr TuneGpuBlas(const HloInstruction* gemm, @@ -221,7 +226,6 @@ class GemmAutotuner { &gemm_config.alpha, &gemm_config.beta, &algorithms); - AutotuneResult best_algorithm; auto tuned_func = [&](const se::blas::AlgorithmType& algorithm) -> absl::StatusOr { // Do a warm-up run first, without a profile result. RunGemm swallows @@ -246,21 +250,15 @@ class GemmAutotuner { return std::move(profile_result); }; - TF_ASSIGN_OR_RETURN(best_algorithm, - GetBestAlgorithm( - gemm, algorithms, gemm_config.beta, tuned_func)); - if (best_algorithm.has_gemm()) { - int alg_idx = best_algorithm.gemm().algorithm(); - best_algorithm.mutable_gemm()->set_algorithm(algorithms[alg_idx]); - } - return best_algorithm; + return GetBestAlgorithm( + gemm, algorithms, gemm_config.beta, false, tuned_func); } // Returns the index (into `algorithms`) of the fastest algorithm. template absl::StatusOr GetBestAlgorithm( const HloInstruction* gemm, absl::Span algorithms, - double beta, TunedFunc&& run_benchmark) { + double beta, bool return_algo_index, TunedFunc&& run_benchmark) { static_assert(std::is_invocable_r_v, TunedFunc, const AlgoT&>, "Tuned function has incorrect prototype!"); @@ -284,7 +282,8 @@ class GemmAutotuner { } // Do not print error messages if should_skip_wrong_results() is ON. - BufferComparator comparator(output_shape, hlo_module_config, + BufferComparator comparator(output_shape, + hlo_module_config.debug_options().xla_gpu_autotune_gemm_rtol(), /* verbose */!autotune_config_.should_skip_wrong_results() ); std::vector results; @@ -319,6 +318,7 @@ class GemmAutotuner { absl::Milliseconds(profile_result.elapsed_time_in_ms())); if (!autotune_config_.should_check_correctness()) { + num_algorithms_left_++; continue; } TF_ASSIGN_OR_RETURN( @@ -334,38 +334,45 @@ class GemmAutotuner { continue; } + num_algorithms_left_++; if (!reference_algorithm) { TF_RETURN_IF_ERROR(stream_->Memcpy(&reference_buffer, OutputBuffer(), OutputBuffer().size())); reference_algorithm = profile_result.algorithm(); - } else { - // Perform the comparison. - TF_ASSIGN_OR_RETURN( - bool outputs_match, - comparator.CompareEqual(stream_, /*current=*/OutputBuffer(), + continue; + } + // Perform the comparison versus the reference algorithm. + TF_ASSIGN_OR_RETURN( + bool outputs_match, + comparator.CompareEqual(stream_, /*current=*/OutputBuffer(), /*expected=*/reference_buffer)); - if (!outputs_match) { - LOG(ERROR) << "Results mismatch between different GEMM algorithms. " - << "This is likely a bug/unexpected loss of precision."; - CHECK(!autotune_config_.should_crash_on_check_failure()); - - // By default, autotuner does NOT really skip wrong results, but - // merely prints out the above error message: this may lead to a - // great confusion. When should_skip_wrong_results() is set to true, - // solutions with accuracy problems will be disqualified. - result.mutable_failure()->set_kind( - autotune_config_.should_skip_wrong_results() ? - AutotuneResult::DISQUALIFIED : - AutotuneResult::WRONG_RESULT); - result.mutable_failure()->mutable_reference_gemm()->set_algorithm( - *reference_algorithm); + if (!outputs_match) { + LOG(ERROR) << "Results mismatch between different GEMM algorithms. " + << "This is likely a bug/unexpected loss of precision."; + CHECK(!autotune_config_.should_crash_on_check_failure()); + + // By default, autotuner does NOT really skip wrong results, but + // merely prints out the above error message: this may lead to a + // great confusion. When should_skip_wrong_results() is set to true, + // solutions with accuracy problems will be disqualified. + auto kind = AutotuneResult::WRONG_RESULT; + if (autotune_config_.should_skip_wrong_results()) { + kind = AutotuneResult::DISQUALIFIED; + num_algorithms_left_--; // Decrement again since we disqualified it. } + result.mutable_failure()->set_kind(kind); + result.mutable_failure()->mutable_reference_gemm()->set_algorithm( + *reference_algorithm); } } // for algorithms absl::StatusOr best = PickBestResult(results, gemm->ToString(), hlo_module_config); if (best.ok()) { + // Return a real algorithm ID if return_algo_index is false: + // e.g., in case of legacy cublas tuning. + if (!return_algo_index) return best; + // Otherwise, map a real algorithm ID to its index among the results. for (size_t i = 0; i < results.size(); ++i) { if (best->gemm().algorithm() == results[i].gemm().algorithm()) { best->mutable_gemm()->set_algorithm(i); @@ -386,13 +393,15 @@ class GemmAutotuner { // Do Gemm Autotune without stream executor. Use results from autotune cache // only. absl::StatusOr RunOnInstruction(HloInstruction* gemm, - const AutotuneConfig& config) { + const AutotuneConfig& config, + size_t *num_algorithms_left) { VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString(); GpuBackendConfig gpu_config = gemm->backend_config().value(); GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config(); + *num_algorithms_left = 0; // Degenerate gemms replaced with memzero operation, no need to auto tune it. if (backend_config.alpha_real() == 0.0 && backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) { @@ -406,6 +415,7 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, AutotunerUtil::Autotune( gemm, config, [&] { return autotuner(gemm, key); })); + *num_algorithms_left = autotuner.num_algorithms_left(); auto old_algorithm = backend_config.selected_algorithm(); bool update_algorithm = IsCublasLtMatmulF8(*gemm) || @@ -447,11 +457,15 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, } absl::StatusOr RunOnComputation(HloComputation* computation, - AutotuneConfig config) { + AutotuneConfig config, size_t *num_algorithms_left) { bool changed = false; + for (HloInstruction* instr : computation->instructions()) { if (IsCublasGemm(*instr)) { - TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, config)); + size_t num_left; + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, config, &num_left)); + // Gathering statistics on the algorithms left after tuning (for testing) + *num_algorithms_left = std::max(*num_algorithms_left, num_left); changed |= result; } } @@ -466,6 +480,7 @@ absl::StatusOr GemmAlgorithmPicker::Run( XLA_SCOPED_LOGGING_TIMER( absl::StrCat("GemmAlgorithmPicker for ", module->name())); + num_algorithms_left_ = 0; if (module->config().debug_options().xla_gpu_autotune_level() == 0) { VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early"; return false; @@ -474,7 +489,8 @@ absl::StatusOr GemmAlgorithmPicker::Run( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_)); + TF_ASSIGN_OR_RETURN(bool result, + RunOnComputation(computation, config_, &num_algorithms_left_)); changed |= result; } return changed; diff --git a/xla/service/gpu/gemm_algorithm_picker.h b/xla/service/gpu/gemm_algorithm_picker.h index 6610a95279b6b..2b5366fa70151 100644 --- a/xla/service/gpu/gemm_algorithm_picker.h +++ b/xla/service/gpu/gemm_algorithm_picker.h @@ -53,6 +53,10 @@ class GemmAlgorithmPicker : public HloModulePass { absl::string_view name() const override { return "gemm-algorithm-picker"; } + size_t num_algorithms_left() const { + return num_algorithms_left_; + } + using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, @@ -60,6 +64,9 @@ class GemmAlgorithmPicker : public HloModulePass { private: AutotuneConfig config_; + // The number of valid algorithms used for autotuning (from the last call), + // to be used for testing purposes. + size_t num_algorithms_left_ = 0; }; } // namespace gpu diff --git a/xla/service/gpu/gemm_algorithm_picker_test.cc b/xla/service/gpu/gemm_algorithm_picker_test.cc index 4da7b2d179244..da47d87ea8fbe 100644 --- a/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gemm_rewriter.h" +#include "xla/service/gpu/variant_visitor.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/platform_util.h" @@ -53,32 +54,100 @@ class GemmAlgorithmPickerTest : public HloTestBase, return debug_options; } - void SetUp() override { - const auto& gpu_cc = backend() - .default_stream_executor() - ->GetDeviceDescription() - .gpu_compute_capability(); - - if (auto* procm = std::get_if(&gpu_cc)) { - if (GetDebugOptionsForTest().xla_gpu_enable_cublaslt() && - !procm->has_hipblaslt()) { - GTEST_SKIP() << "No gpublas-lt support on this architecture!"; - } - } + se::StreamExecutor *stream_exec() { + return backend().default_stream_executor(); + } + const se::DeviceDescription& gpu_device_desc() { + return stream_exec()->GetDeviceDescription(); + } + const se::GpuComputeCapability& gpu_comp() { + return gpu_device_desc().gpu_compute_capability(); } -}; -TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) { - auto comp = backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - if (comp.IsAtLeast(se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Skipping this test for Ampere+ as it is supported and " + void SetUp() override { + std::visit(VariantVisitor{ + [](const se::CudaComputeCapability& cc) { + if(cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "Skipping this test for Ampere+ as it is supported and " "recommended with " "the Nvidia Volta+ GPUs."; + } + }, + [this](const se::RocmComputeCapability& cc) { + if(GetDebugOptionsForTest().xla_gpu_enable_cublaslt() && + !cc.has_hipblaslt()) { + GTEST_SKIP() << "No gpublas-lt support on this architecture!"; + } + }}, + gpu_comp()); + } +}; + +TEST_P(GemmAlgorithmPickerTest, SkipAlgorithmsWithAccuracyCheck) { + constexpr absl::string_view kHlo = R"( +HloModule module + +ENTRY main { + %arg0 = f32[100,100]{1,0} parameter(0) + %arg1 = f32[100,100]{1,0} parameter(1) + ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + auto module_cfg = GetModuleConfigForTest(); + auto debug_opts = module_cfg.debug_options(); + size_t num_left1 = 0, num_left2 = 0; + +TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHlo, module_cfg)); + + { + // Run first with default settings (autotune level = 4), keep the number of + // algorithms left after autotuning + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + RunHloPass( + GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), + module.get())); + + AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts}; + GemmAlgorithmPicker gpicker(cfg); + // Note that, we do not care if the algorithm index has been changed: + // the thing matters is the # of algorithms left after sorting out. + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get())); + num_left1 = gpicker.num_algorithms_left(); + if(num_left1 < 2) { + GTEST_SKIP() << "Too few algorithms left after the first step"; + } } + // Clear cache before the second run! + AutotunerUtil::ClearAutotuneResults(); + { + // Run once again but now with autotune level 5 and embarassingly tight + // rtol which shall disqualify most of the algorithms. + + // Note that, we have "two sources of truth" for GemmAlgorithmPicker: i.e., + // debug_options are used to initialize both 'HloModuleConfig' and also + // 'AutotuneConfig'. + debug_opts.set_xla_gpu_autotune_gemm_rtol(1e-12); + debug_opts.set_xla_gpu_autotune_level(5); + module->mutable_config().set_debug_options(debug_opts); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + RunHloPass( + GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), + module.get())); + + AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts}; + GemmAlgorithmPicker gpicker(cfg); + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get())); + num_left2 = gpicker.num_algorithms_left(); + } + // Assert that we have fewer algorithms left after the second run. + ASSERT_TRUE(num_left1 > num_left2); +} + +TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) { constexpr absl::string_view kHlo = R"( HloModule module @@ -92,19 +161,15 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo, module_cfg)); - se::Platform* platform = PlatformUtil::GetDefaultPlatform().value(); - TF_ASSERT_OK_AND_ASSIGN(std::vector executors, - PlatformUtil::GetStreamExecutors(platform)); - ASSERT_GT(executors.size(), 0); - se::StreamExecutor* stream_exec = executors[0]; bool changed = false; TF_ASSERT_OK_AND_ASSIGN( - changed, RunHloPass(GemmRewriter(stream_exec->GetDeviceDescription() - .gpu_compute_capability()), - m.get())); + changed, + RunHloPass( + GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), + m.get())); changed = false; DebugOptions opts; - AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts}; + AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts}; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GemmAlgorithmPicker(cfg), m.get())); ASSERT_TRUE(changed); @@ -125,9 +190,10 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg)); changed = false; TF_ASSERT_OK_AND_ASSIGN( - changed, RunHloPass(GemmRewriter(stream_exec->GetDeviceDescription() - .gpu_compute_capability()), - m.get())); + changed, + RunHloPass( + GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), + m.get())); changed = false; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GemmAlgorithmPicker(cfg), m.get())); @@ -145,16 +211,6 @@ ENTRY main { } TEST_P(GemmAlgorithmPickerTest, GetAlgorithmWithoutDevice) { - auto comp = backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - if (comp.IsAtLeast(se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Skipping this test for Ampere+ as it is supported and " - "recommended with " - "the Nvidia Volta+ GPUs."; - } - constexpr absl::string_view kHlo = R"( HloModule module @@ -166,21 +222,16 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN( auto m, ParseAndReturnVerifiedModule(kHlo, GetModuleConfigForTest())); - se::Platform* platform = PlatformUtil::GetDefaultPlatform().value(); - TF_ASSERT_OK_AND_ASSIGN(std::vector executors, - PlatformUtil::GetStreamExecutors(platform)); - ASSERT_GT(executors.size(), 0); - se::StreamExecutor* stream_exec = executors[0]; - bool changed = false; TF_ASSERT_OK_AND_ASSIGN( - changed, RunHloPass(GemmRewriter(stream_exec->GetDeviceDescription() - .gpu_compute_capability()), - m.get())); + changed, + RunHloPass( + GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), + m.get())); changed = false; DebugOptions opts; - AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts}; + AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts}; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GemmAlgorithmPicker(cfg), m.get())); @@ -204,13 +255,14 @@ ENTRY main { changed = false; DevicelessConfig deviceless_config{ - stream_exec->GetDeviceDescription().model_str(), - stream_exec->GetDeviceDescription().cuda_compute_capability()}; + gpu_device_desc().model_str(), gpu_comp()}; AutotuneConfig deviceless_cfg{deviceless_config, opts}; TF_ASSERT_OK_AND_ASSIGN( - changed, RunHloPass(GemmRewriter(stream_exec->GetDeviceDescription() - .gpu_compute_capability()), - m.get())); + changed, + RunHloPass( + GemmRewriter(gpu_comp(), + /*toolkit_version=*/12040), + m.get())); changed = false; TF_ASSERT_OK_AND_ASSIGN( changed, RunHloPass(GemmAlgorithmPicker(deviceless_cfg), m.get())) diff --git a/xla/service/gpu/gemm_fusion_autotuner.cc b/xla/service/gpu/gemm_fusion_autotuner.cc index dc7936abdb416..d2f48798ad1cd 100644 --- a/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/gemm_fusion_autotuner.cc @@ -823,7 +823,7 @@ absl::StatusOr> GemmFusionAutotunerImpl::Profile( const HloInstruction& root = *fusion_computation->root_instruction(); BufferComparator comparator(root.shape(), - fusion_computation->parent()->config()); + debug_options_.xla_gpu_autotune_gemm_rtol()); TF_ASSIGN_OR_RETURN(auto rz_buffers, RedzoneBuffers::FromInstruction( diff --git a/xla/service/gpu/triton_fusion_numerics_verifier.cc b/xla/service/gpu/triton_fusion_numerics_verifier.cc index c3e92082e3c04..8d539fd115ece 100644 --- a/xla/service/gpu/triton_fusion_numerics_verifier.cc +++ b/xla/service/gpu/triton_fusion_numerics_verifier.cc @@ -116,7 +116,8 @@ absl::Status CompareBuffers(const ScopedShapedBuffer& current, const ScopedShapedBuffer& expected, const Shape& shape, const HloModuleConfig& config, se::Stream* stream) { - BufferComparator comparator(shape, config); + BufferComparator comparator(shape, + config.debug_options().xla_gpu_autotune_gemm_rtol()); TF_ASSIGN_OR_RETURN(bool outputs_match, comparator.CompareEqual(stream, current.root_buffer(), expected.root_buffer())); From f753b6f419cc2a583c405eb943ec500f1d2ad938 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 4 Jul 2024 12:49:10 +0000 Subject: [PATCH 2/3] small changes --- xla/debug_options_flags.cc | 2 +- xla/service/gpu/buffer_comparator.cc | 5 +++++ xla/service/gpu/buffer_comparator.h | 2 +- xla/service/gpu/gemm_algorithm_picker.cc | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index d860a82ced335..5598c808576d5 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -49,7 +49,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_llvm_enable_invariant_load_metadata(true); opts.set_xla_llvm_disable_expensive_passes(false); opts.set_xla_backend_optimization_level(3); - opts.set_xla_gpu_autotune_level(4); + opts.set_xla_gpu_autotune_level(5); opts.set_xla_gpu_autotune_max_solutions(0); opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); diff --git a/xla/service/gpu/buffer_comparator.cc b/xla/service/gpu/buffer_comparator.cc index 290e7a7df5ab7..396826f2f9106 100644 --- a/xla/service/gpu/buffer_comparator.cc +++ b/xla/service/gpu/buffer_comparator.cc @@ -100,8 +100,13 @@ static absl::StatusOr DeviceCompare( TF_RETURN_IF_ERROR( params.stream->Memcpy(&result, out.memory(), sizeof(result))); TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); + return result == 0; +} + +// Host side comparison code that does the same thing, but reports some of the // differences as well. It only print logs for debugging. // +// Returns true if no differences were seen, false otherwise. template static absl::StatusOr HostCompare(const ComparisonParams& params) { int64_t n = params.current.size() / sizeof(ElementType); diff --git a/xla/service/gpu/buffer_comparator.h b/xla/service/gpu/buffer_comparator.h index b56fc21073e56..107585c2ba901 100644 --- a/xla/service/gpu/buffer_comparator.h +++ b/xla/service/gpu/buffer_comparator.h @@ -51,7 +51,7 @@ class BufferComparator { se::DeviceMemoryBase expected) const; private: Shape shape_; - float relative_tol_; // relative tolerance for comparison + double relative_tol_; // relative tolerance for comparison bool verbose_; // whether to print out error message on mismatch }; diff --git a/xla/service/gpu/gemm_algorithm_picker.cc b/xla/service/gpu/gemm_algorithm_picker.cc index 4911b91af5569..c01a5954a4dca 100644 --- a/xla/service/gpu/gemm_algorithm_picker.cc +++ b/xla/service/gpu/gemm_algorithm_picker.cc @@ -347,7 +347,7 @@ class GemmAutotuner { comparator.CompareEqual(stream_, /*current=*/OutputBuffer(), /*expected=*/reference_buffer)); if (!outputs_match) { - LOG(ERROR) << "Results mismatch between different GEMM algorithms. " + LOG(WARNING) << "Results mismatch between different GEMM algorithms. " << "This is likely a bug/unexpected loss of precision."; CHECK(!autotune_config_.should_crash_on_check_failure()); From 51dc9a5b0cfcba3f23fac0d77dac9e5c70f6bc06 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 10 Jul 2024 09:48:46 +0000 Subject: [PATCH 3/3] adopted changes --- xla/service/gpu/buffer_comparator.cc | 10 +++++----- xla/service/gpu/buffer_comparator_test.cc | 14 ++++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/xla/service/gpu/buffer_comparator.cc b/xla/service/gpu/buffer_comparator.cc index 396826f2f9106..272cfa6831187 100644 --- a/xla/service/gpu/buffer_comparator.cc +++ b/xla/service/gpu/buffer_comparator.cc @@ -46,7 +46,7 @@ using ComparisonKernelT = float, uint64_t, se::DeviceMemory>; struct ComparisonParams { - float relative_tol = 0.1f; + double relative_tol = 0.1f; bool verbose = true; const Shape *shape = nullptr; se::Stream* stream = nullptr; @@ -93,7 +93,8 @@ static absl::StatusOr DeviceCompare( se::DeviceMemory as_uint64(out.memory()); TF_RETURN_IF_ERROR(params.stream->ThenLaunch( dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel, - current_typed, expected_typed, params.relative_tol, buffer_size, as_uint64)); + current_typed, expected_typed, static_cast(params.relative_tol), + buffer_size, as_uint64)); uint64_t result = -1; CHECK_EQ(out.memory().size(), sizeof(result)); @@ -150,9 +151,8 @@ static absl::StatusOr HostCompare(const ComparisonParams& params) { !(std::abs(current_value_canonical - expected_value_canonical) / (std::max(std::abs(current_value_canonical), std::abs(expected_value_canonical)) + - 1) < - static_cast(params.relative_tol))) { - if(!params.verbose) return false; // Return immediately if not verbose. + 1) < params.relative_tol)) { + if (!params.verbose) return false; // Return immediately if not verbose. ++differences_seen; LOG(ERROR) << "Difference at " << i << ": " << current_value << ", expected " << expected_value; diff --git a/xla/service/gpu/buffer_comparator_test.cc b/xla/service/gpu/buffer_comparator_test.cc index 20aee05740c05..9481473efbe26 100644 --- a/xla/service/gpu/buffer_comparator_test.cc +++ b/xla/service/gpu/buffer_comparator_test.cc @@ -39,6 +39,9 @@ namespace xla { namespace gpu { namespace { +constexpr double kDefaultTolerance = 0.1; + + class BufferComparatorTest : public testing::Test { protected: BufferComparatorTest() @@ -53,7 +56,8 @@ class BufferComparatorTest : public testing::Test { // Take floats only for convenience. Still uses ElementType internally. template bool CompareEqualBuffers(const std::vector& current, - const std::vector& expected) { + const std::vector& expected, + double tolerance) { auto stream = stream_exec_->CreateStream().value(); se::DeviceMemoryHandle current_buffer( @@ -82,16 +86,18 @@ class BufferComparatorTest : public testing::Test { // Take floats only for convenience. Still uses ElementType internally. template bool CompareEqualFloatBuffers(const std::vector& lhs_float, - const std::vector& rhs_float) { + const std::vector& rhs_float, + double tolerance = kDefaultTolerance) { std::vector lhs(lhs_float.begin(), lhs_float.end()); std::vector rhs(rhs_float.begin(), rhs_float.end()); - return CompareEqualBuffers(lhs, rhs); + return CompareEqualBuffers(lhs, rhs, tolerance); } template bool CompareEqualComplex(const std::vector>& lhs, const std::vector>& rhs) { - return CompareEqualBuffers>(lhs, rhs); + return CompareEqualBuffers>(lhs, rhs, + kDefaultTolerance); } se::Platform* platform_;