From 02d145e5a948283df3cd30289fa68fc9ceb602cb Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Wed, 8 Jan 2025 19:39:49 -0800 Subject: [PATCH] [Stream] Implement SpecializeEncodings pass (1/n) (#19502) There are three major changes in the revision: - Introduce `AffinityAnalysisDialectInterface` Stream dialect interface. It is used to fetch attributes that are defined by other dialects. In the revision, HAL implements the dialect interface, and it can return whatever attribute attached in HAL::ExecutableTarget attributes. The main idea of the dialect interface is that Stream **does not** need to depend on HAL to get the layout information. - Add `cloneWithLayouts` method to the EncodingAttr. It is used in the encoding specialization pass where it can resolve the layout requirements and add it to the `layouts` field. The other optional parameters are dropped because the layout is already resolved. It can be a new Encoding dialect attribute because it is just describing the layout. The stream tensor ops do not need to know the `op_type`, `element_types` and `operand_index` parameters. It only needs the layout information, and the attribute should implement the interface method. - Partially implement the SpecializeEncodings pass. The responsibility of the pass is large, so I decide to implement it incrementally. This revision only implements the mechanism of updating stream tensor ops' encoding, and only stream.tensor.sizeof op is supported. The rest of the support for other stream tensor op can be added later on. The executable duplication and the update of dispatch ops will be implemented in subsequent PRs. --------- Signed-off-by: hanhanW --- .../Dialect/Encoding/IR/EncodingAttrs.cpp | 10 ++ .../Dialect/Encoding/IR/EncodingAttrs.td | 4 + .../iree/compiler/Dialect/HAL/IR/BUILD.bazel | 2 + .../compiler/Dialect/HAL/IR/CMakeLists.txt | 2 + .../compiler/Dialect/HAL/IR/HALDialect.cpp | 27 +++ .../compiler/Dialect/Stream/IR/BUILD.bazel | 1 + .../compiler/Dialect/Stream/IR/CMakeLists.txt | 1 + .../Dialect/Stream/IR/StreamInterfaces.h | 36 ++++ .../Dialect/Stream/Transforms/BUILD.bazel | 1 + .../Dialect/Stream/Transforms/CMakeLists.txt | 1 + .../Dialect/Stream/Transforms/Passes.cpp | 13 ++ .../Dialect/Stream/Transforms/Passes.td | 10 ++ .../Stream/Transforms/SpecializeEncodings.cpp | 169 ++++++++++++++++++ .../Stream/Transforms/test/BUILD.bazel | 1 + .../Stream/Transforms/test/CMakeLists.txt | 1 + .../Transforms/test/specialize_encodings.mlir | 24 +++ 16 files changed, 303 insertions(+) create mode 100644 compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 593d9b8fc5c6..b388b9ceb9b5 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/LLVM.h" @@ -113,6 +114,15 @@ EncodingAttr EncodingAttr::clone(AffineMap bcastMap) { AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts()); } +EncodingAttr EncodingAttr::cloneWithLayouts(ArrayRef layouts) { + MLIRContext *ctx = getContext(); + return get(ctx, getOperandIndex(), getOpType(), getElementTypes(), + /*user_indexing_maps=*/ArrayAttr(), + /*bcast_map=*/AffineMapAttr(), + /*round_dims_to=*/DenseI64ArrayAttr(), + ArrayAttr::get(ctx, layouts)); +} + /// Returns the bit-width of the scalar type. If the type is complex, it returns /// the type of individual elements * 2 (1 for real and 1 for complex). static unsigned getTypeBitWidth(Type type) { diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td index 54829b68e2cf..434356a7e66c 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td @@ -113,6 +113,10 @@ def EncodingAttr : /// Clones an encoding with a new bcast_map EncodingAttr clone(AffineMap bcastMap); + + /// Clones an encoding with a new layout list and drops other optional + /// parameters (because they are resolved). + EncodingAttr cloneWithLayouts(ArrayRef layouts); }]; let genVerifyDecl = 0; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel index 3f80245bfc8c..576e77dc14ea 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel @@ -101,7 +101,9 @@ iree_compiler_cc_library( deps = [ ":IR", "//compiler/src/iree/compiler/Dialect/HAL:hal_imports", + "//compiler/src/iree/compiler/Dialect/HAL/Analysis", "//compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM", + "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 846bcf0d38a2..e0b68bdf56b9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -79,8 +79,10 @@ iree_cc_library( MLIRParser MLIRSCFDialect MLIRTransformUtils + iree::compiler::Dialect::HAL::Analysis iree::compiler::Dialect::HAL::Conversion::HALToVM iree::compiler::Dialect::HAL::hal_imports + iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::IR iree::compiler::Dialect::VM::Conversion PUBLIC diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp index 00c2c6ebebc8..e28d08fb9a89 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp @@ -6,13 +6,16 @@ #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Dialect/HAL/hal.imports.h" +#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" @@ -115,6 +118,29 @@ class HALToVMConversionInterface : public VMConversionDialectInterface { } }; +class HALAffinityAnalysisDialectInterface + : public IREE::Stream::AffinityAnalysisDialectInterface { +public: + using AffinityAnalysisDialectInterface::AffinityAnalysisDialectInterface; + IREE::Stream::ResolveLayoutAttrFn + makeLayoutAttrResolver(ModuleOp moduleOp) const { + return [=](IREE::Stream::AffinityAttr affinityAttr, Operation *op, + SetVector &layoutAttrs) -> LogicalResult { + // This needs to be in the lambda because the moduleOp could be modified.. + IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return op->emitError("failed to run DeviceAnalysis"); + } + SetVector resultSet; + deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op, + resultSet); + // TODO(hanchung): Populate the EncodingLayoutAttr when it is ready. + layoutAttrs.insert(resultSet.begin(), resultSet.end()); + return success(); + }; + }; +}; + } // namespace HALDialect::HALDialect(MLIRContext *context) @@ -131,6 +157,7 @@ HALDialect::HALDialect(MLIRContext *context) #include "iree/compiler/Dialect/HAL/IR/HALOps.cpp.inc" >(); addInterfaces(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel index 9959bd1ef3ff..2fa22edf5eb6 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel @@ -50,6 +50,7 @@ iree_compiler_cc_library( hdrs = [ "StreamDialect.h", "StreamEnums.h.inc", + "StreamInterfaces.h", "StreamOpInterfaces.h.inc", "StreamOps.h", "StreamOps.h.inc", diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/IR/CMakeLists.txt index 286bb7152ed1..2f10910741ae 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/CMakeLists.txt @@ -16,6 +16,7 @@ iree_cc_library( HDRS "StreamDialect.h" "StreamEnums.h.inc" + "StreamInterfaces.h" "StreamOpInterfaces.h.inc" "StreamOps.h" "StreamOps.h.inc" diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h new file mode 100644 index 000000000000..d18b7a54d72e --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h @@ -0,0 +1,36 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMINTERACES_H_ +#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMINTERACES_H_ + +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::iree_compiler::IREE::Stream { + +using ResolveLayoutAttrFn = std::function &)>; + +class AffinityAnalysisDialectInterface + : public DialectInterface::Base { +public: + AffinityAnalysisDialectInterface(Dialect *dialect) : Base(dialect) {} + + /// The `moduleOp` must remain live and unmodified for as long as the returned + /// capture is. Otherwise, it will likely be incorrect or crash if the module + /// op is mutated, especially when module scope analysis is run. + virtual ResolveLayoutAttrFn + makeLayoutAttrResolver(ModuleOp moduleOp) const = 0; +}; + +} // namespace mlir::iree_compiler::IREE::Stream + +#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAM_INTERFACES_H_ diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel index 90e22a640352..d7ef2a20b4ae 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel @@ -39,6 +39,7 @@ iree_compiler_cc_library( "ScheduleConcurrency.cpp", "ScheduleExecution.cpp", "SpecializeDispatches.cpp", + "SpecializeEncodings.cpp", "VerifyAffinities.cpp", "VerifyAsyncAccessRanges.cpp", "VerifyLowerings.cpp", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt index 9e15b84e0763..b9050532cb06 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt @@ -40,6 +40,7 @@ iree_cc_library( "ScheduleConcurrency.cpp" "ScheduleExecution.cpp" "SpecializeDispatches.cpp" + "SpecializeEncodings.cpp" "VerifyAffinities.cpp" "VerifyAsyncAccessRanges.cpp" "VerifyLowerings.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 2234c62daa58..69e65fed56ff 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -22,6 +22,15 @@ static llvm::cl::opt clAnnotateInputAffinities( "the pipeline for debugging."), llvm::cl::init(false)); +// TODO(hanchung): Enable the pass by default once the implementation is done. +static llvm::cl::opt clSpecializeEncodings( + "iree-stream-experimental-specialize-encodings", + llvm::cl::desc( + "Enables SpecializeEncodingPass in Stream pass pipeline. This pass is " + "currently under development, so it is not enabled by default. It can " + "only handle limited cases at this moment."), + llvm::cl::init(false)); + namespace mlir::iree_compiler::IREE::Stream { using FunctionLikeNest = @@ -140,6 +149,10 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager, // Tensor lowering and resource management //---------------------------------------------------------------------------- + if (clSpecializeEncodings) { + passManager.addPass(IREE::Stream::createSpecializeEncodingsPass()); + } + // Lower stream.tensor.* ops to stream.async.* ops based on // affinity/configuration assigned during placement. FunctionLikeNest(passManager) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td index 3aec709938cf..3dcbbb5e236a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -414,6 +414,16 @@ def SpecializeDispatchesPass : ]; } +def SpecializeEncodingsPass : + Pass<"iree-stream-specialize-encodings", "mlir::ModuleOp"> { + let summary = "Specializes data-tiling encodings based on device analysis."; + let description = [{ + Attaches layouts to encodings and duplicates executables based on device + analysis. + TODO: Unpack the context. The pass is not fully implemented yet. + }]; +} + def AnnotateDispatchArgumentsPass : Pass<"iree-stream-annotate-dispatch-arguments", "mlir::ModuleOp"> { let summary = "Annotates dispatch arguments with potential values derived from dispatch sites."; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp new file mode 100644 index 000000000000..b177bedb718d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp @@ -0,0 +1,169 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" +#include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::IREE::Stream { + +#define DEBUG_TYPE "iree-stream-specialize-encodings" + +#define GEN_PASS_DEF_SPECIALIZEENCODINGSPASS +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" + +namespace { +/// Returns a stably sorted list of dialect interfaces of T for all dialects +/// used within the given module. +template +SmallVector gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) { + SmallPtrSet resultSet; + for (auto dialect : moduleOp.getContext()->getLoadedDialects()) { + auto *dialectInterface = dialect->getRegisteredInterface(); + if (!dialectInterface) + continue; + resultSet.insert(dialectInterface); + } + + // NOTE: to ensure deterministic output we sort the result so that imports are + // always added in a consistent order. + SmallVector results = {resultSet.begin(), resultSet.end()}; + llvm::sort( + results, +[](const T *a, const T *b) { + return a->getDialect()->getNamespace().compare( + b->getDialect()->getNamespace()) < 0; + }); + return results; +} + +// TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType. +static RankedTensorType cloneWithEncoding(RankedTensorType type, + Attribute encodingAttr) { + return RankedTensorType::get(type.getShape(), type.getElementType(), + encodingAttr); +} + +static LogicalResult addLayoutsToTensorPhaseOps( + ModuleOp moduleOp, FunctionOpInterface funcOp, + IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) { + SmallVector candidates; + funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) { + // Only need to update encoding types for ops that have TensorPhaseOp trait. + if (!affinityOp->hasTrait()) { + return; + } + + // Bail out if the operation does not have an affinity attribute. + auto affinityAttr = affinityOp.getAffinityAttr(); + if (!affinityAttr) { + return; + } + candidates.push_back(affinityOp); + }); + + if (candidates.empty()) { + return success(); + } + + IRRewriter rewriter(funcOp.getContext()); + for (auto affinityOp : candidates) { + auto affinityAttr = affinityOp.getAffinityAttr(); + SetVector layouts; + if (failed(resolveLayoutAttr(affinityAttr, moduleOp, layouts))) { + return affinityOp.emitError("failed on making layouts"); + } + + // Returns an updated encoding attribute if an encoding attribute is present + // in the type. Otherwise, returns std::nullopt. + auto getEncodingWithNewLayouts = + [=](Type type) -> std::optional { + auto rankedTensorType = dyn_cast(type); + if (!rankedTensorType) { + return std::nullopt; + } + auto encodingAttr = IREE::Encoding::getEncodingAttr(rankedTensorType); + if (!encodingAttr) { + return std::nullopt; + } + return encodingAttr.cloneWithLayouts(layouts.getArrayRef()); + }; + + // TODO(hanchung): Update other Stream operations. + LogicalResult result = + TypeSwitch(affinityOp) + .Case([&](auto sizeOfOp) { + auto encodingType = + dyn_cast(sizeOfOp.getEncoding()); + if (!encodingType) { + return success(); + } + std::optional encodingAttr = + getEncodingWithNewLayouts(encodingType); + if (!encodingAttr) { + return success(); + } + rewriter.modifyOpInPlace(sizeOfOp, [&] { + sizeOfOp.setEncoding( + cloneWithEncoding(encodingType, encodingAttr.value())); + }); + return success(); + }) + .Default([](auto *op) { return failure(); }); + + if (failed(result)) { + return failure(); + } + } + return success(); +} +} // namespace + +struct SpecializeEncodingsPass + : public impl::SpecializeEncodingsPassBase { + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + auto usedDialects = gatherUsedDialectInterfaces< + IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp); + if (usedDialects.size() != 1) { + moduleOp.emitError("expected only one dialect implementing " + "AffinityAnalysisDialectInterface"); + return signalPassFailure(); + } + + llvm::MapVector executableOps; + for (auto executableOp : moduleOp.getOps()) { + executableOps[executableOp.getName()] = executableOp; + } + + IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr = + usedDialects[0]->makeLayoutAttrResolver(moduleOp); + for (auto funcOp : moduleOp.getOps()) { + if (failed(addLayoutsToTensorPhaseOps(moduleOp, funcOp, + resolveLayoutAttr))) { + funcOp.emitError( + "failed on adding layouts to Stream::TensorPhaseOp with encodings"); + return signalPassFailure(); + } + + // TODO(hanchung): Duplicate executables and update dispatch ops. + } + } +}; + +} // namespace mlir::iree_compiler::IREE::Stream diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index 87d6bea60977..722ce780b148 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -46,6 +46,7 @@ iree_lit_test_suite( "schedule_concurrency.mlir", "schedule_execution.mlir", "specialize_dispatches.mlir", + "specialize_encodings.mlir", "verify_affinities.mlir", "verify_async_access_ranges.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 8c4ca85927c5..6eb964a1771e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -44,6 +44,7 @@ iree_lit_test_suite( "schedule_concurrency.mlir" "schedule_execution.mlir" "specialize_dispatches.mlir" + "specialize_encodings.mlir" "verify_affinities.mlir" "verify_async_access_ranges.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir new file mode 100644 index 000000000000..1ae03e604763 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir @@ -0,0 +1,24 @@ +// RUN: iree-opt --split-input-file --iree-stream-specialize-encodings %s | FileCheck %s + +//------------------------------------------------------------------------------ +// Stream ops that have TensorPhaseOp trait. This test suite tests that the +// encoding is updated that carries resolved layouts. +//------------------------------------------------------------------------------ + +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding_layout = #iree_cpu.vmvx_encoding_layout<>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + util.func public @tensor_sizeof(%d0: index, %d1: index) -> index { + %size = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor{%d0, %d1} : index + util.return %size : index + } +} +// CHECK: #[[EXECUTABLE:.+]] = #hal.executable.target<"vmvx", +// CHECK: #[[$ENCODING:.+]] = #iree_encoding.encoding +// CHECK-SAME: layouts = [#[[EXECUTABLE]]] +// CHECK-LABEL: util.func public @tensor_sizeof +// CHECK: %[[RES:.+]] = stream.tensor.sizeof {{.+}} tensor +// CHECK: return %[[RES]]