Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bufferization] Enable OneShot #10

Open
wants to merge 8 commits into
base: cpu-proto
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
mlir_configure_python_dev_packages()
endif()

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(tools)

add_custom_target(check-torch-mlir-all)
add_dependencies(check-torch-mlir-all check-torch-mlir)

if(MLIR_ENABLE_BINDINGS_PYTHON)
# If parent projects want to configure where to place the python packages,
# respect that.
Expand All @@ -178,6 +171,13 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
endif()
endif()

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(tools)

add_custom_target(check-torch-mlir-all)
add_dependencies(check-torch-mlir-all check-torch-mlir)

add_subdirectory(test)

if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY)
Expand Down
21 changes: 21 additions & 0 deletions conda-dev-env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# To create environment, execute:
# conda env create -f conda-dev-env.yml
#
# To activate:
# conda activate mlir-dev
#
name: mlir-dev
channels:
- conda-forge

dependencies:
- gxx_linux-64 12.*
- gcc_linux-64 12.*
- ccache
- sysroot_linux-64 >=2.14
- llvmdev 16.*
- cmake
- make
- llvm
- python ==3.11
- mkl
16 changes: 16 additions & 0 deletions include/torch-mlir/Conversion/FuseLinalg/FuseLinalg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef TORCHMLIR_FUSE_LINALG_H
#define TORCHMLIR_FUSE_LINALG_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace torch {
std::unique_ptr<mlir::OperationPass<ModuleOp>> createFuseLinalgOpsPass();
}
} // namespace mlir

#endif // TORCHMLIR_FUSE_LINALG_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_LINALGTOKERNELCALLS_LINALGTOKERNELCALLS_H
#define TORCHMLIR_CONVERSION_LINALGTOKERNELCALLS_LINALGTOKERNELCALLS_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLinalgOpsToKernelCallsPass();
}
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_LINALGTOKERNELCALLS_LINALGTOKERNELCALLS_H
22 changes: 22 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,26 @@ def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp">
}
#endif

def ConvertLinalgOpsToKernelCalls
: Pass<"convert-linalg-ops-to-kernel-calls", "ModuleOp"> {
let summary = "Lower linalg.matmul operation to kernel call.";
let constructor = [{
mlir::torch::createConvertLinalgOpsToKernelCallsPass()
}];
let description = [{
This pass replaces linalg.matmul operations with calls to the runtime library.
}];
}

def FuseLinalgOps
: Pass<"fuse-linalg-ops", "ModuleOp"> {
let summary = "Fuse linalg operations.";
let constructor = [{
mlir::torch::createFuseLinalgOpsPass()
}];
let description = [{
This pass fuses linalg ops.
}];
}

#endif // TORCHMLIR_CONVERSION_PASSES
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(CAPI)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Runtime)

set(LinkedLibs
MLIRComplexDialect
Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
add_subdirectory(LinalgToKernelCalls)
add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith)
add_subdirectory(TorchToTosa)
add_subdirectory(FuseLinalg)
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToStablehlo)
endif()
Expand All @@ -10,7 +12,9 @@ add_subdirectory(TorchConversionToMLProgram)
add_subdirectory(Utils)

# TODO: Automate this with add_torch_mlir_conversion_library.
set(linked_libs TorchMLIRTorchToLinalg
set(linked_libs TorchMLIRLinalgToKernelCalls
TorchMLIRTorchToLinalg
FuseLinalg
TorchMLIRTorchToSCF
TorchMLIRTorchToArith
TorchMLIRTorchToTosa
Expand Down
16 changes: 16 additions & 0 deletions lib/Conversion/FuseLinalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
add_mlir_conversion_library(FuseLinalg
FuseLinalg.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/FuseLinalg

DEPENDS
TorchMLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRLinalgDialect
)

torch_mlir_target_includes(FuseLinalg)
128 changes: 128 additions & 0 deletions lib/Conversion/FuseLinalg/FuseLinalg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#include "torch-mlir/Conversion/FuseLinalg/FuseLinalg.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <algorithm>

namespace {

class MatmulTranspose : public mlir::OpRewritePattern<mlir::linalg::MatmulOp> {
public:
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(mlir::linalg::MatmulOp op,
mlir::PatternRewriter &rewriter) const override {

auto mm_op = mlir::cast<mlir::linalg::MatmulOp>(&op);
auto inputs = mm_op->getInputs();
std::vector<int32_t> transposed_input_nums;
transposed_input_nums.reserve(inputs.size());

mlir::linalg::TransposeOp transposeOp = nullptr;
std::fill_n(std::back_inserter(transposed_input_nums), inputs.size(), -1);
if (transposed_input_nums.size() != 2) {
return rewriter.notifyMatchFailure(
op, "Inputs amount is not common for Matmul");
}

for (size_t num_input = 0; num_input < inputs.size(); num_input++) {
mlir::Value input = inputs[num_input];
transposeOp =
llvm::dyn_cast<mlir::linalg::TransposeOp>(input.getDefiningOp());
if (!transposeOp) {
continue;
}
transposed_input_nums[num_input] = num_input;
}

const int default_inps = std::count(transposed_input_nums.cbegin(),
transposed_input_nums.cend(), -1);
if (default_inps != 1) {
return rewriter.notifyMatchFailure(op, "Both Inputs is Transposed");
}

if (!transposeOp) {
return rewriter.notifyMatchFailure(op, "Input is not TransposeOp");
}
/* Maybe we need this
if (isListPotential[ERROR] lyMutated(transposeOp.getResult()))
return rewriter.notifyMatchFailure(
op, "TransposeOp result is potentially mutated");
*/

mlir::Location loc = op.getLoc();

mlir::SmallVector<mlir::Value> fusedInputOperands, fusedOutputOperands;
mlir::SmallVector<mlir::Type> fusedResultTypes;
for (mlir::OpOperand &opOperand : op.getOutputsMutable()) {
fusedOutputOperands.push_back(opOperand.get());
mlir::Type resultType = opOperand.get().getType();
if (!mlir::isa<mlir::MemRefType>(resultType))
fusedResultTypes.push_back(resultType);
}

mlir::Value matmul;
if (transposed_input_nums[0] != -1) {
fusedInputOperands.push_back(transposeOp.getInputMutable().get());
fusedInputOperands.push_back(op.getInputsMutable()[1].get());

auto mm_a = rewriter.create<mlir::linalg::MatmulTransposeAOp>(
loc, fusedResultTypes, fusedInputOperands, fusedOutputOperands);
matmul = mm_a.getResult(0);
} else if (transposed_input_nums[1] != -1) {
fusedInputOperands.push_back(op.getInputsMutable()[0].get());
fusedInputOperands.push_back(transposeOp.getInputMutable().get());

auto mm_b = rewriter.create<mlir::linalg::MatmulTransposeBOp>(
loc, fusedResultTypes, fusedInputOperands, fusedOutputOperands);

matmul = mm_b.getResult(0);
}

rewriter.replaceUsesWithIf(
op.getResult(0), matmul, [&](mlir::OpOperand &use) {
// Only replace consumer uses.
return use.get().getDefiningOp() != transposeOp;
});
rewriter.eraseOp(op);

if (transposeOp.getResult().use_empty()) {
rewriter.eraseOp(transposeOp);
}
return mlir::success();
}
};

class FuseLinalgOps : public mlir::torch::FuseLinalgOpsBase<FuseLinalgOps> {
public:
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);

// pattern.add calls go here
patterns.add<MatmulTranspose>(context);

mlir::GreedyRewriteConfig config;
config.useTopDownTraversal = true;

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
mlir::emitError(getOperation()->getLoc(), "failure in Linalg fusion");
signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
mlir::torch::createFuseLinalgOpsPass() {
return std::make_unique<FuseLinalgOps>();
}
18 changes: 18 additions & 0 deletions lib/Conversion/LinalgToKernelCalls/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_conversion_library(TorchMLIRLinalgToKernelCalls
LinalgToKernelCalls.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg

DEPENDS
TorchMLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRLinalgDialect
MLIRMathDialect
TorchMLIRTorchDialect
)

torch_mlir_target_includes(TorchMLIRLinalgToKernelCalls)
Loading