-
Notifications
You must be signed in to change notification settings - Fork 470
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[XLA:GPU] Introduce xla_gpu_experimental_enable_triton_i4_rewrites, t…
…hat enables the corresponding rewrites for the i4 tensors in triton mlir. The default value is false. The goal of the cl is to move the unpacking logic to the triton level rewrite. As a result the HLO to Triton emitter do not need to take into the account the unpacking logic, could keep using the shapes that match to the actual tensors. etc. The cl: a) adds the flag that enables the triton level rewrites. b) disables int4 support in the triton_fusion_emitter_legacy_matmul if the flag is true. c) changes the mapping from S4 hlo type to triton type. Emitter emits s4 instead of s8 if the flag is true. d) fixes the unpacking logic for the cases where the tensor packed along the minor dim. e) fixes the unpacking logic for the cases when the packed dim actually has only 1 element. f) covers the cases when s4 is the rhs parameter of the dot. PiperOrigin-RevId: 714049078
- Loading branch information
1 parent
5630f58
commit 092b8dd
Showing
12 changed files
with
822 additions
and
145 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
load("//xla:lit.bzl", "lit_test_suite") # @unused | ||
|
||
package( | ||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], | ||
default_visibility = [":friends"], | ||
licenses = ["notice"], | ||
) | ||
|
||
package_group( | ||
name = "friends", | ||
includes = [ | ||
"//xla:friends", | ||
], | ||
) | ||
|
||
# copybara:uncomment_begin(triton-opt tool doesn't build in OSS) | ||
# lit_test_suite( | ||
# name = "mlir_lit_tests", | ||
# srcs = glob(["*.mlir"]), | ||
# cfg = "//xla:lit.cfg.py", | ||
# tools = [ | ||
# "@llvm-project//llvm:FileCheck", | ||
# "//xla/service/gpu/tests:xla-opt", | ||
# ], | ||
# ) | ||
# copybara:uncomment_end |
41 changes: 41 additions & 0 deletions
41
xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_1d.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s | ||
|
||
module { | ||
tt.func @major_1d(%arg0: !tt.ptr<i4> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) { | ||
%c128_i32 = arith.constant 128 : i32 | ||
%c128_i64 = arith.constant 128 : i64 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c1_i64 = arith.constant 1 : i64 | ||
%c64_i32 = arith.constant 64 : i32 | ||
%cst = arith.constant dense<0> : tensor<64x64xi8> | ||
|
||
%0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>, packed_dim = 1 } : <tensor<64x64xi4>> | ||
// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c64_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xi8>> | ||
|
||
%1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x64xi4>> | ||
// CHECK-NEXT: %1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x32xi8>> | ||
|
||
%2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>) : i32 { | ||
// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr<tensor<64x32xi8>>, tensor<64x64xi8>) : i32 { | ||
|
||
%4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x64xi4>> | ||
// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x32xi8>> | ||
|
||
%5 = tt.advance %arg3, [%c0_i32, %c64_i32] : <tensor<64x64xi4>> | ||
// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c32_i32] : <tensor<64x32xi8>> | ||
|
||
%6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8> | ||
// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<64x32xi8> | ||
// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<64x32xi8> | ||
// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<64x32xi8> | ||
// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<64x32xi8> -> tensor<64x32x2xi8> | ||
// CHECK-NEXT: %10 = tt.reshape %9 : tensor<64x32x2xi8> -> tensor<64x64xi8> | ||
|
||
scf.yield %5, %6 : !tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8> | ||
// CHECK-NEXT: scf.yield %5, %10 : !tt.ptr<tensor<64x32xi8>>, tensor<64x64xi8> | ||
} | ||
%3 = tt.make_tensor_ptr %arg1, [%c1_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xi8>> | ||
tt.store %3, %2#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xi8>> | ||
tt.return | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_2d.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s | ||
|
||
module { | ||
tt.func @major_2d(%arg0: !tt.ptr<i4> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) { | ||
%c128_i32 = arith.constant 128 : i32 | ||
%c128_i64 = arith.constant 128 : i64 | ||
%c16_i64 = arith.constant 16 : i64 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c1_i64 = arith.constant 1 : i64 | ||
%c64_i32 = arith.constant 64 : i32 | ||
%cst = arith.constant dense<0> : tensor<64x64xi8> | ||
|
||
%0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c128_i64], [%c1_i64, %c16_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>, packed_dim = 1 } : <tensor<64x64xi4>> | ||
// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c8_i64, %c128_i64], [%c1_i64, %c8_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xi8>> | ||
|
||
%1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x64xi4>> | ||
// CHECK-NEXT: %1 = tt.advance %0, [%c32_i32, %c0_i32] : <tensor<32x64xi8>> | ||
|
||
%2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>) : i32 { | ||
// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr<tensor<32x64xi8>>, tensor<64x64xi8>) : i32 { | ||
|
||
%4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x64xi4>> | ||
// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<32x64xi8>> | ||
|
||
%5 = tt.advance %arg3, [%c0_i32, %c64_i32] : <tensor<64x64xi4>> | ||
// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c64_i32] : <tensor<32x64xi8>> | ||
|
||
%6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8> | ||
// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<32x64xi8> | ||
// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<32x64xi8> | ||
// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<32x64xi8> | ||
// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<32x64xi8> -> tensor<32x64x2xi8> | ||
// CHECK-NEXT: %10 = tt.trans %9 {order = array<i32: 0, 2, 1>} : tensor<32x64x2xi8> -> tensor<32x2x64xi8> | ||
// CHECK-NEXT: %11 = tt.reshape %10 : tensor<32x2x64xi8> -> tensor<64x64xi8> | ||
|
||
scf.yield %5, %6 : !tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8> | ||
// CHECK-NEXT: scf.yield %5, %11 : !tt.ptr<tensor<32x64xi8>>, tensor<64x64xi8> | ||
} | ||
%3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xi8>> | ||
tt.store %3, %2#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xi8>> | ||
tt.return | ||
} | ||
} | ||
|
44 changes: 44 additions & 0 deletions
44
xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_1d.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s | ||
|
||
module { | ||
tt.func @minor_1d(%arg0: !tt.ptr<i4> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) { | ||
%c128_i32 = arith.constant 128 : i32 | ||
%c1_i64 = arith.constant 1 : i64 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c128_i64 = arith.constant 128 : i64 | ||
%c64_i32 = arith.constant 64 : i32 | ||
%cst = arith.constant dense<0> : tensor<64x64xi8> | ||
|
||
%0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>, packed_dim = 0 } : <tensor<64x64xi4>> | ||
// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xi8>> | ||
|
||
%1 = tt.advance %0, [%c0_i32, %c64_i32] : <tensor<64x64xi4>> | ||
// CHECK-NEXT: %1 = tt.advance %0, [%c0_i32, %c64_i32] : <tensor<32x64xi8>> | ||
|
||
%2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>) : i32 { | ||
// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr<tensor<32x64xi8>>, tensor<64x64xi8>) : i32 { | ||
|
||
%4 = tt.load %arg3 {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<64x64xi4>> | ||
// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<32x64xi8>> | ||
|
||
%5 = tt.advance %arg3, [%c64_i32, %c0_i32] : <tensor<64x64xi4>> | ||
// CHECK-NEXT: %5 = tt.advance %arg3, [%c32_i32, %c0_i32] : <tensor<32x64xi8>> | ||
|
||
%6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8> | ||
// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<32x64xi8> | ||
// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<32x64xi8> | ||
// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<32x64xi8> | ||
// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<32x64xi8> -> tensor<32x64x2xi8> | ||
// CHECK-NEXT: %10 = tt.trans %9 {order = array<i32: 0, 2, 1>} : tensor<32x64x2xi8> -> tensor<32x2x64xi8> | ||
// CHECK-NEXT: %11 = tt.reshape %10 : tensor<32x2x64xi8> -> tensor<64x64xi8> | ||
|
||
scf.yield %5, %6 : !tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8> | ||
// CHECK-NEXT: scf.yield %5, %11 : !tt.ptr<tensor<32x64xi8>>, tensor<64x64xi8> | ||
} | ||
%3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xi8>> | ||
tt.store %3, %2#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xi8>> | ||
tt.return | ||
} | ||
} | ||
|
||
|
43 changes: 43 additions & 0 deletions
43
xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_2d.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s | ||
|
||
module { | ||
tt.func @minor_2d(%arg0: !tt.ptr<i4> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}) { | ||
%c128_i32 = arith.constant 128 : i32 | ||
%c1_i64 = arith.constant 1 : i64 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c128_i64 = arith.constant 128 : i64 | ||
%c64_i32 = arith.constant 64 : i32 | ||
%cst = arith.constant dense<0> : tensor<64x64xi8> | ||
|
||
%0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c128_i64], [%c128_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>, packed_dim = 0 } : <tensor<64x64xi4>> | ||
// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xi8>> | ||
|
||
%1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x64xi4>> | ||
// CHECK-NEXT: %1 = tt.advance %0, [%c64_i32, %c0_i32] : <tensor<64x32xi8>> | ||
|
||
%2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8>) : i32 { | ||
// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr<tensor<64x32xi8>>, tensor<64x64xi8>) : i32 { | ||
|
||
%4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x64xi4>> | ||
// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x32xi8>> | ||
|
||
%5 = tt.advance %arg3, [%c0_i32, %c64_i32] : <tensor<64x64xi4>> | ||
// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c32_i32] : <tensor<64x32xi8>> | ||
|
||
%6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8> | ||
// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<64x32xi8> | ||
// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<64x32xi8> | ||
// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<64x32xi8> | ||
// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<64x32xi8> -> tensor<64x32x2xi8> | ||
// CHECK-NEXT: %10 = tt.reshape %9 : tensor<64x32x2xi8> -> tensor<64x64xi8> | ||
|
||
scf.yield %5, %6 : !tt.ptr<tensor<64x64xi4>>, tensor<64x64xi8> | ||
// CHECK-NEXT: scf.yield %5, %10 : !tt.ptr<tensor<64x32xi8>>, tensor<64x64xi8> | ||
} | ||
%3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xi8>> | ||
tt.store %3, %2#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xi8>> | ||
tt.return | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.