Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DisposableElementsAttr for ZHigh constant propagation #3013

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_onnx_mlir_library(OMNNPACompilerOptions

add_onnx_mlir_library(OMNNPACompilerUtils
NNPACompilerUtils.cpp
ZHighDisposableGarbageCollector.cpp

EXCLUDE_FROM_OM_LIBS

Expand Down
14 changes: 8 additions & 6 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
#include "src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp"
#include "src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
Expand Down Expand Up @@ -120,10 +121,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}

// Replace every DisposableElementsAttr with DenseElementsAttr.
// ZHighConstPropagation currently assumes that DenseElementsAttr is used.
pm.addPass(createScrubDisposablePass());

// Experimental feature: Decompose stick/unstick into two phases: layout
// transform and data conversion. Do some optimizations after decomposing.
// Then, recompose again layout and data conversion if they are not optimized.
Expand All @@ -146,15 +143,17 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
// Only support BE machines.
bool isBE = llvm::endianness::native == llvm::endianness::big;
if (isBE)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::zhigh::createZHighConstPropagationPass());
pm.addPass(onnx_mlir::zhigh::createZHighConstPropagationPass());

// Remove common sub-expressions.
pm.addPass(mlir::createCSEPass());

// Clean dead code.
pm.addPass(mlir::createSymbolDCEPass());

// Replace every DisposableElementsAttr with DenseElementsAttr.
pm.addPass(onnx_mlir::zhigh::createZHighScrubDisposablePass());

// Insert an instrumentation after lowering onnx to zhigh to get profiling
// for onnx and zhigh ops.
// Keep this pass at the end of this function.
Expand Down Expand Up @@ -195,6 +194,9 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,

// LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;);
if (emissionTarget >= EmitONNXIR) {
pm.addInstrumentation(
std::make_unique<onnx_mlir::zhigh::ZHighDisposableGarbageCollector>(
pm.getContext()));
addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty(),
/*donotScrubDisposableElementsAttr*/ true);
pm.addPass(onnx_mlir::createDevicePlacementPass(nnpaLoadDevicePlacementFile,
Expand Down
43 changes: 43 additions & 0 deletions src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ZHighDisposableGarbageCollector.cpp -----------------===//
//
// Garbage collects DisposableElementsAttr attributes.
//
//===----------------------------------------------------------------------===//

#include "src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp"
#include "src/Dialect/ONNX/ONNXDialect.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

#include "mlir/IR/BuiltinOps.h"

using namespace mlir;

namespace onnx_mlir {
namespace zhigh {

ZHighDisposableGarbageCollector::ZHighDisposableGarbageCollector(
MLIRContext *context)
: disposablePool(*DisposablePool::get<ONNXDialect>(context)) {}

ZHighDisposableGarbageCollector::~ZHighDisposableGarbageCollector() {}

void ZHighDisposableGarbageCollector::runAfterPass(Pass *pass, Operation *op) {
if (!disposablePool.isActive())
return;
ModuleOp moduleOp = mlir::dyn_cast<ModuleOp>(op);
if (!moduleOp)
return;
disposablePool.garbageCollectUnreachable(
moduleOp, {{ONNXConstantOp::getOperationName(), "value"},
{ONNXConstantOfShapeOp::getOperationName(), "value"},
{ZHighStickifiedConstantOp::getOperationName(), "value"}});
}

} // namespace zhigh
} // namespace onnx_mlir
37 changes: 37 additions & 0 deletions src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ZHighDisposableGarbageCollector.hpp -----------------===//
//
// Garbage collects DisposableElementsAttr attributes.
//
//===----------------------------------------------------------------------===//

#ifndef ONNX_MLIR_ZHIGH_GARBAGE_COLLECTOR_H
#define ONNX_MLIR_ZHIGH_GARBAGE_COLLECTOR_H

#include "mlir/Pass/PassInstrumentation.h"

namespace mlir {
class MLIRContext;
}

namespace onnx_mlir {
class DisposablePool;

namespace zhigh {

struct ZHighDisposableGarbageCollector : public mlir::PassInstrumentation {
ZHighDisposableGarbageCollector(mlir::MLIRContext *context);
~ZHighDisposableGarbageCollector() override;

void runAfterPass(mlir::Pass *pass, mlir::Operation *op) override;

private:
DisposablePool &disposablePool;
};

} // namespace zhigh
} // namespace onnx_mlir
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ bool isF32ScalarConstantTensor(Value v) {
FloatAttr getScalarF32AttrFromConstant(Value v) {
if (!isF32ScalarConstantTensor(v))
return nullptr;
DenseElementsAttr constElements = ElementsAttrBuilder::toDenseElementsAttr(
getElementAttributeFromONNXValue(v));
ElementsAttr constElements = getElementAttributeFromONNXValue(v);
return constElements.getSplatValue<FloatAttr>();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===---------- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh
//---------===//
//===---- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh ---------===//
//
// Copyright 2019-2024 The IBM Research Authors.
//
Expand Down Expand Up @@ -117,4 +116,4 @@ mlir::Value getDynShape(
mlir::Location loc, mlir::PatternRewriter &rewriter, mlir::Value x);

} // namespace onnx_mlir
#endif
#endif
159 changes: 71 additions & 88 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,25 @@ static Value insertAllocForWorkAreaForRNNOps(IndexExprBuilderForKrnl &createIE,
return create.mem.alignedAlloc(resultType, dims, gAlignment);
}

/// Get a dense resource attribute to store stickified data of a given i8 value.
/// Attribute type: tensor<sizeInBytes x i8>
DenseResourceElementsAttr getDenseResourceElementsAttrOfValue(
PatternRewriter &rewriter, ZHighStickifiedConstantOp stickifiedConstant,
int8_t val, int64_t sizeInBytes) {
char *rawData = static_cast<char *>(malloc(sizeInBytes));
assert(rawData && "failed to allocate memory for stickified data");
memset(rawData, val, sizeInBytes);
DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
stickifiedConstant.getOperation()
->getDialect()
->getNamespace(), // use the dialect as the blob "hint"
HeapAsmResourceBlob::allocateAndCopyWithAlign(
llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
free(rawData);
return valueAttr;
}

/// This function emits a buffer of zero elements for the given dimensions and
/// layout. If the given dimensions are static, then a stickified constant is
/// returned.
Expand All @@ -190,48 +209,18 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));

// Create a ZHighStickifiedConstantOp.

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Set zero in value attribute as DenseResourceElementsAttribute.
// ZHighStickifiedConstantOp stickifiedConstant =
// rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
// /*stickified=*/rewriter.getBoolAttr(true),
// /*value=*/nullptr,
// /*alignment=*/rewriter.getI64IntegerAttr(4096));
//
// // Use an dense resource attribute to store stickified data.
// // Attribute type: tensor<sizeInBytes x i8>
// int64_t sizeInBytes =
// affine::getIntOrFloatMemRefSizeInBytes(resType).value();
// char *rawData = static_cast<char *>(malloc(sizeInBytes));
// assert(rawData && "failed to allocate memory for stickified data");
// memset(rawData, 0, sizeInBytes);
// DenseResourceElementsAttr valueAttr =
// DenseUI8ResourceElementsAttr::get(
// RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
// stickifiedConstant.getOperation()
// ->getDialect()
// ->getNamespace(), // use the dialect as the blob "hint"
// HeapAsmResourceBlob::allocateAndCopyWithAlign(
// llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
// stickifiedConstant.setValueAttr(valueAttr);
// free(rawData);
// #else

// Set zero in value attribute as SplatElementsAttr.
FloatAttr floatZero = rewriter.getFloatAttr(resType.getElementType(), 0.0);
ZHighStickifiedConstantOp stickifiedConstant = rewriter.create<
ZHighStickifiedConstantOp>(loc, resType,
/*stickified=*/rewriter.getBoolAttr(true),
/*value=*/SplatElementsAttr::get(cast<ShapedType>(resType), floatZero),
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN

ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
/*value=*/nullptr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// Use an dense resource attribute to store stickified data.
// Attribute type: tensor<sizeInBytes x i8>
int64_t sizeInBytes =
affine::getIntOrFloatMemRefSizeInBytes(resType).value();
DenseResourceElementsAttr valueAttr = getDenseResourceElementsAttrOfValue(
rewriter, stickifiedConstant, 0, sizeInBytes);
stickifiedConstant.setValueAttr(valueAttr);
res = stickifiedConstant.getResult();
} else {
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
Expand Down Expand Up @@ -706,7 +695,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
};

//===----------------------------------------------------------------------===//
// Lower ZHigh Stickified Constant to ZLow Stickified Constant
// Lower ZHigh Stickified Constant to KrnlGlobal
//===----------------------------------------------------------------------===//

struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
Expand All @@ -719,7 +708,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ZHighStickifiedConstantOp zhighStickifiedConstOp =
ZHighStickifiedConstantOp stickifiedConstOp =
llvm::dyn_cast<ZHighStickifiedConstantOp>(op);

// Convert ZTensor type to MemRefType.
Expand All @@ -733,59 +722,53 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();

// Create ZLowStickifiedConstantOp.
StringAttr layout =
getZTensorLayoutAttr(rewriter, *op->result_type_begin());
// Validate the stickified tensor.
Attribute valueAttr = stickifiedConstOp.getValueAttr();
int64_t sizeInBytes = getMemRefEltSizeInBytes(normalizedType);
sizeInBytes *= normalizedType.getNumElements();
if (auto denseAttr = mlir::dyn_cast_or_null<DenseElementsAttr>(valueAttr)) {
ArrayRef<char> data = denseAttr.getRawData();
if (denseAttr.isSplat()) {
// Constant ztensor's buffer is tensor<sizeInBytes x i8>.
int8_t v = denseAttr.getSplatValue<int8_t>();
// NNPA does not work with a splat buffer.
// Expand the memory buffer for NNPA by using DenseResourceElementsAttr.
valueAttr = getDenseResourceElementsAttrOfValue(
rewriter, stickifiedConstOp, v, sizeInBytes);
} else {
assert(
(data.size() == static_cast<uint64_t>(sizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");
}
} else if (auto resourceAttr =
mlir::dyn_cast_or_null<DenseResourceElementsAttr>(
valueAttr)) {
auto blob = resourceAttr.getRawHandle().getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> data = blob->getData();
assert(
(data.size() == static_cast<uint64_t>(sizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");
} else {
llvm_unreachable("Unsupported ElementsAttr");
}

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Lower to KrnlGlobalOp
// // Get dense resource attribute.
// auto blob = mlir::cast<DenseResourceElementsAttr>(
// zhighStickifiedConstOp.getValue().value())
// .getRawHandle()
// .getBlob();
// assert(blob && "Expecting dense resource with a valid blob");
// ArrayRef<char> data = blob->getData();
// // Validate the stickified tensor.
// int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
// memRefSizeInBytes *= normalizedType.getNumElements();
// assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
// "The stickified tensor's buffer size and MemRef's size
// mismatched");
// // Create a KrnlGlobalOp.
// KrnlGlobalOp constantOp =
// rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
// /*shape=*/
// rewriter.getI64ArrayAttr(normalizedShape),
// /*name=*/
// rewriter.getStringAttr(
// "constant_stickify_" + std::to_string(constantID)),
// /*value=*/zhighStickifiedConstOp.getValueAttr(),
// /*offset=*/nullptr,
// /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #else
ZLowStickifiedConstantOp constantOp =
rewriter.create<ZLowStickifiedConstantOp>(loc,
mlir::cast<MemRefType>(zMemRefType.value),
// Create a KrnlGlobalOp.
KrnlGlobalOp constantGlobal =
rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
/*shape=*/
rewriter.getI64ArrayAttr(normalizedShape),
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(),
/*value=*/zhighStickifiedConstOp.getValueAttr(),
/*layout=*/layout,
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
/*value=*/valueAttr,
/*offset=*/nullptr,
/*alignment=*/stickifiedConstOp.getAlignmentAttr());

// Increment constant ID:
constantID++;

rewriter.replaceOp(op, constantOp.getResult());
rewriter.replaceOp(op, constantGlobal.getResult());
return success();
}
};
Expand Down
1 change: 0 additions & 1 deletion src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ add_onnx_mlir_library(OMZHighOps
OMONNXOps # Use ONNXShapeHelper
OMLayoutHelper
OMShapeHelperOpInterface
OMStickify
OMNNPACompilerOptions
MLIRIR

Expand Down
5 changes: 1 addition & 4 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -862,14 +862,11 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
let summary = "ZHigh Stickified Constant operation";
let description = [{
This operator produces a constant tensor to store stickified data.
`value` attribute has original constant or stickified constant.
`stickified` attribute indicates the `value` is already stickified or not.
Stickified data is opaque and must be 4K-aligned. One who produces
the stickified data must make sure its size in bytes consistent with
the output tensor's size.
}];
let arguments = (ins BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
let arguments = (ins OptionalAttr<AnyAttr>:$value,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs AnyZTensor:$output);
}
Expand Down
Loading