Skip to content

Commit

Permalink
[HAL] Use util.assume.int for memref alignments (#19691)
Browse files Browse the repository at this point in the history
When bufferizing, use util.assume.int to construct
memref.assume_alignment, since we can use the divisibility on those
assumptions constrain the subspan offset.
  • Loading branch information
krzysz00 authored Jan 17, 2025
1 parent 4d3f06a commit b08d152
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,25 @@ func.func @matmul() {

// -----

#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
#pipeline_layout = #hal.pipeline.layout<constants = 5, bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
func.func @matmul_fill() {
%cst = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%c1024 = arith.constant 1024 : index
%m = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
%n = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
%k = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
%base_offset_i32 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) alignment(8) : i32
%base_offset = arith.index_castui %base_offset_i32 : i32 to index
%res_offset_i32 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : i32
%res_offset_index = arith.index_castui %res_offset_i32 : i32 to index
%res_offset = util.assume.int %res_offset_index[<umin = 0, umax = 0>, <umin = 128, umax = 128, udiv = 128>, <umin = 1024, umax = 1024, udiv = 1024>] : index
%lhs = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(32) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%m, %k}
%rhs = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%base_offset) : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%k, %n}
%result = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c1024) : !flow.dispatch.tensor<readwrite:tensor<?x?xf32>>{%m, %n}
%result = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%res_offset) : !flow.dispatch.tensor<readwrite:tensor<?x?xf32>>{%m, %n}
%wg_id_y = hal.interface.workgroup.id[1] : index
%wg_count_y = hal.interface.workgroup.count[1] : index
%wg_size_y = hal.interface.workgroup.size[1] : index
Expand Down Expand Up @@ -127,11 +129,14 @@ func.func @matmul_fill() {
// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load layout({{.+}}) ordinal(2)
// CHECK-DAG: %[[BASE_OFFSET_I32:.+]] = hal.interface.constant.load layout({{.+}}) ordinal(3)
// CHECK-DAG: %[[BASE_OFFSET:.+]] = arith.index_castui %[[BASE_OFFSET_I32]]
// CHECK-DAG: %[[RES_OFFSET_I32:.+]] = hal.interface.constant.load layout({{.+}}) ordinal(4)
// CHECK-DAG: %[[RES_OFFSET_INDEX:.+]] = arith.index_castui %[[RES_OFFSET_I32]]
// CHECK-DAG: %[[RES_OFFSET:.+]] = util.assume.int %[[RES_OFFSET_INDEX]]
// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(32)
// CHECK-DAG: memref.assume_alignment %[[LHS]], 32
// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[BASE_OFFSET]])
// CHECK-DAG: memref.assume_alignment %[[RHS]], 8
// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c1024)
// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[RES_OFFSET]])
// CHECK-DAG: memref.assume_alignment %[[RESULT]], 64
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ std::optional<uint64_t> lookupOffsetOrAlignment(Value value) {
}
} else if (auto castOp = dyn_cast<arith::IndexCastUIOp>(op)) {
return lookupOffsetOrAlignment(castOp.getOperand());
} else if (auto assumeOp = dyn_cast<IREE::Util::AssumeIntOp>(op)) {
return assumeOp.getUnionedUnsignedDivisor(
cast<OpResult>(value).getResultNumber());
}

// TODO(benvanik): more searching using util.align and other ops.
Expand Down

0 comments on commit b08d152

Please sign in to comment.