Skip to content

Commit

Permalink
[XLA:GPU] Introduce xla_gpu_experimental_enable_triton_i4_rewrites, t…
Browse files Browse the repository at this point in the history
…hat enables the corresponding rewrites for the i4 tensors in triton mlir. The default value is false.

The goal of the cl is to move the unpacking logic to the triton level rewrite.
As a result the HLO to Triton emitter do not need to take into the account the unpacking logic, could keep using the shapes that match to the actual tensors. etc.

The cl:
a) adds the flag that enables the triton level rewrites.
b) disables int4 support in the triton_fusion_emitter_legacy_matmul if the flag is true.
c) changes the mapping from S4 hlo type to triton type. Emitter emits s4 instead of s8 if the flag is true.
d) fixes the unpacking logic for the cases where the tensor packed along the minor dim.
e) fixes the unpacking logic for the cases when the packed dim actually has only 1 element.
f) covers the cases when s4 is the rhs parameter of the dot.

PiperOrigin-RevId: 714049078
  • Loading branch information
loislo authored and Google-ML-Automation committed Jan 10, 2025
1 parent 5630f58 commit 092b8dd
Show file tree
Hide file tree
Showing 12 changed files with 822 additions and 145 deletions.
9 changes: 9 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_exhaustive_tiling_search(false);

opts.set_xla_gpu_experimental_enable_triton_heroless_priority_fusion(false);
opts.set_xla_gpu_experimental_enable_triton_i4_rewrites(false);

opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0);
opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1);
Expand Down Expand Up @@ -2097,6 +2098,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
flag_list->push_back(tsl::Flag("xla_gpu_enable_triton_gemm_int4",
noop_flag_setter<bool>, true,
"[Deprecated, do not use]"));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_triton_i4_rewrites",
bool_setter_for(
&DebugOptions::set_xla_gpu_experimental_enable_triton_i4_rewrites),
debug_options->xla_gpu_experimental_enable_triton_i4_rewrites(),
"When enabled, the Triton emitter for dot will use int4 as native type "
"and later the Triton IR will be rewritten by Triton IR rewriting pass "
"to use int4 packed into int8."));
flag_list->push_back(
tsl::Flag("xla_gpu_async_dot",
bool_setter_for(&DebugOptions::set_xla_gpu_async_dot),
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,6 @@ xla_test(
"gpu_b100",
"gpu_amd_any",
],
shard_count = 20,
tags = [
"no_mac",
],
Expand All @@ -620,6 +619,7 @@ xla_test(
"//xla/stream_executor:device_description",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:path",
Expand Down
26 changes: 26 additions & 0 deletions xla/service/gpu/fusions/triton/tests/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
load("//xla:lit.bzl", "lit_test_suite") # @unused

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [":friends"],
licenses = ["notice"],
)

package_group(
name = "friends",
includes = [
"//xla:friends",
],
)

# copybara:uncomment_begin(triton-opt tool doesn't build in OSS)
# lit_test_suite(
# name = "mlir_lit_tests",
# srcs = glob(["*.mlir"]),
# cfg = "//xla:lit.cfg.py",
# tools = [
# "@llvm-project//llvm:FileCheck",
# "//xla/service/gpu/tests:xla-opt",
# ],
# )
# copybara:uncomment_end
41 changes: 41 additions & 0 deletions xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_1d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s

module {
tt.func @major_1d(%arg0: !tt.ptr<i4> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) {
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c64_i32 = arith.constant 64 : i32
%cst = arith.constant dense<0> : tensor<64x64xi8>

%0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>, packed_dim = 1 } : <tensor<64x64xi4>>
// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c64_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xi8>>

%1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x64xi4>>
// CHECK-NEXT: %1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x32xi8>>

%2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>) : i32 {
// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr<tensor<64x32xi8>>, tensor<64x64xi8>) : i32 {

%4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x64xi4>>
// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x32xi8>>

%5 = tt.advance %arg3, [%c0_i32, %c64_i32] : <tensor<64x64xi4>>
// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c32_i32] : <tensor<64x32xi8>>

%6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8>
// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<64x32xi8>
// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<64x32xi8>
// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<64x32xi8>
// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<64x32xi8> -> tensor<64x32x2xi8>
// CHECK-NEXT: %10 = tt.reshape %9 : tensor<64x32x2xi8> -> tensor<64x64xi8>

scf.yield %5, %6 : !tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>
// CHECK-NEXT: scf.yield %5, %10 : !tt.ptr<tensor<64x32xi8>>, tensor<64x64xi8>
}
%3 = tt.make_tensor_ptr %arg1, [%c1_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xi8>>
tt.store %3, %2#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xi8>>
tt.return
}
}
44 changes: 44 additions & 0 deletions xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_2d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s

module {
tt.func @major_2d(%arg0: !tt.ptr<i4> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) {
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c16_i64 = arith.constant 16 : i64
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c64_i32 = arith.constant 64 : i32
%cst = arith.constant dense<0> : tensor<64x64xi8>

%0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c128_i64], [%c1_i64, %c16_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>, packed_dim = 1 } : <tensor<64x64xi4>>
// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c8_i64, %c128_i64], [%c1_i64, %c8_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xi8>>

%1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x64xi4>>
// CHECK-NEXT: %1 = tt.advance %0, [%c32_i32, %c0_i32] : <tensor<32x64xi8>>

%2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>) : i32 {
// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr<tensor<32x64xi8>>, tensor<64x64xi8>) : i32 {

%4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x64xi4>>
// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<32x64xi8>>

%5 = tt.advance %arg3, [%c0_i32, %c64_i32] : <tensor<64x64xi4>>
// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c64_i32] : <tensor<32x64xi8>>

%6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8>
// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<32x64xi8>
// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<32x64xi8>
// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<32x64xi8>
// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<32x64xi8> -> tensor<32x64x2xi8>
// CHECK-NEXT: %10 = tt.trans %9 {order = array<i32: 0, 2, 1>} : tensor<32x64x2xi8> -> tensor<32x2x64xi8>
// CHECK-NEXT: %11 = tt.reshape %10 : tensor<32x2x64xi8> -> tensor<64x64xi8>

scf.yield %5, %6 : !tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>
// CHECK-NEXT: scf.yield %5, %11 : !tt.ptr<tensor<32x64xi8>>, tensor<64x64xi8>
}
%3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xi8>>
tt.store %3, %2#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xi8>>
tt.return
}
}

44 changes: 44 additions & 0 deletions xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_1d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s

module {
tt.func @minor_1d(%arg0: !tt.ptr<i4> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) {
%c128_i32 = arith.constant 128 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c128_i64 = arith.constant 128 : i64
%c64_i32 = arith.constant 64 : i32
%cst = arith.constant dense<0> : tensor<64x64xi8>

%0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>, packed_dim = 0 } : <tensor<64x64xi4>>
// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xi8>>

%1 = tt.advance %0, [%c0_i32, %c64_i32] : <tensor<64x64xi4>>
// CHECK-NEXT: %1 = tt.advance %0, [%c0_i32, %c64_i32] : <tensor<32x64xi8>>

%2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>) : i32 {
// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr<tensor<32x64xi8>>, tensor<64x64xi8>) : i32 {

%4 = tt.load %arg3 {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<64x64xi4>>
// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<32x64xi8>>

%5 = tt.advance %arg3, [%c64_i32, %c0_i32] : <tensor<64x64xi4>>
// CHECK-NEXT: %5 = tt.advance %arg3, [%c32_i32, %c0_i32] : <tensor<32x64xi8>>

%6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8>
// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<32x64xi8>
// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<32x64xi8>
// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<32x64xi8>
// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<32x64xi8> -> tensor<32x64x2xi8>
// CHECK-NEXT: %10 = tt.trans %9 {order = array<i32: 0, 2, 1>} : tensor<32x64x2xi8> -> tensor<32x2x64xi8>
// CHECK-NEXT: %11 = tt.reshape %10 : tensor<32x2x64xi8> -> tensor<64x64xi8>

scf.yield %5, %6 : !tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>
// CHECK-NEXT: scf.yield %5, %11 : !tt.ptr<tensor<32x64xi8>>, tensor<64x64xi8>
}
%3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xi8>>
tt.store %3, %2#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xi8>>
tt.return
}
}


43 changes: 43 additions & 0 deletions xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_2d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s

module {
tt.func @minor_2d(%arg0: !tt.ptr<i4> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) {
%c128_i32 = arith.constant 128 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c128_i64 = arith.constant 128 : i64
%c64_i32 = arith.constant 64 : i32
%cst = arith.constant dense<0> : tensor<64x64xi8>

%0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c128_i64], [%c128_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>, packed_dim = 0 } : <tensor<64x64xi4>>
// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xi8>>

%1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x64xi4>>
// CHECK-NEXT: %1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x32xi8>>

%2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>) : i32 {
// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr<tensor<64x32xi8>>, tensor<64x64xi8>) : i32 {

%4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x64xi4>>
// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x32xi8>>

%5 = tt.advance %arg3, [%c0_i32, %c64_i32] : <tensor<64x64xi4>>
// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c32_i32] : <tensor<64x32xi8>>

%6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8>
// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<64x32xi8>
// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<64x32xi8>
// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<64x32xi8>
// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<64x32xi8> -> tensor<64x32x2xi8>
// CHECK-NEXT: %10 = tt.reshape %9 : tensor<64x32x2xi8> -> tensor<64x64xi8>

scf.yield %5, %6 : !tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>
// CHECK-NEXT: scf.yield %5, %10 : !tt.ptr<tensor<64x32xi8>>, tensor<64x64xi8>
}
%3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xi8>>
tt.store %3, %2#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xi8>>
tt.return
}
}


6 changes: 5 additions & 1 deletion xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,11 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
if (type == U16) {
ir_type = b.getI16Type();
} else if (type == S4) {
ir_type = b.getI8Type();
if (debug_options.xla_gpu_experimental_enable_triton_i4_rewrites()) {
ir_type = b.getI4Type();
} else {
ir_type = b.getI8Type();
}
} else {
TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type));
}
Expand Down
Loading

0 comments on commit 092b8dd

Please sign in to comment.