Skip to content

Commit

Permalink
[TOSA] Change PadOp padding to tosa.shape
Browse files Browse the repository at this point in the history
This patch changes PadOp's padding input to type !tosa.shape<2 * rank>,
(where rank is the rank of the PadOp's input), instead of a <rank x 2> tensor.

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I08526a699d6b8ebbaf9ee092cd37580e5d78f919
  • Loading branch information
Tai78641 authored and Jerry-Ge committed Jan 15, 2025
1 parent 7aec7ca commit f5d6242
Show file tree
Hide file tree
Showing 17 changed files with 197 additions and 159 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr);

bool collectShapeValue(Operation *op, llvm::SmallVector<int64_t> &newShape);

#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"

} // namespace tosa
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1557,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Example:

```mlir
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
%0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
```

Example 2:

```mlir
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
%0 = tosa.const_shape { value = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
```
}];

let arguments = (ins
Tosa_RankedTensor:$input1,
TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
Tosa_Shape:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
);
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
return permuted;
}

// Computes shape value using tosa const_shape op.
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);

} // namespace tosa
} // namespace mlir

Expand Down
27 changes: 14 additions & 13 deletions mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
ConversionPatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.getInput1();
auto padding = padOp.getPadding();

ElementsAttr paddingElems;
if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
return rewriter.notifyMatchFailure(
padOp, "padding must be a static shape value");
}
llvm::SmallVector<int64_t> paddingVals;
for (auto idx : paddingElems.getValues<IntegerAttr>()) {
paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
}

ShapedType inputTy = cast<ShapedType>(input.getType());
Type elementTy = inputTy.getElementType();
Expand Down Expand Up @@ -345,18 +354,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
highValues.reserve(rank);

for (int i = 0; i < rank; i++) {
Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({lowIndex}));
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({highIndex}));

lowVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), lowVal);
highVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), highVal);

Value lowVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(paddingVals[2 * i]));
Value highVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
lowValues.push_back(lowVal);
highValues.push_back(highVal);
}
Expand Down
70 changes: 40 additions & 30 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,27 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
}
}

//===----------------------------------------------------------------------===//
// TOSA shape inference helper
//===----------------------------------------------------------------------===//
bool mlir::tosa::collectShapeValue(Operation *op,
llvm::SmallVector<int64_t> &newShape) {
if (!op) {
return false;
}
if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
Attribute constOpAttr = constOp->getAttr("value");
DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
for (int i = 0; i < elementsAttr.size(); i++) {
int64_t val = elementsAttr.getValues<int64_t>()[i];
newShape.push_back(val);
}
return true;
}
// for undefined op, return false.
return false;
}

//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -823,51 +844,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
PadOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
ShapeAdaptor paddingShape(adaptor.getPadding().getType());
auto paddingRank =
cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
SmallVector<int64_t> outputShape;

// If both inputs have unknown shape, we cannot determine the shape of the
// output.
if (!inputShape.hasRank() && !paddingShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}

// If the input rank is unknown we can info the output rank using the
// padding shape's first dim.
// If the input rank is unknown, we can infer the output rank using the
// padding shape's rank divided by 2.
if (!inputShape.hasRank()) {
if (paddingShape.isDynamicDim(0)) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}

outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}

DenseIntElementsAttr paddings;
SmallVector<int64_t> paddingValues;
// If the paddings value is not a constant, all dimensions must be dynamic.
if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
if (!tosa::collectShapeValue(adaptor.getPadding().getDefiningOp(),
paddingValues)) {
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}

SmallVector<int64_t> paddingValues;
for (auto val : paddings) {
paddingValues.push_back(val.getSExtValue());
}

outputShape.reserve(inputShape.getRank());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
if (inputShape.isDynamicDim(i)) {
outputShape.push_back(ShapedType::kDynamic);
continue;
}
auto padFront = paddingValues[i * 2];
auto padBack = paddingValues[i * 2 + 1];
if (padFront < 0 || padBack < 0) {
// if either padding for dim i is -1, output dim is unknown
outputShape.push_back(ShapedType::kDynamic);
continue;
}

outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
paddingValues[i * 2 + 1]);
outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
}

inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
Expand All @@ -877,17 +889,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::verify() {
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
RankedTensorType paddingType = getPadding().getType();
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();

if (inputType.getRank() != outputType.getRank())
return emitOpError() << "expect same input and output tensor rank.";

if (!paddingType.isDynamicDim(0) &&
paddingType.getDimSize(0) != inputType.getRank() * 2)
if (paddingRank != inputType.getRank() * 2)
return emitOpError() << "expected padding tensor dim 0 to have size "
<< inputType.getRank() * 2
<< " (2*rank(shape1)) but got size "
<< paddingType.getDimSize(0);
<< " (2*rank(shape1)) but got size " << paddingRank;

return success();
}
Expand Down
6 changes: 1 addition & 5 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
}
}

auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
auto padSize =
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
Value padSizeVal =
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);

auto padTy = RankedTensorType::get({}, inputETy);
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
Expand Down
6 changes: 1 addition & 5 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
}
}

auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
auto padSize =
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
Value padSizeVal =
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);

auto padTy = RankedTensorType::get({}, inputETy);
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
Expand Down
29 changes: 11 additions & 18 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,14 @@ class TransposeConvStridedConverter
int64_t inputChannels = weightTy.getDimSize(3);

// Pad the weight so that it is modulo of the striding.
llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
weightPadding[3] =
(weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
weightPadding[5] =
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;

Value weightPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), weightPadding);

if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
Expand Down Expand Up @@ -197,17 +196,14 @@ class TransposeConvStridedConverter
/* axis = */ rewriter.getI32IntegerAttr(2));

// We need to pad the input far enough that we can pull all values.
llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;

DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);

Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
Value inputPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), inputPadding);

if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
Expand Down Expand Up @@ -310,17 +306,14 @@ class TransposeConvStridedConverter
rewriter.getDenseI64ArrayAttr(sliceSize))
.getResult();

llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
resultPadding[2] = resultPadTop;
resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
resultPadding[4] = resultPadLeft;
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];

DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);

Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
Value resultPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), resultPadding);

Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,18 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,

return success();
}

Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape) {
auto attr = rewriter.getIndexTensorAttr(shape);
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
mlir::Operation *mlir_op =
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
return mlir_op->getResult(0);
}

SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return ShapedType::isDynamic(dim) ? -1 : dim;
}));
}
Loading

0 comments on commit f5d6242

Please sign in to comment.