Skip to content

Commit

Permalink
Revert "[NNPA] Memory reduction of stickified constant by stickifying…
Browse files Browse the repository at this point in the history
… at file writing (onnx#2917)"

This reverts commit 33b466e.

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld committed Nov 13, 2024
1 parent fa91033 commit f5a25af
Show file tree
Hide file tree
Showing 37 changed files with 336 additions and 707 deletions.
133 changes: 45 additions & 88 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,47 +190,27 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));

// Create a ZHighStickifiedConstantOp.

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Set zero in value attribute as DenseResourceElementsAttribute.
// ZHighStickifiedConstantOp stickifiedConstant =
// rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
// /*stickified=*/rewriter.getBoolAttr(true),
// /*value=*/nullptr,
// /*alignment=*/rewriter.getI64IntegerAttr(4096));
//
// // Use an dense resource attribute to store stickified data.
// // Attribute type: tensor<sizeInBytes x i8>
// int64_t sizeInBytes =
// affine::getIntOrFloatMemRefSizeInBytes(resType).value();
// char *rawData = static_cast<char *>(malloc(sizeInBytes));
// assert(rawData && "failed to allocate memory for stickified data");
// memset(rawData, 0, sizeInBytes);
// DenseResourceElementsAttr valueAttr =
// DenseUI8ResourceElementsAttr::get(
// RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
// stickifiedConstant.getOperation()
// ->getDialect()
// ->getNamespace(), // use the dialect as the blob "hint"
// HeapAsmResourceBlob::allocateAndCopyWithAlign(
// llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
// stickifiedConstant.setValueAttr(valueAttr);
// free(rawData);
// #else

// Set zero in value attribute as SplatElementsAttr.
FloatAttr floatZero = rewriter.getFloatAttr(resType.getElementType(), 0.0);
ZHighStickifiedConstantOp stickifiedConstant = rewriter.create<
ZHighStickifiedConstantOp>(loc, resType,
/*stickified=*/rewriter.getBoolAttr(true),
/*value=*/SplatElementsAttr::get(cast<ShapedType>(resType), floatZero),
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
/*value=*/nullptr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// Use an dense resource attribute to store stickified data.
// Attribute type: tensor<sizeInBytes x i8>
int64_t sizeInBytes =
affine::getIntOrFloatMemRefSizeInBytes(resType).value();
char *rawData = static_cast<char *>(malloc(sizeInBytes));
assert(rawData && "failed to allocate memory for stickified data");
memset(rawData, 0, sizeInBytes);
DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
stickifiedConstant.getOperation()
->getDialect()
->getNamespace(), // use the dialect as the blob "hint"
HeapAsmResourceBlob::allocateAndCopyWithAlign(
llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
stickifiedConstant.setValueAttr(valueAttr);
free(rawData);

res = stickifiedConstant.getResult();
} else {
Expand Down Expand Up @@ -706,7 +686,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
};

//===----------------------------------------------------------------------===//
// Lower ZHigh Stickified Constant to ZLow Stickified Constant
// Lower ZHigh Stickified Constant to KrnlGlobal
//===----------------------------------------------------------------------===//

struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
Expand All @@ -719,7 +699,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ZHighStickifiedConstantOp zhighStickifiedConstOp =
ZHighStickifiedConstantOp stickifiedConstOp =
llvm::dyn_cast<ZHighStickifiedConstantOp>(op);

// Convert ZTensor type to MemRefType.
Expand All @@ -733,59 +713,36 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();

// Create ZLowStickifiedConstantOp.
StringAttr layout =
getZTensorLayoutAttr(rewriter, *op->result_type_begin());

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Lower to KrnlGlobalOp
// // Get dense resource attribute.
// auto blob = mlir::cast<DenseResourceElementsAttr>(
// zhighStickifiedConstOp.getValue().value())
// .getRawHandle()
// .getBlob();
// assert(blob && "Expecting dense resource with a valid blob");
// ArrayRef<char> data = blob->getData();
// // Validate the stickified tensor.
// int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
// memRefSizeInBytes *= normalizedType.getNumElements();
// assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
// "The stickified tensor's buffer size and MemRef's size
// mismatched");
// // Create a KrnlGlobalOp.
// KrnlGlobalOp constantOp =
// rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
// /*shape=*/
// rewriter.getI64ArrayAttr(normalizedShape),
// /*name=*/
// rewriter.getStringAttr(
// "constant_stickify_" + std::to_string(constantID)),
// /*value=*/zhighStickifiedConstOp.getValueAttr(),
// /*offset=*/nullptr,
// /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #else
ZLowStickifiedConstantOp constantOp =
rewriter.create<ZLowStickifiedConstantOp>(loc,
mlir::cast<MemRefType>(zMemRefType.value),
// Get dense resource attribute.
auto blob = mlir::cast<DenseResourceElementsAttr>(
stickifiedConstOp.getValue().value())
.getRawHandle()
.getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> data = blob->getData();

// Validate the stickified tensor.
int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
memRefSizeInBytes *= normalizedType.getNumElements();
assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");

// Create a KrnlGlobalOp.
KrnlGlobalOp constantGlobal =
rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
/*shape=*/
rewriter.getI64ArrayAttr(normalizedShape),
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(),
/*value=*/zhighStickifiedConstOp.getValueAttr(),
/*layout=*/layout,
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
/*value=*/stickifiedConstOp.getValueAttr(),
/*offset=*/nullptr,
/*alignment=*/stickifiedConstOp.getAlignmentAttr());

// Increment constant ID:
constantID++;

rewriter.replaceOp(op, constantOp.getResult());
rewriter.replaceOp(op, constantGlobal.getResult());
return success();
}
};
Expand Down
1 change: 0 additions & 1 deletion src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ add_onnx_mlir_library(OMZHighOps
OMONNXOps # Use ONNXShapeHelper
OMLayoutHelper
OMShapeHelperOpInterface
OMStickify
OMNNPACompilerOptions
MLIRIR

Expand Down
5 changes: 1 addition & 4 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -862,14 +862,11 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
let summary = "ZHigh Stickified Constant operation";
let description = [{
This operator produces a constant tensor to store stickified data.
`value` attribute has original constant or stickified constant.
`stickified` attribute indicates the `value` is already stickified or not.
Stickified data is opaque and must be 4K-aligned. One who produces
the stickified data must make sure its size in bytes consistent with
the output tensor's size.
}];
let arguments = (ins BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
let arguments = (ins OptionalAttr<AnyAttr>:$value,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs AnyZTensor:$output);
}
Expand Down
51 changes: 1 addition & 50 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
Expand Down Expand Up @@ -481,55 +482,5 @@ IntegerAttr getDefaultSaturation(PatternRewriter &rewriter) {
return IntegerAttr();
}

/// MLIR type to zDNN type.
zdnn_data_types mlirTypeToZDNNType(Type elementType) {
if (mlir::isa<FloatType>(elementType)) {
FloatType floatTy = mlir::cast<FloatType>(elementType);
if (floatTy.getWidth() == 16) {
return FP16;
} else if (floatTy.getWidth() == 32) {
return FP32;
} else
llvm_unreachable("Unsupported data type.");
} else
llvm_unreachable("Unsupported data type.");
}

/// Get stickified data from denseElementAttribute
ArrayRef<char> getStickifiedDataOfDenseElemAttr(
DenseElementsAttr denseAttr, StringAttr layout) {
ArrayRef<int64_t> shape = denseAttr.getType().getShape();
Type elementType = denseAttr.getType().getElementType();
int rank = shape.size();
// Read attributes's raw data.
std::vector<char> attrData;
getRawData(denseAttr, attrData);
// Call stickify.
zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc;
// pre-transformed desc.
zdnn_data_layouts zDNNLayout =
convertLayoutAttrToZDNNDataLayout(rank, layout);
// If zDNNLayout is NHWC, we stickify directly from NCHW.
if (zDNNLayout == ZDNN_NHWC)
zDNNLayout = ZDNN_NCHW;
zdnn_data_types zDNNType = onnx_mlir::zhigh::mlirTypeToZDNNType(elementType);
set_info_pre_transformed_desc(&pre_tfrmd_desc, zDNNLayout, zDNNType, shape);
// transformed desc.
zdnn_status status = generate_transformed_desc(&pre_tfrmd_desc, &tfrmd_desc);
assert(status == ZDNN_OK);
// Stick data using the software stickify.
zdnn_ztensor ztensor;
init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor);
status = allochelper_ztensor_alloc(&ztensor);
assert(status == ZDNN_OK);
status = stickify(&ztensor, attrData.data());
assert(status == ZDNN_OK);
int64_t sizeInBytes = ztensor.buffer_size;
char *rawData = (char *)malloc(sizeInBytes);
memcpy(rawData, ztensor.buffer, sizeInBytes);
allochelper_ztensor_free(&ztensor);
return llvm::ArrayRef(rawData, sizeInBytes);
}

} // namespace zhigh
} // namespace onnx_mlir
8 changes: 0 additions & 8 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp"

namespace onnx_mlir {
namespace zhigh {
Expand Down Expand Up @@ -89,13 +88,6 @@ bool hasNNPAUse(mlir::Value v);
/// Get saturation settings.
mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter);

/// MLIR type to zDNN type.
zdnn_data_types mlirTypeToZDNNType(mlir::Type elementType);

/// Get stickified data from denseElementAttribute
mlir::ArrayRef<char> getStickifiedDataOfDenseElemAttr(
mlir::DenseElementsAttr denseAttr, mlir::StringAttr layout);

} // namespace zhigh
} // namespace onnx_mlir
#endif
5 changes: 0 additions & 5 deletions src/Accelerators/NNPA/Dialect/ZLow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,8 @@ add_onnx_mlir_library(OMZLowOps
DEPENDS
OMZLowIncGen
OMONNXZLowCombineIncGen
OMKrnlGlobalOpInterface

LINK_LIBS PUBLIC
MLIRIR
OMMlirDialects
OMZHighOps

ACCEL_INCLUDE_DIRS PRIVATE
${NNPA_INCLUDE_PATH}
)
17 changes: 0 additions & 17 deletions src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def ZMemRef : MemRefOf<[DLF16]>;
//===----------------------------------------------------------------------===//

include "mlir/Interfaces/SideEffectInterfaces.td"
include "src/Interface/KrnlGlobalOpInterface.td"

def ZLowAddOp:ZLow_Op<"add", [MemRefsNormalizable,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Expand Down Expand Up @@ -548,20 +547,4 @@ def ZLowConvertF32ToDLF16VectorOp:ZLow_Op<"vec_f32_to_dlf16", [Pure]> {
];
}

def ZLowStickifiedConstantOp:ZLow_Op<"stickifiedConstant", [MemRefsNormalizable,
DeclareOpInterfaceMethods<KrnlGlobalOpInterface>]> {
let summary = "ZLow Stickified Constant operation.";
let description = [{

}];
let arguments = (ins AnyAttr:$shape,
StrAttr:$name,
BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
OptionalAttr<StrAttr>:$layout,
OptionalAttr<I64Attr>:$offset,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs ZMemRef:$output);
}

#endif // ZLOW_OPS
Loading

0 comments on commit f5a25af

Please sign in to comment.