Skip to content

Commit

Permalink
[MLIR][LLVM][Mem2Reg] Extends support for partial stores (#89740)
Browse files Browse the repository at this point in the history
This commit enhances the LLVM dialect's Mem2Reg interfaces to support
partial stores to memory slots. To achieve this support, the `getStored`
interface method has to be extended with a parameter of the reaching
definition, which is now necessary to produce the resulting value after
this store.
  • Loading branch information
Dinistro authored Apr 24, 2024
1 parent 9f2a068 commit 6e9ea6e
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 87 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
"::mlir::RewriterBase &":$rewriter,
"::mlir::Value":$reachingDef,
"const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
Expand Down
214 changes: 159 additions & 55 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }

Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
const DataLayout &dataLayout) {
Value reachingDef, const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}

Expand Down Expand Up @@ -142,23 +142,29 @@ static bool isSupportedTypeForConversion(Type type) {
}

/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
/// truncations.
/// truncations. Checks for narrowing or widening conversion compatibility
/// depending on `narrowingConversion`.
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
Type srcType) {
Type srcType, bool narrowingConversion) {
if (targetType == srcType)
return true;

if (!isSupportedTypeForConversion(targetType) ||
!isSupportedTypeForConversion(srcType))
return false;

uint64_t targetSize = layout.getTypeSize(targetType);
uint64_t srcSize = layout.getTypeSize(srcType);

// Pointer casts will only be sane when the bitsize of both pointer types is
// the same.
if (isa<LLVM::LLVMPointerType>(targetType) &&
isa<LLVM::LLVMPointerType>(srcType))
return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
return targetSize == srcSize;

return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
if (narrowingConversion)
return targetSize <= srcSize;
return targetSize >= srcSize;
}

/// Checks if `dataLayout` describes a little endian layout.
Expand All @@ -167,22 +173,49 @@ static bool isBigEndian(const DataLayout &dataLayout) {
return endiannessStr && endiannessStr == "big";
}

/// The size of a byte in bits.
constexpr const static uint64_t kBitsInByte = 8;
/// Converts a value to an integer type of the same size.
/// Assumes that the type can be converted.
static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
const DataLayout &dataLayout) {
Type type = val.getType();
assert(isSupportedTypeForConversion(type) &&
"expected value to have a convertible type");

if (isa<IntegerType>(type))
return val;

uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);

if (isa<LLVM::LLVMPointerType>(type))
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
}

/// Converts a value with an integer type to `targetType`.
static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
Value val, Type targetType) {
assert(isa<IntegerType>(val.getType()) &&
"expected value to have an integer type");
assert(isSupportedTypeForConversion(targetType) &&
"expected the target type to be supported for conversions");
if (val.getType() == targetType)
return val;
if (isa<LLVM::LLVMPointerType>(targetType))
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
}

/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
Value srcValue, Type targetType,
const DataLayout &dataLayout) {
// Get the types of the source and target values.
/// Constructs operations that convert `srcValue` into a new value of type
/// `targetType`. Assumes the types have the same bitsize.
static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
Value srcValue, Type targetType,
const DataLayout &dataLayout) {
Type srcType = srcValue.getType();
assert(areConversionCompatible(dataLayout, targetType, srcType) &&
assert(areConversionCompatible(dataLayout, targetType, srcType,
/*narrowingConversion=*/true) &&
"expected that the compatibility was checked before");

uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);

// Nothing has to be done if the types are already the same.
if (srcType == targetType)
return srcValue;
Expand All @@ -196,48 +229,117 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
srcValue);

IntegerType valueSizeInteger =
rewriter.getIntegerType(srcTypeSize * kBitsInByte);
Value replacement = srcValue;
// For all other castable types, casting through integers is necessary.
Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
}

/// Constructs operations that convert `srcValue` into a new value of type
/// `targetType`. Performs bit-level extraction if the source type is larger
/// than the target type. Assumes that this conversion is possible.
static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
Value srcValue, Type targetType,
const DataLayout &dataLayout) {
// Get the types of the source and target values.
Type srcType = srcValue.getType();
assert(areConversionCompatible(dataLayout, targetType, srcType,
/*narrowingConversion=*/true) &&
"expected that the compatibility was checked before");

uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
if (srcTypeSize == targetTypeSize)
return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);

// First, cast the value to a same-sized integer type.
if (isa<LLVM::LLVMPointerType>(srcType))
replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
replacement);
else if (replacement.getType() != valueSizeInteger)
replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
replacement);
Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);

// Truncate the integer if the size of the target is less than the value.
if (targetTypeSize != srcTypeSize) {
if (isBigEndian(dataLayout)) {
uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
replacement =
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
}

replacement = rewriter.create<LLVM::TruncOp>(
loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
replacement);
if (isBigEndian(dataLayout)) {
uint64_t shiftAmount = srcTypeSize - targetTypeSize;
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
replacement =
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
}

replacement = rewriter.create<LLVM::TruncOp>(
loc, rewriter.getIntegerType(targetTypeSize), replacement);

// Now cast the integer to the actual target type if required.
if (isa<LLVM::LLVMPointerType>(targetType))
replacement =
rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
else if (replacement.getType() != targetType)
replacement =
rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
}

/// Constructs operations that insert the bits of `srcValue` into the
/// "beginning" of `reachingDef` (beginning is endianness dependent).
/// Assumes that this conversion is possible.
static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
Value srcValue, Value reachingDef,
const DataLayout &dataLayout) {

assert(areConversionCompatible(dataLayout, reachingDef.getType(),
srcValue.getType(),
/*narrowingConversion=*/false) &&
"expected that the compatibility was checked before");
uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
if (slotTypeSize == valueTypeSize)
return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
dataLayout);

// In the case where the store only overwrites parts of the memory,
// bit fiddling is required to construct the new value.

// First convert both values to integers of the same size.
Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
// Extend the value to the size of the reaching definition.
valueAsInt =
rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
uint64_t sizeDifference = slotTypeSize - valueTypeSize;
if (isBigEndian(dataLayout)) {
// On big endian systems, a store to the base pointer overwrites the most
// significant bits. To accomodate for this, the stored value needs to be
// shifted into the according position.
Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
valueAsInt =
rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
}

// Construct the mask that is used to erase the bits that are overwritten by
// the store.
APInt maskValue;
if (isBigEndian(dataLayout)) {
// Build a mask that has the most significant bits set to zero.
// Note: This is the same as 2^sizeDifference - 1
maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
} else {
// Build a mask that has the least significant bits set to zero.
// Note: This is the same as -(2^valueTypeSize)
maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
maskValue.flipAllBits();
}

// Mask out the affected bits ...
Value mask = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);

// ... and combine the result with the new value.
Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);

return replacement;
return castIntValueToSameSizedType(rewriter, loc, combined,
reachingDef.getType());
}

Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
Value reachingDef,
const DataLayout &dataLayout) {
return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
dataLayout);
assert(reachingDef && reachingDef.getType() == slot.elemType &&
"expected the reaching definition's type to match the slot's type");
return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
dataLayout);
}

bool LLVM::LoadOp::canUsesBeRemoved(
Expand All @@ -249,11 +351,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
Value blockingUse = (*blockingUses.begin())->get();
// If the blocking use is the slot ptr itself, there will be enough
// context to reconstruct the result of the load at removal time, so it can
// be removed (provided it loads the exact stored value and is not
// volatile).
// be removed (provided it is not volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
areConversionCompatible(dataLayout, getResult().getType(),
slot.elemType) &&
slot.elemType, /*narrowingConversion=*/true) &&
!getVolatile_();
}

Expand All @@ -263,9 +364,8 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
Value newResult =
createConversionSequence(rewriter, getLoc(), reachingDefinition,
getResult().getType(), dataLayout);
Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
getResult().getType(), dataLayout);
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
Expand All @@ -283,7 +383,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
areConversionCompatible(dataLayout, slot.elemType,
getValue().getType()) &&
getValue().getType(),
/*narrowingConversion=*/false) &&
!getVolatile_();
}

Expand Down Expand Up @@ -838,6 +939,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
}

Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
Value reachingDef,
const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
Expand Down Expand Up @@ -1149,6 +1251,7 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
}

Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
Expand Down Expand Up @@ -1199,7 +1302,7 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}

Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter,
RewriterBase &rewriter, Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
Expand Down Expand Up @@ -1252,6 +1355,7 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
}

Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }

Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
Value reachingDef,
const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
Expand Down Expand Up @@ -242,6 +243,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
}

Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
Value reachingDef,
const DataLayout &dataLayout) {
return getValue();
}
Expand Down
Loading

0 comments on commit 6e9ea6e

Please sign in to comment.