Skip to content

Commit

Permalink
COPYBARA SYNC:
Browse files Browse the repository at this point in the history
  - 2fd35eb [tf-frontend] infer shape after convert where to static s...
  - bed7740 [torch-frontend] fix pipeline for uint8 (#376)
  - d6f21e2 [dynamo] support enable_tf32 in byteir_backend (#375)
  - e1c6a15 [compiler] add enable-tf32 option (#374)
  - 3ce31e0 [runtime] support gemm/bmm compute on tf32 (#373)
  - d255b60 [onnx] fix squeezed value (#371)
  - 199c508 [compiler] add forall-collapsing pass (#366)
  - c6a965c [*] bump version to 1.8.5.0 (#369)
  - dd35694 [torch-frontend] update torch-mlir (#368)
  - d5be436 [compiler] fix bug for clang compiliation (#365)
  (And 2 more changes)

GitOrigin-RevId: 2fd35eb
  • Loading branch information
Vremold committed Jun 27, 2024
1 parent a3888d7 commit 389fcc6
Show file tree
Hide file tree
Showing 43 changed files with 733 additions and 296 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ class FuncOp;

void populateHloToByreTensorPattern(
RewritePatternSet &patterns,
const llvm::StringMap<llvm::StringRef> &supportMap, bool appendArgTypes);
const llvm::StringMap<llvm::StringRef> &supportMap, bool appendArgTypes,
bool enableTF32);

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertHloToByreTensorPass(bool appendArgTypes = false);
createConvertHloToByreTensorPass(bool appendArgTypes = false,
bool enableTF32 = false);

} // namespace mlir

Expand Down
2 changes: 2 additions & 0 deletions compiler/include/byteir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "func::FuncOp"> {
let options = [
Option<"appendArgTypes", "append-arg-types", "bool", /*default=*/"false",
"append arg types to Byre">,
Option<"enableTF32", "enable-tf32", "bool", /*default=*/"false",
"enable 1xTF32 on fp32 gemm/bmm">,
];
}

Expand Down
1 change: 1 addition & 0 deletions compiler/include/byteir/Dialect/SCF/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifndef BYTEIR_DIALECT_SCF_PASSES_H
#define BYTEIR_DIALECT_SCF_PASSES_H

#include "byteir/Dialect/SCF/Transforms/ForallCollapsing.h"
#include "byteir/Dialect/SCF/Transforms/FuseNestedForall.h"
#include "byteir/Dialect/SCF/Transforms/InsertTrivialSCFLoop.h"

Expand Down
17 changes: 17 additions & 0 deletions compiler/include/byteir/Dialect/SCF/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,21 @@ def FuseNestedForall : Pass<"fuse-nested-forall", "mlir::func::FuncOp"> {
];
}

//===----------------------------------------------------------------------===//
// ForallCollapsing
//===----------------------------------------------------------------------===//

def ForallCollapsing : Pass<"forall-collapsing", "mlir::func::FuncOp"> {
let summary = "collapse forall";
let constructor = "mlir::createForallCollapsingPass()";
let dependentDialects = [
"scf::SCFDialect"
];
let options = [
Option<"anchorTag", "anchor-tag", "std::string",
/*default=*/"",
"Optional unitAttr anchored tag to apply this pass">
];
}

#endif // BYTEIR_DIALECT_SCF_PASSES
34 changes: 34 additions & 0 deletions compiler/include/byteir/Dialect/SCF/Transforms/ForallCollapsing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- ForallCollapsing.h ------------------------------------- C++ --===//
//
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLCOLLAPSING_H
#define BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLCOLLAPSING_H

#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func

std::unique_ptr<OperationPass<func::FuncOp>>
createForallCollapsingPass(llvm::StringRef anchorTag = "");

} // namespace mlir

#endif // BYTEIR_DIALECT_SCF_TRANSFORMS_FORALLCOLLAPSING_H
4 changes: 4 additions & 0 deletions compiler/include/byteir/Pipelines/ByreTensorOpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ struct ByreTensorOptPipelineOptions
*this, "append-arg-types",
llvm::cl::desc("whether to append arg types to Byre"),
llvm::cl::init(false)};
Option<bool> enableTF32{
*this, "enable-tf32",
llvm::cl::desc("whether to enable 1xTF32 on f32 gemm/bmm"),
llvm::cl::init(false)};
};

void createByreTensorOptPipeline(OpPassManager &pm,
Expand Down
69 changes: 47 additions & 22 deletions compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,28 +447,42 @@ class ConvertDotOpToByrePattern : public OpConversionPattern<mhlo::DotOp> {
class ConvertDotGeneralOpToByrePattern
: public OpConversionPattern<mhlo::DotGeneralOp> {
public:
ConvertDotGeneralOpToByrePattern(MLIRContext *ctx, bool appendTypes)
ConvertDotGeneralOpToByrePattern(MLIRContext *ctx, bool appendTypes,
bool enableTF32)
: OpConversionPattern<mhlo::DotGeneralOp>(ctx),
appendArgTypes(appendTypes) {}
appendArgTypes(appendTypes), enableTF32(enableTF32) {}

LogicalResult
matchAndRewrite(mlir::mhlo::DotGeneralOp op,
mlir::mhlo::DotGeneralOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto dotDimensionNumbers = adaptor.getDotDimensionNumbers();
assert(dotDimensionNumbers.getLhsContractingDimensions().size() == 1);
assert(dotDimensionNumbers.getRhsContractingDimensions().size() == 1);
if (dotDimensionNumbers.getLhsBatchingDimensions().size() == 0) {
if (dotDimensionNumbers.getLhsContractingDimensions().size() != 1) {
return failure();
}
if (dotDimensionNumbers.getRhsContractingDimensions().size() != 1) {
return failure();
}
auto lhsBatchs = dotDimensionNumbers.getLhsBatchingDimensions();
auto rhsBatchs = dotDimensionNumbers.getRhsBatchingDimensions();
size_t lhsRank = cast<ShapedType>(op.getLhs().getType()).getRank();
size_t rhsRank = cast<ShapedType>(op.getRhs().getType()).getRank();
if (lhsRank != rhsRank) {
return failure();
}
if (lhsRank != lhsBatchs.size() + 2 || rhsRank != rhsBatchs.size() + 2) {
return failure();
}

if (dotDimensionNumbers.getLhsBatchingDimensions().size() == 0 &&
dotDimensionNumbers.getRhsBatchingDimensions().size() == 0) {
// convert to MatmulOp
auto failureOrComputeOnTensorOp = replaceMhloOpWithByreComputeOnTensorOp(
rewriter, op, "MatmulOp", adaptor.getOperands(), appendArgTypes);
if (failed(failureOrComputeOnTensorOp)) {
return failure();
}
auto computeOnTensorOp = *failureOrComputeOnTensorOp;
// append attribute 'lhsContractingDimension' and
// 'rhsContractingDimension'
int64_t lhsContractingDimension =
dotDimensionNumbers.getLhsContractingDimensions()[0];
int64_t rhsContractingDimension =
Expand All @@ -479,14 +493,14 @@ class ConvertDotGeneralOpToByrePattern
computeOnTensorOp->setAttr(
"rhs_contracting_dimension",
rewriter.getI64IntegerAttr(rhsContractingDimension));
if (this->enableTF32) {
computeOnTensorOp->setAttr("compute_type",
TypeAttr::get(rewriter.getTF32Type()));
}
} else {
// convert to BatchMatmulOp
SmallVector<int64_t> batchingDimensions;
for (int64_t i = 0,
e = cast<ShapedType>(op->getResult(0).getType()).getRank();
i < e - 2; i++) {
batchingDimensions.push_back(i);
}
SmallVector<int64_t> batchingDimensions =
to_vector(llvm::seq<int64_t>(0, lhsRank - 2));
if (!dotDimensionNumbers.getLhsBatchingDimensions().equals(
batchingDimensions) ||
!dotDimensionNumbers.getRhsBatchingDimensions().equals(
Expand Down Expand Up @@ -522,12 +536,17 @@ class ConvertDotGeneralOpToByrePattern
computeOnTensorOp->setAttr(
"rhs_batching_dimensions",
rewriter.getI64ArrayAttr(rhsBatchingDimensions));
if (this->enableTF32) {
computeOnTensorOp->setAttr("compute_type",
TypeAttr::get(rewriter.getTF32Type()));
}
}
return success();
}

private:
bool appendArgTypes;
bool enableTF32;
};

class ConvertConvOpToByrePattern
Expand Down Expand Up @@ -757,9 +776,10 @@ class ConvertSelectAndScatterOpToByrePattern
struct ConvertHloToByreTensorPass
: public ConvertHloToByreTensorBase<ConvertHloToByreTensorPass> {
public:
ConvertHloToByreTensorPass(bool appendArgTypes)
ConvertHloToByreTensorPass(bool appendArgTypes, bool enableTF32)
: ConvertHloToByreTensorBase() {
this->appendArgTypes = appendArgTypes;
this->enableTF32 = enableTF32;

supportMap.insert({"mhlo.transpose", "TransposeOp"});
}
Expand All @@ -770,7 +790,8 @@ struct ConvertHloToByreTensorPass
ConversionTarget target(ctx);
auto funcOp = getOperation();

populateHloToByreTensorPattern(patterns, supportMap, appendArgTypes);
populateHloToByreTensorPattern(patterns, supportMap, appendArgTypes,
enableTF32);
target.addIllegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<tensor::TensorDialect, byre::ByreDialect,
shape::ShapeDialect, arith::ArithDialect>();
Expand All @@ -788,19 +809,22 @@ struct ConvertHloToByreTensorPass

void mlir::populateHloToByreTensorPattern(
RewritePatternSet &patterns,
const llvm::StringMap<llvm::StringRef> &supportMap, bool appendArgTypes) {
const llvm::StringMap<llvm::StringRef> &supportMap, bool appendArgTypes,
bool enableTF32) {

patterns.add<ConvertToByrePattern<mhlo::AddOp>,
ConvertToByrePattern<mhlo::ConvertOp>,
ConvertToByrePattern<mhlo::TransposeOp, /*keepAttrs*/ true>>(
patterns.getContext(), supportMap, appendArgTypes);

patterns.add<ConvertDotGeneralOpToByrePattern>(patterns.getContext(),
appendArgTypes, enableTF32);

patterns.add<ConvertCustomCallOpToByrePattern<mhlo::CustomCallOp>,
ConvertCustomCallOpToByrePattern<ace::CustomCallOp>,
ConvertGatherOpToByrePattern, ConvertScatterOpToByrePattern,
ConvertDotOpToByrePattern, ConvertDotGeneralOpToByrePattern,
ConvertConvOpToByrePattern, ConvertReduceOpToByrePattern,
ConvertReduceWindowOpToByrePattern,
ConvertDotOpToByrePattern, ConvertConvOpToByrePattern,
ConvertReduceOpToByrePattern, ConvertReduceWindowOpToByrePattern,
ConvertSelectAndScatterOpToByrePattern>(patterns.getContext(),
appendArgTypes);

Expand All @@ -811,6 +835,7 @@ void mlir::populateHloToByreTensorPattern(
}

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createConvertHloToByreTensorPass(bool appendArgTypes) {
return std::make_unique<ConvertHloToByreTensorPass>(appendArgTypes);
mlir::createConvertHloToByreTensorPass(bool appendArgTypes, bool enableTF32) {
return std::make_unique<ConvertHloToByreTensorPass>(appendArgTypes,
enableTF32);
}
1 change: 1 addition & 0 deletions compiler/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(ByteIRSCFPasses
ForallCollapsing.cpp
FuseNestedForall.cpp
InsertTrivialSCFLoop.cpp
TilingInterfaceToSCFFor.cpp
Expand Down
Loading

0 comments on commit 389fcc6

Please sign in to comment.