Skip to content

Commit

Permalink
Have code, need some canonicalizers, definitely haven't updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysz00 committed Nov 18, 2024
1 parent 5328767 commit 829c3d5
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,6 @@ namespace mlir::iree_compiler {
using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

/// Helper to linearize the given |ids| with maximum values given as |sizes|.
/// Gets the element ID in terms of |elementCount| and adds the element
/// |offset|. For example,
///
/// IDs = [d0, d1, d2, d3]
/// sizes = [s0, s1, s2, s3]
/// linear_index = d0 * (s1 * s2 * s3)
/// + d1 * (s2 * s3)
/// + d2 * (s3)
/// + d3
/// return element_index = linear_index * |elementCount| + |offset|;
static Value linearizeIndex(OpBuilder &builder, Value offset,
ArrayRef<OpFoldResult> ids, ArrayRef<int64_t> sizes,
int64_t elementCount) {
SmallVector<AffineExpr> exprs(ids.size() + 1);
bindSymbolsList(builder.getContext(), MutableArrayRef{exprs});
AffineExpr idExpr = builder.getAffineConstantExpr(0);

for (int i = 0, e = ids.size(); i < e; ++i) {
if (sizes[i] > 1) {
// Multiply by the residual threads along this dimension (which must be
// faster changing than all previous dimensions) and add the id for this
// dimension.
idExpr = idExpr * builder.getAffineConstantExpr(sizes[i]) + exprs[i];
}
}
idExpr = idExpr * builder.getAffineConstantExpr(elementCount);
idExpr = idExpr + exprs.back();
SmallVector<OpFoldResult> mapArgs(ids);
mapArgs.push_back(offset);
return affine::makeComposedAffineApply(
builder, offset.getLoc(),
AffineMap::get(0, mapArgs.size(), idExpr), mapArgs)
.getResult();
}

/// Given a set of base transfer |indices|, |offsets| for the batch/outer
/// dimensions, and distributed warp and thread indices, computes the indices
/// of the distributed transfer operation based on the |vectorLayout|.
Expand All @@ -94,16 +58,24 @@ static SmallVector<Value> getTransferIndicesFromNestedLayout(
continue;
}
unsigned pos = cast<AffineDimExpr>(dim).getPosition();
SmallVector<OpFoldResult> ids = {
warpIndices[i], b.getIndexAttr(batchOffsets[i]),
b.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
Value offset = indices[pos];
int64_t elementCount = vectorLayout.getElementTile()[i];
Location loc = offset.getLoc();
SmallVector<Value> ids = {
warpIndices[i], b.create<arith::ConstantIndexOp>(loc, batchOffsets[i]),
b.create<arith::ConstantIndexOp>(loc, outerVectorOffsets[i]),
threadIndices[i], offset};
// The order in which a vector dimension is "tiled" is
// subgroups -> batches -> outer vectors -> threads -> elements
SmallVector<int64_t> sizes = {
vectorLayout.getSubgroupTile()[i], vectorLayout.getBatchTile()[i],
vectorLayout.getOuterTile()[i], vectorLayout.getThreadTile()[i]};
slicedIndices[pos] = linearizeIndex(b, indices[pos], ids, sizes,
vectorLayout.getElementTile()[i]);
vectorLayout.getOuterTile()[i], vectorLayout.getThreadTile()[i],
elementCount};
// The offset is often tot an offset within the thread ID. Fixing this
// to allow for linearize everywhere might be nice to do in the future.
// For now, mark this not disjoint so we don't misoptimize.
slicedIndices[pos] = b.create<affine::AffineLinearizeIndexOp>(
loc, ids, sizes, /*disjoint=*/false);
}
return slicedIndices;
}
Expand All @@ -123,19 +95,21 @@ getElementVectorTileShape(NestedLayoutAttr vectorLayout) {

/// Computes the warp and thread indices for the given vector layout from a
/// single linearized thread ID.
static void populateWarpAndThreadIndices(RewriterBase &rewriter, Value threadId,
int64_t subgroupSize,
NestedLayoutAttr vectorLayout,
SmallVector<Value> &warpIndices,
SmallVector<Value> &threadIndices) {
static LogicalResult populateWarpAndThreadIndices(
RewriterBase &rewriter, Value threadId, int64_t subgroupSize,
NestedLayoutAttr vectorLayout, SmallVector<Value> &warpIndices,
SmallVector<Value> &threadIndices) {
// The delinearized thread IDs are returned from outer most to inner most,
// i.e. before applying the layout described dimensions ordering.
int64_t rank = vectorLayout.getRank();
SmallVector<Value> threadIds =
vectorLayout.computeThreadIds(threadId, subgroupSize, rewriter);
if (threadIds.empty() && rank != 0)
return failure();
warpIndices = SmallVector<Value>(threadIds.begin(), threadIds.begin() + rank);
threadIndices = SmallVector<Value>(threadIds.begin() + rank,
threadIds.begin() + 2 * rank);
return success();
}

namespace {
Expand Down Expand Up @@ -189,8 +163,12 @@ struct DistributeTransferRead final
VectorValue acc = cast<VectorValue>(zero);

SmallVector<Value> warpIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, vectorLayout,
warpIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
vectorLayout, warpIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
readOp, "warp or thread tiles have overlapping strides");
}

ValueRange indices = readOp.getIndices();
SmallVector<int64_t> strides(rank, 1);
Expand Down Expand Up @@ -259,8 +237,12 @@ struct DistributeTransferWrite final
int64_t rank = vectorLayout.getRank();

SmallVector<Value> warpIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, vectorLayout,
warpIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
vectorLayout, warpIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
writeOp, "warp or thread tiles have overlapping strides");
}

Value distributedVector =
getDistributed(rewriter, writeOp.getVector(), vectorLayout);
Expand Down Expand Up @@ -1089,8 +1071,12 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
stepOp, "missing nested layout for step op result");
}
SmallVector<Value> subgroupIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout,
subgroupIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
resultLayout, subgroupIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
stepOp, "warp or thread tiles have overlapping strides");
}
ArrayRef<int64_t> subgroupStrides = resultLayout.getSubgroupStrides();
ArrayRef<int64_t> subgroupLengths = resultLayout.getSubgroupTile();
ArrayRef<int64_t> threadStrides = resultLayout.getThreadStrides();
Expand Down
38 changes: 23 additions & 15 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
#include "iree/compiler/Utils/Indexing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -469,28 +470,35 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(

OpFoldResult zero = builder.getIndexAttr(0);
OpFoldResult one = builder.getIndexAttr(1);
Value cZero = builder.createOrFold<arith::ConstantIndexOp>(loc, 0);
canonicalStrides.append(rankReducedShape.size(), one);

SmallVector<Value> vtids;
SmallVector<int64_t> vtidBasis;
SmallVector<size_t> dimToVtid;
if (failed(basisFromSizesStrides(subgroupLayout.thread,
subgroupLayout.tstrides, vtidBasis,
dimToVtid))) {
return failure();
}
auto splitLaneId = builder.create<affine::AffineDelinearizeIndexOp>(
loc, laneId, vtidBasis, /*hasOuterBound=*/false);

// Each thread grabs `element` contiguous data, so the vtid needs to be
// multiplied by `element` to get the next bunch of data.
// vtid: virtual thread id
// tid: lane id
// vtid = ((tid floordiv stride_i) mod size_i) * element_i.
SmallVector<OpFoldResult> vtids;
for (auto [dimSize, dimStride, element] :
llvm::zip_equal(subgroupLayout.thread, subgroupLayout.tstrides,
subgroupLayout.element)) {
if (dimSize == 1) {
vtids.push_back(zero);
continue;
}

// ((tid floordiv stride) mod size) * element.
AffineExpr tidExpr = builder.getAffineDimExpr(0);
AffineMap vtidMap = AffineMap::get(
/*dims=*/1, /*syms=*/0,
(tidExpr.floorDiv(dimStride) % dimSize) * element);
Value vtid = builder.create<affine::AffineApplyOp>(loc, vtidMap, laneId);
//
// Instead of computing those maps, we use one big `delinearize` expression
// in order to prevent unwanted "simplifications" on affine maps that
// worsen the generated code quality.
for (auto [splitResultIdx, element] :
llvm::zip_equal(dimToVtid, subgroupLayout.element)) {
Value vtid = splitLaneId.getResult(splitResultIdx);
if (element != 1)
vtid = builder.create<affine::AffineLinearizeIndexOp>(
loc, ValueRange{vtid, cZero}, ArrayRef<int64_t>{element});
vtids.push_back(vtid);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
#include "iree/compiler/Utils/Indexing.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -436,51 +437,28 @@ NestedLayoutAttr::computeThreadIds(Value threadId, int64_t subgroupSize,

Location loc = threadId.getLoc();

AffineExpr tidExpr, size, stride;
bindDims(rewriter.getContext(), tidExpr);
bindSymbols(rewriter.getContext(), size, stride);

// (tid floordiv stride) mod size
AffineMap threadTidMap =
AffineMap::get(/*dims=*/1, /*syms=*/2, tidExpr.floorDiv(stride) % size);

// (tid floordiv (stride * subgroup_size)) mod size
AffineMap subgroupTidMap = AffineMap::get(
/*dims=*/1, /*syms=*/2, tidExpr.floorDiv(stride * subgroupSize) % size);

for (auto [dimSize, dimStride] :
llvm::zip_equal(getSubgroupTile(), getSubgroupStrides())) {
// Dimension is not distributed.
if (dimStride == 0) {
virtualTids.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dimStride)));
continue;
}
SmallVector<int64_t> subgroupBasis, threadBasis;
SmallVector<size_t> subgroupDimToResult, threadDimToResult;

auto sizeVal =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dimSize));
auto strideVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dimStride));
virtualTids.push_back(rewriter.create<affine::AffineApplyOp>(
loc, subgroupTidMap, ValueRange{threadId, sizeVal, strideVal}));
}
if (failed(basisFromSizesStrides(getSubgroupTile(), getSubgroupStrides(),
subgroupBasis, subgroupDimToResult)))
return {};
if (failed(basisFromSizesStrides(getThreadTile(), getThreadStrides(),
threadBasis, threadDimToResult)))
return {};

for (auto [dimSize, dimStride] :
llvm::zip_equal(getThreadTile(), getThreadStrides())) {
// Dimension is not distributed.
if (dimStride == 0) {
virtualTids.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dimStride)));
continue;
}
// Add the subgroup_size to the end of the subgroup delinearization basis.
subgroupBasis.push_back(subgroupSize);

auto sizeVal =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dimSize));
auto strideVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dimStride));
virtualTids.push_back(rewriter.create<affine::AffineApplyOp>(
loc, threadTidMap, ValueRange{threadId, sizeVal, strideVal}));
}
auto subgroupSplit = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, threadId, subgroupBasis, /*hasOuterBound=*/false);
auto threadSplit = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, threadId, threadBasis, /*hasOuterBound=*/false);

llvm::transform(subgroupDimToResult, std::back_inserter(virtualTids),
[&](size_t idx) { return subgroupSplit.getResult(idx); });
llvm::transform(threadDimToResult, std::back_inserter(virtualTids),
[&](size_t idx) { return threadSplit.getResult(idx); });

return virtualTids;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
The subgroups are placed contiguously with their shape and ordering
determined by:
- `subgroup_tile`: Sizes of this level of tiling
- `subgroup_order`: Ordering of dimensions, from outermost to innermost
- `subgroup_strides`: Stride of this level of tiling. 0 if not distributed.
Tiling levels must not overlap.

For example, subgroup_tile=[4, 2], subgroup_order=[1, 0] will
arrange the subgroups in the order:
Expand Down Expand Up @@ -196,7 +197,9 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
distribution is represented by:

- thread_tile: Sizes of this level of tiling
- thread_order: Ordering of dimensions, from outermost to innermost
- thread_strides: Strides of this level of tiling. 0 means this level is not
distributed.
Tiling levels must not overlap.

Examples of thread distribution over a 8x4 shape:

Expand Down Expand Up @@ -290,7 +293,7 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",

let extraClassDeclaration = [{
// Returns the subgroup/lane ids delinearized from a single linearized
// thread ID.
// thread ID. Returns the empty vector on failure.
SmallVector<Value> computeThreadIds(Value threadId, int64_t subgroupSize, RewriterBase &rewriter) const;
}];

Expand Down
9 changes: 5 additions & 4 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ getSubgroupIdsAndCounts(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value subgroupId =
builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]);
if (i == 0) {
mlir::AffineExpr d0 = builder.getAffineDimExpr(0);
subgroupId = mlir::affine::makeComposedAffineApply(
builder, loc, d0.floorDiv(builder.getAffineConstantExpr(warpSize)),
{subgroupId});
subgroupId =
builder
.create<affine::AffineDelinearizeIndexOp>(
loc, subgroupId, ArrayRef<int64_t>{numSubgroups[i], warpSize})
.getResult(0);
}
procInfo[numDims - 1 - i] = {
subgroupId,
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_compiler_cc_library(
"ElementPackingUtils.cpp",
"EquivalenceUtils.cpp",
"FlatbufferUtils.cpp",
"Indexing.cpp",
"ModuleUtils.cpp",
"OptionUtils.cpp",
"PassUtils.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ iree_cc_library(
"ElementPackingUtils.cpp"
"EquivalenceUtils.cpp"
"FlatbufferUtils.cpp"
"Indexing.cpp"
"ModuleUtils.cpp"
"OptionUtils.cpp"
"PassUtils.cpp"
Expand Down
Loading

0 comments on commit 829c3d5

Please sign in to comment.