Skip to content

Commit

Permalink
Merge pull request #21 from ROCm/rocm-jaxlib-v0.4.28-qa_fusedconv_rev…
Browse files Browse the repository at this point in the history
…ert_mm

Fused conv fix and revert memory management
  • Loading branch information
i-chaochen authored Jun 21, 2024
2 parents e82f7c1 + 087578f commit 6a1f142
Show file tree
Hide file tree
Showing 16 changed files with 1,315 additions and 633 deletions.
14 changes: 8 additions & 6 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3990,6 +3990,7 @@ cc_library(
":conv_algorithm_picker",
":cublas_pad_for_gemms",
":cublas_padding_requirements",
":cudnn_fused_conv_rewriter",
":cusolver_rewriter",
":gemm_algorithm_picker",
":gemm_rewriter",
Expand Down Expand Up @@ -4598,7 +4599,6 @@ cc_library(
name = "cudnn_fused_conv_rewriter",
srcs = ["cudnn_fused_conv_rewriter.cc"],
hdrs = ["cudnn_fused_conv_rewriter.h"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":backend_configs_cc",
":cublas_cudnn",
Expand Down Expand Up @@ -4627,10 +4627,7 @@ cc_library(
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:ml_dtypes",
"@tsl//tsl/platform:statusor",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn_header",
]),
],
)

xla_test(
Expand All @@ -4646,13 +4643,15 @@ xla_test(
backends = [
"gpu_a100",
] + if_oss(["gpu"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) +
if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
shard_count = 10,
deps = [
":backend_configs_cc",
":cublas_cudnn",
":cudnn_fused_conv_rewriter",
":gpu_conv_rewriter",
":stream_executor_util",
"//xla:comparison_util",
"//xla:error_spec",
"//xla/hlo/ir:hlo",
Expand All @@ -4667,6 +4666,7 @@ xla_test(
"//xla/service:reshape_mover",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor:stream_executor_headers",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
Expand All @@ -4681,6 +4681,8 @@ xla_test(
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn_header",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers"
]),
)

Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/service/float_normalization.h"
#include "xla/service/gpu/autotuner_util.h"
#include "xla/service/gpu/conv_algorithm_picker.h"
#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
#include "xla/service/gpu/cublas_pad_for_gemms.h"
#include "xla/service/gpu/cublas_padding_requirements.h"
#include "xla/service/gpu/cusolver_rewriter.h"
Expand Down Expand Up @@ -109,6 +110,9 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
pipeline.AddPass<GpusolverRewriter>();
pipeline.AddPass<GpuConvRewriter>();
pipeline.AddPass<GpuConvPaddingLegalization>();
auto rcc = std::get<se::RocmComputeCapability>(gpu_version);
pipeline.AddPass<CudnnFusedConvRewriter>(rcc, dnn_version,
0);

// The conv padding/vectorization passes which we need to get rid of. They
// also leave behind unnecessary tuple/get-tuple-element pairs that
Expand Down
66 changes: 41 additions & 25 deletions xla/service/gpu/cudnn_fused_conv_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,23 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/comparison_util.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
#include "tsl/platform/ml_dtypes.h"

#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cudnn/cudnn.h"
#endif

#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/primitive_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/ml_dtypes.h"
#include "tsl/platform/statusor.h"

namespace xla {
Expand Down Expand Up @@ -96,6 +91,10 @@ bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) {
return IsConvCustomCall(instr) && !IsConvDepthwise(instr);
}

bool IsROCm(se::GpuComputeCapability cc) {
return std::holds_alternative<se::RocmComputeCapability>(cc);
}

// elu, relu6, and leaky-relu activations are supported in cudnn via the
// "runtime fusion" engine, which JIT compiles C++ code. This can be slow to
// compile, so we guard it with a debug option.
Expand All @@ -106,8 +105,12 @@ bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) {
// Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default
// due to apparent bugs in cudnn 8.9.0. See debug_options_flags.cc for details.
bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts,
se::CudaComputeCapability cc) {
return debug_opts.xla_gpu_use_runtime_fusion() && cc.IsAtLeast(7, 5);
se::GpuComputeCapability cc) {
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&cc);
if(cuda_cc != nullptr)
return debug_opts.xla_gpu_use_runtime_fusion() && cuda_cc->IsAtLeast(7, 5);
else
return true;
}

bool IsSuitableForCudnnRuntimeFusion(HloInstruction* conv) {
Expand Down Expand Up @@ -658,10 +661,17 @@ CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution,
// 5. Optionally calculate the maximum of the absolute of the result.
// 6. Optionally cast the output back to FP8.
absl::StatusOr<bool> F8GraphConv(HloComputation* comp,
se::CudaComputeCapability cc) {
se::CudaComputeCapability cc,
se::dnn::VersionInfo dnn_version,
int32_t toolkit_version) {
bool changed = false;

#if CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900
if (dnn_version < se::dnn::VersionInfo(8, 9, 0)) {
return false;
}
if (toolkit_version < 12000) {
return false;
}
if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) {
return false;
}
Expand Down Expand Up @@ -759,7 +769,6 @@ absl::StatusOr<bool> F8GraphConv(HloComputation* comp,
changed = true;
}
}
#endif // CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900
return changed;
}

Expand Down Expand Up @@ -984,7 +993,7 @@ absl::StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
}

absl::StatusOr<bool> FuseElu(HloComputation* comp,
se::CudaComputeCapability cc) {
se::GpuComputeCapability cc) {
if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
cc)) {
return false;
Expand Down Expand Up @@ -1085,7 +1094,7 @@ absl::StatusOr<bool> FuseRelu(HloComputation* comp) {
}

absl::StatusOr<bool> FuseRelu6(HloComputation* comp,
se::CudaComputeCapability cc) {
se::GpuComputeCapability cc) {
if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
cc)) {
return false;
Expand Down Expand Up @@ -1134,7 +1143,7 @@ absl::StatusOr<bool> FuseRelu6(HloComputation* comp,
}

absl::StatusOr<bool> FuseLeakyRelu(HloComputation* comp,
se::CudaComputeCapability cc) {
se::GpuComputeCapability cc) {
if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(),
cc)) {
return false;
Expand Down Expand Up @@ -1254,7 +1263,10 @@ absl::StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
return changed;
}

absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp) {
absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp,
se::GpuComputeCapability cc) {
if(IsROCm(cc))
return false;
bool changed = false;
for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
HloInstruction* gte = nullptr;
Expand Down Expand Up @@ -1480,9 +1492,13 @@ absl::StatusOr<bool> CudnnFusedConvRewriter::Run(
bool changed = false;
// Rewrite FP8 convolutions and supported adjacent pointwise ops into a
// ForwardGraph Custom Call.
TF_ASSIGN_OR_RETURN(changed, F8GraphConv(comp, compute_capability_));
if (changed) {
return changed;
if(!IsROCm(compute_capability_)) {
auto cc = std::get<se::CudaComputeCapability>(compute_capability_);
TF_ASSIGN_OR_RETURN(
changed, F8GraphConv(comp, cc, dnn_version_, toolkit_version_));
if (changed) {
return changed;
}
}
// Fuse "inside out" starting with the operations closest to the conv.
TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp));
Expand Down Expand Up @@ -1516,7 +1532,7 @@ absl::StatusOr<bool> CudnnFusedConvRewriter::Run(
TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
any_changed |= changed;

TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp));
TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_));
any_changed |= changed;

// f16 convs' bias+side-input can appear before or after conversion to f16.
Expand Down
21 changes: 18 additions & 3 deletions xla/service/gpu/cudnn_fused_conv_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
#define XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_

#include <cstdint>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/dnn.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -98,8 +101,18 @@ namespace gpu {
// pass returns an error -- cudnn will not be able to run it.
class CudnnFusedConvRewriter : public HloModulePass {
public:
explicit CudnnFusedConvRewriter(se::CudaComputeCapability cc)
: compute_capability_(cc) {}
CudnnFusedConvRewriter(se::CudaComputeCapability cc,
se::dnn::VersionInfo dnn_version,
int32_t toolkit_version)
: compute_capability_(cc),
dnn_version_(dnn_version),
toolkit_version_(toolkit_version) {}
CudnnFusedConvRewriter(se::RocmComputeCapability cc,
se::dnn::VersionInfo dnn_version,
int32_t toolkit_version)
: compute_capability_(cc),
dnn_version_(dnn_version),
toolkit_version_(toolkit_version) {}

absl::string_view name() const override {
return "cudnn-fused-convolution-rewriter";
Expand All @@ -111,7 +124,9 @@ class CudnnFusedConvRewriter : public HloModulePass {
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

private:
const se::CudaComputeCapability compute_capability_;
const se::GpuComputeCapability compute_capability_;
const se::dnn::VersionInfo dnn_version_;
const int32_t toolkit_version_;
};

} // namespace gpu
Expand Down
Loading

0 comments on commit 6a1f142

Please sign in to comment.