Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support max_unpool2d lowering #3733

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7190,22 +7190,24 @@ def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$indices,
AnyTorchListOfTorchIntType:$output_size
AnyTorchListOfTorchIntType:$output_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
Expand Down
11 changes: 6 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3456,11 +3456,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
SmallVector<int64_t> resultShape(resultType.getSizes());
Value resultShapeList =
createConstantIntList(binder, rewriter, resultShape);
if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
binder.op, resultType, data, indices, resultShapeList);
return success();
}

SmallVector<int64_t> padding, strides;
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
Expand Down Expand Up @@ -3495,6 +3490,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value paddingList = createConstantIntList(binder, rewriter, padding);
Value stridesList = createConstantIntList(binder, rewriter, strides);

if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
binder.op, resultType, data, indices, resultShapeList,
stridesList, paddingList);
return success();
}
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool3dOp>(
binder.op, resultType, data, indices, resultShapeList, stridesList,
paddingList);
Expand Down
62 changes: 44 additions & 18 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,36 +596,51 @@ namespace {
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
// What worse, without knowing kernel size we cannot even reliably detect such
// cases and this conversion will just return invalid values.
class ConvertAtenMaxUnpool3dOp final
: public OpConversionPattern<AtenMaxUnpool3dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

template <> struct DimensionTraits<AtenMaxUnpool2dOp> {
static constexpr int64_t Dim = 2;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};

template <> struct DimensionTraits<AtenMaxUnpool3dOp> {
static constexpr int64_t Dim = 3;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};

template <typename OpTy>
class ConvertAtenMaxUnpoolOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;

private:
static const int64_t Dim = DimensionTraits<OpTy>::Dim;

LogicalResult createUnpoolOp(OpTy &op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Location loc = op->getLoc();
const TypeConverter *typeConverter = getTypeConverter();
const TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());

ArrayRef<int64_t> inputSize = selfType.getShape().take_back(3);
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(Dim);
if (ShapedType::isDynamicShape(inputSize))
return rewriter.notifyMatchFailure(op,
"input type must be of static shape");

Value indices = adaptor.getIndices();
auto indicesType = cast<RankedTensorType>(indices.getType());
if (inputSize != indicesType.getShape().take_back(3))
if (inputSize != indicesType.getShape().take_back(Dim))
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");

auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
if (!resType)
return rewriter.notifyMatchFailure(op, "invalid result type");

ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(3);
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(Dim);
if (ShapedType::isDynamicShape(inferredOutSize))
return rewriter.notifyMatchFailure(op,
"output type must be of static shape");
Expand All @@ -636,7 +651,7 @@ class ConvertAtenMaxUnpool3dOp final
return rewriter.notifyMatchFailure(op,
"only support constant int output");

if (inferredOutSize != ArrayRef(output))
if (inferredOutSize != ArrayRef(output).take_back(Dim))
return rewriter.notifyMatchFailure(op, "Invalid output size");
}
SmallVector<int64_t> stride;
Expand All @@ -652,12 +667,12 @@ class ConvertAtenMaxUnpool3dOp final

// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
// (padding.size() == 6).
if (stride.size() != 3 || padding.size() != 3)
if (stride.size() != Dim || padding.size() != Dim)
return rewriter.notifyMatchFailure(
op, "stride and padding must be of size 3");
op, "stride and padding must be of size Dim");

int64_t outRank = resType.getRank();
int64_t NC = outRank - 3;
int64_t NC = outRank - Dim;

for (auto &&[inDim, outDim, str, pad] :
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
Expand Down Expand Up @@ -694,7 +709,7 @@ class ConvertAtenMaxUnpool3dOp final
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
// pad self and indices tensors to avoid out of bounds access.
SmallVector<int64_t> expectedInputShape =
llvm::to_vector(resType.getShape().drop_back(3));
llvm::to_vector(resType.getShape().drop_back(Dim));
for (auto &&[str, pad, resSize] :
llvm::zip_equal(stride, padding, inferredOutSize))
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
Expand All @@ -707,7 +722,7 @@ class ConvertAtenMaxUnpool3dOp final
SmallVector<int64_t> low(outRank, 0);
SmallVector<int64_t> high(NC, 0);
for (auto &&[inpSize, outSize] : llvm::zip_equal(
inputSize, ArrayRef(expectedInputShape).take_back(3))) {
inputSize, ArrayRef(expectedInputShape).take_back(Dim))) {
high.emplace_back(outSize - inpSize);
}

Expand Down Expand Up @@ -826,6 +841,13 @@ class ConvertAtenMaxUnpool3dOp final
rewriter.replaceOp(op, result);
return success();
}

public:
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return createUnpoolOp(op, adaptor, rewriter);
}
};
} // namespace

Expand Down Expand Up @@ -1526,8 +1548,12 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
context);

target.addIllegalOp<AtenMaxUnpool2dOp>();
target.addIllegalOp<AtenMaxUnpool3dOp>();
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
patterns.add<ConvertAtenMaxUnpoolOp<AtenMaxUnpool2dOp>>(typeConverter,
context);
patterns.add<ConvertAtenMaxUnpoolOp<AtenMaxUnpool3dOp>>(typeConverter,
context);

target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
patterns
Expand Down
65 changes: 65 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8183,6 +8183,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
" %str_0 = torch.constant.str \"AssertionError: output_size must have 2 elements\"\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: Input be of rank 3 or 4\"\n"
" %true = torch.constant.bool true\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %6 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %9 = torch.aten.eq.int %8, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.list<int>) {\n"
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %15 : !torch.list<int>\n"
" } else {\n"
" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.prim.ListConstruct %11, %12, %13 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %14 : !torch.list<int>\n"
" }\n"
" return %10 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
Expand Down Expand Up @@ -12133,6 +12194,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,15 @@ def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
return maxpool3d, indices

def aten〇max_unpool2d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
assert (len(self) == 4 or len(self) == 3), "Input be of rank 3 or 4"
assert (len(output_size) == 2), "output_size must have 2 elements"
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
if len(self) == 4:
return [self[0], self[1], output_size[0], output_size[1]]
else:
return [self[0], output_size[0], output_size[1]]

def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
assert (len(output_size) == 3), "output_size must have 3 elements"
Expand Down Expand Up @@ -3205,6 +3214,10 @@ def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64

def aten〇max_unpool2d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
)
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
emit(
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
has_canonicalizer=True,
Expand Down
30 changes: 30 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,36 @@ def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils):
# ==============================================================================


class MaxUnpool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, 2, 2], torch.float32, True),
([-1, -1, 2, 2], torch.int64, True),
]
)
def forward(self, x, indices):
return torch.ops.aten.max_unpool2d(x, indices, (4, 4), (2, 2), (0, 0))


@register_test_case(module_factory=lambda: MaxUnpool2dModule())
def MaxUnpool2dModule_basic(module, tu: TestUtils):
input = tu.rand(2, 2, 4, 4)
pool = torch.nn.MaxPool2d(
kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), return_indices=True
)
output, indices = pool(input)

module.forward(output, indices)


# ==============================================================================


class MaxUnpool3dModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading