Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stream] Add layouts to encodings for all stream tensor AffinityOp. #19726

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ updateTensorSizeOfOp(RewriterBase &rewriter,
return success();
}

/// Updates the target encoding of `op` with resolved layouts.
static LogicalResult
updateTensorFillOp(RewriterBase &rewriter, IREE::Stream::TensorFillOp op,
const SetVector<Attribute> &layoutResolvers) {
auto encodingType = dyn_cast<RankedTensorType>(op.getTargetEncoding());
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
getEncodingWithNewLayouts(encodingType, layoutResolvers);
if (!encodingAttr) {
return success();
}
rewriter.modifyOpInPlace(op, [&] {
op.setTargetEncoding(cloneWithEncoding(encodingType, encodingAttr.value()));
});
return success();
}

/// Returns failure if `op` has encoding. The EncodingAttr has padding
/// semantic, a constant op with such encoding can not be resolved at this
/// moment.
Expand All @@ -123,7 +139,70 @@ updateTensorConstantOp(RewriterBase &rewriter,
return success();
}

/// Updates the result_encoding for `op`. The op have to define a
/// Returns a failure if there are encodings in target encoding type or update
/// encoding type.
static LogicalResult updateTensorUpdateOp(RewriterBase &rewriter,
IREE::Stream::TensorUpdateOp op) {
auto targetEncodingType = dyn_cast<RankedTensorType>(op.getTargetEncoding());
if (targetEncodingType && targetEncodingType.getEncoding()) {
return failure();
}
auto updateEncodingType = dyn_cast<RankedTensorType>(op.getUpdateEncoding());
if (updateEncodingType && updateEncodingType.getEncoding()) {
return failure();
}
return success();
}

/// Returns a failure if there are encodings in source encoding type or result
/// encoding type.
static LogicalResult updateTensorCloneOp(RewriterBase &rewriter,
IREE::Stream::TensorCloneOp op) {
auto sourceEncodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
if (sourceEncodingType && sourceEncodingType.getEncoding()) {
return failure();
}
auto resultEncodingType = dyn_cast<RankedTensorType>(op.getResultEncoding());
if (resultEncodingType && resultEncodingType.getEncoding()) {
return failure();
}
return success();
}

/// Returns a failure if there are encodings in source encoding type or result
/// encoding type.
static LogicalResult updateTensorSliceOp(RewriterBase &rewriter,
IREE::Stream::TensorSliceOp op) {
auto sourceEncodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
if (sourceEncodingType && sourceEncodingType.getEncoding()) {
return failure();
}
auto resultEncodingType = dyn_cast<RankedTensorType>(op.getResultEncoding());
if (resultEncodingType && resultEncodingType.getEncoding()) {
return failure();
}
return success();
}

/// Updates the source_encoding for `op`. The op has to define a
/// `source_encoding` parameter.
template <typename OpTy>
static LogicalResult
updateSourceEncoding(RewriterBase &rewriter, OpTy op,
const SetVector<Attribute> &layoutResolvers) {
auto encodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
getEncodingWithNewLayouts(encodingType, layoutResolvers);
if (!encodingAttr) {
return success();
}
rewriter.modifyOpInPlace(op, [&] {
op.setSourceEncoding(cloneWithEncoding(encodingType, encodingAttr.value()));
});
return success();
}

/// Updates the result_encoding for `op`. The op has to define a
/// `result_encoding` parameter.
template <typename OpTy>
static LogicalResult
Expand All @@ -141,6 +220,16 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op,
return success();
}

/// Adds the resolved layouts to all tensor types on stream tensor ops, if
/// encodings are present. Most of stream tensor ops implement
/// AffinityOpInterface, where a stream affinity indicates the kind of
/// enviroment the ops are expected run in. When an encoding is present in the
/// tensor type, the method resolves the layouts, strips outdated information,
/// and adds the resolved layouts to the encodings. The updated encodings should
/// have enough information for other lowering transformations.
/// TODO(hanchung): Add support for stream.tensor.load ops and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's some TBD work to make load/store better - it'll likely require the same behavior as with the current implementation and is something we'll want to do earlier on in flow as otherwise we'll need a stream builtin that can adjust to different data types and perform the conversion (for multiple elements) or something like a switch that translates the load/store indices (for single element) on the host.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue ID that I can link with? Otherwise, I'd just put my handle to the TODO. :)

/// stream.tensor.store ops. They are not affinity ops, so additional analysis
/// will be needed in the work.
static LogicalResult addLayoutsToTensorPhaseOps(
ModuleOp moduleOp, FunctionOpInterface funcOp,
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
Expand Down Expand Up @@ -171,7 +260,6 @@ static LogicalResult addLayoutsToTensorPhaseOps(
return affinityOp.emitError("failed on making layout resolvers");
}

// TODO(hanchung): Update other Stream operations.
LogicalResult result =
TypeSwitch<Operation *, LogicalResult>(affinityOp)
.Case<IREE::Stream::TensorSizeOfOp>([&](auto op) {
Expand All @@ -184,6 +272,15 @@ static LogicalResult addLayoutsToTensorPhaseOps(
.Case<IREE::Stream::TensorConstantOp>([&](auto op) {
return updateTensorConstantOp(rewriter, op, layoutResolvers);
})
.Case<IREE::Stream::TensorFillOp>([&](auto op) {
return updateTensorFillOp(rewriter, op, layoutResolvers);
})
.Case<IREE::Stream::TensorCloneOp>(
[&](auto op) { return updateTensorCloneOp(rewriter, op); })
.Case<IREE::Stream::TensorSliceOp>(
[&](auto op) { return updateTensorSliceOp(rewriter, op); })
.Case<IREE::Stream::TensorUpdateOp>(
[&](auto op) { return updateTensorUpdateOp(rewriter, op); })
.Default([](auto *op) { return failure(); });

if (failed(result)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,35 @@ module {

// -----

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
util.global private @device_a = #device_target_local_0_

util.func public @tensor_fill_op(%arg0: f32, %arg1: !stream.resource<*>, %arg2: index, %arg3: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = stream.tensor.fill on(#hal.device.affinity<@device_a>)
%arg0, %arg1[%c0, %c0 for %c1, %c1] : f32
-> tensor<?x4xf32, #encoding>{%arg2} in %arg1 as !stream.resource<*>{%arg3}
util.return
}
}
// CHECK: #[[$ENCODING:.+]] = #iree_encoding.encoding
// CHECK-SAME: #iree_cpu.vmvx_encoding_layout
// CHECK-SAME: encoding_info = {innerDimsPos = [{{.+}}], innerTileSizes = [{{.+}}], outerDimsPerm = [{{.+}}]}
// CHECK: #[[TARGET:.+]] = #hal.device.target
// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]]
// CHECK-LABEL: util.func public @tensor_fill_op
// CHECK: stream.tensor.fill on(#hal.device.affinity<@[[$DEVICE]]>)
// CHECK-SAME: f32 -> tensor<?x4xf32, #[[$ENCODING]]>

// -----

// Checks that the stream.tensor.constant op with encoding is not supported.

#map0 = affine_map<(m, n, k) -> (m, k)>
Expand All @@ -82,3 +111,73 @@ module {
util.return
}
}

// -----

// Checks that the stream.tensor.clone op with encoding is not supported.

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
util.global private @device_a = #device_target_local_0_

// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
util.func public @tensor_clone_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%0 = stream.tensor.clone on(#hal.device.affinity<@device_a>)
%arg0 : tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
-> tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
util.return
}
}

// -----

// Checks that the stream.tensor.slice op with encoding is not supported.

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
util.global private @device_a = #device_target_local_0_

// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
util.func public @tensor_slice_op_with_encoding(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = stream.tensor.slice on(#hal.device.affinity<@device_a>)
%arg0[%c0, %c1 for %arg3, %c1] : tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
-> tensor<?x1xf32, #encoding>{%arg3} in !stream.resource<*>{%arg4}
util.return
}
}

// -----

// Checks that the stream.tensor.update op with encoding is not supported.

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
util.global private @device_a = #device_target_local_0_

// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
util.func public @tensor_update_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index, %arg4: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = stream.tensor.update on(#hal.device.affinity<@device_a>)
%arg0, %arg2[%c0, %c0] : tensor<2x2xf32, #encoding> in !stream.resource<*>{%arg1}
-> tensor<?x4xf32, #encoding>{%arg3} in %arg2 as !stream.resource<*>{%arg4}
util.return
}
}
Loading