Skip to content

Commit

Permalink
update according to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Dec 19, 2024
1 parent 0572ec3 commit 3099d0b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ struct ConvertTensorImportOp
RankedTensorType tensorType,
ValueRange dynamicDims,
OpBuilder &builder) {
// If the encoding attr is about packed storage then we don't need all this
// If the encoding attr is about packed storage then we don't need
// assertion, because packed storage attribute is about memory layout and it
// doesn't affect the tensor shape.
if (IREE::Encoding::hasPackedStorageAttr(tensorType)) {
return success();
}
Expand Down
45 changes: 24 additions & 21 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ llvm::cl::opt<bool> clEnableI1Support(

namespace mlir::iree_compiler {

bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
return needToPackSubByteElementBitWidth(
bitWidth, /*isPackedStorage=*/clEnableI1Support);
}

bool needToPackSubByteElementBitWidth(unsigned bitWidth, bool isPackedStorage) {
static bool needToPackSubByteElementBitWidthImpl(unsigned bitWidth,
bool isPackedStorage) {
// Enable i1 support if requested.
if (isPackedStorage && bitWidth == 1) {
return true;
Expand All @@ -46,30 +42,31 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth, bool isPackedStorage) {
return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1;
}

bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
return needToPackSubByteElementBitWidthImpl(
bitWidth, /*isPackedStorage=*/clEnableI1Support);
}

bool needToPackSubByteElements(RankedTensorType shapedType) {
unsigned bitWidth = IREE::Util::getTypeBitWidth(shapedType.getElementType());
// Two paths to enable packed storage for i1 tensors: the attribute or cl
// option. The cl option will be dropped once frontend supports emitting
// tensors with attributes.
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
return needToPackSubByteElementBitWidth(bitWidth, isPackedStorage);
return needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage);
}

Type legalizeStorageElementType(Type elementType) {
return legalizeStorageElementType(elementType,
/*isPackedStorage=*/clEnableI1Support);
}

Type legalizeStorageElementType(Type elementType, bool isPackedStorage) {
static Type legalizeStorageElementTypeImpl(Type elementType,
bool isPackedStorage) {
// Only handle integers; floats in MLIR all have aligned widths (today).
auto intType = dyn_cast<IntegerType>(elementType);
if (!intType)
return elementType;

// For sub-byte elements, default to pack them into bytes.
unsigned bitWidth = intType.getWidth();
if (needToPackSubByteElementBitWidth(bitWidth, isPackedStorage))
if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage))
return elementType;

// Otherwise, extend them to the next power-of-two bit width.
Expand All @@ -81,6 +78,12 @@ Type legalizeStorageElementType(Type elementType, bool isPackedStorage) {
intType.getSignedness());
}

Type legalizeStorageElementType(Type elementType) {
// Consider packed storage for i1 tensors if cl opt is set.
return legalizeStorageElementTypeImpl(elementType,
/*isPackedStorage=*/clEnableI1Support);
}

Value calculateStorageElementCountInBytes(Location loc,
RankedTensorType shapedType,
ValueRange dynamicDims,
Expand All @@ -93,16 +96,16 @@ Value calculateStorageElementCountInBytes(Location loc,
loc, builder, shapedType, dynamicDims);
}

// TODO: remove cl options once frontend can emit packed i1 tensors.
// TODO(lialan): remove cl options once frontend can emit packed i1 tensors.
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
Type alignedElementType =
legalizeStorageElementType(shapedType.getElementType(), isPackedStorage);
Type alignedElementType = legalizeStorageElementTypeImpl(
shapedType.getElementType(), isPackedStorage);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

// Calculate all static dims first, if any.
int64_t staticCount = 1;
if (!needToPackSubByteElementBitWidth(elementBits, isPackedStorage)) {
if (!needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
}

Expand All @@ -117,7 +120,7 @@ Value calculateStorageElementCountInBytes(Location loc,
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
}
// Sub-byte packing requires putting multiple elements in the same byte.
if (needToPackSubByteElementBitWidth(elementBits, isPackedStorage)) {
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
assert(8 % elementBits == 0);
unsigned byteElements = 8 / elementBits;
// TODO(antiagainst): We may want to emit runtime check to make sure this is
Expand All @@ -140,12 +143,12 @@ Value calculateStorageElementOffsetInBytes(Location loc,
// TODO: remove cl options once frontend can emit packed i1 tensors.
bool isPackedStorage =
IREE::Encoding::hasPackedStorageAttr(originalType) || clEnableI1Support;
Type alignedElementType = legalizeStorageElementType(
Type alignedElementType = legalizeStorageElementTypeImpl(
originalType.getElementType(), isPackedStorage);
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);

// Sub-byte packing requires putting multiple elements in the same byte.
if (needToPackSubByteElementBitWidth(elementBits, isPackedStorage)) {
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
Value byteElements =
builder.create<arith::ConstantIndexOp>(loc, 8 / elementBits);
// TODO(antiagainst): We may want to emit runtime check to make sure this is
Expand Down
10 changes: 0 additions & 10 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ namespace mlir::iree_compiler {

/// Returns true if the given |bitWidth|, if appearing at runtime-kernel
/// interface, is less than a byte that should be tightly packed together.
bool needToPackSubByteElementBitWidth(unsigned bitWidth, bool isPackedStorage);

/// Temporary wrapper for the above function. `isPackedStorage` will be
/// determined by the cl option. This allows enabling packed storage for i1
/// in both attribute and cl option ways.
bool needToPackSubByteElementBitWidth(unsigned bitWidth);

/// Returns true if the given |shapedType|, if appearing at runtime-kernel
Expand All @@ -33,11 +28,6 @@ bool needToPackSubByteElements(RankedTensorType shapedType);
/// runtime and kernel. For such cases, we perform tight packing for supported
/// sub-byte elements, and expand to the next power-of-two bit width for other
/// cases.
Type legalizeStorageElementType(Type elementType, bool isPackedStorage);

/// Temporary wrapper for the above function. `isPackedStorage` will be
/// determined by the cl option. This allows enabling packed storage for i1
/// in both attribute and cl option ways.
Type legalizeStorageElementType(Type elementType);

/// Emits IR with the given |builder| to calculate the total number of bytes
Expand Down

0 comments on commit 3099d0b

Please sign in to comment.