Skip to content

Commit

Permalink
Use getRawBytes to make the data reading faster
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld committed Nov 20, 2024
1 parent 7dea632 commit 6d044be
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
15 changes: 13 additions & 2 deletions src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ static void getRawData(ElementsAttr attr_, std::vector<char> &data) {
"Must be DenseElementsAttr or DisposableElementsAttr");

if (disposalAttr) {
data.resize(numElements * getEltSizeInBytes(elemTy));
disposalAttr.readRawBytes(data);
ArrayBuffer<char> dstBytes = disposalAttr.getRawBytes();
data = dstBytes.get();
return;
}

Expand Down Expand Up @@ -94,6 +94,17 @@ ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter,

// Attribute type: tensor<sizeInBytes x i8>
int64_t sizeInBytes = ztensor->buffer_size;

// DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
// RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
// stickifiedConstant.getOperation()
// ->getDialect()
// ->getNamespace(), // use the dialect as the blob "hint"
// HeapAsmResourceBlob::allocateAndCopyWithAlign(
// llvm::ArrayRef((char *)ztensor->buffer, sizeInBytes),
// alignof(char)));
// allochelper_ztensor_free(ztensor);

RankedTensorType dataType =
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type());
std::unique_ptr<llvm::MemoryBuffer> memBuf = llvm::MemoryBuffer::getMemBuffer(
Expand Down
10 changes: 5 additions & 5 deletions src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,11 @@ class DisposableElementsAttr
template <typename X>
void readArray(MutableArrayRef<X> dst) const;

// Copies out the elements in a flat array in row-major order.
// Returns a pointer to the underlying data as a flat byte array, if
// everything aligns, otherwise makes and returns a copy.
// If the element type is bool the data holds one byte (with value 0 or 1) per
// bool (contrary to how DenseElementsAttr::getRawData() bit packs bools).
void readRawBytes(MutableArrayRef<char> dst) const;
onnx_mlir::ArrayBuffer<char> getRawBytes() const;

// Returns a pointer to the underlying data as a flat WideNum array, if
// everything aligns, otherwise makes and returns a copy.
Expand Down Expand Up @@ -313,11 +314,10 @@ class DisposableElementsAttr
// Warning: This is inefficient because it calls unflattenIndex on flatIndex.
size_t flatIndexToBufferPos(size_t flatIndex) const;

// Returns a pointer to the underlying data as a flat byte array, if
// everything aligns, otherwise makes and returns a copy.
// Copies out the elements in a flat array in row-major order.
// If the element type is bool the data holds one byte (with value 0 or 1) per
// bool (contrary to how DenseElementsAttr::getRawData() bit packs bools).
onnx_mlir::ArrayBuffer<char> getRawBytes() const;
void readRawBytes(MutableArrayRef<char> dst) const;

}; // class DisposableElementsAttr

Expand Down

0 comments on commit 6d044be

Please sign in to comment.