Skip to content

Commit

Permalink
[TOSA] Add Tosa_Shape type and ConstShapeOp
Browse files Browse the repository at this point in the history
Adds:
 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for
    rank-4 shape values (size-4 array of index values)
 2. const_shape operator
 3. trait TosaShapeOperator, added to tosa shape operators, and a
    verifier that all operands and results of operator are tosa shapes
 4. trait TosaResolvableShapeOperands, added to all tosa operators, and
    a verifier that every tosa shape operand is produced by a tosa shape
    operator (indicated by trait TosaShapeOperator)
 5. trait TosaShapeOperatorWithSameRanks, added to
    Tosa_ElementwiseShapeOp and a verifier that all operands and result
    shapes have same ranks
 5. changed TileOp's multiples from attribute to input, of !tosa.shape
    type.
 6. add folder for tosa ConstShape operator

Signed-off-by: Jerry Ge <[email protected]>
Signed-off-by: Tai Ly <[email protected]>

Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8
  • Loading branch information
Jerry-Ge committed Jan 10, 2025
1 parent 749bdc8 commit 39d8e44
Show file tree
Hide file tree
Showing 18 changed files with 441 additions and 32 deletions.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)

set(LLVM_TARGET_DEFINITIONS TosaOps.td)
mlir_tablegen(TosaOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
mlir_tablegen(TosaOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
add_public_tablegen_target(MLIRTosaAttributesIncGen)

set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)

1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
let cppNamespace = "mlir::tosa";
let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
41 changes: 41 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,55 @@ template <typename ConcreteType>
class TosaElementwiseOperator
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};

LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
/// This class verifies that tosa shape operands are compile time resolvable
template <typename ConcreteType>
class TosaResolvableShapeOperands
: public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaResolvableShapeOperands(op);
}
};

LogicalResult verifyTosaShapeOperator(Operation *op);
/// This class indicates that op operates on tosa shape types
template <typename ConcreteType>
class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaShapeOperator(op);
}
};

LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
/// This class indicates that op operates on tosa shape types
template <typename ConcreteType>
class TosaShapeOperatorWithSameRanks
: public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaShapeOperatorWithSameRanks(op);
}
};

} // namespace tosa
} // namespace OpTrait

namespace tosa {

bool isa_tosa_shape_type(mlir::Type t);

} // namespace tosa

} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"

Expand Down
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"

include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
include "mlir/Dialect/Tosa/IR/TosaTypes.td"

//===----------------------------------------------------------------------===//
// TOSA Spec Section 2.2
Expand Down Expand Up @@ -1689,12 +1690,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {

let arguments = (ins
Tosa_Tensor:$input1,
DenseI64ArrayAttr:$multiples);
Tosa_Shape:$multiples);

let results = (outs
Tosa_Tensor:$output
);

let extraClassDeclaration = [{
LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
}];

let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -2106,4 +2111,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [

include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"

include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"

#endif // TOSA_OPS
79 changes: 79 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines shape operators for the TOSA dialect.
//
//===----------------------------------------------------------------------===//

#ifndef TOSA_SHAPE_OPS
#define TOSA_SHAPE_OPS

include "mlir/IR/OpBase.td"

include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"

include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
include "mlir/Dialect/Tosa/IR/TosaTypes.td"

// Op trait: operator has operands and results with TOSA shape type
def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
: Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";

let hasFolder = 1;
}

// op trait: shape operator has same ranks for operands and results
def TosaShapeOperatorWithSameRanks : NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
: Tosa_ShapeOp<mnemonic, !listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
}

//===----------------------------------------------------------------------===//
// Operator: ConstShape
//===----------------------------------------------------------------------===//
def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
let summary = "Constant Shape op.";

let description = [{
A node containing constant data for use as the input to an shape operation. May
hold data only in index data type.

Example:

```mlir
// Generic form
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
```
}];

let arguments = (ins
IndexElementsAttr:$value
);

let results = (outs
Tosa_Shape:$output
);

let hasVerifier = 1;
}

#endif // TOSA_SHAPE_OPS
87 changes: 87 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//===-- TosaTypes.td - TOSA type definitions ---------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the type definitions for the TOSA dialect.
//
//===----------------------------------------------------------------------===//

#ifndef TOSA_TYPES
#define TOSA_TYPES

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

include "mlir/Dialect/Tosa/IR/TosaOpBase.td"

//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//

// The base class for Tosa dialect types.
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Tosa_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

//===----------------------------------------------------------------------===//
// ShapeType
//===----------------------------------------------------------------------===//
def Tosa_Shape : Tosa_Type<"shape", "shape"> {
let summary = "Shape with static rank and Index element type";
let description = [{
Syntax:

```
shape-type ::= `shape` `<` rank `>`
```
Values with shape type represents a shape with a fixed rank and a list of dimensions.
Rank must be zero or a positive integer.
Each dimension is represented by the builtin Index type.

Examples:

```mlir
// Shape with rank of four, for example, [1, 1, 8, 16]:
!tosa.shape<4>

// Shape with rank of one, for example, [16]:
!tosa.shape<1>

// Shape with rank zero, for example, [] (i.e., shape of scalar values):
!tosa.shape<0>
```
}];
let parameters = (ins
"int":$rank
);
let builders = [
TypeBuilder<(ins "int":$rank)>
];
let assemblyFormat = "`<` $rank `>`";

let genVerifyDecl = 1;
}

def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;

// Whether a Tosa Shape type has a rank equal to the specified rank.
class IsTosaShapeOfRankPred<int rank> : And<[
IsTosaShapeType,
CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
]>;

class TosaShapeOfRank<int rank> :
Type<IsTosaShapeOfRankPred<rank>,
"Tosa shape type of rank " # rank
>;

def Rank1TosaShape : TosaShapeOfRank<1>;
def Rank2TosaShape : TosaShapeOfRank<2>;
def Rank4TosaShape : TosaShapeOfRank<4>;

#endif // TOSA_TYPES
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();

ArrayRef<int64_t> multiples = op.getMultiples();
SmallVector<int64_t> multiples;
if (failed(op.getConstantMultiples(multiples)))
return failure();

// Broadcast the newly added dimensions to their appropriate multiple.
SmallVector<int64_t, 2> genericShape;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.addLegalOp<tosa::ApplyScaleOp>();
target.addLegalOp<tosa::IfOp>();
target.addLegalOp<tosa::ConstOp>();
target.addLegalOp<tosa::ConstShapeOp>();
target.addLegalOp<tosa::WhileOp>();
target.addLegalOp<tosa::ConcatOp>();
target.addLegalOp<tosa::SliceOp>();
Expand Down
19 changes: 16 additions & 3 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

#define REDUCE_FOLDER(OP) \
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
Expand Down Expand Up @@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
if (allOnes && getInput1().getType() == getType())
return getInput1();
if (getInput1().getType() == getType()) {
if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
adaptor.getMultiples())) {
if (multiples.isSplat() &&
multiples.getSplatValue<APInt>().getSExtValue() == 1)
return getInput1();
if (auto int_array_attr =
llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
if (llvm::all_of(int_array_attr.getValues<APInt>(),
[](APInt v) { return v.getSExtValue() == 1; }))
return getInput1();
}
}
}
return {};
}

Expand Down
Loading

0 comments on commit 39d8e44

Please sign in to comment.