diff --git a/CMakeLists.txt b/CMakeLists.txt index cf33ccac1400..89a31cf1fdc0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -163,13 +163,6 @@ if(MLIR_ENABLE_BINDINGS_PYTHON) mlir_configure_python_dev_packages() endif() -add_subdirectory(include) -add_subdirectory(lib) -add_subdirectory(tools) - -add_custom_target(check-torch-mlir-all) -add_dependencies(check-torch-mlir-all check-torch-mlir) - if(MLIR_ENABLE_BINDINGS_PYTHON) # If parent projects want to configure where to place the python packages, # respect that. @@ -178,6 +171,13 @@ if(MLIR_ENABLE_BINDINGS_PYTHON) endif() endif() +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(tools) + +add_custom_target(check-torch-mlir-all) +add_dependencies(check-torch-mlir-all check-torch-mlir) + add_subdirectory(test) if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY) diff --git a/conda-dev-env.yml b/conda-dev-env.yml new file mode 100644 index 000000000000..ef8f10f42939 --- /dev/null +++ b/conda-dev-env.yml @@ -0,0 +1,21 @@ +# To create environment, execute: +# conda env create -f conda-dev-env.yml +# +# To activate: +# conda activate mlir-dev +# +name: mlir-dev +channels: + - conda-forge + +dependencies: + - gxx_linux-64 12.* + - gcc_linux-64 12.* + - ccache + - sysroot_linux-64 >=2.14 + - llvmdev 16.* + - cmake + - make + - llvm + - python ==3.11 + - mkl diff --git a/include/torch-mlir/Conversion/FuseLinalg/FuseLinalg.h b/include/torch-mlir/Conversion/FuseLinalg/FuseLinalg.h new file mode 100644 index 000000000000..917d0c6542d5 --- /dev/null +++ b/include/torch-mlir/Conversion/FuseLinalg/FuseLinalg.h @@ -0,0 +1,16 @@ +#ifndef TORCHMLIR_FUSE_LINALG_H +#define TORCHMLIR_FUSE_LINALG_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { +std::unique_ptr> createFuseLinalgOpsPass(); +} +} // namespace mlir + +#endif // TORCHMLIR_FUSE_LINALG_H diff --git a/include/torch-mlir/Conversion/LinalgToKernelCalls/LinalgToKernelCalls.h b/include/torch-mlir/Conversion/LinalgToKernelCalls/LinalgToKernelCalls.h new file mode 100644 index 000000000000..3d0a7c1ba475 --- /dev/null +++ b/include/torch-mlir/Conversion/LinalgToKernelCalls/LinalgToKernelCalls.h @@ -0,0 +1,25 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_LINALGTOKERNELCALLS_LINALGTOKERNELCALLS_H +#define TORCHMLIR_CONVERSION_LINALGTOKERNELCALLS_LINALGTOKERNELCALLS_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { +std::unique_ptr> +createConvertLinalgOpsToKernelCallsPass(); +} +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_LINALGTOKERNELCALLS_LINALGTOKERNELCALLS_H diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 3a130f472b3b..714c66522c1c 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -154,4 +154,26 @@ def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> } #endif +def ConvertLinalgOpsToKernelCalls + : Pass<"convert-linalg-ops-to-kernel-calls", "ModuleOp"> { + let summary = "Lower linalg.matmul operation to kernel call."; + let constructor = [{ + mlir::torch::createConvertLinalgOpsToKernelCallsPass() + }]; + let description = [{ + This pass replaces linalg.matmul operations with calls to the runtime library. + }]; +} + +def FuseLinalgOps + : Pass<"fuse-linalg-ops", "ModuleOp"> { + let summary = "Fuse linalg operations."; + let constructor = [{ + mlir::torch::createFuseLinalgOpsPass() + }]; + let description = [{ + This pass fuses linalg ops. + }]; +} + #endif // TORCHMLIR_CONVERSION_PASSES diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 8956066b8769..6095ecb5c8f5 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(CAPI) add_subdirectory(Conversion) add_subdirectory(Dialect) +add_subdirectory(Runtime) set(LinkedLibs MLIRComplexDialect diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index f26b4d6e895e..ed9df2bcfdfa 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,7 +1,9 @@ +add_subdirectory(LinalgToKernelCalls) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) add_subdirectory(TorchToTosa) +add_subdirectory(FuseLinalg) if(TORCH_MLIR_ENABLE_STABLEHLO) add_subdirectory(TorchToStablehlo) endif() @@ -10,7 +12,9 @@ add_subdirectory(TorchConversionToMLProgram) add_subdirectory(Utils) # TODO: Automate this with add_torch_mlir_conversion_library. -set(linked_libs TorchMLIRTorchToLinalg +set(linked_libs TorchMLIRLinalgToKernelCalls + TorchMLIRTorchToLinalg + FuseLinalg TorchMLIRTorchToSCF TorchMLIRTorchToArith TorchMLIRTorchToTosa diff --git a/lib/Conversion/FuseLinalg/CMakeLists.txt b/lib/Conversion/FuseLinalg/CMakeLists.txt new file mode 100644 index 000000000000..94377969ff1b --- /dev/null +++ b/lib/Conversion/FuseLinalg/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(FuseLinalg + FuseLinalg.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/FuseLinalg + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRLinalgDialect +) + +torch_mlir_target_includes(FuseLinalg) diff --git a/lib/Conversion/FuseLinalg/FuseLinalg.cpp b/lib/Conversion/FuseLinalg/FuseLinalg.cpp new file mode 100644 index 000000000000..91649d8a07c7 --- /dev/null +++ b/lib/Conversion/FuseLinalg/FuseLinalg.cpp @@ -0,0 +1,128 @@ +#include "torch-mlir/Conversion/FuseLinalg/FuseLinalg.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +namespace { + +class MatmulTranspose : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + mlir::LogicalResult + matchAndRewrite(mlir::linalg::MatmulOp op, + mlir::PatternRewriter &rewriter) const override { + + auto mm_op = mlir::cast(&op); + auto inputs = mm_op->getInputs(); + std::vector transposed_input_nums; + transposed_input_nums.reserve(inputs.size()); + + mlir::linalg::TransposeOp transposeOp = nullptr; + std::fill_n(std::back_inserter(transposed_input_nums), inputs.size(), -1); + if (transposed_input_nums.size() != 2) { + return rewriter.notifyMatchFailure( + op, "Inputs amount is not common for Matmul"); + } + + for (size_t num_input = 0; num_input < inputs.size(); num_input++) { + mlir::Value input = inputs[num_input]; + transposeOp = + llvm::dyn_cast(input.getDefiningOp()); + if (!transposeOp) { + continue; + } + transposed_input_nums[num_input] = num_input; + } + + const int default_inps = std::count(transposed_input_nums.cbegin(), + transposed_input_nums.cend(), -1); + if (default_inps != 1) { + return rewriter.notifyMatchFailure(op, "Both Inputs is Transposed"); + } + + if (!transposeOp) { + return rewriter.notifyMatchFailure(op, "Input is not TransposeOp"); + } + /* Maybe we need this + if (isListPotential[ERROR] lyMutated(transposeOp.getResult())) + return rewriter.notifyMatchFailure( + op, "TransposeOp result is potentially mutated"); + */ + + mlir::Location loc = op.getLoc(); + + mlir::SmallVector fusedInputOperands, fusedOutputOperands; + mlir::SmallVector fusedResultTypes; + for (mlir::OpOperand &opOperand : op.getOutputsMutable()) { + fusedOutputOperands.push_back(opOperand.get()); + mlir::Type resultType = opOperand.get().getType(); + if (!mlir::isa(resultType)) + fusedResultTypes.push_back(resultType); + } + + mlir::Value matmul; + if (transposed_input_nums[0] != -1) { + fusedInputOperands.push_back(transposeOp.getInputMutable().get()); + fusedInputOperands.push_back(op.getInputsMutable()[1].get()); + + auto mm_a = rewriter.create( + loc, fusedResultTypes, fusedInputOperands, fusedOutputOperands); + matmul = mm_a.getResult(0); + } else if (transposed_input_nums[1] != -1) { + fusedInputOperands.push_back(op.getInputsMutable()[0].get()); + fusedInputOperands.push_back(transposeOp.getInputMutable().get()); + + auto mm_b = rewriter.create( + loc, fusedResultTypes, fusedInputOperands, fusedOutputOperands); + + matmul = mm_b.getResult(0); + } + + rewriter.replaceUsesWithIf( + op.getResult(0), matmul, [&](mlir::OpOperand &use) { + // Only replace consumer uses. + return use.get().getDefiningOp() != transposeOp; + }); + rewriter.eraseOp(op); + + if (transposeOp.getResult().use_empty()) { + rewriter.eraseOp(transposeOp); + } + return mlir::success(); + } +}; + +class FuseLinalgOps : public mlir::torch::FuseLinalgOpsBase { +public: + void runOnOperation() override { + mlir::MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + + // pattern.add calls go here + patterns.add(context); + + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + mlir::emitError(getOperation()->getLoc(), "failure in Linalg fusion"); + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::createFuseLinalgOpsPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/LinalgToKernelCalls/CMakeLists.txt b/lib/Conversion/LinalgToKernelCalls/CMakeLists.txt new file mode 100644 index 000000000000..c54cdf10a0dd --- /dev/null +++ b/lib/Conversion/LinalgToKernelCalls/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(TorchMLIRLinalgToKernelCalls + LinalgToKernelCalls.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRLinalgDialect + MLIRMathDialect + TorchMLIRTorchDialect +) + +torch_mlir_target_includes(TorchMLIRLinalgToKernelCalls) diff --git a/lib/Conversion/LinalgToKernelCalls/LinalgToKernelCalls.cpp b/lib/Conversion/LinalgToKernelCalls/LinalgToKernelCalls.cpp new file mode 100644 index 000000000000..834cb7f424b9 --- /dev/null +++ b/lib/Conversion/LinalgToKernelCalls/LinalgToKernelCalls.cpp @@ -0,0 +1,143 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/LinalgToKernelCalls/LinalgToKernelCalls.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +LogicalResult +convertLinalgOpsInFunc(func::FuncOp func, + std::map> &usedKernels) { + OpBuilder builder(func.getBody()); + SmallVector replacedOps; + func.walk([&](linalg::LinalgOp op) { + mlir::Operation *valid_op; + std::string fn_name; + if (isa(op)) { + valid_op = op; + fn_name = "matmul_kernel_"; + } else if (isa(op)) { + valid_op = op; + fn_name = "matmul_transpose_a_kernel_"; + } else if (isa(op)) { + valid_op = op; + fn_name = "matmul_transpose_b_kernel_"; + } else { + return; + } + auto types = valid_op->getOperandTypes(); + auto lhs_type = types[0]; + auto rhs_type = types[1]; + auto res_type = types[2]; + auto lhs_elem_type = lhs_type.cast().getElementType(); + auto rhs_elem_type = rhs_type.cast().getElementType(); + auto res_elem_type = res_type.cast().getElementType(); + + if (lhs_elem_type != rhs_elem_type || lhs_elem_type != res_elem_type) + return; + + if (lhs_type.cast().getMemorySpace() != + rhs_type.cast().getMemorySpace() || + lhs_type.cast().getMemorySpace() != + res_type.cast().getMemorySpace()) + return; + + if (!lhs_elem_type.isF16() && !lhs_elem_type.isF32() && + !lhs_elem_type.isF64()) + return; + + builder.setInsertionPoint(valid_op); + + auto unranked_type = UnrankedMemRefType::get( + lhs_elem_type, lhs_type.cast().getMemorySpace()); + + llvm::raw_string_ostream rss(fn_name); + lhs_elem_type.print(rss); + if (!usedKernels.count(fn_name)) { + usedKernels.emplace( + fn_name, + SmallVector({unranked_type, unranked_type, unranked_type})); + } + + SmallVector unranked_ops; + for (OpOperand &operand : valid_op->getOpOperands()) { + unranked_ops.push_back(builder.create( + operand.get().getLoc(), unranked_type, operand.get())); + } + builder.create(valid_op->getLoc(), fn_name, + valid_op->getResultTypes(), unranked_ops); + replacedOps.push_back(valid_op); + }); + + for (Operation *op : replacedOps) { + op->erase(); + } + + return success(); +} +} // namespace + +namespace { +class ConvertLinalgOpsToKernelCalls + : public ConvertLinalgOpsToKernelCallsBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + auto module = getOperation(); + OpBuilder b(module.getBodyRegion()); + std::map> usedKernels; + for (auto func : module.getOps()) { + if (failed(convertLinalgOpsInFunc(func, usedKernels))) + return signalPassFailure(); + } + + // Create FuncOp for each used kernel function. + for (auto &p : usedKernels) { + auto kernelFunc = b.create( + module.getLoc(), p.first, + FunctionType::get(module.getContext(), p.second, {})); + kernelFunc.setPrivate(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::createConvertLinalgOpsToKernelCallsPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 0dae24678a4b..3e9cf211424f 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -13,12 +13,14 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif // TORCH_MLIR_ENABLE_STABLEHLO +#include "torch-mlir/Conversion/LinalgToKernelCalls/LinalgToKernelCalls.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/FuseLinalg/FuseLinalg.h" //===----------------------------------------------------------------------===// // Pass registration diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 26c7e1d01cf1..37d74ba6dc8f 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -362,7 +363,7 @@ class ConvertAtenViewOp : public OpConversionPattern { auto [inputShape, outputShape] = getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); - + // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. @@ -380,8 +381,8 @@ class ConvertAtenViewOp : public OpConversionPattern { bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1; bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1; bool singleDynDimsAreEqual = - inputHasOneDynDim && outputHasOneDynDim && - productReduce(inputShape) == productReduce(outputShape); + inputHasOneDynDim && outputHasOneDynDim && + productReduce(inputShape) == productReduce(outputShape); SmallVector> unchangedDims; for (auto [outputDim, outputDimSize] : llvm::enumerate(outputSizeTorchInt)) { @@ -857,39 +858,21 @@ class ConvertAtenTransposeIntOp auto loc = op.getLoc(); SmallVector outputDims; - for (auto i = 0; i < inputRank; i++) + for (auto i = 0; i < inputRank; i++) { outputDims.push_back(getDimOp(rewriter, loc, adaptor.getSelf(), i)); + } std::swap(outputDims[dim0], outputDims[dim1]); Value outVector = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); - SmallVector idExprs; - SmallVector swapExprs; - for (auto i = 0; i < inputRank; i++) - idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); - for (auto i = 0; i < inputRank; i++) { - if (i == dim0) - swapExprs.push_back(idExprs[dim1]); - else if (i == dim1) - swapExprs.push_back(idExprs[dim0]); - else - swapExprs.push_back(idExprs[i]); - } - SmallVector indexingMaps = { - AffineMap::get(inputRank, 0, idExprs, op.getContext()), - AffineMap::get(inputRank, 0, swapExprs, op.getContext())}; - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); - auto transpose = rewriter - .create( - loc, outVector.getType(), inVector, outVector, - indexingMaps, iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, outType, transpose); + SmallVector dense({dim1, dim0}); + + auto transpose = + rewriter.create(loc, inVector, outVector, dense); + + rewriter.replaceOpWithNewOp(op, outType, + transpose.getResult()[0]); return success(); } }; @@ -1154,7 +1137,8 @@ class ConvertAtenContiguousOp : public OpConversionPattern { return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); return success(); } }; @@ -1407,7 +1391,8 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { SmallVector indexingMaps{inputMap, outputMap}; - SmallVector iteratorTypes(resultType.getRank(), utils::IteratorType::parallel); + SmallVector iteratorTypes( + resultType.getRank(), utils::IteratorType::parallel); Value constantZero = getConstant(rewriter, loc, 0, mlir::IndexType::get(context)); @@ -1417,7 +1402,6 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { loc, outTensor.getType(), input, outTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value realVal = b.create(loc, elementType, args[0]); Value imagVal = diff --git a/lib/Runtime/CMakeLists.txt b/lib/Runtime/CMakeLists.txt new file mode 100644 index 000000000000..171bcbcd880c --- /dev/null +++ b/lib/Runtime/CMakeLists.txt @@ -0,0 +1,11 @@ +set(MKL_LINK "static") +find_package(MKL CONFIG REQUIRED) + +add_library(TorchMLIRKernels SHARED Kernels.cpp) + +target_compile_options(TorchMLIRKernels PUBLIC $) +target_include_directories(TorchMLIRKernels PUBLIC $) +target_link_libraries(TorchMLIRKernels PRIVATE $) + +set_target_properties(TorchMLIRKernels PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs") diff --git a/lib/Runtime/Kernels.cpp b/lib/Runtime/Kernels.cpp new file mode 100644 index 000000000000..26b5a7595079 --- /dev/null +++ b/lib/Runtime/Kernels.cpp @@ -0,0 +1,120 @@ +#include +#include + +#include "mkl.h" + +template struct tensor { + T *buffer; // Allocated memory, used for deallocation only + T *data; // Aligned data. + int64_t offset; // Offset (in elems) to the first element + int64_t rank[2]; // Shape in elems + int64_t stride[2]; // Strides in elems + + T &ref(int64_t i, int64_t j) { + return *(data + i * stride[0] + j * stride[1]); + } +}; + +namespace { +template +void trace_matmul_call(int64_t lhs_rank, tensor *lhs, int64_t rhs_rank, + tensor *rhs, int64_t res_rank, tensor *res) { + std::cout << "test_matmul" << std::endl; + std::cout << " lhs_rank=" << lhs_rank << std::endl + << " rhs_rank=" << rhs_rank << std::endl + << " res_rank=" << res_rank << std::endl; + std::cout << " lhs={" << lhs->buffer << ", " << lhs->data << ", " + << lhs->offset << ", [" << lhs->rank[0] << ", " << lhs->rank[1] + << "], [" << lhs->stride[0] << ", " << lhs->stride[1] << "]}" + << std::endl; + std::cout << " rhs={" << rhs->buffer << ", " << rhs->data << ", " + << rhs->offset << ", [" << rhs->rank[0] << ", " << rhs->rank[1] + << "], [" << rhs->stride[0] << ", " << rhs->stride[1] << "]}" + << std::endl; + std::cout << " res={" << res->buffer << ", " << res->data << ", " + << res->offset << ", [" << res->rank[0] << ", " << res->rank[1] + << "], [" << res->stride[0] << ", " << res->stride[1] << "]}" + << std::endl; +} + +template +void test_matmul(int64_t lhs_rank, tensor *lhs, int64_t rhs_rank, + tensor *rhs, int64_t res_rank, tensor *res) { + trace_matmul_call(lhs_rank, lhs, rhs_rank, rhs, res_rank, res); + std::cout << "Using naive implementation." << std::endl; + for (int64_t i = 0; i < lhs->rank[0]; ++i) { + for (int64_t j = 0; j < rhs->rank[1]; ++j) { + for (int64_t k = 0; k < rhs->rank[0]; ++k) { + res->ref(i, j) += lhs->ref(i, k) * rhs->ref(k, j); + } + } + } +} +} // namespace + +// extern "C" void test_matmul_f16(int64_t lhs_rank, tensor<_Float16> *lhs, +// int64_t rhs_rank, tensor<_Float16> *rhs, +// int64_t res_rank, tensor<_Float16> *res) { +// test_matmul(lhs_rank, lhs, rhs_rank, rhs, res_rank, res); +// } + +extern "C" void matmul_kernel_f32(int64_t lhs_rank, tensor *lhs, + int64_t rhs_rank, tensor *rhs, + int64_t res_rank, tensor *res) { + // trace_matmul_call(lhs_rank, lhs, rhs_rank, rhs, res_rank, res); + // std::cout << "Using MKL implementation." << std::endl; + float alpha = 1.0; + float beta = 1.0; + auto m = lhs->rank[0]; + auto k = lhs->rank[1]; + auto n = rhs->rank[1]; + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, + lhs->data, k, rhs->data, n, beta, res->data, n); +} + +extern "C" void +matmul_transpose_b_kernel_f32(int64_t lhs_rank, tensor *lhs, + int64_t rhs_rank, tensor *rhs, + int64_t res_rank, tensor *res) { + // trace_matmul_call(lhs_rank, lhs, rhs_rank, rhs, res_rank, res); + float alpha = 1.0; + float beta = 1.0; + auto m = lhs->rank[0]; + auto k = lhs->rank[1]; + auto n = rhs->rank[0]; + // std::cout << "Using MKL implementation. [m(rows op(A)): " << m + // << ", k(cols op(A)): " << k << ", n(cols op(B) = cols C): " << n + // << "]" << std::endl; + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, alpha, + lhs->data, k, rhs->data, k, beta, res->data, n); +} + +extern "C" void +matmul_transpose_a_kernel_f32(int64_t lhs_rank, tensor *lhs, + int64_t rhs_rank, tensor *rhs, + int64_t res_rank, tensor *res) { + // trace_matmul_call(lhs_rank, lhs, rhs_rank, rhs, res_rank, res); + std::cout << "[Matmul transpose a] Not tested MKL implementation." << std::endl; + float alpha = 1.0; + float beta = 1.0; + auto m = lhs->rank[1]; + auto k = lhs->rank[0]; + auto n = rhs->rank[1]; + // std::cout << "Using MKL implementation. [m(rows op(A)): " << m + // << ", k(cols op(A)): " << k << ", n(cols op(B) = cols C): " << n + // << "]" << std::endl; + cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k, alpha, + lhs->data, k, rhs->data, k, beta, res->data, n); +} + +extern "C" void matmul_kernel_f64(int64_t lhs_rank, tensor *lhs, + int64_t rhs_rank, tensor *rhs, + int64_t res_rank, tensor *res) { + double alpha = 1.0; + double beta = 1.0; + auto m = lhs->rank[0]; + auto k = lhs->rank[1]; + auto n = rhs->rank[1]; + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, + lhs->data, k, rhs->data, n, beta, res->data, n); +} diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index 885f344778f5..c37772d6a531 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -7,7 +7,7 @@ import re import sys -from torch_mlir_e2e_test.framework import run_tests +from torch_mlir_e2e_test.framework import run_tests, TestOptions from torch_mlir_e2e_test.reporting import report_results from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY @@ -24,6 +24,7 @@ ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.linalg_on_tensors_backends.cpuprotobackend import CpuProtoLinalgOnTensorsBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend from .xfail_sets import ( @@ -43,7 +44,7 @@ register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "cpuproto"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -60,6 +61,10 @@ def _get_argparse(): parser.add_argument("-f", "--filter", default=".*", help=""" Regular expression specifying which tests to include in this run. """) + parser.add_argument("-l", "--list_tests", + default=False, + action="store_true", + help="List all available tests and exit.") parser.add_argument("-v", "--verbose", default=False, action="store_true", @@ -77,10 +82,34 @@ def _get_argparse(): default=False, action="store_true", help="return exit code 0 even if the test fails to unblock pipeline") + parser.add_argument("--dump", + choices=TestOptions.dump_choices, + default=[], + action="append", + help=f""" +Available options: +"all": enable all dumps +"fx-graph": dump input FX Graph +"torch-mlir": dump generated Torch MLIR module +"linalg-mlir": dump module lowered to linalg dialect +"llvm-mlir": dump module lowered to LLVM dialect +"torch-mlir-lowering": dump after-pass results in Torch to Linalg pipeline +"linalg-mlir-lowering": dump after-pass results in Linalg to LLVM pipeline +"obj": dump compiled code to object file +""") + parser.add_argument("--use-kernels", + default=False, + action="store_true", + help="Enable linalg ops replacement with runtime library kernel calls.") + parser.add_argument("--enable-timer", + default=False, + action="store_true", + help="Enable debug timings collection.") return parser def main(): args = _get_argparse().parse_args() + opts = TestOptions(dumps=args.dump, use_kernels=args.use_kernels, debug_timer=args.enable_timer) all_test_unique_names = set( test.unique_name for test in GLOBAL_TEST_REGISTRY) @@ -111,12 +140,22 @@ def main(): xfail_set = LTC_XFAIL_SET crashing_set = LTC_CRASHING_SET elif args.config == "torchdynamo": - config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) + config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend(), opts=opts) + xfail_set = TORCHDYNAMO_XFAIL_SET + crashing_set = TORCHDYNAMO_CRASHING_SET + elif args.config == "cpuproto": + config = TorchDynamoTestConfig(CpuProtoLinalgOnTensorsBackend(opts), opts=opts) xfail_set = TORCHDYNAMO_XFAIL_SET crashing_set = TORCHDYNAMO_CRASHING_SET do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set) available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt] + + if args.list_tests is True: + for test in available_tests: + print(test.unique_name) + sys.exit(0) + if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None: for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed: if arg not in all_test_unique_names: diff --git a/projects/pt1/examples/mlp.py b/projects/pt1/examples/mlp.py new file mode 100644 index 000000000000..e8c5c2c33cb6 --- /dev/null +++ b/projects/pt1/examples/mlp.py @@ -0,0 +1,40 @@ +import torch +import torch_mlir +from torch_mlir_e2e_test.framework import TraceItem +from torch_mlir_e2e_test.configs.torchdynamo import TorchDynamoTestConfig +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + RefBackendLinalgOnTensorsBackend, +) + + +class MLP(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.flatten = torch.nn.Flatten() + self.linear1 = torch.nn.Linear(input_dim, input_dim // 2) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(input_dim // 2, output_dim) + + def forward(self, x): + x = self.flatten(x) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +def model_factory(): + return MLP(16, 2) + + +torch.set_default_dtype(torch.float16) +model = model_factory() +test_input = torch.rand(2, 4, 4) + +ref_res = [TraceItem(symbol="mlp", inputs=[test_input], output=model(test_input))] + +config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) +exp_res = config.run(model, ref_res) + +print(ref_res[0].output) +print(exp_res[0].output) diff --git a/projects/pt1/python/torch_mlir/__init__.py b/projects/pt1/python/torch_mlir/__init__.py index 8de6cc1a14fe..932c812dd4df 100644 --- a/projects/pt1/python/torch_mlir/__init__.py +++ b/projects/pt1/python/torch_mlir/__init__.py @@ -268,7 +268,7 @@ def _canon_extra_library(extra_library): return extra_library_file_name -def _lower_mlir_module(verbose, output_type, module): +def _lower_mlir_module(verbose, output_type, module, ir_dump_file = None): if verbose: print("\n====================") print("Torch Backend IR") @@ -280,7 +280,7 @@ def _lower_mlir_module(verbose, output_type, module): if output_type == OutputType.TOSA: run_pipeline_with_repro_report( module, "builtin.module(torch-backend-to-tosa-backend-pipeline)", - "Lowering Torch Backend IR -> TOSA Backend IR") + "Lowering Torch Backend IR -> TOSA Backend IR", ir_dump_file) if verbose: print("\n====================") print("TOSA Backend IR") @@ -291,7 +291,7 @@ def _lower_mlir_module(verbose, output_type, module): run_pipeline_with_repro_report( module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", - "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") + "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ir_dump_file) if verbose: print("\n====================") print("LINALG Backend IR") @@ -302,7 +302,7 @@ def _lower_mlir_module(verbose, output_type, module): run_pipeline_with_repro_report( module, "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", - "Lowering Torch Backend IR -> StableHLO Backend IR") + "Lowering Torch Backend IR -> StableHLO Backend IR", ir_dump_file) if verbose: print("\n====================") print("StableHLO Backend IR") diff --git a/projects/pt1/python/torch_mlir/compiler_utils.py b/projects/pt1/python/torch_mlir/compiler_utils.py index 56e250e16802..3ace9da99237 100644 --- a/projects/pt1/python/torch_mlir/compiler_utils.py +++ b/projects/pt1/python/torch_mlir/compiler_utils.py @@ -25,9 +25,23 @@ def get_module_name_for_debug_dump(module): class TorchMlirCompilerError(Exception): pass +class StderrToFile: + def __init__(self, file: str): + self._file_name = file + + def __enter__(self): + self._fd = os.open(self._file_name, os.O_WRONLY | os.O_CREAT) + self._old_stderr_fd = os.dup(2) + os.dup2(self._fd, 2) + + def __exit__(self, *args): + os.dup2(self._old_stderr_fd, 2) + os.close(self._fd) + def run_pipeline_with_repro_report(module, pipeline: str, - description: str): + description: str, + ir_dump_file: str = None): """Runs `pipeline` on `module`, with a nice repro report if it fails.""" module_name = get_module_name_for_debug_dump(module) try: @@ -38,7 +52,14 @@ def run_pipeline_with_repro_report(module, # Lower module in place to make it ready for compiler backends. with module.context: pm = PassManager.parse(pipeline) - pm.run(module.operation) + if ir_dump_file is not None: + module.context.enable_multithreading(False) + pm.enable_ir_printing() + with StderrToFile(ir_dump_file): + pm.run(module.operation) + module.context.enable_multithreading(True) + else: + pm.run(module.operation) except Exception as e: # TODO: More robust. # - don't arbitrarily clutter up /tmp. When a test suite has many diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index c53227acf36a..78c41d40db9c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -29,8 +29,16 @@ recursively_convert_to_numpy, recursively_convert_from_numpy, ) -from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.framework import TestConfig, TestOptions, Trace, TraceItem, DebugTimer +DUMPS_ENABLED = True + +def _dump_repr_to_file(representation, filename: str): + if not DUMPS_ENABLED: + return + + with open(filename, 'w') as f: + f.write(str(representation)) def refine_result_type(_result): if isinstance(_result, tuple): @@ -57,6 +65,8 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: def jit( model: torch.nn.Module, example_args: _example_args, + symbol: str, + opts: TestOptions, output_type: Union[str, "OutputType"] = OutputType.TORCH, backend_legal_ops: Optional[Sequence[str]] = None, extra_library=None, @@ -87,9 +97,18 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule, # way of differentiating between the two. assert not _returns_empty_tuple(gm), "encountered graph that does not return anything" + if opts.is_dump_enabled("fx-graph"): + with open(f"{model._get_name()}.{symbol}-fx-graph.txt", "w") as f: + print(gm.graph, file=f) + nonlocal mlir_module *_, model_name, nth_graph = get_aot_compilation_context() mlir_module = import_fx_graph_as_func(gm.graph, model_name) + + if opts.is_dump_enabled("torch-mlir"): + with open(f"{model._get_name()}.{symbol}-torch.mlir", "w") as f: + print(mlir_module, file=f) + return gm my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend, @@ -105,6 +124,7 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule, option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) + " extra-library=" + extra_library_file_name + "}") assert mlir_module is not None + _dump_repr_to_file(mlir_module, 'forward.mlir') run_pipeline_with_repro_report( mlir_module, # f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", @@ -112,41 +132,70 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule, "Lowering TorchFX IR -> Torch Backend IR", ) - return _lower_mlir_module(verbose, output_type, mlir_module) + ir_file = f"{model._get_name()}.{symbol}-torch-to-linanlg.txt" if opts.is_dump_enabled( + "torch-mlir-lowering") else None + return _lower_mlir_module(verbose, output_type, mlir_module, ir_file) class TorchDynamoTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with TorchDynamo""" - def __init__(self, backend): + def __init__(self, backend, opts=TestOptions()): super().__init__() self.backend = backend + self.opts = opts def compile(self, program: torch.nn.Module) -> torch.nn.Module: return program def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] - for item in trace: - module = jit(artifact, - item.inputs, - output_type="linalg-on-tensors") - module = self.backend.compile(module) - backend_module = self.backend.load(module) - params = { - **dict(artifact.named_parameters(remove_duplicate=False)), - **dict(artifact.named_buffers(remove_duplicate=False)), - } - params_flat, params_spec = pytree.tree_flatten(params) - params_flat = list(params_flat) - with torch.no_grad(): - numpy_inputs = recursively_convert_to_numpy(params_flat + - item.inputs) - outputs = getattr(backend_module, - artifact.__class__.__name__)(*numpy_inputs) - output = refine_result_type(outputs) - result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + timing_logger = print if self.opts.is_debug_timer_enabled() else None + with DebugTimer("TorchDynamoTestConfig.run()", logger=timing_logger): + for item in trace: + with DebugTimer("JIT", logger=timing_logger): + module = jit(artifact, + item.inputs, + item.symbol, + self.opts, + output_type="linalg-on-tensors") + + if self.opts.is_dump_enabled("linalg-mlir"): + with open(f"{artifact._get_name()}.{item.symbol}-linalg.mlir", "w") as f: + print(module, file=f) + + ir_file = f"{artifact._get_name()}.{item.symbol}-linalg-to-llvm.txt" if self.opts.is_dump_enabled( + "linalg-mlir-lowering") else None + with DebugTimer("Backend.compile()", logger=timing_logger): + module = self.backend.compile(module, ir_file) + + if self.opts.is_dump_enabled("llvm-mlir"): + with open(f"{artifact._get_name()}.{item.symbol}-llvm.mlir", "w") as f: + print(module, file=f) + + with DebugTimer("Backend.load()", logger=timing_logger): + backend_module = self.backend.load(module) + params = { + **dict(artifact.named_parameters(remove_duplicate=False)), + **dict(artifact.named_buffers(remove_duplicate=False)), + } + params_flat, params_spec = pytree.tree_flatten(params) + params_flat = list(params_flat) + with torch.no_grad(): + with DebugTimer("recursively_convert_to_numpy", logger=timing_logger): + numpy_inputs = recursively_convert_to_numpy(params_flat + + item.inputs) + outputs = getattr(backend_module, + artifact.__class__.__name__)(*numpy_inputs) + + with DebugTimer("refine_result_type", logger=timing_logger): + output = refine_result_type(outputs) + + if self.opts.is_dump_enabled("obj"): + backend_module.ee.dump_to_object_file(f"{artifact._get_name()}.{item.symbol}.o") + + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index f1fbad2ec914..5d2e49468b0d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -27,6 +27,9 @@ import sys import traceback +import time +import functools + import torch import multiprocess as mp @@ -95,6 +98,80 @@ def clone_trace(trace: Trace) -> Trace: # this type. CompiledArtifact = TypeVar('CompiledArtifact') +class DebugTimer: + """Basic debug timer + Usage examples: + 1. + t = DebugTimer('MyName', logger=print) + t.start() + doStuff(...) + t.stop() + + 2. + @DebugTimer('run') + def run(...): + doStuff(...) + + 3. + with DebugTimer('withSmth'): + doStuff(...) + """ + def __init__(self, name=None, logger=print) -> None: + self.begin = None + self.elapsed = None + self.logger = logger + self.name = name + + def start(self) -> None: + if self.begin is not None: + raise RuntimeError("Attempt to start a running timer.") + self.begin = time.perf_counter_ns() + + def stop(self): + if self.begin is None: + raise RuntimeError("Attempt to stop a non-running timer.") + self.elapsed = time.perf_counter_ns() - self.begin + self.begin = None + + self._report() + return self.elapsed + + def _report(self): + if self.logger: + rep_line = "elapsed " + "{:.4f}".format(float(self.elapsed) / 10e6) + " ms" + rep_line = self.name + ": " + rep_line if self.name is not None else rep_line + self.logger(rep_line) + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def __call__(self, func): + @functools.wraps(func) + def wrapper_debug_timer(*args, **kwargs): + with self: + return func(*args, **kwargs) + return wrapper_debug_timer + +class TestOptions: + """Test run options.""" + + dump_choices = ["all", "fx-graph", "torch-mlir", "linalg-mlir", "llvm-mlir", "torch-mlir-lowering", "linalg-mlir-lowering", "obj"] + + def __init__(self, *, dumps: List[str] = [], use_kernels=False, debug_timer=False): + self.dumps = {opt for opt in dumps} + self.use_kernels = use_kernels + self.debug_timer = debug_timer + + def is_dump_enabled(self, dump: str): + return dump in self.dumps or "all" in self.dumps + + def is_debug_timer_enabled(self): + return self.debug_timer + class TestConfig(abc.ABC): """The interface implemented by backends to run tests. diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/cpuprotobackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/cpuprotobackend.py new file mode 100644 index 000000000000..ecc355cf024c --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/cpuprotobackend.py @@ -0,0 +1,140 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import os + +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.execution_engine import * +from torch_mlir.runtime import * +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir_e2e_test.framework import TestOptions, DebugTimer + +from .abc import LinalgOnTensorsBackend +from .refbackend import RefBackendInvoker + +__all__ = [ + "CpuProtoLinalgOnTensorsBackend", +] + + +def _build_lowering_pipeline(opts: TestOptions): + passes = [ + "fuse-linalg-ops", + "func.func(refback-generalize-tensor-pad)", + # Apply some optimizations. It would be great if MLIR had more useful + # optimizations that worked out of the box here. + # Note: When measured, this doesn't seem to actually help that much + # for the linalg-on-tensors backend. + # This is likely because if things are naturally fusable we usually already + # emit things in that form from the high level (e.g. single linalg-generic). + # Other backends are likely to benefit more. + "func.func(linalg-fuse-elementwise-ops)", + "convert-shape-to-std", + # Bufferize. + "one-shot-bufferize", + "func.func(scf-bufferize)", + "func.func(tm-tensor-bufferize)", + "func.func(empty-tensor-to-alloc-tensor)", + "func.func(linalg-bufferize)", + "func-bufferize", + "arith-bufferize", + "refback-mlprogram-bufferize", + "func.func(tensor-bufferize)", + "func.func(finalizing-bufferize)", + #"func.func(buffer-deallocation)", + "ownership-based-buffer-deallocation", + "canonicalize", + "buffer-deallocation-simplification", + "bufferization-lower-deallocations", + "cse", + "canonicalize", + # Munge to make it ExecutionEngine compatible. + # Specifically, we rewrite calling convention boundaries to be in terms + # of unranked memref, and we rewrite the return to actually be a + # callback that consumes the return (the final munged function always + # returns void at the C level -- we get the return value by providing the + # callback). + "refback-munge-calling-conventions" + ] + if opts.use_kernels: + # Introduce kernel calls for operations we want to execute using library + # kernels. + passes.append("convert-linalg-ops-to-kernel-calls") + passes.extend([ + # Insert global variable and instruction sequence for getting the next + # global seed used in stateful rng. + # Lower to LLVM + "func.func(tm-tensor-to-loops)", + "func.func(refback-munge-memref-copy)", + "func.func(convert-linalg-to-loops)", + "func.func(lower-affine)", + "convert-scf-to-cf", + "func.func(refback-expand-ops-for-llvm)", + "func.func(arith-expand)", + "func.func(convert-math-to-llvm)", + # Handle some complex mlir::math ops (e.g. atan2) + "convert-math-to-libm", + "expand-strided-metadata", + "finalize-memref-to-llvm", + "lower-affine", + "convert-bufferization-to-memref", + "finalize-memref-to-llvm", + "func.func(convert-arith-to-llvm)", + "convert-func-to-llvm", + "convert-cf-to-llvm", + "convert-complex-to-llvm", + "reconcile-unrealized-casts" + ]) + return "builtin.module(" + ",".join(passes) + ")" + + +def _find_shared_lib(name): + this_file_dir = os.path.dirname(os.path.abspath(__file__)) + lib_file_path = f"{this_file_dir}/../../torch_mlir/_mlir_libs/{name}" + if not os.path.isfile(lib_file_path): + raise RuntimeError(f"Cannot find runtime library: {lib_file_path}") + return lib_file_path + + +def _collect_shared_libs(opts: TestOptions): + shared_libs = [] + if opts.use_kernels: + shared_libs.append(_find_shared_lib("libTorchMLIRKernels.so")) + return shared_libs + + +class CpuProtoLinalgOnTensorsBackend(LinalgOnTensorsBackend): + """Main entry-point for the reference backend.""" + def __init__(self, opts: TestOptions = TestOptions()): + super().__init__() + self._opts = opts + + def compile(self, imported_module: Module, ir_file: str = None): + """Compiles an imported module, with a flat list of functions. + The module is expected to be in linalg-on-tensors + scalar code form. + TODO: More clearly define the backend contract. Generally this will + extend to support globals, lists, and other stuff. + + Args: + imported_module: The MLIR module consisting of funcs in the torch + dialect. + ir_file: If specified, use it as output file for MLIR passes dumps + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + with DebugTimer('CpuProtoLinalgOnTensorsBackend.compile()', logger=print if self._opts.debug_timer else None): + run_pipeline_with_repro_report( + imported_module, _build_lowering_pipeline(self._opts), + "Lowering Linalg-on-Tensors IR to LLVM with RefBackend", ir_file) + return imported_module + + def load(self, module) -> RefBackendInvoker: + """Loads a compiled artifact into the runtime.""" + with DebugTimer('CpuProtoLinalgOnTensorsBackend.load()', logger=print if self._opts.debug_timer else None): + invoker = RefBackendInvoker(module, + shared_libs=_collect_shared_libs(self._opts), logger=print if self._opts.debug_timer else None) + return invoker diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 1b9dbb0d2c51..d4168e358640 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -15,6 +15,8 @@ from .abc import LinalgOnTensorsBackend +from torch_mlir_e2e_test.framework import DebugTimer + __all__ = [ "RefBackendLinalgOnTensorsBackend", ] @@ -80,9 +82,10 @@ def get_ctype_func(func_name): class RefBackendInvoker: - def __init__(self, module): - self.ee = ExecutionEngine(module) + def __init__(self, module, shared_libs=None, logger=None): + self.ee = ExecutionEngine(module, shared_libs=shared_libs) self.result = None + self.logger = logger return_funcs = get_return_funcs(module) @@ -105,14 +108,15 @@ def consume_return_funcs(*args): def __getattr__(self, function_name: str): def invoke(*args): - ffi_args = [] - for arg in args: - assert_arg_type_is_supported(arg.dtype) - ffi_args.append( - ctypes.pointer( - ctypes.pointer(get_unranked_memref_descriptor(arg)))) - - self.ee.invoke(function_name, *ffi_args) + with DebugTimer('refbackend.invoke() args conversion', logger=self.logger): + ffi_args = [] + for arg in args: + assert_arg_type_is_supported(arg.dtype) + ffi_args.append( + ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(arg)))) + with DebugTimer('ExecutionEngine.invoke()', logger=self.logger): + self.ee.invoke(function_name, *ffi_args) result = self.result assert result is not None, "Invocation didn't produce a result" self.result = None @@ -182,7 +186,7 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend): def __init__(self): super().__init__() - def compile(self, imported_module: Module): + def compile(self, imported_module: Module, ir_file: str = None): """Compiles an imported module, with a flat list of functions. The module is expected to be in linalg-on-tensors + scalar code form. TODO: More clearly define the backend contract. Generally this will @@ -198,9 +202,9 @@ def compile(self, imported_module: Module): run_pipeline_with_repro_report( imported_module, LOWERING_PIPELINE, - "Lowering Linalg-on-Tensors IR to LLVM with RefBackend") + "Lowering Linalg-on-Tensors IR to LLVM with RefBackend", ir_file) return imported_module def load(self, module) -> RefBackendInvoker: """Loads a compiled artifact into the runtime.""" - return RefBackendInvoker(module) + return RefBackendInvoker(module, shared_libs=[]) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py index faebcadf30f6..b568bd4e28e0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py @@ -99,3 +99,44 @@ def forward(self, x): @register_test_case(module_factory=lambda: BatchMlpLayerModule()) def BatchMlpLayerModule_basic(module, tu: TestUtils): module.forward(tu.rand(7, 5, 3)) + + +class MLP(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.flatten = torch.nn.Flatten() + self.linear1 = torch.nn.Linear(input_dim, input_dim // 2) + # self.relu = torch.nn.ReLU() + # self.linear2 = torch.nn.Linear(input_dim // 2, output_dim) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + x = self.flatten(x) + x = self.linear1(x) + # x = self.relu(x) + # x = self.linear2(x) + return x + +model = MLP(128 * 128, 0) + +def model_factory(): + return model + + +test_input = torch.rand(1, 128, 128) +w = model.linear1.weight.detach().numpy() +b = model.linear1.bias.detach().numpy() +print("in shape: ", test_input.shape) +print(" w shape: ", w.shape) +print(" b shape: ", b.shape) + + +@register_test_case(module_factory=model_factory) +def MLP_basic(module, tu: TestUtils): + out = module.forward(test_input) + print("[test body] out shape: ", out.size()) +