Skip to content

Commit

Permalink
Generalize MaxUnpool lowering including 2d case
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Oct 9, 2024
1 parent dba2946 commit 80c2426
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 24 deletions.
27 changes: 27 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7159,6 +7159,33 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
}];
}

def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$indices,
AnyTorchListOfTorchIntType:$output_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3464,6 +3464,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value paddingList = createConstantIntList(binder, rewriter, padding);
Value stridesList = createConstantIntList(binder, rewriter, strides);

if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
binder.op, resultType, data, indices, resultShapeList,
stridesList, paddingList);
return success();
}
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool3dOp>(
binder.op, resultType, data, indices, resultShapeList, stridesList,
paddingList);
Expand Down
63 changes: 44 additions & 19 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,37 +596,51 @@ namespace {
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
// What worse, without knowing kernel size we cannot even reliably detect such
// cases and this conversion will just return invalid values.
class ConvertAtenMaxUnpool3dOp final
: public OpConversionPattern<AtenMaxUnpool3dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

template <> struct DimensionTraits<AtenMaxUnpool2dOp> {
static constexpr int64_t Dim = 2;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};

template <> struct DimensionTraits<AtenMaxUnpool3dOp> {
static constexpr int64_t Dim = 3;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};

template <typename OpTy>
class ConvertAtenMaxUnpoolOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;

private:
static const int64_t Dim = DimensionTraits<OpTy>::Dim;

LogicalResult createUnpoolOp(OpTy &op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Location loc = op->getLoc();
const TypeConverter *typeConverter = getTypeConverter();
const TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());

size_t spatial = selfType.getRank() - 2;
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(spatial);
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(Dim);
if (ShapedType::isDynamicShape(inputSize))
return rewriter.notifyMatchFailure(op,
"input type must be of static shape");

Value indices = adaptor.getIndices();
auto indicesType = cast<RankedTensorType>(indices.getType());
if (inputSize != indicesType.getShape().take_back(spatial))
if (inputSize != indicesType.getShape().take_back(Dim))
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");

auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
if (!resType)
return rewriter.notifyMatchFailure(op, "invalid result type");

ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(spatial);
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(Dim);
if (ShapedType::isDynamicShape(inferredOutSize))
return rewriter.notifyMatchFailure(op,
"output type must be of static shape");
Expand All @@ -637,7 +651,7 @@ class ConvertAtenMaxUnpool3dOp final
return rewriter.notifyMatchFailure(op,
"only support constant int output");

if (inferredOutSize != ArrayRef(output).take_back(spatial))
if (inferredOutSize != ArrayRef(output).take_back(Dim))
return rewriter.notifyMatchFailure(op, "Invalid output size");
}
SmallVector<int64_t> stride;
Expand All @@ -653,12 +667,12 @@ class ConvertAtenMaxUnpool3dOp final

// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
// (padding.size() == 6).
if (stride.size() != spatial || padding.size() != spatial)
if (stride.size() != Dim || padding.size() != Dim)
return rewriter.notifyMatchFailure(
op, "stride and padding must be of size 3");
op, "stride and padding must be of size Dim");

int64_t outRank = resType.getRank();
int64_t NC = outRank - spatial;
int64_t NC = outRank - Dim;

for (auto &&[inDim, outDim, str, pad] :
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
Expand Down Expand Up @@ -695,7 +709,7 @@ class ConvertAtenMaxUnpool3dOp final
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
// pad self and indices tensors to avoid out of bounds access.
SmallVector<int64_t> expectedInputShape =
llvm::to_vector(resType.getShape().drop_back(spatial));
llvm::to_vector(resType.getShape().drop_back(Dim));
for (auto &&[str, pad, resSize] :
llvm::zip_equal(stride, padding, inferredOutSize))
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
Expand All @@ -708,7 +722,7 @@ class ConvertAtenMaxUnpool3dOp final
SmallVector<int64_t> low(outRank, 0);
SmallVector<int64_t> high(NC, 0);
for (auto &&[inpSize, outSize] : llvm::zip_equal(
inputSize, ArrayRef(expectedInputShape).take_back(spatial))) {
inputSize, ArrayRef(expectedInputShape).take_back(Dim))) {
high.emplace_back(outSize - inpSize);
}

Expand Down Expand Up @@ -827,6 +841,13 @@ class ConvertAtenMaxUnpool3dOp final
rewriter.replaceOp(op, result);
return success();
}

public:
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return createUnpoolOp(op, adaptor, rewriter);
}
};
} // namespace

Expand Down Expand Up @@ -1527,8 +1548,12 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
context);

target.addIllegalOp<AtenMaxUnpool2dOp>();
target.addIllegalOp<AtenMaxUnpool3dOp>();
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
patterns.add<ConvertAtenMaxUnpoolOp<AtenMaxUnpool2dOp>>(typeConverter,
context);
patterns.add<ConvertAtenMaxUnpoolOp<AtenMaxUnpool3dOp>>(typeConverter,
context);

target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
patterns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1056,14 +1056,17 @@ def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
return maxpool3d, indices

def aten〇max_unpool2d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
assert (len(self) == 4), "Input be of rank 4"
assert (len(output_size) == 2), "output_size must have 2 elements"
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
return [self[0], self[1], output_size[0], output_size[1]]

def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
assert (len(self) == 5), "Input be of rank 5"
assert (len(output_size) == 3), "output_size must have 3 elements"
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
if len(self) == 5:
return [self[0], self[1], output_size[0], output_size[1], output_size[2]]
else:
return [self[0], output_size[0], output_size[1], output_size[2]]
return [self[0], self[1], output_size[0], output_size[1], output_size[2]]

def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
return input_size
Expand Down Expand Up @@ -3202,6 +3205,10 @@ def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64

def aten〇max_unpool2d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
)
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
emit(
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
has_canonicalizer=True,
Expand Down

0 comments on commit 80c2426

Please sign in to comment.