Skip to content

Commit

Permalink
Use DisposableElementsAttr for ZHigh Constant Propagation
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld committed Nov 18, 2024
1 parent c99e50d commit 48cf039
Show file tree
Hide file tree
Showing 26 changed files with 489 additions and 271 deletions.
3 changes: 3 additions & 0 deletions docs/AddCustomAccelerators.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ virtual void conversionTargetONNXToKrnl(
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const = 0;

/// Setup PassManager for onnx-mlir-opt.
virtual void setupPassManager(mlir::PassManager &pm) const = 0;

//===--------------------------------------------------------------------===//
// Hooks for krnl-to-llvm pass
//===--------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions src/Accelerators/Accelerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class Accelerator {
/// command line options.
virtual void registerPasses(int optLevel) const = 0;

/// Setup PassManager for onnx-mlir-opt.
virtual void setupPassManager(mlir::PassManager &pm) const = 0;
//===--------------------------------------------------------------------===//
// Hooks for onnx-to-krnl pass
//===--------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_onnx_mlir_library(OMNNPACompilerOptions

add_onnx_mlir_library(OMNNPACompilerUtils
NNPACompilerUtils.cpp
ZHighDisposableGarbageCollector.cpp

EXCLUDE_FROM_OM_LIBS

Expand Down
11 changes: 7 additions & 4 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
#include "src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp"
#include "src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
Expand Down Expand Up @@ -120,10 +121,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}

// Replace every DisposableElementsAttr with DenseElementsAttr.
// ZHighConstPropagation currently assumes that DenseElementsAttr is used.
pm.addPass(createScrubDisposablePass());

// Experimental feature: Decompose stick/unstick into two phases: layout
// transform and data conversion. Do some optimizations after decomposing.
// Then, recompose again layout and data conversion if they are not optimized.
Expand Down Expand Up @@ -155,6 +152,9 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
// Clean dead code.
pm.addPass(mlir::createSymbolDCEPass());

// Replace every DisposableElementsAttr with DenseElementsAttr.
pm.addPass(onnx_mlir::zhigh::createZHighScrubDisposablePass());

// Insert an instrumentation after lowering onnx to zhigh to get profiling
// for onnx and zhigh ops.
// Keep this pass at the end of this function.
Expand Down Expand Up @@ -195,6 +195,9 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,

// LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;);
if (emissionTarget >= EmitONNXIR) {
pm.addInstrumentation(
std::make_unique<onnx_mlir::zhigh::ZHighDisposableGarbageCollector>(
pm.getContext()));
addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty(),
/*donotScrubDisposableElementsAttr*/ true);
pm.addPass(onnx_mlir::createDevicePlacementPass(nnpaLoadDevicePlacementFile,
Expand Down
43 changes: 43 additions & 0 deletions src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ZHighDisposableGarbageCollector.cpp -----------------===//
//
// Garbage collects DisposableElementsAttr attributes.
//
//===----------------------------------------------------------------------===//

#include "src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp"
#include "src/Dialect/ONNX/ONNXDialect.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

#include "mlir/IR/BuiltinOps.h"

using namespace mlir;

namespace onnx_mlir {
namespace zhigh {

ZHighDisposableGarbageCollector::ZHighDisposableGarbageCollector(
MLIRContext *context)
: disposablePool(*DisposablePool::get<ONNXDialect>(context)) {}

ZHighDisposableGarbageCollector::~ZHighDisposableGarbageCollector() {}

void ZHighDisposableGarbageCollector::runAfterPass(Pass *pass, Operation *op) {
if (!disposablePool.isActive())
return;
ModuleOp moduleOp = mlir::dyn_cast<ModuleOp>(op);
if (!moduleOp)
return;
disposablePool.garbageCollectUnreachable(
moduleOp, {{ONNXConstantOp::getOperationName(), "value"},
{ONNXConstantOfShapeOp::getOperationName(), "value"},
{ZHighStickifiedConstantOp::getOperationName(), "value"}});
}

} // namespace zhigh
} // namespace onnx_mlir
37 changes: 37 additions & 0 deletions src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ZHighDisposableGarbageCollector.hpp -----------------===//
//
// Garbage collects DisposableElementsAttr attributes.
//
//===----------------------------------------------------------------------===//

#ifndef ONNX_MLIR_ZHIGH_GARBAGE_COLLECTOR_H
#define ONNX_MLIR_ZHIGH_GARBAGE_COLLECTOR_H

#include "mlir/Pass/PassInstrumentation.h"

namespace mlir {
class MLIRContext;
}

namespace onnx_mlir {
class DisposablePool;

namespace zhigh {

struct ZHighDisposableGarbageCollector : public mlir::PassInstrumentation {
ZHighDisposableGarbageCollector(mlir::MLIRContext *context);
~ZHighDisposableGarbageCollector() override;

void runAfterPass(mlir::Pass *pass, mlir::Operation *op) override;

private:
DisposablePool &disposablePool;
};

} // namespace zhigh
} // namespace onnx_mlir
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ bool isF32ScalarConstantTensor(Value v) {
FloatAttr getScalarF32AttrFromConstant(Value v) {
if (!isF32ScalarConstantTensor(v))
return nullptr;
DenseElementsAttr constElements = ElementsAttrBuilder::toDenseElementsAttr(
getElementAttributeFromONNXValue(v));
ElementsAttr constElements = getElementAttributeFromONNXValue(v);
return constElements.getSplatValue<FloatAttr>();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===---------- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh
//---------===//
//===---- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh ---------===//
//
// Copyright 2019-2024 The IBM Research Authors.
//
Expand Down Expand Up @@ -117,4 +116,4 @@ mlir::Value getDynShape(
mlir::Location loc, mlir::PatternRewriter &rewriter, mlir::Value x);

} // namespace onnx_mlir
#endif
#endif
77 changes: 52 additions & 25 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,25 @@ static Value insertAllocForWorkAreaForRNNOps(IndexExprBuilderForKrnl &createIE,
return create.mem.alignedAlloc(resultType, dims, gAlignment);
}

/// Get an dense resource attribute to store stickified data of zeros.
/// Attribute type: tensor<sizeInBytes x i8>
DenseResourceElementsAttr getDenseResourceElementsAttrOfZero(
PatternRewriter &rewriter, ZHighStickifiedConstantOp stickifiedConstant,
int64_t sizeInBytes) {
char *rawData = static_cast<char *>(malloc(sizeInBytes));
assert(rawData && "failed to allocate memory for stickified data");
memset(rawData, 0, sizeInBytes);
DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
stickifiedConstant.getOperation()
->getDialect()
->getNamespace(), // use the dialect as the blob "hint"
HeapAsmResourceBlob::allocateAndCopyWithAlign(
llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
free(rawData);
return valueAttr;
}

/// This function emits a buffer of zero elements for the given dimensions and
/// layout. If the given dimensions are static, then a stickified constant is
/// returned.
Expand Down Expand Up @@ -199,19 +218,9 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
// Attribute type: tensor<sizeInBytes x i8>
int64_t sizeInBytes =
affine::getIntOrFloatMemRefSizeInBytes(resType).value();
char *rawData = static_cast<char *>(malloc(sizeInBytes));
assert(rawData && "failed to allocate memory for stickified data");
memset(rawData, 0, sizeInBytes);
DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
stickifiedConstant.getOperation()
->getDialect()
->getNamespace(), // use the dialect as the blob "hint"
HeapAsmResourceBlob::allocateAndCopyWithAlign(
llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
DenseResourceElementsAttr valueAttr = getDenseResourceElementsAttrOfZero(
rewriter, stickifiedConstant, sizeInBytes);
stickifiedConstant.setValueAttr(valueAttr);
free(rawData);

res = stickifiedConstant.getResult();
} else {
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
Expand Down Expand Up @@ -713,19 +722,37 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();

// Get dense resource attribute.
auto blob = mlir::cast<DenseResourceElementsAttr>(
stickifiedConstOp.getValue().value())
.getRawHandle()
.getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> data = blob->getData();

// Validate the stickified tensor.
int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
memRefSizeInBytes *= normalizedType.getNumElements();
assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");
Attribute valueAttr = stickifiedConstOp.getValueAttr();
int64_t sizeInBytes = getMemRefEltSizeInBytes(normalizedType);
sizeInBytes *= normalizedType.getNumElements();
if (auto denseAttr = mlir::dyn_cast_or_null<DenseElementsAttr>(valueAttr)) {
ArrayRef<char> data = denseAttr.getRawData();
if (denseAttr.isSplat()) {
// Constant ztensor's buffer is tensor<sizeInBytes x i8>.
int8_t v = denseAttr.getSplatValue<int8_t>();
assert(v == 0 && "Cannot be a non-zero splat value");
// NNPA does not work with a splat buffer.
// Expand the memory buffer for NNPA by using DenseResourceElementsAttr.
valueAttr = getDenseResourceElementsAttrOfZero(
rewriter, stickifiedConstOp, sizeInBytes);
} else {
assert(
(data.size() == static_cast<uint64_t>(sizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");
}
} else if (auto resourceAttr =
mlir::dyn_cast_or_null<DenseResourceElementsAttr>(
valueAttr)) {
auto blob = resourceAttr.getRawHandle().getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> data = blob->getData();
assert(
(data.size() == static_cast<uint64_t>(sizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");
} else {
llvm_unreachable("Unsupported ElementsAttr");
}

// Create a KrnlGlobalOp.
KrnlGlobalOp constantGlobal =
Expand All @@ -735,7 +762,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*value=*/stickifiedConstOp.getValueAttr(),
/*value=*/valueAttr,
/*offset=*/nullptr,
/*alignment=*/stickifiedConstOp.getAlignmentAttr());

Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_onnx_mlir_library(OMZHighOps
ZHighOps/Stick/Stick.cpp
ZHighOps/StickForGRU/StickForGRU.cpp
ZHighOps/StickForLSTM/StickForLSTM.cpp
ZHighOps/StickifiedConstant/StickifiedConstant.cpp
ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp
ZHighOps/Unstick/Unstick.cpp

Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
let arguments = (ins OptionalAttr<AnyAttr>:$value,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs AnyZTensor:$output);
let hasCustomAssemblyFormat = 1;
}

def ZHighStickifiedConstantOfShapeOp:ZHigh_Op<"StickifiedConstantOfShape", [Pure,
Expand Down
Loading

0 comments on commit 48cf039

Please sign in to comment.