From dcc247f1af596ab78928bcf20ba578146044a2be Mon Sep 17 00:00:00 2001 From: MocusEZ Date: Thu, 25 Jul 2024 18:10:04 +0800 Subject: [PATCH] first commit --- CMakeLists.txt | 22 + Ch1/CMakeLists.txt | 13 + Ch1/include/toy/AST.h | 246 ++ Ch1/include/toy/Lexer.h | 232 ++ Ch1/include/toy/Parser.h | 489 +++ Ch1/parser/AST.cpp | 237 ++ Ch1/toyc.cpp | 71 + Ch2/CMakeLists.txt | 23 + Ch2/include/CMakeLists.txt | 1 + Ch2/include/toy/AST.h | 246 ++ Ch2/include/toy/CMakeLists.txt | 1 + Ch2/include/toy/Dialect.cpp.inc | 23 + Ch2/include/toy/Dialect.h | 33 + Ch2/include/toy/Dialect.h.inc | 26 + Ch2/include/toy/Lexer.h | 232 ++ Ch2/include/toy/MLIRGen.h | 35 + Ch2/include/toy/Ops.cpp.inc | 2049 ++++++++++++ Ch2/include/toy/Ops.h.inc | 1240 +++++++ Ch2/include/toy/Ops.td | 335 ++ Ch2/include/toy/Parser.h | 489 +++ Ch2/include/toy/run.sh | 4 + Ch2/mlir/Dialect.cpp | 323 ++ Ch2/mlir/MLIRGen.cpp | 457 +++ Ch2/parser/AST.cpp | 237 ++ Ch2/toyc.cpp | 145 + Ch3/CMakeLists.txt | 36 + Ch3/include/CMakeLists.txt | 1 + Ch3/include/toy/AST.h | 246 ++ Ch3/include/toy/CMakeLists.txt | 6 + Ch3/include/toy/Dialect.cpp.inc | 23 + Ch3/include/toy/Dialect.h | 33 + Ch3/include/toy/Dialect.h.inc | 26 + Ch3/include/toy/Lexer.h | 232 ++ Ch3/include/toy/MLIRGen.h | 35 + Ch3/include/toy/Ops.cpp.inc | 2061 ++++++++++++ Ch3/include/toy/Ops.h.inc | 1247 +++++++ Ch3/include/toy/Ops.td | 339 ++ Ch3/include/toy/Parser.h | 489 +++ Ch3/include/toy/run.sh | 4 + Ch3/mlir/Dialect.cpp | 323 ++ Ch3/mlir/MLIRGen.cpp | 457 +++ Ch3/mlir/ToyCombine.cpp | 69 + Ch3/mlir/ToyCombine.inc | 176 + Ch3/mlir/ToyCombine.td | 64 + Ch3/mlir/run.sh | 2 + Ch3/parser/AST.cpp | 237 ++ Ch3/toyc.cpp | 170 + Ch4/CMakeLists.txt | 40 + Ch4/include/CMakeLists.txt | 1 + Ch4/include/run.sh | 7 + Ch4/include/toy/AST.h | 246 ++ Ch4/include/toy/CMakeLists.txt | 13 + Ch4/include/toy/Dialect.cpp.inc | 23 + Ch4/include/toy/Dialect.h | 36 + Ch4/include/toy/Dialect.h.inc | 26 + Ch4/include/toy/Lexer.h | 232 ++ Ch4/include/toy/MLIRGen.h | 35 + Ch4/include/toy/Ops.cpp.inc | 2242 +++++++++++++ Ch4/include/toy/Ops.h.inc | 1360 ++++++++ Ch4/include/toy/Ops.td | 372 +++ Ch4/include/toy/Parser.h | 489 +++ Ch4/include/toy/Passes.h | 26 + Ch4/include/toy/ShapeInferenceInterface.h | 28 + Ch4/include/toy/ShapeInferenceInterface.td | 30 + .../toy/ShapeInferenceOpInterfaces.cpp.inc | 12 + .../toy/ShapeInferenceOpInterfaces.h.inc | 61 + Ch4/include/toy/run.sh | 7 + Ch4/mlir/Dialect.cpp | 444 +++ Ch4/mlir/MLIRGen.cpp | 461 +++ Ch4/mlir/ShapeInferencePass.cpp | 122 + Ch4/mlir/ToyCombine.cpp | 69 + Ch4/mlir/ToyCombine.inc | 176 + Ch4/mlir/ToyCombine.td | 63 + Ch4/mlir/run.sh | 2 + Ch4/parser/AST.cpp | 237 ++ Ch4/toyc.cpp | 179 + Ch5/CMakeLists.txt | 45 + Ch5/include/CMakeLists.txt | 1 + Ch5/include/run.sh | 7 + Ch5/include/toy/AST.h | 246 ++ Ch5/include/toy/CMakeLists.txt | 13 + Ch5/include/toy/Dialect.cpp.inc | 23 + Ch5/include/toy/Dialect.h | 36 + Ch5/include/toy/Dialect.h.inc | 26 + Ch5/include/toy/Lexer.h | 232 ++ Ch5/include/toy/MLIRGen.h | 35 + Ch5/include/toy/Ops.cpp.inc | 2252 +++++++++++++ Ch5/include/toy/Ops.h.inc | 1361 ++++++++ Ch5/include/toy/Ops.td | 372 +++ Ch5/include/toy/Parser.h | 489 +++ Ch5/include/toy/Passes.h | 31 + Ch5/include/toy/ShapeInferenceInterface.h | 28 + Ch5/include/toy/ShapeInferenceInterface.td | 30 + .../toy/ShapeInferenceOpInterfaces.cpp.inc | 12 + .../toy/ShapeInferenceOpInterfaces.h.inc | 61 + Ch5/mlir/Dialect.cpp | 444 +++ Ch5/mlir/LowerToAffineLoops.cpp | 385 +++ Ch5/mlir/MLIRGen.cpp | 461 +++ Ch5/mlir/ShapeInferencePass.cpp | 122 + Ch5/mlir/ToyCombine.cpp | 69 + Ch5/mlir/ToyCombine.inc | 176 + Ch5/mlir/ToyCombine.td | 63 + Ch5/mlir/run.sh | 2 + Ch5/parser/AST.cpp | 237 ++ Ch5/toyc.cpp | 207 ++ Ch6/CMakeLists.txt | 65 + Ch6/include/CMakeLists.txt | 1 + Ch6/include/run.sh | 7 + Ch6/include/toy/AST.h | 246 ++ Ch6/include/toy/CMakeLists.txt | 13 + Ch6/include/toy/Dialect.cpp.inc | 23 + Ch6/include/toy/Dialect.h | 36 + Ch6/include/toy/Dialect.h.inc | 26 + Ch6/include/toy/Lexer.h | 232 ++ Ch6/include/toy/MLIRGen.h | 35 + Ch6/include/toy/Ops.cpp.inc | 2252 +++++++++++++ Ch6/include/toy/Ops.h.inc | 1361 ++++++++ Ch6/include/toy/Ops.td | 372 +++ Ch6/include/toy/Parser.h | 489 +++ Ch6/include/toy/Passes.h | 35 + Ch6/include/toy/ShapeInferenceInterface.h | 28 + Ch6/include/toy/ShapeInferenceInterface.td | 30 + .../toy/ShapeInferenceOpInterfaces.cpp.inc | 12 + .../toy/ShapeInferenceOpInterfaces.h.inc | 61 + Ch6/mlir/Dialect.cpp | 444 +++ Ch6/mlir/LowerToAffineLoops.cpp | 385 +++ Ch6/mlir/LowerToLLVM.cpp | 241 ++ Ch6/mlir/MLIRGen.cpp | 461 +++ Ch6/mlir/ShapeInferencePass.cpp | 122 + Ch6/mlir/ToyCombine.cpp | 69 + Ch6/mlir/ToyCombine.inc | 176 + Ch6/mlir/ToyCombine.td | 63 + Ch6/mlir/run.sh | 2 + Ch6/parser/AST.cpp | 237 ++ Ch6/toyc.cpp | 329 ++ Ch7/CMakeLists.txt | 62 + Ch7/include/CMakeLists.txt | 1 + Ch7/include/run.sh | 7 + Ch7/include/toy/AST.h | 313 ++ Ch7/include/toy/CMakeLists.txt | 13 + Ch7/include/toy/Dialect.cpp.inc | 23 + Ch7/include/toy/Dialect.h | 82 + Ch7/include/toy/Dialect.h.inc | 40 + Ch7/include/toy/Lexer.h | 235 ++ Ch7/include/toy/MLIRGen.h | 35 + Ch7/include/toy/Ops.cpp.inc | 2907 +++++++++++++++++ Ch7/include/toy/Ops.h.inc | 1698 ++++++++++ Ch7/include/toy/Ops.td | 453 +++ Ch7/include/toy/Parser.h | 683 ++++ Ch7/include/toy/Passes.h | 35 + Ch7/include/toy/ShapeInferenceInterface.h | 28 + Ch7/include/toy/ShapeInferenceInterface.td | 30 + .../toy/ShapeInferenceOpInterfaces.cpp.inc | 12 + .../toy/ShapeInferenceOpInterfaces.h.inc | 61 + Ch7/mlir/Dialect.cpp | 665 ++++ Ch7/mlir/LowerToAffineLoops.cpp | 385 +++ Ch7/mlir/LowerToLLVM.cpp | 241 ++ Ch7/mlir/MLIRGen.cpp | 692 ++++ Ch7/mlir/ShapeInferencePass.cpp | 122 + Ch7/mlir/ToyCombine.cpp | 90 + Ch7/mlir/ToyCombine.inc | 176 + Ch7/mlir/ToyCombine.td | 63 + Ch7/mlir/run.sh | 2 + Ch7/parser/AST.cpp | 274 ++ Ch7/toyc.cpp | 330 ++ Examples/Toy/Ch1/ast.toy | 74 + Examples/Toy/Ch1/empty.toy | 3 + Examples/Toy/Ch2/ast.toy | 76 + Examples/Toy/Ch2/codegen.toy | 31 + Examples/Toy/Ch2/empty.toy | 3 + Examples/Toy/Ch2/invalid.mlir | 9 + Examples/Toy/Ch2/scalar.toy | 14 + Examples/Toy/Ch3/ast.toy | 76 + Examples/Toy/Ch3/codegen.toy | 31 + Examples/Toy/Ch3/empty.toy | 3 + Examples/Toy/Ch3/invalid.mlir | 9 + Examples/Toy/Ch3/scalar.toy | 14 + Examples/Toy/Ch3/transpose_transpose.toy | 22 + Examples/Toy/Ch3/trivial_reshape.toy | 16 + Examples/Toy/Ch4/ast.toy | 76 + Examples/Toy/Ch4/codegen.toy | 31 + Examples/Toy/Ch4/empty.toy | 3 + Examples/Toy/Ch4/invalid.mlir | 9 + Examples/Toy/Ch4/scalar.toy | 14 + Examples/Toy/Ch4/shape_inference.mlir | 30 + Examples/Toy/Ch4/transpose_transpose.toy | 17 + Examples/Toy/Ch4/trivial_reshape.toy | 16 + Examples/Toy/Ch5/affine-lowering.mlir | 64 + Examples/Toy/Ch5/ast.toy | 76 + Examples/Toy/Ch5/codegen.toy | 31 + Examples/Toy/Ch5/empty.toy | 3 + Examples/Toy/Ch5/invalid.mlir | 9 + Examples/Toy/Ch5/scalar.toy | 14 + Examples/Toy/Ch5/shape_inference.mlir | 30 + Examples/Toy/Ch5/transpose_transpose.toy | 17 + Examples/Toy/Ch5/trivial_reshape.toy | 16 + Examples/Toy/Ch6/affine-lowering.mlir | 64 + Examples/Toy/Ch6/ast.toy | 76 + Examples/Toy/Ch6/codegen.toy | 31 + Examples/Toy/Ch6/empty.toy | 3 + Examples/Toy/Ch6/invalid.mlir | 9 + Examples/Toy/Ch6/jit.toy | 6 + Examples/Toy/Ch6/lit.local.cfg | 3 + Examples/Toy/Ch6/llvm-lowering.mlir | 23 + Examples/Toy/Ch6/scalar.toy | 14 + Examples/Toy/Ch6/shape_inference.mlir | 30 + Examples/Toy/Ch6/transpose_transpose.toy | 17 + Examples/Toy/Ch6/trivial_reshape.toy | 16 + Examples/Toy/Ch7/affine-lowering.mlir | 64 + Examples/Toy/Ch7/ast.toy | 76 + Examples/Toy/Ch7/codegen.toy | 31 + Examples/Toy/Ch7/empty.toy | 4 + Examples/Toy/Ch7/invalid.mlir | 9 + Examples/Toy/Ch7/jit.toy | 6 + Examples/Toy/Ch7/lit.local.cfg | 3 + Examples/Toy/Ch7/llvm-lowering.mlir | 23 + Examples/Toy/Ch7/scalar.toy | 14 + Examples/Toy/Ch7/shape_inference.mlir | 30 + Examples/Toy/Ch7/struct-ast.toy | 61 + Examples/Toy/Ch7/struct-codegen.toy | 44 + Examples/Toy/Ch7/struct-opt.mlir | 15 + Examples/Toy/Ch7/transpose_transpose.toy | 17 + Examples/Toy/Ch7/trivial_reshape.toy | 16 + README.md | 34 + 224 files changed, 47177 insertions(+) create mode 100644 CMakeLists.txt create mode 100644 Ch1/CMakeLists.txt create mode 100644 Ch1/include/toy/AST.h create mode 100644 Ch1/include/toy/Lexer.h create mode 100644 Ch1/include/toy/Parser.h create mode 100644 Ch1/parser/AST.cpp create mode 100644 Ch1/toyc.cpp create mode 100644 Ch2/CMakeLists.txt create mode 100644 Ch2/include/CMakeLists.txt create mode 100644 Ch2/include/toy/AST.h create mode 100644 Ch2/include/toy/CMakeLists.txt create mode 100644 Ch2/include/toy/Dialect.cpp.inc create mode 100644 Ch2/include/toy/Dialect.h create mode 100644 Ch2/include/toy/Dialect.h.inc create mode 100644 Ch2/include/toy/Lexer.h create mode 100644 Ch2/include/toy/MLIRGen.h create mode 100644 Ch2/include/toy/Ops.cpp.inc create mode 100644 Ch2/include/toy/Ops.h.inc create mode 100644 Ch2/include/toy/Ops.td create mode 100644 Ch2/include/toy/Parser.h create mode 100644 Ch2/include/toy/run.sh create mode 100644 Ch2/mlir/Dialect.cpp create mode 100644 Ch2/mlir/MLIRGen.cpp create mode 100644 Ch2/parser/AST.cpp create mode 100644 Ch2/toyc.cpp create mode 100644 Ch3/CMakeLists.txt create mode 100644 Ch3/include/CMakeLists.txt create mode 100644 Ch3/include/toy/AST.h create mode 100644 Ch3/include/toy/CMakeLists.txt create mode 100644 Ch3/include/toy/Dialect.cpp.inc create mode 100644 Ch3/include/toy/Dialect.h create mode 100644 Ch3/include/toy/Dialect.h.inc create mode 100644 Ch3/include/toy/Lexer.h create mode 100644 Ch3/include/toy/MLIRGen.h create mode 100644 Ch3/include/toy/Ops.cpp.inc create mode 100644 Ch3/include/toy/Ops.h.inc create mode 100644 Ch3/include/toy/Ops.td create mode 100644 Ch3/include/toy/Parser.h create mode 100644 Ch3/include/toy/run.sh create mode 100644 Ch3/mlir/Dialect.cpp create mode 100644 Ch3/mlir/MLIRGen.cpp create mode 100644 Ch3/mlir/ToyCombine.cpp create mode 100644 Ch3/mlir/ToyCombine.inc create mode 100644 Ch3/mlir/ToyCombine.td create mode 100644 Ch3/mlir/run.sh create mode 100644 Ch3/parser/AST.cpp create mode 100644 Ch3/toyc.cpp create mode 100644 Ch4/CMakeLists.txt create mode 100644 Ch4/include/CMakeLists.txt create mode 100644 Ch4/include/run.sh create mode 100644 Ch4/include/toy/AST.h create mode 100644 Ch4/include/toy/CMakeLists.txt create mode 100644 Ch4/include/toy/Dialect.cpp.inc create mode 100644 Ch4/include/toy/Dialect.h create mode 100644 Ch4/include/toy/Dialect.h.inc create mode 100644 Ch4/include/toy/Lexer.h create mode 100644 Ch4/include/toy/MLIRGen.h create mode 100644 Ch4/include/toy/Ops.cpp.inc create mode 100644 Ch4/include/toy/Ops.h.inc create mode 100644 Ch4/include/toy/Ops.td create mode 100644 Ch4/include/toy/Parser.h create mode 100644 Ch4/include/toy/Passes.h create mode 100644 Ch4/include/toy/ShapeInferenceInterface.h create mode 100644 Ch4/include/toy/ShapeInferenceInterface.td create mode 100644 Ch4/include/toy/ShapeInferenceOpInterfaces.cpp.inc create mode 100644 Ch4/include/toy/ShapeInferenceOpInterfaces.h.inc create mode 100644 Ch4/include/toy/run.sh create mode 100644 Ch4/mlir/Dialect.cpp create mode 100644 Ch4/mlir/MLIRGen.cpp create mode 100644 Ch4/mlir/ShapeInferencePass.cpp create mode 100644 Ch4/mlir/ToyCombine.cpp create mode 100644 Ch4/mlir/ToyCombine.inc create mode 100644 Ch4/mlir/ToyCombine.td create mode 100644 Ch4/mlir/run.sh create mode 100644 Ch4/parser/AST.cpp create mode 100644 Ch4/toyc.cpp create mode 100644 Ch5/CMakeLists.txt create mode 100644 Ch5/include/CMakeLists.txt create mode 100644 Ch5/include/run.sh create mode 100644 Ch5/include/toy/AST.h create mode 100644 Ch5/include/toy/CMakeLists.txt create mode 100644 Ch5/include/toy/Dialect.cpp.inc create mode 100644 Ch5/include/toy/Dialect.h create mode 100644 Ch5/include/toy/Dialect.h.inc create mode 100644 Ch5/include/toy/Lexer.h create mode 100644 Ch5/include/toy/MLIRGen.h create mode 100644 Ch5/include/toy/Ops.cpp.inc create mode 100644 Ch5/include/toy/Ops.h.inc create mode 100644 Ch5/include/toy/Ops.td create mode 100644 Ch5/include/toy/Parser.h create mode 100644 Ch5/include/toy/Passes.h create mode 100644 Ch5/include/toy/ShapeInferenceInterface.h create mode 100644 Ch5/include/toy/ShapeInferenceInterface.td create mode 100644 Ch5/include/toy/ShapeInferenceOpInterfaces.cpp.inc create mode 100644 Ch5/include/toy/ShapeInferenceOpInterfaces.h.inc create mode 100644 Ch5/mlir/Dialect.cpp create mode 100644 Ch5/mlir/LowerToAffineLoops.cpp create mode 100644 Ch5/mlir/MLIRGen.cpp create mode 100644 Ch5/mlir/ShapeInferencePass.cpp create mode 100644 Ch5/mlir/ToyCombine.cpp create mode 100644 Ch5/mlir/ToyCombine.inc create mode 100644 Ch5/mlir/ToyCombine.td create mode 100644 Ch5/mlir/run.sh create mode 100644 Ch5/parser/AST.cpp create mode 100644 Ch5/toyc.cpp create mode 100644 Ch6/CMakeLists.txt create mode 100644 Ch6/include/CMakeLists.txt create mode 100644 Ch6/include/run.sh create mode 100644 Ch6/include/toy/AST.h create mode 100644 Ch6/include/toy/CMakeLists.txt create mode 100644 Ch6/include/toy/Dialect.cpp.inc create mode 100644 Ch6/include/toy/Dialect.h create mode 100644 Ch6/include/toy/Dialect.h.inc create mode 100644 Ch6/include/toy/Lexer.h create mode 100644 Ch6/include/toy/MLIRGen.h create mode 100644 Ch6/include/toy/Ops.cpp.inc create mode 100644 Ch6/include/toy/Ops.h.inc create mode 100644 Ch6/include/toy/Ops.td create mode 100644 Ch6/include/toy/Parser.h create mode 100644 Ch6/include/toy/Passes.h create mode 100644 Ch6/include/toy/ShapeInferenceInterface.h create mode 100644 Ch6/include/toy/ShapeInferenceInterface.td create mode 100644 Ch6/include/toy/ShapeInferenceOpInterfaces.cpp.inc create mode 100644 Ch6/include/toy/ShapeInferenceOpInterfaces.h.inc create mode 100644 Ch6/mlir/Dialect.cpp create mode 100644 Ch6/mlir/LowerToAffineLoops.cpp create mode 100644 Ch6/mlir/LowerToLLVM.cpp create mode 100644 Ch6/mlir/MLIRGen.cpp create mode 100644 Ch6/mlir/ShapeInferencePass.cpp create mode 100644 Ch6/mlir/ToyCombine.cpp create mode 100644 Ch6/mlir/ToyCombine.inc create mode 100644 Ch6/mlir/ToyCombine.td create mode 100644 Ch6/mlir/run.sh create mode 100644 Ch6/parser/AST.cpp create mode 100644 Ch6/toyc.cpp create mode 100644 Ch7/CMakeLists.txt create mode 100644 Ch7/include/CMakeLists.txt create mode 100644 Ch7/include/run.sh create mode 100644 Ch7/include/toy/AST.h create mode 100644 Ch7/include/toy/CMakeLists.txt create mode 100644 Ch7/include/toy/Dialect.cpp.inc create mode 100644 Ch7/include/toy/Dialect.h create mode 100644 Ch7/include/toy/Dialect.h.inc create mode 100644 Ch7/include/toy/Lexer.h create mode 100644 Ch7/include/toy/MLIRGen.h create mode 100644 Ch7/include/toy/Ops.cpp.inc create mode 100644 Ch7/include/toy/Ops.h.inc create mode 100644 Ch7/include/toy/Ops.td create mode 100644 Ch7/include/toy/Parser.h create mode 100644 Ch7/include/toy/Passes.h create mode 100644 Ch7/include/toy/ShapeInferenceInterface.h create mode 100644 Ch7/include/toy/ShapeInferenceInterface.td create mode 100644 Ch7/include/toy/ShapeInferenceOpInterfaces.cpp.inc create mode 100644 Ch7/include/toy/ShapeInferenceOpInterfaces.h.inc create mode 100644 Ch7/mlir/Dialect.cpp create mode 100644 Ch7/mlir/LowerToAffineLoops.cpp create mode 100644 Ch7/mlir/LowerToLLVM.cpp create mode 100644 Ch7/mlir/MLIRGen.cpp create mode 100644 Ch7/mlir/ShapeInferencePass.cpp create mode 100644 Ch7/mlir/ToyCombine.cpp create mode 100644 Ch7/mlir/ToyCombine.inc create mode 100644 Ch7/mlir/ToyCombine.td create mode 100644 Ch7/mlir/run.sh create mode 100644 Ch7/parser/AST.cpp create mode 100644 Ch7/toyc.cpp create mode 100644 Examples/Toy/Ch1/ast.toy create mode 100644 Examples/Toy/Ch1/empty.toy create mode 100644 Examples/Toy/Ch2/ast.toy create mode 100644 Examples/Toy/Ch2/codegen.toy create mode 100644 Examples/Toy/Ch2/empty.toy create mode 100644 Examples/Toy/Ch2/invalid.mlir create mode 100644 Examples/Toy/Ch2/scalar.toy create mode 100644 Examples/Toy/Ch3/ast.toy create mode 100644 Examples/Toy/Ch3/codegen.toy create mode 100644 Examples/Toy/Ch3/empty.toy create mode 100644 Examples/Toy/Ch3/invalid.mlir create mode 100644 Examples/Toy/Ch3/scalar.toy create mode 100644 Examples/Toy/Ch3/transpose_transpose.toy create mode 100644 Examples/Toy/Ch3/trivial_reshape.toy create mode 100644 Examples/Toy/Ch4/ast.toy create mode 100644 Examples/Toy/Ch4/codegen.toy create mode 100644 Examples/Toy/Ch4/empty.toy create mode 100644 Examples/Toy/Ch4/invalid.mlir create mode 100644 Examples/Toy/Ch4/scalar.toy create mode 100644 Examples/Toy/Ch4/shape_inference.mlir create mode 100644 Examples/Toy/Ch4/transpose_transpose.toy create mode 100644 Examples/Toy/Ch4/trivial_reshape.toy create mode 100644 Examples/Toy/Ch5/affine-lowering.mlir create mode 100644 Examples/Toy/Ch5/ast.toy create mode 100644 Examples/Toy/Ch5/codegen.toy create mode 100644 Examples/Toy/Ch5/empty.toy create mode 100644 Examples/Toy/Ch5/invalid.mlir create mode 100644 Examples/Toy/Ch5/scalar.toy create mode 100644 Examples/Toy/Ch5/shape_inference.mlir create mode 100644 Examples/Toy/Ch5/transpose_transpose.toy create mode 100644 Examples/Toy/Ch5/trivial_reshape.toy create mode 100644 Examples/Toy/Ch6/affine-lowering.mlir create mode 100644 Examples/Toy/Ch6/ast.toy create mode 100644 Examples/Toy/Ch6/codegen.toy create mode 100644 Examples/Toy/Ch6/empty.toy create mode 100644 Examples/Toy/Ch6/invalid.mlir create mode 100644 Examples/Toy/Ch6/jit.toy create mode 100644 Examples/Toy/Ch6/lit.local.cfg create mode 100644 Examples/Toy/Ch6/llvm-lowering.mlir create mode 100644 Examples/Toy/Ch6/scalar.toy create mode 100644 Examples/Toy/Ch6/shape_inference.mlir create mode 100644 Examples/Toy/Ch6/transpose_transpose.toy create mode 100644 Examples/Toy/Ch6/trivial_reshape.toy create mode 100644 Examples/Toy/Ch7/affine-lowering.mlir create mode 100644 Examples/Toy/Ch7/ast.toy create mode 100644 Examples/Toy/Ch7/codegen.toy create mode 100644 Examples/Toy/Ch7/empty.toy create mode 100644 Examples/Toy/Ch7/invalid.mlir create mode 100644 Examples/Toy/Ch7/jit.toy create mode 100644 Examples/Toy/Ch7/lit.local.cfg create mode 100644 Examples/Toy/Ch7/llvm-lowering.mlir create mode 100644 Examples/Toy/Ch7/scalar.toy create mode 100644 Examples/Toy/Ch7/shape_inference.mlir create mode 100644 Examples/Toy/Ch7/struct-ast.toy create mode 100644 Examples/Toy/Ch7/struct-codegen.toy create mode 100644 Examples/Toy/Ch7/struct-opt.mlir create mode 100644 Examples/Toy/Ch7/transpose_transpose.toy create mode 100644 Examples/Toy/Ch7/trivial_reshape.toy create mode 100644 README.md diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..6a44100 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,22 @@ +cmake_minimum_required(VERSION 3.10.0) +project(toy VERSION 0.1.0 LANGUAGES CXX C) + +set(LLVM_DIR /usr/lib/llvm-18/lib/cmake/llvm/) +find_package(LLVM REQUIRED CONFIG) +set(MLIR_DIR /usr/lib/llvm-18/lib/cmake/mlir/) +find_package(MLIR REQUIRED CONFIG) + +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + +add_subdirectory(Ch1) +add_subdirectory(Ch2) +add_subdirectory(Ch3) +add_subdirectory(Ch4) +add_subdirectory(Ch5) +add_subdirectory(Ch6) +add_subdirectory(Ch7) diff --git a/Ch1/CMakeLists.txt b/Ch1/CMakeLists.txt new file mode 100644 index 0000000..2e8ba91 --- /dev/null +++ b/Ch1/CMakeLists.txt @@ -0,0 +1,13 @@ +# For a better template to copy, see examples/standalone +set(LLVM_LINK_COMPONENTS + Support + ) + +add_executable(toyc-ch1 + toyc.cpp + parser/AST.cpp + ) +include_directories(include/) +target_link_libraries(toyc-ch1 + PRIVATE + MLIRSupport) diff --git a/Ch1/include/toy/AST.h b/Ch1/include/toy/AST.h new file mode 100644 index 0000000..d2ba101 --- /dev/null +++ b/Ch1/include/toy/AST.h @@ -0,0 +1,246 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_AST_H +#define TOY_AST_H + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(std::move(location)) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double val; + +public: + NumberExprAST(Location loc, double val) + : ExprAST(Expr_Num, std::move(loc)), val(val) {} + + double getValue() { return val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, std::move(loc)), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + std::optional> expr; + +public: + ReturnExprAST(Location loc, std::optional> expr) + : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} + + std::optional getExpr() { + if (expr.has_value()) + return expr->get(); + return std::nullopt; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, std::move(loc)), callee(callee), + args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(std::move(location)), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() { return functions.begin(); } + auto end() { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // TOY_AST_H diff --git a/Ch1/include/toy/Lexer.h b/Ch1/include/toy/Lexer.h new file mode 100644 index 0000000..ecbb3b4 --- /dev/null +++ b/Ch1/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_LEXER_H +#define TOY_LEXER_H + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purposes. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purposes (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // TOY_LEXER_H diff --git a/Ch1/include/toy/Parser.h b/Ch1/include/toy/Parser.h new file mode 100644 index 0000000..1f20616 --- /dev/null +++ b/Ch1/include/toy/Parser.h @@ -0,0 +1,489 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PARSER_H +#define TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + std::optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + + if (lexer.getCurToken() != tok_def) + return parseError("def", "in prototype"); + lexer.consume(tok_def); + + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName(lexer.getId()); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name(lexer.getId()); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError(")", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // TOY_PARSER_H diff --git a/Ch1/parser/AST.cpp b/Ch1/parser/AST.cpp new file mode 100644 index 0000000..2546f2a --- /dev/null +++ b/Ch1/parser/AST.cpp @@ -0,0 +1,237 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template +static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + llvm::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto *num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + llvm::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + llvm::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().has_value()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + llvm::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; + indent(); + llvm::errs() << "Params: ["; + llvm::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/Ch1/toyc.cpp b/Ch1/toyc.cpp new file mode 100644 index 0000000..fb7b484 --- /dev/null +++ b/Ch1/toyc.cpp @@ -0,0 +1,71 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" +#include "toy/Lexer.h" +#include "toy/Parser.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); +namespace { +enum Action { None, DumpAST }; +} // namespace + +static cl::opt + emitAction("emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump"))); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); + Parser parser(lexer); + return parser.parseModule(); +} + +int main(int argc, char **argv) { + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + switch (emitAction) { + case Action::DumpAST: + dump(*moduleAST); + return 0; + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/Ch2/CMakeLists.txt b/Ch2/CMakeLists.txt new file mode 100644 index 0000000..1334f9a --- /dev/null +++ b/Ch2/CMakeLists.txt @@ -0,0 +1,23 @@ +# For a better template to copy, see examples/standalone +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Support + ) + +add_executable(toyc-ch2 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + ) +include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch2 + PRIVATE + MLIRAnalysis + MLIRFunctionInterfaces + MLIRIR + MLIRParser + MLIRSideEffectInterfaces + MLIRTransforms) diff --git a/Ch2/include/CMakeLists.txt b/Ch2/include/CMakeLists.txt new file mode 100644 index 0000000..37c89d0 --- /dev/null +++ b/Ch2/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/Ch2/include/toy/AST.h b/Ch2/include/toy/AST.h new file mode 100644 index 0000000..d2ba101 --- /dev/null +++ b/Ch2/include/toy/AST.h @@ -0,0 +1,246 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_AST_H +#define TOY_AST_H + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(std::move(location)) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double val; + +public: + NumberExprAST(Location loc, double val) + : ExprAST(Expr_Num, std::move(loc)), val(val) {} + + double getValue() { return val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, std::move(loc)), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + std::optional> expr; + +public: + ReturnExprAST(Location loc, std::optional> expr) + : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} + + std::optional getExpr() { + if (expr.has_value()) + return expr->get(); + return std::nullopt; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, std::move(loc)), callee(callee), + args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(std::move(location)), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() { return functions.begin(); } + auto end() { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // TOY_AST_H diff --git a/Ch2/include/toy/CMakeLists.txt b/Ch2/include/toy/CMakeLists.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/Ch2/include/toy/CMakeLists.txt @@ -0,0 +1 @@ + diff --git a/Ch2/include/toy/Dialect.cpp.inc b/Ch2/include/toy/Dialect.cpp.inc new file mode 100644 index 0000000..8cbc772 --- /dev/null +++ b/Ch2/include/toy/Dialect.cpp.inc @@ -0,0 +1,23 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) +namespace mlir { +namespace toy { + +ToyDialect::ToyDialect(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get()) { + + initialize(); +} + +ToyDialect::~ToyDialect() = default; + +} // namespace toy +} // namespace mlir diff --git a/Ch2/include/toy/Dialect.h b/Ch2/include/toy/Dialect.h new file mode 100644 index 0000000..292f50f --- /dev/null +++ b/Ch2/include/toy/Dialect.h @@ -0,0 +1,33 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See docs/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +/// Include the auto-generated header file containing the declaration of the toy +/// dialect. +#include "toy/Dialect.h.inc" + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/Ch2/include/toy/Dialect.h.inc b/Ch2/include/toy/Dialect.h.inc new file mode 100644 index 0000000..f19d867 --- /dev/null +++ b/Ch2/include/toy/Dialect.h.inc @@ -0,0 +1,26 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +namespace mlir { +namespace toy { + +class ToyDialect : public ::mlir::Dialect { + explicit ToyDialect(::mlir::MLIRContext *context); + + void initialize(); + friend class ::mlir::MLIRContext; +public: + ~ToyDialect() override; + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("toy"); + } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) diff --git a/Ch2/include/toy/Lexer.h b/Ch2/include/toy/Lexer.h new file mode 100644 index 0000000..3c59cd9 --- /dev/null +++ b/Ch2/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_LEXER_H +#define TOY_LEXER_H + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // TOY_LEXER_H diff --git a/Ch2/include/toy/MLIRGen.h b/Ch2/include/toy/MLIRGen.h new file mode 100644 index 0000000..fe9dbe5 --- /dev/null +++ b/Ch2/include/toy/MLIRGen.h @@ -0,0 +1,35 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_MLIRGEN_H +#define TOY_MLIRGEN_H + +#include + +namespace mlir { +class MLIRContext; +template +class OwningOpRef; +class ModuleOp; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // TOY_MLIRGEN_H diff --git a/Ch2/include/toy/Ops.cpp.inc b/Ch2/include/toy/Ops.cpp.inc new file mode 100644 index 0000000..a3eabd3 --- /dev/null +++ b/Ch2/include/toy/Ops.cpp.inc @@ -0,0 +1,2049 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifdef GET_OP_LIST +#undef GET_OP_LIST + +::mlir::toy::AddOp, +::mlir::toy::ConstantOp, +::mlir::toy::FuncOp, +::mlir::toy::GenericCallOp, +::mlir::toy::MulOp, +::mlir::toy::PrintOp, +::mlir::toy::ReshapeOp, +::mlir::toy::ReturnOp, +::mlir::toy::TransposeOp +#endif // GET_OP_LIST + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be variadic of tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::RankedTensorType>(type))) && ((::llvm::cast<::mlir::ShapedType>(type).hasStaticShape()))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be statically shaped tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::DenseFPElementsAttr>(attr) &&::llvm::cast<::mlir::DenseElementsAttr>(attr).getType().getElementType().isF64()))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: 64-bit float elements attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops0(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::StringAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: string attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops1(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::TypeAttr>(attr))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: type attribute of function type"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops2(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::ArrayAttr>(attr))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(attr), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: Array of dictionary attributes"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops3(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: flat symbol reference attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops4(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_region_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, + unsigned regionIndex) { + if (!((true))) { + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: any region"; + } + return ::mlir::success(); +} +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.add", odsAttrs.getContext()); +} + +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(AddOp op) : AddOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair AddOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr AddOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +AddOpAdaptor::AddOpAdaptor(AddOp op) : AddOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult AddOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair AddOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range AddOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &AddOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &AddOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair AddOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range AddOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void AddOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult AddOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult AddOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.constant", odsAttrs.getContext()); +} + +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(ConstantOp op) : ConstantOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ConstantOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ConstantOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValueAttr() { + auto attr = ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); + return attr; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValue() { + auto attr = getValueAttr(); + return attr; +} + +} // namespace detail +ConstantOpAdaptor::ConstantOpAdaptor(ConstantOp op) : ConstantOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ConstantOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitError(loc, "'toy.constant' op ""requires attribute 'value'"); + + if (tblgen_value && !((::llvm::isa<::mlir::DenseFPElementsAttr>(tblgen_value) &&::llvm::cast<::mlir::DenseElementsAttr>(tblgen_value).getType().getElementType().isF64()))) + return emitError(loc, "'toy.constant' op ""attribute 'value' failed to satisfy constraint: 64-bit float elements attribute"); + return ::mlir::success(); +} + +std::pair ConstantOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ConstantOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair ConstantOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ConstantOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult ConstantOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.value; + auto attr = dict.get("value"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for value in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `value` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute ConstantOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.value; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("value", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code ConstantOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.value.getAsOpaquePointer())); +} + +std::optional ConstantOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "value") + return prop.value; + return std::nullopt; +} + +void ConstantOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "value") { + prop.value = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void ConstantOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.value) attrs.append("value", prop.value); +} + +::mlir::LogicalResult ConstantOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getValueAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(attr, "value", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.value))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ConstantOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.value); +} + +::mlir::DenseElementsAttr ConstantOp::getValueAttr() { + return ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); +} + +::mlir::DenseElementsAttr ConstantOp::getValue() { + auto attr = getValueAttr(); + return attr; +} + +void ConstantOp::setValueAttr(::mlir::DenseElementsAttr attr) { + (*this)->setAttr(getValueAttrName(), attr); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value) { + build(odsBuilder, odsState, value.getType(), value); + +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + odsState.addTypes(resultType0); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ConstantOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 0u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ConstantOp::verifyInvariantsImpl() { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitOpError("requires attribute 'value'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(*this, tblgen_value, "value"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +void ConstantOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.func", odsAttrs.getContext()); +} + +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(FuncOp op) : FuncOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair FuncOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr FuncOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::StringAttr FuncOpGenericAdaptorBase::getSymNameAttr() { + auto attr = ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); + return attr; +} + +::llvm::StringRef FuncOpGenericAdaptorBase::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOpGenericAdaptorBase::getFunctionTypeAttr() { + auto attr = ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); + return attr; +} + +::mlir::FunctionType FuncOpGenericAdaptorBase::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getArgAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getResAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::Region &FuncOpGenericAdaptorBase::getBody() { + return *odsRegions[0]; +} + +::mlir::RegionRange FuncOpGenericAdaptorBase::getRegions() { + return odsRegions; +} + +} // namespace detail +FuncOpAdaptor::FuncOpAdaptor(FuncOp op) : FuncOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult FuncOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitError(loc, "'toy.func' op ""requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitError(loc, "'toy.func' op ""requires attribute 'sym_name'"); + + if (tblgen_sym_name && !((::llvm::isa<::mlir::StringAttr>(tblgen_sym_name)))) + return emitError(loc, "'toy.func' op ""attribute 'sym_name' failed to satisfy constraint: string attribute"); + + if (tblgen_function_type && !(((::llvm::isa<::mlir::TypeAttr>(tblgen_function_type))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))))) + return emitError(loc, "'toy.func' op ""attribute 'function_type' failed to satisfy constraint: type attribute of function type"); + + if (tblgen_arg_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_arg_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_arg_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'arg_attrs' failed to satisfy constraint: Array of dictionary attributes"); + + if (tblgen_res_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_res_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_res_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'res_attrs' failed to satisfy constraint: Array of dictionary attributes"); + return ::mlir::success(); +} + +std::pair FuncOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range FuncOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair FuncOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range FuncOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Region &FuncOp::getBody() { + return (*this)->getRegion(0); +} + +::mlir::LogicalResult FuncOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.arg_attrs; + auto attr = dict.get("arg_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for arg_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `arg_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.function_type; + auto attr = dict.get("function_type"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for function_type in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `function_type` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.res_attrs; + auto attr = dict.get("res_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for res_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `res_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.sym_name; + auto attr = dict.get("sym_name"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for sym_name in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `sym_name` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute FuncOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.arg_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("arg_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.function_type; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("function_type", + propStorage)); + } + + { + const auto &propStorage = prop.res_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("res_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.sym_name; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("sym_name", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code FuncOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.arg_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.function_type.getAsOpaquePointer()), + llvm::hash_value(prop.res_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.sym_name.getAsOpaquePointer())); +} + +std::optional FuncOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "arg_attrs") + return prop.arg_attrs; + + if (name == "function_type") + return prop.function_type; + + if (name == "res_attrs") + return prop.res_attrs; + + if (name == "sym_name") + return prop.sym_name; + return std::nullopt; +} + +void FuncOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "arg_attrs") { + prop.arg_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "function_type") { + prop.function_type = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "res_attrs") { + prop.res_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "sym_name") { + prop.sym_name = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void FuncOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.arg_attrs) attrs.append("arg_attrs", prop.arg_attrs); + + if (prop.function_type) attrs.append("function_type", prop.function_type); + + if (prop.res_attrs) attrs.append("res_attrs", prop.res_attrs); + + if (prop.sym_name) attrs.append("sym_name", prop.sym_name); +} + +::mlir::LogicalResult FuncOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getArgAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "arg_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getFunctionTypeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(attr, "function_type", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getResAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "res_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getSymNameAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(attr, "sym_name", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readOptionalAttribute(prop.arg_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.function_type))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readOptionalAttribute(prop.res_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.sym_name))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void FuncOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + + writer.writeOptionalAttribute(prop.arg_attrs); + writer.writeAttribute(prop.function_type); + + writer.writeOptionalAttribute(prop.res_attrs); + writer.writeAttribute(prop.sym_name); +} + +::mlir::StringAttr FuncOp::getSymNameAttr() { + return ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); +} + +::llvm::StringRef FuncOp::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOp::getFunctionTypeAttr() { + return ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); +} + +::mlir::FunctionType FuncOp::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOp::getArgAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOp::getResAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +void FuncOp::setSymNameAttr(::mlir::StringAttr attr) { + (*this)->setAttr(getSymNameAttrName(), attr); +} + +void FuncOp::setSymName(::llvm::StringRef attrValue) { + (*this)->setAttr(getSymNameAttrName(), ::mlir::Builder((*this)->getContext()).getStringAttr(attrValue)); +} + +void FuncOp::setFunctionTypeAttr(::mlir::TypeAttr attr) { + (*this)->setAttr(getFunctionTypeAttrName(), attr); +} + +void FuncOp::setFunctionType(::mlir::FunctionType attrValue) { + (*this)->setAttr(getFunctionTypeAttrName(), ::mlir::TypeAttr::get(attrValue)); +} + +void FuncOp::setArgAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getArgAttrsAttrName(), attr); +} + +void FuncOp::setResAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getResAttrsAttrName(), attr); +} + +::mlir::Attribute FuncOp::removeArgAttrsAttr() { + auto &attr = getProperties().arg_attrs; + attr = {}; + return attr; +} + +::mlir::Attribute FuncOp::removeResAttrsAttr() { + auto &attr = getProperties().res_attrs; + attr = {}; + return attr; +} + +::mlir::LogicalResult FuncOp::verifyInvariantsImpl() { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitOpError("requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitOpError("requires attribute 'sym_name'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(*this, tblgen_sym_name, "sym_name"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(*this, tblgen_function_type, "function_type"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_arg_attrs, "arg_attrs"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_res_attrs, "res_attrs"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + + for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(0))) + if (::mlir::failed(__mlir_ods_local_region_constraint_Ops0(*this, region, "body", index++))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.generic_call", odsAttrs.getContext()); +} + +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(GenericCallOp op) : GenericCallOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair GenericCallOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr GenericCallOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::FlatSymbolRefAttr GenericCallOpGenericAdaptorBase::getCalleeAttr() { + auto attr = ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); + return attr; +} + +::llvm::StringRef GenericCallOpGenericAdaptorBase::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +} // namespace detail +GenericCallOpAdaptor::GenericCallOpAdaptor(GenericCallOp op) : GenericCallOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult GenericCallOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitError(loc, "'toy.generic_call' op ""requires attribute 'callee'"); + + if (tblgen_callee && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(tblgen_callee)))) + return emitError(loc, "'toy.generic_call' op ""attribute 'callee' failed to satisfy constraint: flat symbol reference attribute"); + return ::mlir::success(); +} + +std::pair GenericCallOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range GenericCallOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range GenericCallOp::getInputs() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange GenericCallOp::getInputsMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair GenericCallOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range GenericCallOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult GenericCallOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.callee; + auto attr = dict.get("callee"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for callee in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `callee` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute GenericCallOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.callee; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("callee", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code GenericCallOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.callee.getAsOpaquePointer())); +} + +std::optional GenericCallOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "callee") + return prop.callee; + return std::nullopt; +} + +void GenericCallOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "callee") { + prop.callee = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void GenericCallOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.callee) attrs.append("callee", prop.callee); +} + +::mlir::LogicalResult GenericCallOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getCalleeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(attr, "callee", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.callee))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.callee); +} + +::mlir::FlatSymbolRefAttr GenericCallOp::getCalleeAttr() { + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); +} + +::llvm::StringRef GenericCallOp::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +void GenericCallOp::setCalleeAttr(::mlir::FlatSymbolRefAttr attr) { + (*this)->setAttr(getCalleeAttrName(), attr); +} + +void GenericCallOp::setCallee(::llvm::StringRef attrValue) { + (*this)->setAttr(getCalleeAttrName(), ::mlir::SymbolRefAttr::get(::mlir::Builder((*this)->getContext()).getContext(), attrValue)); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariantsImpl() { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitOpError("requires attribute 'callee'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(*this, tblgen_callee, "callee"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult GenericCallOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::FlatSymbolRefAttr calleeAttr; + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputsOperands; + ::llvm::SMLoc inputsOperandsLoc; + (void)inputsOperandsLoc; + ::llvm::ArrayRef<::mlir::Type> inputsTypes; + ::llvm::ArrayRef<::mlir::Type> allResultTypes; + + if (parser.parseCustomAttributeWithFallback(calleeAttr, parser.getBuilder().getType<::mlir::NoneType>())) { + return ::mlir::failure(); + } + if (calleeAttr) result.getOrAddProperties().callee = calleeAttr; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands)) + return ::mlir::failure(); + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + ::mlir::FunctionType inputs__allResult_functionType; + if (parser.parseType(inputs__allResult_functionType)) + return ::mlir::failure(); + inputsTypes = inputs__allResult_functionType.getInputs(); + allResultTypes = inputs__allResult_functionType.getResults(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter.printAttributeWithoutType(getCalleeAttr()); + _odsPrinter << "("; + _odsPrinter << getInputs(); + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + elidedAttrs.push_back("callee"); + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter.printFunctionalType(getInputs().getTypes(), getOperation()->getResultTypes()); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.mul", odsAttrs.getContext()); +} + +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(MulOp op) : MulOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair MulOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr MulOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +MulOpAdaptor::MulOpAdaptor(MulOp op) : MulOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult MulOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair MulOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range MulOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &MulOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &MulOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair MulOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range MulOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void MulOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult MulOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult MulOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.print", odsAttrs.getContext()); +} + +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(PrintOp op) : PrintOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair PrintOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr PrintOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +PrintOpAdaptor::PrintOpAdaptor(PrintOp op) : PrintOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult PrintOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair PrintOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range PrintOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> PrintOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &PrintOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair PrintOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range PrintOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input) { + odsState.addOperands(input); +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 0u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void PrintOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult PrintOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult PrintOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult PrintOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void PrintOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.reshape", odsAttrs.getContext()); +} + +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(ReshapeOp op) : ReshapeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReshapeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ReshapeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReshapeOpAdaptor::ReshapeOpAdaptor(ReshapeOp op) : ReshapeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReshapeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReshapeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ReshapeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> ReshapeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &ReshapeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair ReshapeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReshapeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ReshapeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops2(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult ReshapeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReshapeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.return", odsAttrs.getContext()); +} + +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(ReturnOp op) : ReturnOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReturnOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr ReturnOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReturnOpAdaptor::ReturnOpAdaptor(ReturnOp op) : ReturnOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReturnOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReturnOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range ReturnOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range ReturnOp::getInput() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange ReturnOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair ReturnOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReturnOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState) { + build(odsBuilder, odsState, std::nullopt); +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input) { + odsState.addOperands(input); +} + +void ReturnOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReturnOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReturnOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult ReturnOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputOperands; + ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::llvm::SmallVector<::mlir::Type, 1> inputTypes; + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputOperands)) + return ::mlir::failure(); + if (!inputOperands.empty()) { + if (parser.parseColon()) + return ::mlir::failure(); + + if (parser.parseTypeList(inputTypes)) + return ::mlir::failure(); + } + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReturnOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + if (!getInput().empty()) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter << getInput().getTypes(); + } + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); +} + +void ReturnOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.transpose", odsAttrs.getContext()); +} + +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(TransposeOp op) : TransposeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair TransposeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr TransposeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +TransposeOpAdaptor::TransposeOpAdaptor(TransposeOp op) : TransposeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult TransposeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair TransposeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range TransposeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> TransposeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &TransposeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair TransposeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range TransposeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void TransposeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult TransposeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult TransposeOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult TransposeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void TransposeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch2/include/toy/Ops.h.inc b/Ch2/include/toy/Ops.h.inc new file mode 100644 index 0000000..ef289c7 --- /dev/null +++ b/Ch2/include/toy/Ops.h.inc @@ -0,0 +1,1240 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES) +#undef GET_OP_FWD_DEFINES +namespace mlir { +namespace toy { +class AddOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ConstantOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class FuncOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class GenericCallOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class MulOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class PrintOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReshapeOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReturnOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class TransposeOp; +} // namespace toy +} // namespace mlir +#endif + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class AddOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + AddOpGenericAdaptorBase(AddOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class AddOpGenericAdaptor : public detail::AddOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::AddOpGenericAdaptorBase; +public: + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + AddOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class AddOpAdaptor : public AddOpGenericAdaptor<::mlir::ValueRange> { +public: + using AddOpGenericAdaptor::AddOpGenericAdaptor; + AddOpAdaptor(AddOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class AddOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants> { +public: + using Op::Op; + using Op::print; + using Adaptor = AddOpAdaptor; + template + using GenericAdaptor = AddOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.add"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ConstantOpGenericAdaptorBase { +public: + struct Properties { + using valueTy = ::mlir::DenseElementsAttr; + valueTy value; + + auto getValue() { + auto &propStorage = this->value; + return ::llvm::cast<::mlir::DenseElementsAttr>(propStorage); + } + void setValue(const ::mlir::DenseElementsAttr &propValue) { + this->value = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.value == this->value && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + ConstantOpGenericAdaptorBase(ConstantOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); +}; +} // namespace detail +template +class ConstantOpGenericAdaptor : public detail::ConstantOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ConstantOpGenericAdaptorBase; +public: + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ConstantOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + ConstantOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ConstantOpAdaptor : public ConstantOpGenericAdaptor<::mlir::ValueRange> { +public: + using ConstantOpGenericAdaptor::ConstantOpGenericAdaptor; + ConstantOpAdaptor(ConstantOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ConstantOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ConstantOpAdaptor; + template + using GenericAdaptor = ConstantOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getValueAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.constant"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); + void setValueAttr(::mlir::DenseElementsAttr attr); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class FuncOpGenericAdaptorBase { +public: + struct Properties { + using arg_attrsTy = ::mlir::ArrayAttr; + arg_attrsTy arg_attrs; + + auto getArgAttrs() { + auto &propStorage = this->arg_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setArgAttrs(const ::mlir::ArrayAttr &propValue) { + this->arg_attrs = propValue; + } + using function_typeTy = ::mlir::TypeAttr; + function_typeTy function_type; + + auto getFunctionType() { + auto &propStorage = this->function_type; + return ::llvm::cast<::mlir::TypeAttr>(propStorage); + } + void setFunctionType(const ::mlir::TypeAttr &propValue) { + this->function_type = propValue; + } + using res_attrsTy = ::mlir::ArrayAttr; + res_attrsTy res_attrs; + + auto getResAttrs() { + auto &propStorage = this->res_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setResAttrs(const ::mlir::ArrayAttr &propValue) { + this->res_attrs = propValue; + } + using sym_nameTy = ::mlir::StringAttr; + sym_nameTy sym_name; + + auto getSymName() { + auto &propStorage = this->sym_name; + return ::llvm::cast<::mlir::StringAttr>(propStorage); + } + void setSymName(const ::mlir::StringAttr &propValue) { + this->sym_name = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.arg_attrs == this->arg_attrs && + rhs.function_type == this->function_type && + rhs.res_attrs == this->res_attrs && + rhs.sym_name == this->sym_name && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + FuncOpGenericAdaptorBase(FuncOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + ::mlir::Region &getBody(); + ::mlir::RegionRange getRegions(); +}; +} // namespace detail +template +class FuncOpGenericAdaptor : public detail::FuncOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::FuncOpGenericAdaptorBase; +public: + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : FuncOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + FuncOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class FuncOpAdaptor : public FuncOpGenericAdaptor<::mlir::ValueRange> { +public: + using FuncOpGenericAdaptor::FuncOpGenericAdaptor; + FuncOpAdaptor(FuncOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class FuncOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = FuncOpAdaptor; + template + using GenericAdaptor = FuncOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("arg_attrs"), ::llvm::StringRef("function_type"), ::llvm::StringRef("res_attrs"), ::llvm::StringRef("sym_name")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getArgAttrsAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getArgAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + ::mlir::StringAttr getFunctionTypeAttrName() { + return getAttributeNameForIndex(1); + } + + static ::mlir::StringAttr getFunctionTypeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 1); + } + + ::mlir::StringAttr getResAttrsAttrName() { + return getAttributeNameForIndex(2); + } + + static ::mlir::StringAttr getResAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 2); + } + + ::mlir::StringAttr getSymNameAttrName() { + return getAttributeNameForIndex(3); + } + + static ::mlir::StringAttr getSymNameAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 3); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.func"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::Region &getBody(); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + void setSymNameAttr(::mlir::StringAttr attr); + void setSymName(::llvm::StringRef attrValue); + void setFunctionTypeAttr(::mlir::TypeAttr attr); + void setFunctionType(::mlir::FunctionType attrValue); + void setArgAttrsAttr(::mlir::ArrayAttr attr); + void setResAttrsAttr(::mlir::ArrayAttr attr); + ::mlir::Attribute removeArgAttrsAttr(); + ::mlir::Attribute removeResAttrsAttr(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef name, FunctionType type, ArrayRef attrs = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 4 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class GenericCallOpGenericAdaptorBase { +public: + struct Properties { + using calleeTy = ::mlir::FlatSymbolRefAttr; + calleeTy callee; + + auto getCallee() { + auto &propStorage = this->callee; + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(propStorage); + } + void setCallee(const ::mlir::FlatSymbolRefAttr &propValue) { + this->callee = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.callee == this->callee && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + GenericCallOpGenericAdaptorBase(GenericCallOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); +}; +} // namespace detail +template +class GenericCallOpGenericAdaptor : public detail::GenericCallOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::GenericCallOpGenericAdaptorBase; +public: + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : GenericCallOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + GenericCallOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInputs() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class GenericCallOpAdaptor : public GenericCallOpGenericAdaptor<::mlir::ValueRange> { +public: + using GenericCallOpGenericAdaptor::GenericCallOpGenericAdaptor; + GenericCallOpAdaptor(GenericCallOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class GenericCallOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::VariadicOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = GenericCallOpAdaptor; + template + using GenericAdaptor = GenericCallOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("callee")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getCalleeAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getCalleeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.generic_call"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInputs(); + ::mlir::MutableOperandRange getInputsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); + void setCalleeAttr(::mlir::FlatSymbolRefAttr attr); + void setCallee(::llvm::StringRef attrValue); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef callee, ArrayRef arguments); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class MulOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + MulOpGenericAdaptorBase(MulOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class MulOpGenericAdaptor : public detail::MulOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::MulOpGenericAdaptorBase; +public: + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : MulOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + MulOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class MulOpAdaptor : public MulOpGenericAdaptor<::mlir::ValueRange> { +public: + using MulOpGenericAdaptor::MulOpGenericAdaptor; + MulOpAdaptor(MulOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class MulOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants> { +public: + using Op::Op; + using Op::print; + using Adaptor = MulOpAdaptor; + template + using GenericAdaptor = MulOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.mul"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class PrintOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + PrintOpGenericAdaptorBase(PrintOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class PrintOpGenericAdaptor : public detail::PrintOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::PrintOpGenericAdaptorBase; +public: + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : PrintOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + PrintOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class PrintOpAdaptor : public PrintOpGenericAdaptor<::mlir::ValueRange> { +public: + using PrintOpGenericAdaptor::PrintOpGenericAdaptor; + PrintOpAdaptor(PrintOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class PrintOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = PrintOpAdaptor; + template + using GenericAdaptor = PrintOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.print"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReshapeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReshapeOpGenericAdaptorBase(ReshapeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReshapeOpGenericAdaptor : public detail::ReshapeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReshapeOpGenericAdaptorBase; +public: + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReshapeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReshapeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReshapeOpAdaptor : public ReshapeOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReshapeOpGenericAdaptor::ReshapeOpGenericAdaptor; + ReshapeOpAdaptor(ReshapeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReshapeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReshapeOpAdaptor; + template + using GenericAdaptor = ReshapeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.reshape"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReturnOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReturnOpGenericAdaptorBase(ReturnOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReturnOpGenericAdaptor : public detail::ReturnOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReturnOpGenericAdaptorBase; +public: + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReturnOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReturnOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInput() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReturnOpAdaptor : public ReturnOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReturnOpGenericAdaptor::ReturnOpGenericAdaptor; + ReturnOpAdaptor(ReturnOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReturnOp : public ::mlir::Op::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::IsTerminator> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReturnOpAdaptor; + template + using GenericAdaptor = ReturnOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.return"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInput(); + ::mlir::MutableOperandRange getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: + bool hasOperand() { return getNumOperands() != 0; } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class TransposeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + TransposeOpGenericAdaptorBase(TransposeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class TransposeOpGenericAdaptor : public detail::TransposeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::TransposeOpGenericAdaptorBase; +public: + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : TransposeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + TransposeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class TransposeOpAdaptor : public TransposeOpGenericAdaptor<::mlir::ValueRange> { +public: + using TransposeOpGenericAdaptor::TransposeOpGenericAdaptor; + TransposeOpAdaptor(TransposeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class TransposeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants> { +public: + using Op::Op; + using Op::print; + using Adaptor = TransposeOpAdaptor; + template + using GenericAdaptor = TransposeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.transpose"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch2/include/toy/Ops.td b/Ch2/include/toy/Ops.td new file mode 100644 index 0000000..1a1b136 --- /dev/null +++ b/Ch2/include/toy/Ops.td @@ -0,0 +1,335 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "::mlir::toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'Pure' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [Pure]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> + : tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<(ins "DenseElementsAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<(ins "double":$value)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : Toy_Op<"add"> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : Toy_Op<"func", [ + FunctionOpInterface, IsolatedFromAbove + ]> { + let summary = "user defined function operation"; + let description = [{ + The "toy.func" operation represents a user defined function. These are + callable SSA-region operations that contain toy computations. + + Example: + + ```mlir + toy.func @main() { + %0 = toy.constant dense<5.500000e+00> : tensor + %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> + toy.print %1 : tensor<2x2xf64> + toy.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } + }]; + + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +def GenericCallOp : Toy_Op<"generic_call"> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = toy.generic_call @my_func(%1, %3) + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Specialize assembly printing and parsing using a declarative format. + let assemblyFormat = [{ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + ]; +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : Toy_Op<"mul"> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); + + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def ReshapeOp : Toy_Op<"reshape"> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, + Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + toy.func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // The return operation only emits the input in the format if it is present. + let assemblyFormat = "($input^ `:` type($input))? attr-dict "; + + // Allow building a ReturnOp with no return operand. + let builders = [ + OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + ]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def TransposeOp : Toy_Op<"transpose"> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; + + // Invoke a static verify method to verify this transpose operation. + let hasVerifier = 1; +} + +#endif // TOY_OPS diff --git a/Ch2/include/toy/Parser.h b/Ch2/include/toy/Parser.h new file mode 100644 index 0000000..1f20616 --- /dev/null +++ b/Ch2/include/toy/Parser.h @@ -0,0 +1,489 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PARSER_H +#define TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + std::optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + + if (lexer.getCurToken() != tok_def) + return parseError("def", "in prototype"); + lexer.consume(tok_def); + + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName(lexer.getId()); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name(lexer.getId()); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError(")", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // TOY_PARSER_H diff --git a/Ch2/include/toy/run.sh b/Ch2/include/toy/run.sh new file mode 100644 index 0000000..b123ce1 --- /dev/null +++ b/Ch2/include/toy/run.sh @@ -0,0 +1,4 @@ +mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include Ops.td > Ops.h.inc +mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include Ops.td > Ops.cpp.inc +mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include Ops.td > Dialect.h.inc +mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include Ops.td > Dialect.cpp.inc \ No newline at end of file diff --git a/Ch2/mlir/Dialect.cpp b/Ch2/mlir/Dialect.cpp new file mode 100644 index 0000000..d35bd9f --- /dev/null +++ b/Ch2/mlir/Dialect.cpp @@ -0,0 +1,323 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include + +using namespace mlir; +using namespace mlir::toy; + +#include "toy/Dialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void ToyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// A generalized parser for binary operations. This parses the different forms +/// of 'printBinaryOp' below. +static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + SmallVector operands; + SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (FunctionType funcType = llvm::dyn_cast(type)) { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// A generalized printer for binary operations. It prints in two different +/// forms depending on if all of the types match. +static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })) { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder.getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// The 'OpAsmParser' class provides a collection of methods for parsing +/// various punctuation, as well as attributes, operands, types, etc. Each of +/// these methods returns a `ParseResult`. This class is a wrapper around +/// `LogicalResult` that can be converted to a boolean `true` value on failure, +/// or `false` on success. This allows for easily chaining together a set of +/// parser rules. These rules are used to populate an `mlir::OperationState` +/// similarly to the `build` methods described above. +mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::DenseElementsAttr value; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(value, "value", result.attributes)) + return failure(); + + result.addTypes(value.getType()); + return success(); +} + +/// The 'OpAsmPrinter' class is a stream that allows for formatting +/// strings, attributes, operands, types, etc. +void ConstantOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); + printer << getValue(); +} + +/// Verifier for the constant operation. This corresponds to the +/// `let hasVerifier = 1` in the op definition. +mlir::LogicalResult ConstantOp::verify() { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(getResult().getType()); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = llvm::cast(getValue().getType()); + if (attrType.getRank() != resultType.getRank()) { + return emitOpError("return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs) { + // FunctionOpInterface provides a convenient `build` method that will populate + // the state of our FuncOp, and create an entry block. + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + // Dispatch to the FunctionOpInterface provided utility method that parses the + // function operation. + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(mlir::OpAsmPrinter &p) { + // Dispatch to the FunctionOpInterface provided utility method that prints the + // function operation. + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult ReturnOp::verify() { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast((*this)->getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (getNumOperands() > 1) + return emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError() << "does not return the same number of values (" + << getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!hasOperand()) + return mlir::success(); + + auto inputType = *operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || llvm::isa(inputType) || + llvm::isa(resultType)) + return mlir::success(); + + return emitError() << "type of return operand (" << inputType + << ") doesn't match function result type (" << resultType + << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(value); +} + +mlir::LogicalResult TransposeOp::verify() { + auto inputType = llvm::dyn_cast(getOperand().getType()); + auto resultType = llvm::dyn_cast(getType()); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/Ch2/mlir/MLIRGen.cpp b/Ch2/mlir/MLIRGen.cpp new file mode 100644 index 0000000..2f0a88f --- /dev/null +++ b/Ch2/mlir/MLIRGen.cpp @@ -0,0 +1,457 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &f : moduleAST) + mlirGen(f); + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(const Location &loc) { + return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::toy::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector argTypes(proto.getArgs().size(), + getType(VarType{})); + auto funcType = builder.getFunctionType(argTypes, std::nullopt); + return builder.create(location, proto.getName(), + funcType); + } + + /// Emit a new function and add it to the MLIR module. + mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope varScope(symbolTable); + + // Create an MLIR function for the given prototype. + builder.setInsertionPointToEnd(theModule.getBody()); + mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); + if (!function) + return nullptr; + + // Let's start the body of the function now! + mlir::Block &entryBlock = function.front(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto nameValue : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(nameValue)->getName(), + std::get<1>(nameValue)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType( + function.getFunctionType().getInputs(), getType(VarType{}))); + } + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().has_value()) { + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, + expr ? ArrayRef(expr) : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builtin calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to + // user-defined functions are mapped to a custom call that takes the callee + // name as an attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto *init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope varScope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/Ch2/parser/AST.cpp b/Ch2/parser/AST.cpp new file mode 100644 index 0000000..2546f2a --- /dev/null +++ b/Ch2/parser/AST.cpp @@ -0,0 +1,237 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template +static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + llvm::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto *num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + llvm::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + llvm::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().has_value()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + llvm::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; + indent(); + llvm::errs() << "Params: ["; + llvm::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/Ch2/toyc.cpp b/Ch2/toyc.cpp new file mode 100644 index 0000000..e33b49b --- /dev/null +++ b/Ch2/toyc.cpp @@ -0,0 +1,145 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include +#include +#include +#include + +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} // namespace +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR }; +} // namespace +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); + Parser parser(lexer); + return parser.parseModule(); +} + +int dumpMLIR() { + mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).ends_with(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + mlir::OwningOpRef module = mlirGen(context, *moduleAST); + if (!module) + return 1; + + module->dump(); + return 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + // Register any command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/Ch3/CMakeLists.txt b/Ch3/CMakeLists.txt new file mode 100644 index 0000000..d8b6bf9 --- /dev/null +++ b/Ch3/CMakeLists.txt @@ -0,0 +1,36 @@ +# For a better template to copy, see examples/standalone +include_directories(include) +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Support + ) + +# set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +# mlir_tablegen(ToyCombine.inc -gen-rewriters) +# add_public_tablegen_target(ToyCh3CombineIncGen) + +add_executable(toyc-ch3 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/ToyCombine.cpp + + # DEPENDS + # ToyCh3OpsIncGen + # ToyCh3CombineIncGen + ) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch3 + PRIVATE + MLIRAnalysis + MLIRFunctionInterfaces + MLIRIR + MLIRParser + MLIRPass + MLIRSideEffectInterfaces + MLIRTransforms) + diff --git a/Ch3/include/CMakeLists.txt b/Ch3/include/CMakeLists.txt new file mode 100644 index 0000000..37c89d0 --- /dev/null +++ b/Ch3/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/Ch3/include/toy/AST.h b/Ch3/include/toy/AST.h new file mode 100644 index 0000000..d2ba101 --- /dev/null +++ b/Ch3/include/toy/AST.h @@ -0,0 +1,246 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_AST_H +#define TOY_AST_H + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(std::move(location)) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double val; + +public: + NumberExprAST(Location loc, double val) + : ExprAST(Expr_Num, std::move(loc)), val(val) {} + + double getValue() { return val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, std::move(loc)), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + std::optional> expr; + +public: + ReturnExprAST(Location loc, std::optional> expr) + : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} + + std::optional getExpr() { + if (expr.has_value()) + return expr->get(); + return std::nullopt; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, std::move(loc)), callee(callee), + args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(std::move(location)), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() { return functions.begin(); } + auto end() { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // TOY_AST_H diff --git a/Ch3/include/toy/CMakeLists.txt b/Ch3/include/toy/CMakeLists.txt new file mode 100644 index 0000000..b8baa27 --- /dev/null +++ b/Ch3/include/toy/CMakeLists.txt @@ -0,0 +1,6 @@ +# set(LLVM_TARGET_DEFINITIONS Ops.td) +# mlir_tablegen(Ops.h.inc -gen-op-decls) +# mlir_tablegen(Ops.cpp.inc -gen-op-defs) +# mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +# mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +# add_public_tablegen_target(ToyCh3OpsIncGen) diff --git a/Ch3/include/toy/Dialect.cpp.inc b/Ch3/include/toy/Dialect.cpp.inc new file mode 100644 index 0000000..8cbc772 --- /dev/null +++ b/Ch3/include/toy/Dialect.cpp.inc @@ -0,0 +1,23 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) +namespace mlir { +namespace toy { + +ToyDialect::ToyDialect(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get()) { + + initialize(); +} + +ToyDialect::~ToyDialect() = default; + +} // namespace toy +} // namespace mlir diff --git a/Ch3/include/toy/Dialect.h b/Ch3/include/toy/Dialect.h new file mode 100644 index 0000000..292f50f --- /dev/null +++ b/Ch3/include/toy/Dialect.h @@ -0,0 +1,33 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See docs/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +/// Include the auto-generated header file containing the declaration of the toy +/// dialect. +#include "toy/Dialect.h.inc" + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/Ch3/include/toy/Dialect.h.inc b/Ch3/include/toy/Dialect.h.inc new file mode 100644 index 0000000..f19d867 --- /dev/null +++ b/Ch3/include/toy/Dialect.h.inc @@ -0,0 +1,26 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +namespace mlir { +namespace toy { + +class ToyDialect : public ::mlir::Dialect { + explicit ToyDialect(::mlir::MLIRContext *context); + + void initialize(); + friend class ::mlir::MLIRContext; +public: + ~ToyDialect() override; + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("toy"); + } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) diff --git a/Ch3/include/toy/Lexer.h b/Ch3/include/toy/Lexer.h new file mode 100644 index 0000000..3c59cd9 --- /dev/null +++ b/Ch3/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_LEXER_H +#define TOY_LEXER_H + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // TOY_LEXER_H diff --git a/Ch3/include/toy/MLIRGen.h b/Ch3/include/toy/MLIRGen.h new file mode 100644 index 0000000..fe9dbe5 --- /dev/null +++ b/Ch3/include/toy/MLIRGen.h @@ -0,0 +1,35 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_MLIRGEN_H +#define TOY_MLIRGEN_H + +#include + +namespace mlir { +class MLIRContext; +template +class OwningOpRef; +class ModuleOp; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // TOY_MLIRGEN_H diff --git a/Ch3/include/toy/Ops.cpp.inc b/Ch3/include/toy/Ops.cpp.inc new file mode 100644 index 0000000..555fb7a --- /dev/null +++ b/Ch3/include/toy/Ops.cpp.inc @@ -0,0 +1,2061 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifdef GET_OP_LIST +#undef GET_OP_LIST + +::mlir::toy::AddOp, +::mlir::toy::ConstantOp, +::mlir::toy::FuncOp, +::mlir::toy::GenericCallOp, +::mlir::toy::MulOp, +::mlir::toy::PrintOp, +::mlir::toy::ReshapeOp, +::mlir::toy::ReturnOp, +::mlir::toy::TransposeOp +#endif // GET_OP_LIST + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be variadic of tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::RankedTensorType>(type))) && ((::llvm::cast<::mlir::ShapedType>(type).hasStaticShape()))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be statically shaped tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::DenseFPElementsAttr>(attr) &&::llvm::cast<::mlir::DenseElementsAttr>(attr).getType().getElementType().isF64()))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: 64-bit float elements attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops0(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::StringAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: string attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops1(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::TypeAttr>(attr))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: type attribute of function type"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops2(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::ArrayAttr>(attr))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(attr), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: Array of dictionary attributes"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops3(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: flat symbol reference attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops4(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_region_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, + unsigned regionIndex) { + if (!((true))) { + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: any region"; + } + return ::mlir::success(); +} +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.add", odsAttrs.getContext()); +} + +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(AddOp op) : AddOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair AddOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr AddOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +AddOpAdaptor::AddOpAdaptor(AddOp op) : AddOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult AddOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair AddOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range AddOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &AddOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &AddOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair AddOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range AddOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void AddOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult AddOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult AddOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void AddOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.constant", odsAttrs.getContext()); +} + +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(ConstantOp op) : ConstantOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ConstantOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ConstantOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValueAttr() { + auto attr = ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); + return attr; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValue() { + auto attr = getValueAttr(); + return attr; +} + +} // namespace detail +ConstantOpAdaptor::ConstantOpAdaptor(ConstantOp op) : ConstantOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ConstantOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitError(loc, "'toy.constant' op ""requires attribute 'value'"); + + if (tblgen_value && !((::llvm::isa<::mlir::DenseFPElementsAttr>(tblgen_value) &&::llvm::cast<::mlir::DenseElementsAttr>(tblgen_value).getType().getElementType().isF64()))) + return emitError(loc, "'toy.constant' op ""attribute 'value' failed to satisfy constraint: 64-bit float elements attribute"); + return ::mlir::success(); +} + +std::pair ConstantOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ConstantOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair ConstantOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ConstantOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult ConstantOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.value; + auto attr = dict.get("value"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for value in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `value` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute ConstantOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.value; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("value", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code ConstantOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.value.getAsOpaquePointer())); +} + +std::optional ConstantOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "value") + return prop.value; + return std::nullopt; +} + +void ConstantOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "value") { + prop.value = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void ConstantOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.value) attrs.append("value", prop.value); +} + +::mlir::LogicalResult ConstantOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getValueAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(attr, "value", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.value))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ConstantOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.value); +} + +::mlir::DenseElementsAttr ConstantOp::getValueAttr() { + return ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); +} + +::mlir::DenseElementsAttr ConstantOp::getValue() { + auto attr = getValueAttr(); + return attr; +} + +void ConstantOp::setValueAttr(::mlir::DenseElementsAttr attr) { + (*this)->setAttr(getValueAttrName(), attr); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value) { + build(odsBuilder, odsState, value.getType(), value); + +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + odsState.addTypes(resultType0); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ConstantOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 0u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ConstantOp::verifyInvariantsImpl() { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitOpError("requires attribute 'value'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(*this, tblgen_value, "value"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +void ConstantOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.func", odsAttrs.getContext()); +} + +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(FuncOp op) : FuncOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair FuncOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr FuncOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::StringAttr FuncOpGenericAdaptorBase::getSymNameAttr() { + auto attr = ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); + return attr; +} + +::llvm::StringRef FuncOpGenericAdaptorBase::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOpGenericAdaptorBase::getFunctionTypeAttr() { + auto attr = ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); + return attr; +} + +::mlir::FunctionType FuncOpGenericAdaptorBase::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getArgAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getResAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::Region &FuncOpGenericAdaptorBase::getBody() { + return *odsRegions[0]; +} + +::mlir::RegionRange FuncOpGenericAdaptorBase::getRegions() { + return odsRegions; +} + +} // namespace detail +FuncOpAdaptor::FuncOpAdaptor(FuncOp op) : FuncOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult FuncOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitError(loc, "'toy.func' op ""requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitError(loc, "'toy.func' op ""requires attribute 'sym_name'"); + + if (tblgen_sym_name && !((::llvm::isa<::mlir::StringAttr>(tblgen_sym_name)))) + return emitError(loc, "'toy.func' op ""attribute 'sym_name' failed to satisfy constraint: string attribute"); + + if (tblgen_function_type && !(((::llvm::isa<::mlir::TypeAttr>(tblgen_function_type))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))))) + return emitError(loc, "'toy.func' op ""attribute 'function_type' failed to satisfy constraint: type attribute of function type"); + + if (tblgen_arg_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_arg_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_arg_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'arg_attrs' failed to satisfy constraint: Array of dictionary attributes"); + + if (tblgen_res_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_res_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_res_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'res_attrs' failed to satisfy constraint: Array of dictionary attributes"); + return ::mlir::success(); +} + +std::pair FuncOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range FuncOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair FuncOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range FuncOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Region &FuncOp::getBody() { + return (*this)->getRegion(0); +} + +::mlir::LogicalResult FuncOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.arg_attrs; + auto attr = dict.get("arg_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for arg_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `arg_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.function_type; + auto attr = dict.get("function_type"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for function_type in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `function_type` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.res_attrs; + auto attr = dict.get("res_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for res_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `res_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.sym_name; + auto attr = dict.get("sym_name"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for sym_name in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `sym_name` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute FuncOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.arg_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("arg_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.function_type; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("function_type", + propStorage)); + } + + { + const auto &propStorage = prop.res_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("res_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.sym_name; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("sym_name", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code FuncOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.arg_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.function_type.getAsOpaquePointer()), + llvm::hash_value(prop.res_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.sym_name.getAsOpaquePointer())); +} + +std::optional FuncOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "arg_attrs") + return prop.arg_attrs; + + if (name == "function_type") + return prop.function_type; + + if (name == "res_attrs") + return prop.res_attrs; + + if (name == "sym_name") + return prop.sym_name; + return std::nullopt; +} + +void FuncOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "arg_attrs") { + prop.arg_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "function_type") { + prop.function_type = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "res_attrs") { + prop.res_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "sym_name") { + prop.sym_name = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void FuncOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.arg_attrs) attrs.append("arg_attrs", prop.arg_attrs); + + if (prop.function_type) attrs.append("function_type", prop.function_type); + + if (prop.res_attrs) attrs.append("res_attrs", prop.res_attrs); + + if (prop.sym_name) attrs.append("sym_name", prop.sym_name); +} + +::mlir::LogicalResult FuncOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getArgAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "arg_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getFunctionTypeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(attr, "function_type", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getResAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "res_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getSymNameAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(attr, "sym_name", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readOptionalAttribute(prop.arg_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.function_type))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readOptionalAttribute(prop.res_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.sym_name))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void FuncOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + + writer.writeOptionalAttribute(prop.arg_attrs); + writer.writeAttribute(prop.function_type); + + writer.writeOptionalAttribute(prop.res_attrs); + writer.writeAttribute(prop.sym_name); +} + +::mlir::StringAttr FuncOp::getSymNameAttr() { + return ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); +} + +::llvm::StringRef FuncOp::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOp::getFunctionTypeAttr() { + return ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); +} + +::mlir::FunctionType FuncOp::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOp::getArgAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOp::getResAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +void FuncOp::setSymNameAttr(::mlir::StringAttr attr) { + (*this)->setAttr(getSymNameAttrName(), attr); +} + +void FuncOp::setSymName(::llvm::StringRef attrValue) { + (*this)->setAttr(getSymNameAttrName(), ::mlir::Builder((*this)->getContext()).getStringAttr(attrValue)); +} + +void FuncOp::setFunctionTypeAttr(::mlir::TypeAttr attr) { + (*this)->setAttr(getFunctionTypeAttrName(), attr); +} + +void FuncOp::setFunctionType(::mlir::FunctionType attrValue) { + (*this)->setAttr(getFunctionTypeAttrName(), ::mlir::TypeAttr::get(attrValue)); +} + +void FuncOp::setArgAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getArgAttrsAttrName(), attr); +} + +void FuncOp::setResAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getResAttrsAttrName(), attr); +} + +::mlir::Attribute FuncOp::removeArgAttrsAttr() { + auto &attr = getProperties().arg_attrs; + attr = {}; + return attr; +} + +::mlir::Attribute FuncOp::removeResAttrsAttr() { + auto &attr = getProperties().res_attrs; + attr = {}; + return attr; +} + +::mlir::LogicalResult FuncOp::verifyInvariantsImpl() { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitOpError("requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitOpError("requires attribute 'sym_name'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(*this, tblgen_sym_name, "sym_name"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(*this, tblgen_function_type, "function_type"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_arg_attrs, "arg_attrs"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_res_attrs, "res_attrs"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + + for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(0))) + if (::mlir::failed(__mlir_ods_local_region_constraint_Ops0(*this, region, "body", index++))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.generic_call", odsAttrs.getContext()); +} + +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(GenericCallOp op) : GenericCallOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair GenericCallOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr GenericCallOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::FlatSymbolRefAttr GenericCallOpGenericAdaptorBase::getCalleeAttr() { + auto attr = ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); + return attr; +} + +::llvm::StringRef GenericCallOpGenericAdaptorBase::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +} // namespace detail +GenericCallOpAdaptor::GenericCallOpAdaptor(GenericCallOp op) : GenericCallOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult GenericCallOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitError(loc, "'toy.generic_call' op ""requires attribute 'callee'"); + + if (tblgen_callee && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(tblgen_callee)))) + return emitError(loc, "'toy.generic_call' op ""attribute 'callee' failed to satisfy constraint: flat symbol reference attribute"); + return ::mlir::success(); +} + +std::pair GenericCallOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range GenericCallOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range GenericCallOp::getInputs() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange GenericCallOp::getInputsMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair GenericCallOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range GenericCallOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult GenericCallOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.callee; + auto attr = dict.get("callee"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for callee in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `callee` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute GenericCallOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.callee; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("callee", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code GenericCallOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.callee.getAsOpaquePointer())); +} + +std::optional GenericCallOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "callee") + return prop.callee; + return std::nullopt; +} + +void GenericCallOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "callee") { + prop.callee = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void GenericCallOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.callee) attrs.append("callee", prop.callee); +} + +::mlir::LogicalResult GenericCallOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getCalleeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(attr, "callee", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.callee))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.callee); +} + +::mlir::FlatSymbolRefAttr GenericCallOp::getCalleeAttr() { + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); +} + +::llvm::StringRef GenericCallOp::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +void GenericCallOp::setCalleeAttr(::mlir::FlatSymbolRefAttr attr) { + (*this)->setAttr(getCalleeAttrName(), attr); +} + +void GenericCallOp::setCallee(::llvm::StringRef attrValue) { + (*this)->setAttr(getCalleeAttrName(), ::mlir::SymbolRefAttr::get(::mlir::Builder((*this)->getContext()).getContext(), attrValue)); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariantsImpl() { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitOpError("requires attribute 'callee'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(*this, tblgen_callee, "callee"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult GenericCallOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::FlatSymbolRefAttr calleeAttr; + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputsOperands; + ::llvm::SMLoc inputsOperandsLoc; + (void)inputsOperandsLoc; + ::llvm::ArrayRef<::mlir::Type> inputsTypes; + ::llvm::ArrayRef<::mlir::Type> allResultTypes; + + if (parser.parseCustomAttributeWithFallback(calleeAttr, parser.getBuilder().getType<::mlir::NoneType>())) { + return ::mlir::failure(); + } + if (calleeAttr) result.getOrAddProperties().callee = calleeAttr; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands)) + return ::mlir::failure(); + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + ::mlir::FunctionType inputs__allResult_functionType; + if (parser.parseType(inputs__allResult_functionType)) + return ::mlir::failure(); + inputsTypes = inputs__allResult_functionType.getInputs(); + allResultTypes = inputs__allResult_functionType.getResults(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter.printAttributeWithoutType(getCalleeAttr()); + _odsPrinter << "("; + _odsPrinter << getInputs(); + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + elidedAttrs.push_back("callee"); + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter.printFunctionalType(getInputs().getTypes(), getOperation()->getResultTypes()); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.mul", odsAttrs.getContext()); +} + +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(MulOp op) : MulOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair MulOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr MulOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +MulOpAdaptor::MulOpAdaptor(MulOp op) : MulOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult MulOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair MulOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range MulOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &MulOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &MulOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair MulOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range MulOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void MulOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult MulOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult MulOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void MulOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.print", odsAttrs.getContext()); +} + +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(PrintOp op) : PrintOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair PrintOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr PrintOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +PrintOpAdaptor::PrintOpAdaptor(PrintOp op) : PrintOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult PrintOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair PrintOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range PrintOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> PrintOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &PrintOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair PrintOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range PrintOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input) { + odsState.addOperands(input); +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 0u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void PrintOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult PrintOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult PrintOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult PrintOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void PrintOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.reshape", odsAttrs.getContext()); +} + +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(ReshapeOp op) : ReshapeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReshapeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ReshapeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReshapeOpAdaptor::ReshapeOpAdaptor(ReshapeOp op) : ReshapeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReshapeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReshapeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ReshapeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> ReshapeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &ReshapeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair ReshapeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReshapeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ReshapeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops2(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult ReshapeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReshapeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void ReshapeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.return", odsAttrs.getContext()); +} + +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(ReturnOp op) : ReturnOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReturnOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr ReturnOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReturnOpAdaptor::ReturnOpAdaptor(ReturnOp op) : ReturnOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReturnOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReturnOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range ReturnOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range ReturnOp::getInput() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange ReturnOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair ReturnOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReturnOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState) { + build(odsBuilder, odsState, std::nullopt); +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input) { + odsState.addOperands(input); +} + +void ReturnOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReturnOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReturnOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult ReturnOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputOperands; + ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::llvm::SmallVector<::mlir::Type, 1> inputTypes; + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputOperands)) + return ::mlir::failure(); + if (!inputOperands.empty()) { + if (parser.parseColon()) + return ::mlir::failure(); + + if (parser.parseTypeList(inputTypes)) + return ::mlir::failure(); + } + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReturnOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + if (!getInput().empty()) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter << getInput().getTypes(); + } + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); +} + +void ReturnOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.transpose", odsAttrs.getContext()); +} + +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(TransposeOp op) : TransposeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair TransposeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr TransposeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +TransposeOpAdaptor::TransposeOpAdaptor(TransposeOp op) : TransposeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult TransposeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair TransposeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range TransposeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> TransposeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &TransposeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair TransposeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range TransposeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void TransposeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult TransposeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult TransposeOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult TransposeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void TransposeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void TransposeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch3/include/toy/Ops.h.inc b/Ch3/include/toy/Ops.h.inc new file mode 100644 index 0000000..65c4be8 --- /dev/null +++ b/Ch3/include/toy/Ops.h.inc @@ -0,0 +1,1247 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES) +#undef GET_OP_FWD_DEFINES +namespace mlir { +namespace toy { +class AddOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ConstantOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class FuncOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class GenericCallOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class MulOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class PrintOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReshapeOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReturnOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class TransposeOp; +} // namespace toy +} // namespace mlir +#endif + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class AddOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + AddOpGenericAdaptorBase(AddOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class AddOpGenericAdaptor : public detail::AddOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::AddOpGenericAdaptorBase; +public: + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + AddOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class AddOpAdaptor : public AddOpGenericAdaptor<::mlir::ValueRange> { +public: + using AddOpGenericAdaptor::AddOpGenericAdaptor; + AddOpAdaptor(AddOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class AddOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = AddOpAdaptor; + template + using GenericAdaptor = AddOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.add"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ConstantOpGenericAdaptorBase { +public: + struct Properties { + using valueTy = ::mlir::DenseElementsAttr; + valueTy value; + + auto getValue() { + auto &propStorage = this->value; + return ::llvm::cast<::mlir::DenseElementsAttr>(propStorage); + } + void setValue(const ::mlir::DenseElementsAttr &propValue) { + this->value = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.value == this->value && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + ConstantOpGenericAdaptorBase(ConstantOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); +}; +} // namespace detail +template +class ConstantOpGenericAdaptor : public detail::ConstantOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ConstantOpGenericAdaptorBase; +public: + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ConstantOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + ConstantOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ConstantOpAdaptor : public ConstantOpGenericAdaptor<::mlir::ValueRange> { +public: + using ConstantOpGenericAdaptor::ConstantOpGenericAdaptor; + ConstantOpAdaptor(ConstantOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ConstantOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ConstantOpAdaptor; + template + using GenericAdaptor = ConstantOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getValueAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.constant"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); + void setValueAttr(::mlir::DenseElementsAttr attr); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class FuncOpGenericAdaptorBase { +public: + struct Properties { + using arg_attrsTy = ::mlir::ArrayAttr; + arg_attrsTy arg_attrs; + + auto getArgAttrs() { + auto &propStorage = this->arg_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setArgAttrs(const ::mlir::ArrayAttr &propValue) { + this->arg_attrs = propValue; + } + using function_typeTy = ::mlir::TypeAttr; + function_typeTy function_type; + + auto getFunctionType() { + auto &propStorage = this->function_type; + return ::llvm::cast<::mlir::TypeAttr>(propStorage); + } + void setFunctionType(const ::mlir::TypeAttr &propValue) { + this->function_type = propValue; + } + using res_attrsTy = ::mlir::ArrayAttr; + res_attrsTy res_attrs; + + auto getResAttrs() { + auto &propStorage = this->res_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setResAttrs(const ::mlir::ArrayAttr &propValue) { + this->res_attrs = propValue; + } + using sym_nameTy = ::mlir::StringAttr; + sym_nameTy sym_name; + + auto getSymName() { + auto &propStorage = this->sym_name; + return ::llvm::cast<::mlir::StringAttr>(propStorage); + } + void setSymName(const ::mlir::StringAttr &propValue) { + this->sym_name = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.arg_attrs == this->arg_attrs && + rhs.function_type == this->function_type && + rhs.res_attrs == this->res_attrs && + rhs.sym_name == this->sym_name && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + FuncOpGenericAdaptorBase(FuncOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + ::mlir::Region &getBody(); + ::mlir::RegionRange getRegions(); +}; +} // namespace detail +template +class FuncOpGenericAdaptor : public detail::FuncOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::FuncOpGenericAdaptorBase; +public: + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : FuncOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + FuncOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class FuncOpAdaptor : public FuncOpGenericAdaptor<::mlir::ValueRange> { +public: + using FuncOpGenericAdaptor::FuncOpGenericAdaptor; + FuncOpAdaptor(FuncOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class FuncOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = FuncOpAdaptor; + template + using GenericAdaptor = FuncOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("arg_attrs"), ::llvm::StringRef("function_type"), ::llvm::StringRef("res_attrs"), ::llvm::StringRef("sym_name")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getArgAttrsAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getArgAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + ::mlir::StringAttr getFunctionTypeAttrName() { + return getAttributeNameForIndex(1); + } + + static ::mlir::StringAttr getFunctionTypeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 1); + } + + ::mlir::StringAttr getResAttrsAttrName() { + return getAttributeNameForIndex(2); + } + + static ::mlir::StringAttr getResAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 2); + } + + ::mlir::StringAttr getSymNameAttrName() { + return getAttributeNameForIndex(3); + } + + static ::mlir::StringAttr getSymNameAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 3); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.func"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::Region &getBody(); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + void setSymNameAttr(::mlir::StringAttr attr); + void setSymName(::llvm::StringRef attrValue); + void setFunctionTypeAttr(::mlir::TypeAttr attr); + void setFunctionType(::mlir::FunctionType attrValue); + void setArgAttrsAttr(::mlir::ArrayAttr attr); + void setResAttrsAttr(::mlir::ArrayAttr attr); + ::mlir::Attribute removeArgAttrsAttr(); + ::mlir::Attribute removeResAttrsAttr(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef name, FunctionType type, ArrayRef attrs = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 4 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the current operation that is callable. + ::mlir::Region *getCallableRegion() { return &getBody(); } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class GenericCallOpGenericAdaptorBase { +public: + struct Properties { + using calleeTy = ::mlir::FlatSymbolRefAttr; + calleeTy callee; + + auto getCallee() { + auto &propStorage = this->callee; + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(propStorage); + } + void setCallee(const ::mlir::FlatSymbolRefAttr &propValue) { + this->callee = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.callee == this->callee && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + GenericCallOpGenericAdaptorBase(GenericCallOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); +}; +} // namespace detail +template +class GenericCallOpGenericAdaptor : public detail::GenericCallOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::GenericCallOpGenericAdaptorBase; +public: + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : GenericCallOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + GenericCallOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInputs() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class GenericCallOpAdaptor : public GenericCallOpGenericAdaptor<::mlir::ValueRange> { +public: + using GenericCallOpGenericAdaptor::GenericCallOpGenericAdaptor; + GenericCallOpAdaptor(GenericCallOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class GenericCallOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::VariadicOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = GenericCallOpAdaptor; + template + using GenericAdaptor = GenericCallOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("callee")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getCalleeAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getCalleeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.generic_call"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInputs(); + ::mlir::MutableOperandRange getInputsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); + void setCalleeAttr(::mlir::FlatSymbolRefAttr attr); + void setCallee(::llvm::StringRef attrValue); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef callee, ArrayRef arguments); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class MulOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + MulOpGenericAdaptorBase(MulOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class MulOpGenericAdaptor : public detail::MulOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::MulOpGenericAdaptorBase; +public: + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : MulOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + MulOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class MulOpAdaptor : public MulOpGenericAdaptor<::mlir::ValueRange> { +public: + using MulOpGenericAdaptor::MulOpGenericAdaptor; + MulOpAdaptor(MulOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class MulOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = MulOpAdaptor; + template + using GenericAdaptor = MulOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.mul"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class PrintOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + PrintOpGenericAdaptorBase(PrintOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class PrintOpGenericAdaptor : public detail::PrintOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::PrintOpGenericAdaptorBase; +public: + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : PrintOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + PrintOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class PrintOpAdaptor : public PrintOpGenericAdaptor<::mlir::ValueRange> { +public: + using PrintOpGenericAdaptor::PrintOpGenericAdaptor; + PrintOpAdaptor(PrintOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class PrintOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = PrintOpAdaptor; + template + using GenericAdaptor = PrintOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.print"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReshapeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReshapeOpGenericAdaptorBase(ReshapeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReshapeOpGenericAdaptor : public detail::ReshapeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReshapeOpGenericAdaptorBase; +public: + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReshapeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReshapeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReshapeOpAdaptor : public ReshapeOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReshapeOpGenericAdaptor::ReshapeOpGenericAdaptor; + ReshapeOpAdaptor(ReshapeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReshapeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReshapeOpAdaptor; + template + using GenericAdaptor = ReshapeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.reshape"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReturnOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReturnOpGenericAdaptorBase(ReturnOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReturnOpGenericAdaptor : public detail::ReturnOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReturnOpGenericAdaptorBase; +public: + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReturnOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReturnOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInput() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReturnOpAdaptor : public ReturnOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReturnOpGenericAdaptor::ReturnOpGenericAdaptor; + ReturnOpAdaptor(ReturnOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReturnOp : public ::mlir::Op::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::IsTerminator> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReturnOpAdaptor; + template + using GenericAdaptor = ReturnOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.return"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInput(); + ::mlir::MutableOperandRange getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: + bool hasOperand() { return getNumOperands() != 0; } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class TransposeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + TransposeOpGenericAdaptorBase(TransposeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class TransposeOpGenericAdaptor : public detail::TransposeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::TransposeOpGenericAdaptorBase; +public: + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : TransposeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + TransposeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class TransposeOpAdaptor : public TransposeOpGenericAdaptor<::mlir::ValueRange> { +public: + using TransposeOpGenericAdaptor::TransposeOpGenericAdaptor; + TransposeOpAdaptor(TransposeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class TransposeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = TransposeOpAdaptor; + template + using GenericAdaptor = TransposeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.transpose"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch3/include/toy/Ops.td b/Ch3/include/toy/Ops.td new file mode 100644 index 0000000..021802b --- /dev/null +++ b/Ch3/include/toy/Ops.td @@ -0,0 +1,339 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "::mlir::toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'Pure' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [Pure]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> + : tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<(ins "DenseElementsAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<(ins "double":$value)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : Toy_Op<"add", [Pure]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : Toy_Op<"func", [ + FunctionOpInterface, IsolatedFromAbove + ]> { + let summary = "user defined function operation"; + let description = [{ + The "toy.func" operation represents a user defined function. These are + callable SSA-region operations that contain toy computations. + + Example: + + ```mlir + toy.func @main() { + %0 = toy.constant dense<5.500000e+00> : tensor + %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> + toy.print %1 : tensor<2x2xf64> + toy.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the current operation that is callable. + ::mlir::Region *getCallableRegion() { return &getBody(); } + }]; + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +def GenericCallOp : Toy_Op<"generic_call"> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = toy.generic_call @my_func(%1, %3) + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Specialize assembly printing and parsing using a declarative format. + let assemblyFormat = [{ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + ]; +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : Toy_Op<"mul", [Pure]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); + + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def ReshapeOp : Toy_Op<"reshape", [Pure]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, + Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + toy.func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // The return operation only emits the input in the format if it is present. + let assemblyFormat = "($input^ `:` type($input))? attr-dict "; + + // Allow building a ReturnOp with no return operand. + let builders = [ + OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + ]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def TransposeOp : Toy_Op<"transpose", [Pure]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +#endif // TOY_OPS diff --git a/Ch3/include/toy/Parser.h b/Ch3/include/toy/Parser.h new file mode 100644 index 0000000..1f20616 --- /dev/null +++ b/Ch3/include/toy/Parser.h @@ -0,0 +1,489 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PARSER_H +#define TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + std::optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + + if (lexer.getCurToken() != tok_def) + return parseError("def", "in prototype"); + lexer.consume(tok_def); + + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName(lexer.getId()); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name(lexer.getId()); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError(")", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // TOY_PARSER_H diff --git a/Ch3/include/toy/run.sh b/Ch3/include/toy/run.sh new file mode 100644 index 0000000..b123ce1 --- /dev/null +++ b/Ch3/include/toy/run.sh @@ -0,0 +1,4 @@ +mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include Ops.td > Ops.h.inc +mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include Ops.td > Ops.cpp.inc +mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include Ops.td > Dialect.h.inc +mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include Ops.td > Dialect.cpp.inc \ No newline at end of file diff --git a/Ch3/mlir/Dialect.cpp b/Ch3/mlir/Dialect.cpp new file mode 100644 index 0000000..79d82e5 --- /dev/null +++ b/Ch3/mlir/Dialect.cpp @@ -0,0 +1,323 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include + +using namespace mlir; +using namespace mlir::toy; + +#include "toy/Dialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void ToyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// A generalized parser for binary operations. This parses the different forms +/// of 'printBinaryOp' below. +static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + SmallVector operands; + SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (FunctionType funcType = llvm::dyn_cast(type)) { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// A generalized printer for binary operations. It prints in two different +/// forms depending on if all of the types match. +static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })) { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder.getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// The 'OpAsmParser' class provides a collection of methods for parsing +/// various punctuation, as well as attributes, operands, types, etc. Each of +/// these methods returns a `ParseResult`. This class is a wrapper around +/// `LogicalResult` that can be converted to a boolean `true` value on failure, +/// or `false` on success. This allows for easily chaining together a set of +/// parser rules. These rules are used to populate an `mlir::OperationState` +/// similarly to the `build` methods described above. +mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::DenseElementsAttr value; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(value, "value", result.attributes)) + return failure(); + + result.addTypes(value.getType()); + return success(); +} + +/// The 'OpAsmPrinter' class is a stream that allows for formatting +/// strings, attributes, operands, types, etc. +void ConstantOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); + printer << getValue(); +} + +/// Verifier for the constant operation. This corresponds to the +/// `let hasVerifier = 1` in the op definition. +mlir::LogicalResult ConstantOp::verify() { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(getResult().getType()); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = llvm::cast(getValue().getType()); + if (attrType.getRank() != resultType.getRank()) { + return emitOpError("return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs) { + // FunctionOpInterface provides a convenient `build` method that will populate + // the state of our FuncOp, and create an entry block. + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + // Dispatch to the FunctionOpInterface provided utility method that parses the + // function operation. + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(mlir::OpAsmPrinter &p) { + // Dispatch to the FunctionOpInterface provided utility method that prints the + // function operation. + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult ReturnOp::verify() { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast((*this)->getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (getNumOperands() > 1) + return emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError() << "does not return the same number of values (" + << getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!hasOperand()) + return mlir::success(); + + auto inputType = *operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || llvm::isa(inputType) || + llvm::isa(resultType)) + return mlir::success(); + + return emitError() << "type of return operand (" << inputType + << ") doesn't match function result type (" << resultType + << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(value); +} + +mlir::LogicalResult TransposeOp::verify() { + auto inputType = llvm::dyn_cast(getOperand().getType()); + auto resultType = llvm::dyn_cast(getType()); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/Ch3/mlir/MLIRGen.cpp b/Ch3/mlir/MLIRGen.cpp new file mode 100644 index 0000000..2f0a88f --- /dev/null +++ b/Ch3/mlir/MLIRGen.cpp @@ -0,0 +1,457 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &f : moduleAST) + mlirGen(f); + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(const Location &loc) { + return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::toy::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector argTypes(proto.getArgs().size(), + getType(VarType{})); + auto funcType = builder.getFunctionType(argTypes, std::nullopt); + return builder.create(location, proto.getName(), + funcType); + } + + /// Emit a new function and add it to the MLIR module. + mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope varScope(symbolTable); + + // Create an MLIR function for the given prototype. + builder.setInsertionPointToEnd(theModule.getBody()); + mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); + if (!function) + return nullptr; + + // Let's start the body of the function now! + mlir::Block &entryBlock = function.front(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto nameValue : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(nameValue)->getName(), + std::get<1>(nameValue)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType( + function.getFunctionType().getInputs(), getType(VarType{}))); + } + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().has_value()) { + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, + expr ? ArrayRef(expr) : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builtin calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to + // user-defined functions are mapped to a custom call that takes the callee + // name as an attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto *init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope varScope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/Ch3/mlir/ToyCombine.cpp b/Ch3/mlir/ToyCombine.cpp new file mode 100644 index 0000000..3ce35c8 --- /dev/null +++ b/Ch3/mlir/ToyCombine.cpp @@ -0,0 +1,69 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Dialect.h" +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // namespace + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> x +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); + + // Input defined by another transpose? If not, no match. + if (!transposeInputOp) + return failure(); + + // Otherwise, we have a redundant transpose. Use the rewriter. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); + return success(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} diff --git a/Ch3/mlir/ToyCombine.inc b/Ch3/mlir/ToyCombine.inc new file mode 100644 index 0000000..33ec8e1 --- /dev/null +++ b/Ch3/mlir/ToyCombine.inc @@ -0,0 +1,176 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Rewriters *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: ToyCombine.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/* Generated from: + ToyCombine.td:47 +*/ +struct FoldConstantReshapeOptPattern : public ::mlir::RewritePattern { + FoldConstantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.constant"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::DenseElementsAttr arg; + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ConstantOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ConstantOp type"; + }); + } + { + auto tblgen_attr = op1->getAttrOfType<::mlir::DenseElementsAttr>("value");(void)tblgen_attr; + if (!(tblgen_attr)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "expected op 'toy.constant' to have attribute 'value' of type '::mlir::DenseElementsAttr'"; + }); + } + arg = tblgen_attr; + } + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + auto nativeVar_0 = arg.reshape(::llvm::cast((*res.getODSResults(0).begin()).getType())); (void)nativeVar_0; + ::mlir::toy::ConstantOp tblgen_ConstantOp_1; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + if (auto tmpAttr = nativeVar_0) { + tblgen_attrs.emplace_back(rewriter.getStringAttr("value"), tmpAttr); + } + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ConstantOp_1 = rewriter.create<::mlir::toy::ConstantOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ConstantOp_1.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:60 +*/ +struct RedundantReshapeOptPattern : public ::mlir::RewritePattern { + RedundantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 1, context, {}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + arg = castedOp0.getODSOperands(0); + if (!(((*res.getODSResults(0).begin()).getType() == (*arg.begin()).getType()))){ + return rewriter.notifyMatchFailure(op0, [&](::mlir::Diagnostic &diag) { + diag << "entities 'res, arg' failed to satisfy constraint: ''"; + }); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ arg }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:34 +*/ +struct ReshapeReshapeOptPattern : public ::mlir::RewritePattern { + ReshapeReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.reshape"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ReshapeOp type"; + }); + } + arg = castedOp1.getODSOperands(0); + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + ::mlir::toy::ReshapeOp tblgen_ReshapeOp_0; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + tblgen_values.push_back((*arg.begin())); + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ReshapeOp_0 = rewriter.create<::mlir::toy::ReshapeOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ReshapeOp_0.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} diff --git a/Ch3/mlir/ToyCombine.td b/Ch3/mlir/ToyCombine.td new file mode 100644 index 0000000..8bd2b44 --- /dev/null +++ b/Ch3/mlir/ToyCombine.td @@ -0,0 +1,64 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "mlir/IR/PatternBase.td" +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +// list supplementalPatterns = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : Constraint>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/Ch3/mlir/run.sh b/Ch3/mlir/run.sh new file mode 100644 index 0000000..f592fde --- /dev/null +++ b/Ch3/mlir/run.sh @@ -0,0 +1,2 @@ +mlir-tblgen-18 -gen-rewriters -I /usr/lib/llvm-18/include -I ../include ToyCombine.td > ToyCombine.inc + diff --git a/Ch3/parser/AST.cpp b/Ch3/parser/AST.cpp new file mode 100644 index 0000000..2546f2a --- /dev/null +++ b/Ch3/parser/AST.cpp @@ -0,0 +1,237 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template +static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + llvm::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto *num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + llvm::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + llvm::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().has_value()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + llvm::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; + indent(); + llvm::errs() << "Params: ["; + llvm::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/Ch3/toyc.cpp b/Ch3/toyc.cpp new file mode 100644 index 0000000..c2c5f1f --- /dev/null +++ b/Ch3/toyc.cpp @@ -0,0 +1,170 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" + +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} // namespace +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR }; +} // namespace +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningOpRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).ends_with(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return -1; + } + + // Parse the input mlir. + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int dumpMLIR() { + mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + + mlir::OwningOpRef module; + llvm::SourceMgr sourceMgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + if (int error = loadMLIR(sourceMgr, context, module)) + return error; + + if (enableOpt) { + mlir::PassManager pm(module.get()->getName()); + // Apply any generic pass manager command line options and run the pipeline. + if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) + return 4; + + // Add a run of the canonicalizer to optimize the mlir module. + pm.addNestedPass(mlir::createCanonicalizerPass()); + if (mlir::failed(pm.run(*module))) + return 4; + } + + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + // Register any command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/Ch4/CMakeLists.txt b/Ch4/CMakeLists.txt new file mode 100644 index 0000000..0352b99 --- /dev/null +++ b/Ch4/CMakeLists.txt @@ -0,0 +1,40 @@ +# For a better template to copy, see examples/standalone +include_directories(include) +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Support + ) + +# set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +# mlir_tablegen(ToyCombine.inc -gen-rewriters) +# add_public_tablegen_target(ToyCh4CombineIncGen) + +add_executable(toyc-ch4 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + + # DEPENDS + # ToyCh4OpsIncGen + # ToyCh4ShapeInferenceInterfaceIncGen + # ToyCh4CombineIncGen + ) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +target_link_libraries(toyc-ch4 + PRIVATE + MLIRAnalysis + MLIRCastInterfaces + MLIRCallInterfaces + MLIRFunctionInterfaces + MLIRIR + MLIRParser + MLIRPass + MLIRSideEffectInterfaces + MLIRTransforms) + diff --git a/Ch4/include/CMakeLists.txt b/Ch4/include/CMakeLists.txt new file mode 100644 index 0000000..37c89d0 --- /dev/null +++ b/Ch4/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/Ch4/include/run.sh b/Ch4/include/run.sh new file mode 100644 index 0000000..b9d18af --- /dev/null +++ b/Ch4/include/run.sh @@ -0,0 +1,7 @@ +mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.h.inc +mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.cpp.inc +mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.h.inc +mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.cpp.inc + +mlir-tblgen-18 -gen-op-interface-decls -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.h.inc +mlir-tblgen-18 -gen-op-interface-defs -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.cpp.inc diff --git a/Ch4/include/toy/AST.h b/Ch4/include/toy/AST.h new file mode 100644 index 0000000..d2ba101 --- /dev/null +++ b/Ch4/include/toy/AST.h @@ -0,0 +1,246 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_AST_H +#define TOY_AST_H + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(std::move(location)) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double val; + +public: + NumberExprAST(Location loc, double val) + : ExprAST(Expr_Num, std::move(loc)), val(val) {} + + double getValue() { return val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, std::move(loc)), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + std::optional> expr; + +public: + ReturnExprAST(Location loc, std::optional> expr) + : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} + + std::optional getExpr() { + if (expr.has_value()) + return expr->get(); + return std::nullopt; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, std::move(loc)), callee(callee), + args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(std::move(location)), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() { return functions.begin(); } + auto end() { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // TOY_AST_H diff --git a/Ch4/include/toy/CMakeLists.txt b/Ch4/include/toy/CMakeLists.txt new file mode 100644 index 0000000..79a1f71 --- /dev/null +++ b/Ch4/include/toy/CMakeLists.txt @@ -0,0 +1,13 @@ +# Most dialects should use add_mlir_dialect(). See examples/standalone. +# set(LLVM_TARGET_DEFINITIONS Ops.td) +# mlir_tablegen(Ops.h.inc -gen-op-decls) +# mlir_tablegen(Ops.cpp.inc -gen-op-defs) +# mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +# mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +# add_public_tablegen_target(ToyCh4OpsIncGen) + +# Most dialects should use add_mlir_interfaces(). +# set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +# mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +# mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +# add_public_tablegen_target(ToyCh4ShapeInferenceInterfaceIncGen) diff --git a/Ch4/include/toy/Dialect.cpp.inc b/Ch4/include/toy/Dialect.cpp.inc new file mode 100644 index 0000000..8cbc772 --- /dev/null +++ b/Ch4/include/toy/Dialect.cpp.inc @@ -0,0 +1,23 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) +namespace mlir { +namespace toy { + +ToyDialect::ToyDialect(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get()) { + + initialize(); +} + +ToyDialect::~ToyDialect() = default; + +} // namespace toy +} // namespace mlir diff --git a/Ch4/include/toy/Dialect.h b/Ch4/include/toy/Dialect.h new file mode 100644 index 0000000..5db325e --- /dev/null +++ b/Ch4/include/toy/Dialect.h @@ -0,0 +1,36 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See docs/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "toy/ShapeInferenceInterface.h" + +/// Include the auto-generated header file containing the declaration of the toy +/// dialect. +#include "toy/Dialect.h.inc" + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/Ch4/include/toy/Dialect.h.inc b/Ch4/include/toy/Dialect.h.inc new file mode 100644 index 0000000..f19d867 --- /dev/null +++ b/Ch4/include/toy/Dialect.h.inc @@ -0,0 +1,26 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +namespace mlir { +namespace toy { + +class ToyDialect : public ::mlir::Dialect { + explicit ToyDialect(::mlir::MLIRContext *context); + + void initialize(); + friend class ::mlir::MLIRContext; +public: + ~ToyDialect() override; + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("toy"); + } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) diff --git a/Ch4/include/toy/Lexer.h b/Ch4/include/toy/Lexer.h new file mode 100644 index 0000000..3c59cd9 --- /dev/null +++ b/Ch4/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_LEXER_H +#define TOY_LEXER_H + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // TOY_LEXER_H diff --git a/Ch4/include/toy/MLIRGen.h b/Ch4/include/toy/MLIRGen.h new file mode 100644 index 0000000..fe9dbe5 --- /dev/null +++ b/Ch4/include/toy/MLIRGen.h @@ -0,0 +1,35 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_MLIRGEN_H +#define TOY_MLIRGEN_H + +#include + +namespace mlir { +class MLIRContext; +template +class OwningOpRef; +class ModuleOp; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // TOY_MLIRGEN_H diff --git a/Ch4/include/toy/Ops.cpp.inc b/Ch4/include/toy/Ops.cpp.inc new file mode 100644 index 0000000..c1a6abc --- /dev/null +++ b/Ch4/include/toy/Ops.cpp.inc @@ -0,0 +1,2242 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifdef GET_OP_LIST +#undef GET_OP_LIST + +::mlir::toy::AddOp, +::mlir::toy::CastOp, +::mlir::toy::ConstantOp, +::mlir::toy::FuncOp, +::mlir::toy::GenericCallOp, +::mlir::toy::MulOp, +::mlir::toy::PrintOp, +::mlir::toy::ReshapeOp, +::mlir::toy::ReturnOp, +::mlir::toy::TransposeOp +#endif // GET_OP_LIST + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be variadic of tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::RankedTensorType>(type))) && ((::llvm::cast<::mlir::ShapedType>(type).hasStaticShape()))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be statically shaped tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::DenseFPElementsAttr>(attr) &&::llvm::cast<::mlir::DenseElementsAttr>(attr).getType().getElementType().isF64()))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: 64-bit float elements attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops0(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::StringAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: string attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops1(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::TypeAttr>(attr))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: type attribute of function type"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops2(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::ArrayAttr>(attr))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(attr), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: Array of dictionary attributes"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops3(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: flat symbol reference attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops4(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_region_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, + unsigned regionIndex) { + if (!((true))) { + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: any region"; + } + return ::mlir::success(); +} +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.add", odsAttrs.getContext()); +} + +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(AddOp op) : AddOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair AddOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr AddOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +AddOpAdaptor::AddOpAdaptor(AddOp op) : AddOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult AddOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair AddOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range AddOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &AddOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &AddOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair AddOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range AddOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void AddOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult AddOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult AddOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void AddOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::CastOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +CastOpGenericAdaptorBase::CastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.cast", odsAttrs.getContext()); +} + +CastOpGenericAdaptorBase::CastOpGenericAdaptorBase(CastOp op) : CastOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair CastOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr CastOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +CastOpAdaptor::CastOpAdaptor(CastOp op) : CastOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult CastOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair CastOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range CastOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> CastOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &CastOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair CastOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range CastOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> CastOp::getOutput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(0).begin()); +} + +void CastOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(output); +} + +void CastOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void CastOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult CastOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult CastOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult CastOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + outputRawTypes[0] = type; + } + result.addTypes(outputTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void CastOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + { + auto type = getOutput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +void CastOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::CastOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.constant", odsAttrs.getContext()); +} + +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(ConstantOp op) : ConstantOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ConstantOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ConstantOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValueAttr() { + auto attr = ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); + return attr; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValue() { + auto attr = getValueAttr(); + return attr; +} + +} // namespace detail +ConstantOpAdaptor::ConstantOpAdaptor(ConstantOp op) : ConstantOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ConstantOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitError(loc, "'toy.constant' op ""requires attribute 'value'"); + + if (tblgen_value && !((::llvm::isa<::mlir::DenseFPElementsAttr>(tblgen_value) &&::llvm::cast<::mlir::DenseElementsAttr>(tblgen_value).getType().getElementType().isF64()))) + return emitError(loc, "'toy.constant' op ""attribute 'value' failed to satisfy constraint: 64-bit float elements attribute"); + return ::mlir::success(); +} + +std::pair ConstantOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ConstantOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair ConstantOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ConstantOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult ConstantOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.value; + auto attr = dict.get("value"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for value in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `value` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute ConstantOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.value; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("value", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code ConstantOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.value.getAsOpaquePointer())); +} + +std::optional ConstantOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "value") + return prop.value; + return std::nullopt; +} + +void ConstantOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "value") { + prop.value = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void ConstantOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.value) attrs.append("value", prop.value); +} + +::mlir::LogicalResult ConstantOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getValueAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(attr, "value", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.value))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ConstantOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.value); +} + +::mlir::DenseElementsAttr ConstantOp::getValueAttr() { + return ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); +} + +::mlir::DenseElementsAttr ConstantOp::getValue() { + auto attr = getValueAttr(); + return attr; +} + +void ConstantOp::setValueAttr(::mlir::DenseElementsAttr attr) { + (*this)->setAttr(getValueAttrName(), attr); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value) { + build(odsBuilder, odsState, value.getType(), value); + +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + odsState.addTypes(resultType0); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ConstantOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 0u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ConstantOp::verifyInvariantsImpl() { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitOpError("requires attribute 'value'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(*this, tblgen_value, "value"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +void ConstantOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.func", odsAttrs.getContext()); +} + +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(FuncOp op) : FuncOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair FuncOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr FuncOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::StringAttr FuncOpGenericAdaptorBase::getSymNameAttr() { + auto attr = ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); + return attr; +} + +::llvm::StringRef FuncOpGenericAdaptorBase::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOpGenericAdaptorBase::getFunctionTypeAttr() { + auto attr = ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); + return attr; +} + +::mlir::FunctionType FuncOpGenericAdaptorBase::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getArgAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getResAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::Region &FuncOpGenericAdaptorBase::getBody() { + return *odsRegions[0]; +} + +::mlir::RegionRange FuncOpGenericAdaptorBase::getRegions() { + return odsRegions; +} + +} // namespace detail +FuncOpAdaptor::FuncOpAdaptor(FuncOp op) : FuncOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult FuncOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitError(loc, "'toy.func' op ""requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitError(loc, "'toy.func' op ""requires attribute 'sym_name'"); + + if (tblgen_sym_name && !((::llvm::isa<::mlir::StringAttr>(tblgen_sym_name)))) + return emitError(loc, "'toy.func' op ""attribute 'sym_name' failed to satisfy constraint: string attribute"); + + if (tblgen_function_type && !(((::llvm::isa<::mlir::TypeAttr>(tblgen_function_type))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))))) + return emitError(loc, "'toy.func' op ""attribute 'function_type' failed to satisfy constraint: type attribute of function type"); + + if (tblgen_arg_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_arg_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_arg_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'arg_attrs' failed to satisfy constraint: Array of dictionary attributes"); + + if (tblgen_res_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_res_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_res_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'res_attrs' failed to satisfy constraint: Array of dictionary attributes"); + return ::mlir::success(); +} + +std::pair FuncOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range FuncOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair FuncOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range FuncOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Region &FuncOp::getBody() { + return (*this)->getRegion(0); +} + +::mlir::LogicalResult FuncOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.arg_attrs; + auto attr = dict.get("arg_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for arg_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `arg_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.function_type; + auto attr = dict.get("function_type"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for function_type in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `function_type` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.res_attrs; + auto attr = dict.get("res_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for res_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `res_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.sym_name; + auto attr = dict.get("sym_name"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for sym_name in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `sym_name` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute FuncOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.arg_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("arg_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.function_type; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("function_type", + propStorage)); + } + + { + const auto &propStorage = prop.res_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("res_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.sym_name; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("sym_name", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code FuncOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.arg_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.function_type.getAsOpaquePointer()), + llvm::hash_value(prop.res_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.sym_name.getAsOpaquePointer())); +} + +std::optional FuncOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "arg_attrs") + return prop.arg_attrs; + + if (name == "function_type") + return prop.function_type; + + if (name == "res_attrs") + return prop.res_attrs; + + if (name == "sym_name") + return prop.sym_name; + return std::nullopt; +} + +void FuncOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "arg_attrs") { + prop.arg_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "function_type") { + prop.function_type = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "res_attrs") { + prop.res_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "sym_name") { + prop.sym_name = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void FuncOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.arg_attrs) attrs.append("arg_attrs", prop.arg_attrs); + + if (prop.function_type) attrs.append("function_type", prop.function_type); + + if (prop.res_attrs) attrs.append("res_attrs", prop.res_attrs); + + if (prop.sym_name) attrs.append("sym_name", prop.sym_name); +} + +::mlir::LogicalResult FuncOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getArgAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "arg_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getFunctionTypeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(attr, "function_type", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getResAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "res_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getSymNameAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(attr, "sym_name", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readOptionalAttribute(prop.arg_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.function_type))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readOptionalAttribute(prop.res_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.sym_name))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void FuncOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + + writer.writeOptionalAttribute(prop.arg_attrs); + writer.writeAttribute(prop.function_type); + + writer.writeOptionalAttribute(prop.res_attrs); + writer.writeAttribute(prop.sym_name); +} + +::mlir::StringAttr FuncOp::getSymNameAttr() { + return ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); +} + +::llvm::StringRef FuncOp::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOp::getFunctionTypeAttr() { + return ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); +} + +::mlir::FunctionType FuncOp::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOp::getArgAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOp::getResAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +void FuncOp::setSymNameAttr(::mlir::StringAttr attr) { + (*this)->setAttr(getSymNameAttrName(), attr); +} + +void FuncOp::setSymName(::llvm::StringRef attrValue) { + (*this)->setAttr(getSymNameAttrName(), ::mlir::Builder((*this)->getContext()).getStringAttr(attrValue)); +} + +void FuncOp::setFunctionTypeAttr(::mlir::TypeAttr attr) { + (*this)->setAttr(getFunctionTypeAttrName(), attr); +} + +void FuncOp::setFunctionType(::mlir::FunctionType attrValue) { + (*this)->setAttr(getFunctionTypeAttrName(), ::mlir::TypeAttr::get(attrValue)); +} + +void FuncOp::setArgAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getArgAttrsAttrName(), attr); +} + +void FuncOp::setResAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getResAttrsAttrName(), attr); +} + +::mlir::Attribute FuncOp::removeArgAttrsAttr() { + auto &attr = getProperties().arg_attrs; + attr = {}; + return attr; +} + +::mlir::Attribute FuncOp::removeResAttrsAttr() { + auto &attr = getProperties().res_attrs; + attr = {}; + return attr; +} + +::mlir::LogicalResult FuncOp::verifyInvariantsImpl() { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitOpError("requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitOpError("requires attribute 'sym_name'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(*this, tblgen_sym_name, "sym_name"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(*this, tblgen_function_type, "function_type"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_arg_attrs, "arg_attrs"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_res_attrs, "res_attrs"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + + for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(0))) + if (::mlir::failed(__mlir_ods_local_region_constraint_Ops0(*this, region, "body", index++))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.generic_call", odsAttrs.getContext()); +} + +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(GenericCallOp op) : GenericCallOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair GenericCallOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr GenericCallOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::FlatSymbolRefAttr GenericCallOpGenericAdaptorBase::getCalleeAttr() { + auto attr = ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); + return attr; +} + +::llvm::StringRef GenericCallOpGenericAdaptorBase::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +} // namespace detail +GenericCallOpAdaptor::GenericCallOpAdaptor(GenericCallOp op) : GenericCallOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult GenericCallOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitError(loc, "'toy.generic_call' op ""requires attribute 'callee'"); + + if (tblgen_callee && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(tblgen_callee)))) + return emitError(loc, "'toy.generic_call' op ""attribute 'callee' failed to satisfy constraint: flat symbol reference attribute"); + return ::mlir::success(); +} + +std::pair GenericCallOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range GenericCallOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range GenericCallOp::getInputs() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange GenericCallOp::getInputsMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair GenericCallOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range GenericCallOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult GenericCallOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.callee; + auto attr = dict.get("callee"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for callee in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `callee` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute GenericCallOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.callee; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("callee", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code GenericCallOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.callee.getAsOpaquePointer())); +} + +std::optional GenericCallOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "callee") + return prop.callee; + return std::nullopt; +} + +void GenericCallOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "callee") { + prop.callee = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void GenericCallOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.callee) attrs.append("callee", prop.callee); +} + +::mlir::LogicalResult GenericCallOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getCalleeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(attr, "callee", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.callee))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.callee); +} + +::mlir::FlatSymbolRefAttr GenericCallOp::getCalleeAttr() { + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); +} + +::llvm::StringRef GenericCallOp::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +void GenericCallOp::setCalleeAttr(::mlir::FlatSymbolRefAttr attr) { + (*this)->setAttr(getCalleeAttrName(), attr); +} + +void GenericCallOp::setCallee(::llvm::StringRef attrValue) { + (*this)->setAttr(getCalleeAttrName(), ::mlir::SymbolRefAttr::get(::mlir::Builder((*this)->getContext()).getContext(), attrValue)); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariantsImpl() { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitOpError("requires attribute 'callee'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(*this, tblgen_callee, "callee"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult GenericCallOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::FlatSymbolRefAttr calleeAttr; + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputsOperands; + ::llvm::SMLoc inputsOperandsLoc; + (void)inputsOperandsLoc; + ::llvm::ArrayRef<::mlir::Type> inputsTypes; + ::llvm::ArrayRef<::mlir::Type> allResultTypes; + + if (parser.parseCustomAttributeWithFallback(calleeAttr, parser.getBuilder().getType<::mlir::NoneType>())) { + return ::mlir::failure(); + } + if (calleeAttr) result.getOrAddProperties().callee = calleeAttr; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands)) + return ::mlir::failure(); + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + ::mlir::FunctionType inputs__allResult_functionType; + if (parser.parseType(inputs__allResult_functionType)) + return ::mlir::failure(); + inputsTypes = inputs__allResult_functionType.getInputs(); + allResultTypes = inputs__allResult_functionType.getResults(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter.printAttributeWithoutType(getCalleeAttr()); + _odsPrinter << "("; + _odsPrinter << getInputs(); + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + elidedAttrs.push_back("callee"); + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter.printFunctionalType(getInputs().getTypes(), getOperation()->getResultTypes()); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.mul", odsAttrs.getContext()); +} + +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(MulOp op) : MulOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair MulOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr MulOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +MulOpAdaptor::MulOpAdaptor(MulOp op) : MulOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult MulOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair MulOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range MulOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &MulOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &MulOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair MulOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range MulOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void MulOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult MulOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult MulOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void MulOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.print", odsAttrs.getContext()); +} + +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(PrintOp op) : PrintOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair PrintOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr PrintOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +PrintOpAdaptor::PrintOpAdaptor(PrintOp op) : PrintOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult PrintOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair PrintOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range PrintOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> PrintOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &PrintOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair PrintOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range PrintOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input) { + odsState.addOperands(input); +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 0u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void PrintOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult PrintOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult PrintOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult PrintOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void PrintOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.reshape", odsAttrs.getContext()); +} + +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(ReshapeOp op) : ReshapeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReshapeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ReshapeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReshapeOpAdaptor::ReshapeOpAdaptor(ReshapeOp op) : ReshapeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReshapeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReshapeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ReshapeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> ReshapeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &ReshapeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair ReshapeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReshapeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ReshapeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops2(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult ReshapeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReshapeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void ReshapeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.return", odsAttrs.getContext()); +} + +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(ReturnOp op) : ReturnOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReturnOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr ReturnOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReturnOpAdaptor::ReturnOpAdaptor(ReturnOp op) : ReturnOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReturnOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReturnOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range ReturnOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range ReturnOp::getInput() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange ReturnOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair ReturnOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReturnOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState) { + build(odsBuilder, odsState, std::nullopt); +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input) { + odsState.addOperands(input); +} + +void ReturnOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReturnOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReturnOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult ReturnOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputOperands; + ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::llvm::SmallVector<::mlir::Type, 1> inputTypes; + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputOperands)) + return ::mlir::failure(); + if (!inputOperands.empty()) { + if (parser.parseColon()) + return ::mlir::failure(); + + if (parser.parseTypeList(inputTypes)) + return ::mlir::failure(); + } + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReturnOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + if (!getInput().empty()) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter << getInput().getTypes(); + } + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); +} + +void ReturnOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.transpose", odsAttrs.getContext()); +} + +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(TransposeOp op) : TransposeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair TransposeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr TransposeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +TransposeOpAdaptor::TransposeOpAdaptor(TransposeOp op) : TransposeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult TransposeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair TransposeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range TransposeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> TransposeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &TransposeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair TransposeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range TransposeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void TransposeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult TransposeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult TransposeOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult TransposeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void TransposeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void TransposeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch4/include/toy/Ops.h.inc b/Ch4/include/toy/Ops.h.inc new file mode 100644 index 0000000..ad68406 --- /dev/null +++ b/Ch4/include/toy/Ops.h.inc @@ -0,0 +1,1360 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES) +#undef GET_OP_FWD_DEFINES +namespace mlir { +namespace toy { +class AddOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class CastOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ConstantOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class FuncOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class GenericCallOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class MulOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class PrintOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReshapeOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReturnOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class TransposeOp; +} // namespace toy +} // namespace mlir +#endif + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class AddOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + AddOpGenericAdaptorBase(AddOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class AddOpGenericAdaptor : public detail::AddOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::AddOpGenericAdaptorBase; +public: + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + AddOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class AddOpAdaptor : public AddOpGenericAdaptor<::mlir::ValueRange> { +public: + using AddOpGenericAdaptor::AddOpGenericAdaptor; + AddOpAdaptor(AddOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class AddOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = AddOpAdaptor; + template + using GenericAdaptor = AddOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.add"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::CastOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class CastOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + CastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + CastOpGenericAdaptorBase(CastOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class CastOpGenericAdaptor : public detail::CastOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::CastOpGenericAdaptorBase; +public: + CastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + CastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : CastOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + CastOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class CastOpAdaptor : public CastOpGenericAdaptor<::mlir::ValueRange> { +public: + using CastOpGenericAdaptor::CastOpGenericAdaptor; + CastOpAdaptor(CastOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class CastOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::CastOpInterface::Trait, ShapeInference::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultShape> { +public: + using Op::Op; + using Op::print; + using Adaptor = CastOpAdaptor; + template + using GenericAdaptor = CastOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.cast"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getOutput(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static bool areCastCompatible(::mlir::TypeRange inputs, ::mlir::TypeRange outputs); + void inferShapes(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::CastOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ConstantOpGenericAdaptorBase { +public: + struct Properties { + using valueTy = ::mlir::DenseElementsAttr; + valueTy value; + + auto getValue() { + auto &propStorage = this->value; + return ::llvm::cast<::mlir::DenseElementsAttr>(propStorage); + } + void setValue(const ::mlir::DenseElementsAttr &propValue) { + this->value = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.value == this->value && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + ConstantOpGenericAdaptorBase(ConstantOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); +}; +} // namespace detail +template +class ConstantOpGenericAdaptor : public detail::ConstantOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ConstantOpGenericAdaptorBase; +public: + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ConstantOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + ConstantOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ConstantOpAdaptor : public ConstantOpGenericAdaptor<::mlir::ValueRange> { +public: + using ConstantOpGenericAdaptor::ConstantOpGenericAdaptor; + ConstantOpAdaptor(ConstantOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ConstantOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ConstantOpAdaptor; + template + using GenericAdaptor = ConstantOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getValueAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.constant"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); + void setValueAttr(::mlir::DenseElementsAttr attr); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class FuncOpGenericAdaptorBase { +public: + struct Properties { + using arg_attrsTy = ::mlir::ArrayAttr; + arg_attrsTy arg_attrs; + + auto getArgAttrs() { + auto &propStorage = this->arg_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setArgAttrs(const ::mlir::ArrayAttr &propValue) { + this->arg_attrs = propValue; + } + using function_typeTy = ::mlir::TypeAttr; + function_typeTy function_type; + + auto getFunctionType() { + auto &propStorage = this->function_type; + return ::llvm::cast<::mlir::TypeAttr>(propStorage); + } + void setFunctionType(const ::mlir::TypeAttr &propValue) { + this->function_type = propValue; + } + using res_attrsTy = ::mlir::ArrayAttr; + res_attrsTy res_attrs; + + auto getResAttrs() { + auto &propStorage = this->res_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setResAttrs(const ::mlir::ArrayAttr &propValue) { + this->res_attrs = propValue; + } + using sym_nameTy = ::mlir::StringAttr; + sym_nameTy sym_name; + + auto getSymName() { + auto &propStorage = this->sym_name; + return ::llvm::cast<::mlir::StringAttr>(propStorage); + } + void setSymName(const ::mlir::StringAttr &propValue) { + this->sym_name = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.arg_attrs == this->arg_attrs && + rhs.function_type == this->function_type && + rhs.res_attrs == this->res_attrs && + rhs.sym_name == this->sym_name && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + FuncOpGenericAdaptorBase(FuncOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + ::mlir::Region &getBody(); + ::mlir::RegionRange getRegions(); +}; +} // namespace detail +template +class FuncOpGenericAdaptor : public detail::FuncOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::FuncOpGenericAdaptorBase; +public: + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : FuncOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + FuncOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class FuncOpAdaptor : public FuncOpGenericAdaptor<::mlir::ValueRange> { +public: + using FuncOpGenericAdaptor::FuncOpGenericAdaptor; + FuncOpAdaptor(FuncOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class FuncOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = FuncOpAdaptor; + template + using GenericAdaptor = FuncOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("arg_attrs"), ::llvm::StringRef("function_type"), ::llvm::StringRef("res_attrs"), ::llvm::StringRef("sym_name")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getArgAttrsAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getArgAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + ::mlir::StringAttr getFunctionTypeAttrName() { + return getAttributeNameForIndex(1); + } + + static ::mlir::StringAttr getFunctionTypeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 1); + } + + ::mlir::StringAttr getResAttrsAttrName() { + return getAttributeNameForIndex(2); + } + + static ::mlir::StringAttr getResAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 2); + } + + ::mlir::StringAttr getSymNameAttrName() { + return getAttributeNameForIndex(3); + } + + static ::mlir::StringAttr getSymNameAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 3); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.func"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::Region &getBody(); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + void setSymNameAttr(::mlir::StringAttr attr); + void setSymName(::llvm::StringRef attrValue); + void setFunctionTypeAttr(::mlir::TypeAttr attr); + void setFunctionType(::mlir::FunctionType attrValue); + void setArgAttrsAttr(::mlir::ArrayAttr attr); + void setResAttrsAttr(::mlir::ArrayAttr attr); + ::mlir::Attribute removeArgAttrsAttr(); + ::mlir::Attribute removeResAttrsAttr(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef name, FunctionType type, ArrayRef attrs = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 4 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class GenericCallOpGenericAdaptorBase { +public: + struct Properties { + using calleeTy = ::mlir::FlatSymbolRefAttr; + calleeTy callee; + + auto getCallee() { + auto &propStorage = this->callee; + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(propStorage); + } + void setCallee(const ::mlir::FlatSymbolRefAttr &propValue) { + this->callee = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.callee == this->callee && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + GenericCallOpGenericAdaptorBase(GenericCallOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); +}; +} // namespace detail +template +class GenericCallOpGenericAdaptor : public detail::GenericCallOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::GenericCallOpGenericAdaptorBase; +public: + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : GenericCallOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + GenericCallOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInputs() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class GenericCallOpAdaptor : public GenericCallOpGenericAdaptor<::mlir::ValueRange> { +public: + using GenericCallOpGenericAdaptor::GenericCallOpGenericAdaptor; + GenericCallOpAdaptor(GenericCallOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class GenericCallOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::VariadicOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::CallOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = GenericCallOpAdaptor; + template + using GenericAdaptor = GenericCallOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("callee")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getCalleeAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getCalleeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.generic_call"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInputs(); + ::mlir::MutableOperandRange getInputsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); + void setCalleeAttr(::mlir::FlatSymbolRefAttr attr); + void setCallee(::llvm::StringRef attrValue); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef callee, ArrayRef arguments); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::CallInterfaceCallable getCallableForCallee(); + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee); + ::mlir::Operation::operand_range getArgOperands(); + ::mlir::MutableOperandRange getArgOperandsMutable(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class MulOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + MulOpGenericAdaptorBase(MulOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class MulOpGenericAdaptor : public detail::MulOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::MulOpGenericAdaptorBase; +public: + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : MulOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + MulOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class MulOpAdaptor : public MulOpGenericAdaptor<::mlir::ValueRange> { +public: + using MulOpGenericAdaptor::MulOpGenericAdaptor; + MulOpAdaptor(MulOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class MulOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = MulOpAdaptor; + template + using GenericAdaptor = MulOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.mul"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class PrintOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + PrintOpGenericAdaptorBase(PrintOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class PrintOpGenericAdaptor : public detail::PrintOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::PrintOpGenericAdaptorBase; +public: + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : PrintOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + PrintOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class PrintOpAdaptor : public PrintOpGenericAdaptor<::mlir::ValueRange> { +public: + using PrintOpGenericAdaptor::PrintOpGenericAdaptor; + PrintOpAdaptor(PrintOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class PrintOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = PrintOpAdaptor; + template + using GenericAdaptor = PrintOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.print"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReshapeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReshapeOpGenericAdaptorBase(ReshapeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReshapeOpGenericAdaptor : public detail::ReshapeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReshapeOpGenericAdaptorBase; +public: + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReshapeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReshapeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReshapeOpAdaptor : public ReshapeOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReshapeOpGenericAdaptor::ReshapeOpGenericAdaptor; + ReshapeOpAdaptor(ReshapeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReshapeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReshapeOpAdaptor; + template + using GenericAdaptor = ReshapeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.reshape"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReturnOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReturnOpGenericAdaptorBase(ReturnOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReturnOpGenericAdaptor : public detail::ReturnOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReturnOpGenericAdaptorBase; +public: + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReturnOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReturnOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInput() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReturnOpAdaptor : public ReturnOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReturnOpGenericAdaptor::ReturnOpGenericAdaptor; + ReturnOpAdaptor(ReturnOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReturnOp : public ::mlir::Op::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::IsTerminator> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReturnOpAdaptor; + template + using GenericAdaptor = ReturnOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.return"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInput(); + ::mlir::MutableOperandRange getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: + bool hasOperand() { return getNumOperands() != 0; } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class TransposeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + TransposeOpGenericAdaptorBase(TransposeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class TransposeOpGenericAdaptor : public detail::TransposeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::TransposeOpGenericAdaptorBase; +public: + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : TransposeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + TransposeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class TransposeOpAdaptor : public TransposeOpGenericAdaptor<::mlir::ValueRange> { +public: + using TransposeOpGenericAdaptor::TransposeOpGenericAdaptor; + TransposeOpAdaptor(TransposeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class TransposeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = TransposeOpAdaptor; + template + using GenericAdaptor = TransposeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.transpose"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + void inferShapes(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch4/include/toy/Ops.td b/Ch4/include/toy/Ops.td new file mode 100644 index 0000000..075fd1a --- /dev/null +++ b/Ch4/include/toy/Ops.td @@ -0,0 +1,372 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "toy/ShapeInferenceInterface.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "::mlir::toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'Pure' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [Pure]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> + : tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<(ins "DenseElementsAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<(ins "double":$value)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : Toy_Op<"add", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Pure, + SameOperandsAndResultShape + ]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types must + both be tensor types with the same element type. If both are ranked, then + shape is required to match. The operation is invalid if converting to a + mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : Toy_Op<"func", [ + FunctionOpInterface, IsolatedFromAbove + ]> { + let summary = "user defined function operation"; + let description = [{ + The "toy.func" operation represents a user defined function. These are + callable SSA-region operations that contain toy computations. + + Example: + + ```mlir + toy.func @main() { + %0 = toy.constant dense<5.500000e+00> : tensor + %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> + toy.print %1 : tensor<2x2xf64> + toy.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } + }]; + + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = toy.generic_call @my_func(%1, %3) + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Specialize assembly printing and parsing using a declarative format. + let assemblyFormat = [{ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + ]; +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : Toy_Op<"mul", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); + + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def ReshapeOp : Toy_Op<"reshape", [Pure]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, + Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + toy.func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // The return operation only emits the input in the format if it is present. + let assemblyFormat = "($input^ `:` type($input))? attr-dict "; + + // Allow building a ReturnOp with no return operand. + let builders = [ + OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + ]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def TransposeOp : Toy_Op<"transpose", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +#endif // TOY_OPS diff --git a/Ch4/include/toy/Parser.h b/Ch4/include/toy/Parser.h new file mode 100644 index 0000000..1f20616 --- /dev/null +++ b/Ch4/include/toy/Parser.h @@ -0,0 +1,489 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PARSER_H +#define TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + std::optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + + if (lexer.getCurToken() != tok_def) + return parseError("def", "in prototype"); + lexer.consume(tok_def); + + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName(lexer.getId()); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name(lexer.getId()); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError(")", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // TOY_PARSER_H diff --git a/Ch4/include/toy/Passes.h b/Ch4/include/toy/Passes.h new file mode 100644 index 0000000..0eafa08 --- /dev/null +++ b/Ch4/include/toy/Passes.h @@ -0,0 +1,26 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PASSES_H +#define TOY_PASSES_H + +#include + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr createShapeInferencePass(); +} // namespace toy +} // namespace mlir + +#endif // TOY_PASSES_H diff --git a/Ch4/include/toy/ShapeInferenceInterface.h b/Ch4/include/toy/ShapeInferenceInterface.h new file mode 100644 index 0000000..cfe5a87 --- /dev/null +++ b/Ch4/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,28 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // namespace toy +} // namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/Ch4/include/toy/ShapeInferenceInterface.td b/Ch4/include/toy/ShapeInferenceInterface.td new file mode 100644 index 0000000..2279015 --- /dev/null +++ b/Ch4/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,30 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/Ch4/include/toy/ShapeInferenceOpInterfaces.cpp.inc b/Ch4/include/toy/ShapeInferenceOpInterfaces.cpp.inc new file mode 100644 index 0000000..a481d2e --- /dev/null +++ b/Ch4/include/toy/ShapeInferenceOpInterfaces.cpp.inc @@ -0,0 +1,12 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Interface Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/// Infer and set the output shape for the current operation. +void ShapeInference::inferShapes() { + return getImpl()->inferShapes(getImpl(), getOperation()); + } diff --git a/Ch4/include/toy/ShapeInferenceOpInterfaces.h.inc b/Ch4/include/toy/ShapeInferenceOpInterfaces.h.inc new file mode 100644 index 0000000..bb24654 --- /dev/null +++ b/Ch4/include/toy/ShapeInferenceOpInterfaces.h.inc @@ -0,0 +1,61 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Interface Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +class ShapeInference; +namespace detail { +struct ShapeInferenceInterfaceTraits { + struct Concept { + /// The methods defined by the interface. + void (*inferShapes)(const Concept *impl, ::mlir::Operation *); + }; + template + class Model : public Concept { + public: + using Interface = ShapeInference; + Model() : Concept{inferShapes} {} + + static inline void inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val); + }; + template + class FallbackModel : public Concept { + public: + using Interface = ShapeInference; + FallbackModel() : Concept{inferShapes} {} + + static inline void inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val); + }; + template + class ExternalModel : public FallbackModel { + public: + using ConcreteEntity = ConcreteOp; + }; +};template +struct ShapeInferenceTrait; + +} // namespace detail +class ShapeInference : public ::mlir::OpInterface { +public: + using ::mlir::OpInterface::OpInterface; + template + struct Trait : public detail::ShapeInferenceTrait {}; + /// Infer and set the output shape for the current operation. + void inferShapes(); +}; +namespace detail { + template + struct ShapeInferenceTrait : public ::mlir::OpInterface::Trait { + }; +}// namespace detail +template +void detail::ShapeInferenceInterfaceTraits::Model::inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val) { + return (llvm::cast(tablegen_opaque_val)).inferShapes(); +} +template +void detail::ShapeInferenceInterfaceTraits::FallbackModel::inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val) { + return static_cast(impl)->inferShapes(tablegen_opaque_val); +} diff --git a/Ch4/include/toy/run.sh b/Ch4/include/toy/run.sh new file mode 100644 index 0000000..b9d18af --- /dev/null +++ b/Ch4/include/toy/run.sh @@ -0,0 +1,7 @@ +mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.h.inc +mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.cpp.inc +mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.h.inc +mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.cpp.inc + +mlir-tblgen-18 -gen-op-interface-decls -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.h.inc +mlir-tblgen-18 -gen-op-interface-defs -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.cpp.inc diff --git a/Ch4/mlir/Dialect.cpp b/Ch4/mlir/Dialect.cpp new file mode 100644 index 0000000..86a0e1d --- /dev/null +++ b/Ch4/mlir/Dialect.cpp @@ -0,0 +1,444 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::toy; + +#include "toy/Dialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } + + // All functions within toy can be inlined. + bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void ToyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// A generalized parser for binary operations. This parses the different forms +/// of 'printBinaryOp' below. +static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + SmallVector operands; + SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (FunctionType funcType = llvm::dyn_cast(type)) { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// A generalized printer for binary operations. It prints in two different +/// forms depending on if all of the types match. +static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })) { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder.getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// The 'OpAsmPrinter' class provides a collection of methods for parsing +/// various punctuation, as well as attributes, operands, types, etc. Each of +/// these methods returns a `ParseResult`. This class is a wrapper around +/// `LogicalResult` that can be converted to a boolean `true` value on failure, +/// or `false` on success. This allows for easily chaining together a set of +/// parser rules. These rules are used to populate an `mlir::OperationState` +/// similarly to the `build` methods described above. +mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::DenseElementsAttr value; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(value, "value", result.attributes)) + return failure(); + + result.addTypes(value.getType()); + return success(); +} + +/// The 'OpAsmPrinter' class is a stream that allows for formatting +/// strings, attributes, operands, types, etc. +void ConstantOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); + printer << getValue(); +} + +/// Verifier for the constant operation. This corresponds to the +/// `let hasVerifier = 1` in the op definition. +mlir::LogicalResult ConstantOp::verify() { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(getResult().getType()); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = llvm::cast(getValue().getType()); + if (attrType.getRank() != resultType.getRank()) { + return emitOpError("return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } + +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = llvm::dyn_cast(inputs.front()); + TensorType output = llvm::dyn_cast(outputs.front()); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs) { + // FunctionOpInterface provides a convenient `build` method that will populate + // the state of our FuncOp, and create an entry block. + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + // Dispatch to the FunctionOpInterface provided utility method that parses the + // function operation. + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(mlir::OpAsmPrinter &p) { + // Dispatch to the FunctionOpInterface provided utility method that prints the + // function operation. + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return (*this)->getAttrOfType("callee"); +} + +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } + +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult ReturnOp::verify() { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast((*this)->getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (getNumOperands() > 1) + return emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError() << "does not return the same number of values (" + << getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!hasOperand()) + return mlir::success(); + + auto inputType = *operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || llvm::isa(inputType) || + llvm::isa(resultType)) + return mlir::success(); + + return emitError() << "type of return operand (" << inputType + << ") doesn't match function result type (" << resultType + << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = llvm::cast(getOperand().getType()); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +mlir::LogicalResult TransposeOp::verify() { + auto inputType = llvm::dyn_cast(getOperand().getType()); + auto resultType = llvm::dyn_cast(getType()); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/Ch4/mlir/MLIRGen.cpp b/Ch4/mlir/MLIRGen.cpp new file mode 100644 index 0000000..6c5474a --- /dev/null +++ b/Ch4/mlir/MLIRGen.cpp @@ -0,0 +1,461 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &f : moduleAST) + mlirGen(f); + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(const Location &loc) { + return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::toy::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector argTypes(proto.getArgs().size(), + getType(VarType{})); + auto funcType = builder.getFunctionType(argTypes, std::nullopt); + return builder.create(location, proto.getName(), + funcType); + } + + /// Emit a new function and add it to the MLIR module. + mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope varScope(symbolTable); + + // Create an MLIR function for the given prototype. + builder.setInsertionPointToEnd(theModule.getBody()); + mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); + if (!function) + return nullptr; + + // Let's start the body of the function now! + mlir::Block &entryBlock = function.front(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto nameValue : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(nameValue)->getName(), + std::get<1>(nameValue)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType( + function.getFunctionType().getInputs(), getType(VarType{}))); + } + + // If this function isn't main, then set the visibility to private. + if (funcAST.getProto()->getName() != "main") + function.setPrivate(); + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().has_value()) { + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, + expr ? ArrayRef(expr) : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builtin calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to + // user-defined functions are mapped to a custom call that takes the callee + // name as an attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto *init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope varScope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/Ch4/mlir/ShapeInferencePass.cpp b/Ch4/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000..a9e995e --- /dev/null +++ b/Ch4/mlir/ShapeInferencePass.cpp @@ -0,0 +1,122 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a pass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +struct ShapeInferencePass + : public mlir::PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + + void runOnOperation() override { + auto f = getOperation(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, allOperandsInferred); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has all of its + /// operands inferred. + static bool allOperandsInferred(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return llvm::isa(operandType); + }); + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !llvm::isa(resultType); + }); + } +}; +} // namespace + +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/Ch4/mlir/ToyCombine.cpp b/Ch4/mlir/ToyCombine.cpp new file mode 100644 index 0000000..3ce35c8 --- /dev/null +++ b/Ch4/mlir/ToyCombine.cpp @@ -0,0 +1,69 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Dialect.h" +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // namespace + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> x +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); + + // Input defined by another transpose? If not, no match. + if (!transposeInputOp) + return failure(); + + // Otherwise, we have a redundant transpose. Use the rewriter. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); + return success(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} diff --git a/Ch4/mlir/ToyCombine.inc b/Ch4/mlir/ToyCombine.inc new file mode 100644 index 0000000..61c6203 --- /dev/null +++ b/Ch4/mlir/ToyCombine.inc @@ -0,0 +1,176 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Rewriters *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: ToyCombine.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/* Generated from: + ToyCombine.td:46 +*/ +struct FoldConstantReshapeOptPattern : public ::mlir::RewritePattern { + FoldConstantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.constant"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::DenseElementsAttr arg; + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ConstantOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ConstantOp type"; + }); + } + { + auto tblgen_attr = op1->getAttrOfType<::mlir::DenseElementsAttr>("value");(void)tblgen_attr; + if (!(tblgen_attr)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "expected op 'toy.constant' to have attribute 'value' of type '::mlir::DenseElementsAttr'"; + }); + } + arg = tblgen_attr; + } + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + auto nativeVar_0 = arg.reshape(::llvm::cast((*res.getODSResults(0).begin()).getType())); (void)nativeVar_0; + ::mlir::toy::ConstantOp tblgen_ConstantOp_1; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + if (auto tmpAttr = nativeVar_0) { + tblgen_attrs.emplace_back(rewriter.getStringAttr("value"), tmpAttr); + } + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ConstantOp_1 = rewriter.create<::mlir::toy::ConstantOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ConstantOp_1.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:59 +*/ +struct RedundantReshapeOptPattern : public ::mlir::RewritePattern { + RedundantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 1, context, {}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + arg = castedOp0.getODSOperands(0); + if (!(((*res.getODSResults(0).begin()).getType() == (*arg.begin()).getType()))){ + return rewriter.notifyMatchFailure(op0, [&](::mlir::Diagnostic &diag) { + diag << "entities 'res, arg' failed to satisfy constraint: ''"; + }); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ arg }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:33 +*/ +struct ReshapeReshapeOptPattern : public ::mlir::RewritePattern { + ReshapeReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.reshape"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ReshapeOp type"; + }); + } + arg = castedOp1.getODSOperands(0); + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + ::mlir::toy::ReshapeOp tblgen_ReshapeOp_0; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + tblgen_values.push_back((*arg.begin())); + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ReshapeOp_0 = rewriter.create<::mlir::toy::ReshapeOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ReshapeOp_0.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} diff --git a/Ch4/mlir/ToyCombine.td b/Ch4/mlir/ToyCombine.td new file mode 100644 index 0000000..11d7831 --- /dev/null +++ b/Ch4/mlir/ToyCombine.td @@ -0,0 +1,63 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "mlir/IR/PatternBase.td" +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : Constraint>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/Ch4/mlir/run.sh b/Ch4/mlir/run.sh new file mode 100644 index 0000000..f592fde --- /dev/null +++ b/Ch4/mlir/run.sh @@ -0,0 +1,2 @@ +mlir-tblgen-18 -gen-rewriters -I /usr/lib/llvm-18/include -I ../include ToyCombine.td > ToyCombine.inc + diff --git a/Ch4/parser/AST.cpp b/Ch4/parser/AST.cpp new file mode 100644 index 0000000..2546f2a --- /dev/null +++ b/Ch4/parser/AST.cpp @@ -0,0 +1,237 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template +static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + llvm::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto *num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + llvm::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + llvm::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().has_value()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + llvm::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; + indent(); + llvm::errs() << "Params: ["; + llvm::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/Ch4/toyc.cpp b/Ch4/toyc.cpp new file mode 100644 index 0000000..a1534ed --- /dev/null +++ b/Ch4/toyc.cpp @@ -0,0 +1,179 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} // namespace +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR }; +} // namespace +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningOpRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).ends_with(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return -1; + } + + // Parse the input mlir. + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int dumpMLIR() { + mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + + mlir::OwningOpRef module; + llvm::SourceMgr sourceMgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + if (int error = loadMLIR(sourceMgr, context, module)) + return error; + + if (enableOpt) { + mlir::PassManager pm(module.get()->getName()); + // Apply any generic pass manager command line options and run the pipeline. + if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) + return 4; + + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + if (mlir::failed(pm.run(*module))) + return 4; + } + + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + // Register any command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/Ch5/CMakeLists.txt b/Ch5/CMakeLists.txt new file mode 100644 index 0000000..43ce7ac --- /dev/null +++ b/Ch5/CMakeLists.txt @@ -0,0 +1,45 @@ +# For a better template to copy, see examples/standalone +include_directories(include) +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Support + ) + +# set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +# mlir_tablegen(ToyCombine.inc -gen-rewriters) +# add_public_tablegen_target(ToyCh5CombineIncGen) + +add_executable(toyc-ch5 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/LowerToAffineLoops.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + + # DEPENDS + # ToyCh5ShapeInferenceInterfaceIncGen + # ToyCh5OpsIncGen + # ToyCh5CombineIncGen + ) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) +target_link_libraries(toyc-ch5 + PRIVATE + ${dialect_libs} + ${extension_libs} + MLIRAnalysis + MLIRCallInterfaces + MLIRCastInterfaces + MLIRFunctionInterfaces + MLIRIR + MLIRParser + MLIRPass + MLIRSideEffectInterfaces + MLIRSupport + MLIRTransforms) diff --git a/Ch5/include/CMakeLists.txt b/Ch5/include/CMakeLists.txt new file mode 100644 index 0000000..37c89d0 --- /dev/null +++ b/Ch5/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/Ch5/include/run.sh b/Ch5/include/run.sh new file mode 100644 index 0000000..b9d18af --- /dev/null +++ b/Ch5/include/run.sh @@ -0,0 +1,7 @@ +mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.h.inc +mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.cpp.inc +mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.h.inc +mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.cpp.inc + +mlir-tblgen-18 -gen-op-interface-decls -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.h.inc +mlir-tblgen-18 -gen-op-interface-defs -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.cpp.inc diff --git a/Ch5/include/toy/AST.h b/Ch5/include/toy/AST.h new file mode 100644 index 0000000..d2ba101 --- /dev/null +++ b/Ch5/include/toy/AST.h @@ -0,0 +1,246 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_AST_H +#define TOY_AST_H + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(std::move(location)) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double val; + +public: + NumberExprAST(Location loc, double val) + : ExprAST(Expr_Num, std::move(loc)), val(val) {} + + double getValue() { return val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, std::move(loc)), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + std::optional> expr; + +public: + ReturnExprAST(Location loc, std::optional> expr) + : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} + + std::optional getExpr() { + if (expr.has_value()) + return expr->get(); + return std::nullopt; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, std::move(loc)), callee(callee), + args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(std::move(location)), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() { return functions.begin(); } + auto end() { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // TOY_AST_H diff --git a/Ch5/include/toy/CMakeLists.txt b/Ch5/include/toy/CMakeLists.txt new file mode 100644 index 0000000..4cae256 --- /dev/null +++ b/Ch5/include/toy/CMakeLists.txt @@ -0,0 +1,13 @@ +# # Most dialects should use add_mlir_dialect(). See examples/standalone. +# set(LLVM_TARGET_DEFINITIONS Ops.td) +# mlir_tablegen(Ops.h.inc -gen-op-decls) +# mlir_tablegen(Ops.cpp.inc -gen-op-defs) +# mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +# mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +# add_public_tablegen_target(ToyCh5OpsIncGen) + +# # Most dialects should use add_mlir_interfaces(). +# set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +# mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +# mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +# add_public_tablegen_target(ToyCh5ShapeInferenceInterfaceIncGen) diff --git a/Ch5/include/toy/Dialect.cpp.inc b/Ch5/include/toy/Dialect.cpp.inc new file mode 100644 index 0000000..8cbc772 --- /dev/null +++ b/Ch5/include/toy/Dialect.cpp.inc @@ -0,0 +1,23 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) +namespace mlir { +namespace toy { + +ToyDialect::ToyDialect(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get()) { + + initialize(); +} + +ToyDialect::~ToyDialect() = default; + +} // namespace toy +} // namespace mlir diff --git a/Ch5/include/toy/Dialect.h b/Ch5/include/toy/Dialect.h new file mode 100644 index 0000000..5db325e --- /dev/null +++ b/Ch5/include/toy/Dialect.h @@ -0,0 +1,36 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See docs/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "toy/ShapeInferenceInterface.h" + +/// Include the auto-generated header file containing the declaration of the toy +/// dialect. +#include "toy/Dialect.h.inc" + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/Ch5/include/toy/Dialect.h.inc b/Ch5/include/toy/Dialect.h.inc new file mode 100644 index 0000000..f19d867 --- /dev/null +++ b/Ch5/include/toy/Dialect.h.inc @@ -0,0 +1,26 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +namespace mlir { +namespace toy { + +class ToyDialect : public ::mlir::Dialect { + explicit ToyDialect(::mlir::MLIRContext *context); + + void initialize(); + friend class ::mlir::MLIRContext; +public: + ~ToyDialect() override; + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("toy"); + } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) diff --git a/Ch5/include/toy/Lexer.h b/Ch5/include/toy/Lexer.h new file mode 100644 index 0000000..3c59cd9 --- /dev/null +++ b/Ch5/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_LEXER_H +#define TOY_LEXER_H + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // TOY_LEXER_H diff --git a/Ch5/include/toy/MLIRGen.h b/Ch5/include/toy/MLIRGen.h new file mode 100644 index 0000000..fe9dbe5 --- /dev/null +++ b/Ch5/include/toy/MLIRGen.h @@ -0,0 +1,35 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_MLIRGEN_H +#define TOY_MLIRGEN_H + +#include + +namespace mlir { +class MLIRContext; +template +class OwningOpRef; +class ModuleOp; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // TOY_MLIRGEN_H diff --git a/Ch5/include/toy/Ops.cpp.inc b/Ch5/include/toy/Ops.cpp.inc new file mode 100644 index 0000000..6bb98a2 --- /dev/null +++ b/Ch5/include/toy/Ops.cpp.inc @@ -0,0 +1,2252 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifdef GET_OP_LIST +#undef GET_OP_LIST + +::mlir::toy::AddOp, +::mlir::toy::CastOp, +::mlir::toy::ConstantOp, +::mlir::toy::FuncOp, +::mlir::toy::GenericCallOp, +::mlir::toy::MulOp, +::mlir::toy::PrintOp, +::mlir::toy::ReshapeOp, +::mlir::toy::ReturnOp, +::mlir::toy::TransposeOp +#endif // GET_OP_LIST + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be variadic of tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))) || (((::llvm::isa<::mlir::MemRefType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values or memref of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::RankedTensorType>(type))) && ((::llvm::cast<::mlir::ShapedType>(type).hasStaticShape()))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be statically shaped tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::DenseFPElementsAttr>(attr) &&::llvm::cast<::mlir::DenseElementsAttr>(attr).getType().getElementType().isF64()))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: 64-bit float elements attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops0(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::StringAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: string attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops1(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::TypeAttr>(attr))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: type attribute of function type"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops2(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::ArrayAttr>(attr))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(attr), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: Array of dictionary attributes"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops3(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: flat symbol reference attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops4(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_region_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, + unsigned regionIndex) { + if (!((true))) { + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: any region"; + } + return ::mlir::success(); +} +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.add", odsAttrs.getContext()); +} + +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(AddOp op) : AddOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair AddOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr AddOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +AddOpAdaptor::AddOpAdaptor(AddOp op) : AddOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult AddOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair AddOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range AddOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &AddOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &AddOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair AddOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range AddOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void AddOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult AddOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult AddOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void AddOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::CastOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +CastOpGenericAdaptorBase::CastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.cast", odsAttrs.getContext()); +} + +CastOpGenericAdaptorBase::CastOpGenericAdaptorBase(CastOp op) : CastOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair CastOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr CastOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +CastOpAdaptor::CastOpAdaptor(CastOp op) : CastOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult CastOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair CastOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range CastOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> CastOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &CastOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair CastOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range CastOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> CastOp::getOutput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(0).begin()); +} + +void CastOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(output); +} + +void CastOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void CastOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult CastOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult CastOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult CastOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + outputRawTypes[0] = type; + } + result.addTypes(outputTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void CastOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + { + auto type = getOutput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +void CastOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::CastOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.constant", odsAttrs.getContext()); +} + +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(ConstantOp op) : ConstantOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ConstantOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ConstantOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValueAttr() { + auto attr = ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); + return attr; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValue() { + auto attr = getValueAttr(); + return attr; +} + +} // namespace detail +ConstantOpAdaptor::ConstantOpAdaptor(ConstantOp op) : ConstantOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ConstantOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitError(loc, "'toy.constant' op ""requires attribute 'value'"); + + if (tblgen_value && !((::llvm::isa<::mlir::DenseFPElementsAttr>(tblgen_value) &&::llvm::cast<::mlir::DenseElementsAttr>(tblgen_value).getType().getElementType().isF64()))) + return emitError(loc, "'toy.constant' op ""attribute 'value' failed to satisfy constraint: 64-bit float elements attribute"); + return ::mlir::success(); +} + +std::pair ConstantOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ConstantOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair ConstantOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ConstantOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult ConstantOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.value; + auto attr = dict.get("value"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for value in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `value` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute ConstantOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.value; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("value", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code ConstantOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.value.getAsOpaquePointer())); +} + +std::optional ConstantOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "value") + return prop.value; + return std::nullopt; +} + +void ConstantOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "value") { + prop.value = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void ConstantOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.value) attrs.append("value", prop.value); +} + +::mlir::LogicalResult ConstantOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getValueAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(attr, "value", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.value))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ConstantOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.value); +} + +::mlir::DenseElementsAttr ConstantOp::getValueAttr() { + return ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); +} + +::mlir::DenseElementsAttr ConstantOp::getValue() { + auto attr = getValueAttr(); + return attr; +} + +void ConstantOp::setValueAttr(::mlir::DenseElementsAttr attr) { + (*this)->setAttr(getValueAttrName(), attr); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value) { + build(odsBuilder, odsState, value.getType(), value); + +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + odsState.addTypes(resultType0); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ConstantOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 0u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ConstantOp::verifyInvariantsImpl() { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitOpError("requires attribute 'value'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(*this, tblgen_value, "value"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +void ConstantOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.func", odsAttrs.getContext()); +} + +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(FuncOp op) : FuncOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair FuncOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr FuncOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::StringAttr FuncOpGenericAdaptorBase::getSymNameAttr() { + auto attr = ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); + return attr; +} + +::llvm::StringRef FuncOpGenericAdaptorBase::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOpGenericAdaptorBase::getFunctionTypeAttr() { + auto attr = ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); + return attr; +} + +::mlir::FunctionType FuncOpGenericAdaptorBase::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getArgAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getResAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::Region &FuncOpGenericAdaptorBase::getBody() { + return *odsRegions[0]; +} + +::mlir::RegionRange FuncOpGenericAdaptorBase::getRegions() { + return odsRegions; +} + +} // namespace detail +FuncOpAdaptor::FuncOpAdaptor(FuncOp op) : FuncOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult FuncOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitError(loc, "'toy.func' op ""requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitError(loc, "'toy.func' op ""requires attribute 'sym_name'"); + + if (tblgen_sym_name && !((::llvm::isa<::mlir::StringAttr>(tblgen_sym_name)))) + return emitError(loc, "'toy.func' op ""attribute 'sym_name' failed to satisfy constraint: string attribute"); + + if (tblgen_function_type && !(((::llvm::isa<::mlir::TypeAttr>(tblgen_function_type))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))))) + return emitError(loc, "'toy.func' op ""attribute 'function_type' failed to satisfy constraint: type attribute of function type"); + + if (tblgen_arg_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_arg_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_arg_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'arg_attrs' failed to satisfy constraint: Array of dictionary attributes"); + + if (tblgen_res_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_res_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_res_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'res_attrs' failed to satisfy constraint: Array of dictionary attributes"); + return ::mlir::success(); +} + +std::pair FuncOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range FuncOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair FuncOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range FuncOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Region &FuncOp::getBody() { + return (*this)->getRegion(0); +} + +::mlir::LogicalResult FuncOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.arg_attrs; + auto attr = dict.get("arg_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for arg_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `arg_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.function_type; + auto attr = dict.get("function_type"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for function_type in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `function_type` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.res_attrs; + auto attr = dict.get("res_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for res_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `res_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.sym_name; + auto attr = dict.get("sym_name"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for sym_name in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `sym_name` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute FuncOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.arg_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("arg_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.function_type; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("function_type", + propStorage)); + } + + { + const auto &propStorage = prop.res_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("res_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.sym_name; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("sym_name", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code FuncOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.arg_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.function_type.getAsOpaquePointer()), + llvm::hash_value(prop.res_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.sym_name.getAsOpaquePointer())); +} + +std::optional FuncOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "arg_attrs") + return prop.arg_attrs; + + if (name == "function_type") + return prop.function_type; + + if (name == "res_attrs") + return prop.res_attrs; + + if (name == "sym_name") + return prop.sym_name; + return std::nullopt; +} + +void FuncOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "arg_attrs") { + prop.arg_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "function_type") { + prop.function_type = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "res_attrs") { + prop.res_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "sym_name") { + prop.sym_name = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void FuncOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.arg_attrs) attrs.append("arg_attrs", prop.arg_attrs); + + if (prop.function_type) attrs.append("function_type", prop.function_type); + + if (prop.res_attrs) attrs.append("res_attrs", prop.res_attrs); + + if (prop.sym_name) attrs.append("sym_name", prop.sym_name); +} + +::mlir::LogicalResult FuncOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getArgAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "arg_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getFunctionTypeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(attr, "function_type", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getResAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "res_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getSymNameAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(attr, "sym_name", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readOptionalAttribute(prop.arg_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.function_type))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readOptionalAttribute(prop.res_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.sym_name))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void FuncOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + + writer.writeOptionalAttribute(prop.arg_attrs); + writer.writeAttribute(prop.function_type); + + writer.writeOptionalAttribute(prop.res_attrs); + writer.writeAttribute(prop.sym_name); +} + +::mlir::StringAttr FuncOp::getSymNameAttr() { + return ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); +} + +::llvm::StringRef FuncOp::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOp::getFunctionTypeAttr() { + return ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); +} + +::mlir::FunctionType FuncOp::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOp::getArgAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOp::getResAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +void FuncOp::setSymNameAttr(::mlir::StringAttr attr) { + (*this)->setAttr(getSymNameAttrName(), attr); +} + +void FuncOp::setSymName(::llvm::StringRef attrValue) { + (*this)->setAttr(getSymNameAttrName(), ::mlir::Builder((*this)->getContext()).getStringAttr(attrValue)); +} + +void FuncOp::setFunctionTypeAttr(::mlir::TypeAttr attr) { + (*this)->setAttr(getFunctionTypeAttrName(), attr); +} + +void FuncOp::setFunctionType(::mlir::FunctionType attrValue) { + (*this)->setAttr(getFunctionTypeAttrName(), ::mlir::TypeAttr::get(attrValue)); +} + +void FuncOp::setArgAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getArgAttrsAttrName(), attr); +} + +void FuncOp::setResAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getResAttrsAttrName(), attr); +} + +::mlir::Attribute FuncOp::removeArgAttrsAttr() { + auto &attr = getProperties().arg_attrs; + attr = {}; + return attr; +} + +::mlir::Attribute FuncOp::removeResAttrsAttr() { + auto &attr = getProperties().res_attrs; + attr = {}; + return attr; +} + +::mlir::LogicalResult FuncOp::verifyInvariantsImpl() { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitOpError("requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitOpError("requires attribute 'sym_name'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(*this, tblgen_sym_name, "sym_name"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(*this, tblgen_function_type, "function_type"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_arg_attrs, "arg_attrs"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_res_attrs, "res_attrs"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + + for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(0))) + if (::mlir::failed(__mlir_ods_local_region_constraint_Ops0(*this, region, "body", index++))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.generic_call", odsAttrs.getContext()); +} + +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(GenericCallOp op) : GenericCallOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair GenericCallOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr GenericCallOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::FlatSymbolRefAttr GenericCallOpGenericAdaptorBase::getCalleeAttr() { + auto attr = ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); + return attr; +} + +::llvm::StringRef GenericCallOpGenericAdaptorBase::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +} // namespace detail +GenericCallOpAdaptor::GenericCallOpAdaptor(GenericCallOp op) : GenericCallOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult GenericCallOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitError(loc, "'toy.generic_call' op ""requires attribute 'callee'"); + + if (tblgen_callee && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(tblgen_callee)))) + return emitError(loc, "'toy.generic_call' op ""attribute 'callee' failed to satisfy constraint: flat symbol reference attribute"); + return ::mlir::success(); +} + +std::pair GenericCallOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range GenericCallOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range GenericCallOp::getInputs() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange GenericCallOp::getInputsMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair GenericCallOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range GenericCallOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult GenericCallOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.callee; + auto attr = dict.get("callee"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for callee in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `callee` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute GenericCallOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.callee; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("callee", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code GenericCallOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.callee.getAsOpaquePointer())); +} + +std::optional GenericCallOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "callee") + return prop.callee; + return std::nullopt; +} + +void GenericCallOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "callee") { + prop.callee = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void GenericCallOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.callee) attrs.append("callee", prop.callee); +} + +::mlir::LogicalResult GenericCallOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getCalleeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(attr, "callee", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.callee))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.callee); +} + +::mlir::FlatSymbolRefAttr GenericCallOp::getCalleeAttr() { + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); +} + +::llvm::StringRef GenericCallOp::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +void GenericCallOp::setCalleeAttr(::mlir::FlatSymbolRefAttr attr) { + (*this)->setAttr(getCalleeAttrName(), attr); +} + +void GenericCallOp::setCallee(::llvm::StringRef attrValue) { + (*this)->setAttr(getCalleeAttrName(), ::mlir::SymbolRefAttr::get(::mlir::Builder((*this)->getContext()).getContext(), attrValue)); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariantsImpl() { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitOpError("requires attribute 'callee'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(*this, tblgen_callee, "callee"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult GenericCallOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::FlatSymbolRefAttr calleeAttr; + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputsOperands; + ::llvm::SMLoc inputsOperandsLoc; + (void)inputsOperandsLoc; + ::llvm::ArrayRef<::mlir::Type> inputsTypes; + ::llvm::ArrayRef<::mlir::Type> allResultTypes; + + if (parser.parseCustomAttributeWithFallback(calleeAttr, parser.getBuilder().getType<::mlir::NoneType>())) { + return ::mlir::failure(); + } + if (calleeAttr) result.getOrAddProperties().callee = calleeAttr; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands)) + return ::mlir::failure(); + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + ::mlir::FunctionType inputs__allResult_functionType; + if (parser.parseType(inputs__allResult_functionType)) + return ::mlir::failure(); + inputsTypes = inputs__allResult_functionType.getInputs(); + allResultTypes = inputs__allResult_functionType.getResults(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter.printAttributeWithoutType(getCalleeAttr()); + _odsPrinter << "("; + _odsPrinter << getInputs(); + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + elidedAttrs.push_back("callee"); + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter.printFunctionalType(getInputs().getTypes(), getOperation()->getResultTypes()); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.mul", odsAttrs.getContext()); +} + +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(MulOp op) : MulOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair MulOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr MulOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +MulOpAdaptor::MulOpAdaptor(MulOp op) : MulOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult MulOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair MulOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range MulOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &MulOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &MulOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair MulOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range MulOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void MulOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult MulOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult MulOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void MulOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.print", odsAttrs.getContext()); +} + +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(PrintOp op) : PrintOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair PrintOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr PrintOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +PrintOpAdaptor::PrintOpAdaptor(PrintOp op) : PrintOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult PrintOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair PrintOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range PrintOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Value PrintOp::getInput() { + return ::llvm::cast<::mlir::Value>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &PrintOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair PrintOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range PrintOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input) { + odsState.addOperands(input); +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 0u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void PrintOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult PrintOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops2(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult PrintOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult PrintOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::Type type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void PrintOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.reshape", odsAttrs.getContext()); +} + +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(ReshapeOp op) : ReshapeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReshapeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ReshapeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReshapeOpAdaptor::ReshapeOpAdaptor(ReshapeOp op) : ReshapeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReshapeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReshapeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ReshapeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> ReshapeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &ReshapeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair ReshapeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReshapeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ReshapeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops3(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult ReshapeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReshapeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void ReshapeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.return", odsAttrs.getContext()); +} + +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(ReturnOp op) : ReturnOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReturnOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr ReturnOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReturnOpAdaptor::ReturnOpAdaptor(ReturnOp op) : ReturnOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReturnOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReturnOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range ReturnOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range ReturnOp::getInput() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange ReturnOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair ReturnOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReturnOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState) { + build(odsBuilder, odsState, std::nullopt); +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input) { + odsState.addOperands(input); +} + +void ReturnOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReturnOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReturnOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult ReturnOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputOperands; + ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::llvm::SmallVector<::mlir::Type, 1> inputTypes; + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputOperands)) + return ::mlir::failure(); + if (!inputOperands.empty()) { + if (parser.parseColon()) + return ::mlir::failure(); + + if (parser.parseTypeList(inputTypes)) + return ::mlir::failure(); + } + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReturnOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + if (!getInput().empty()) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter << getInput().getTypes(); + } + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); +} + +void ReturnOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.transpose", odsAttrs.getContext()); +} + +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(TransposeOp op) : TransposeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair TransposeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr TransposeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +TransposeOpAdaptor::TransposeOpAdaptor(TransposeOp op) : TransposeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult TransposeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair TransposeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range TransposeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> TransposeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &TransposeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair TransposeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range TransposeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void TransposeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult TransposeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult TransposeOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult TransposeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void TransposeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void TransposeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch5/include/toy/Ops.h.inc b/Ch5/include/toy/Ops.h.inc new file mode 100644 index 0000000..a3011f8 --- /dev/null +++ b/Ch5/include/toy/Ops.h.inc @@ -0,0 +1,1361 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES) +#undef GET_OP_FWD_DEFINES +namespace mlir { +namespace toy { +class AddOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class CastOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ConstantOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class FuncOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class GenericCallOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class MulOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class PrintOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReshapeOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReturnOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class TransposeOp; +} // namespace toy +} // namespace mlir +#endif + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class AddOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + AddOpGenericAdaptorBase(AddOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class AddOpGenericAdaptor : public detail::AddOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::AddOpGenericAdaptorBase; +public: + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + AddOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class AddOpAdaptor : public AddOpGenericAdaptor<::mlir::ValueRange> { +public: + using AddOpGenericAdaptor::AddOpGenericAdaptor; + AddOpAdaptor(AddOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class AddOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = AddOpAdaptor; + template + using GenericAdaptor = AddOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.add"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::CastOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class CastOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + CastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + CastOpGenericAdaptorBase(CastOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class CastOpGenericAdaptor : public detail::CastOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::CastOpGenericAdaptorBase; +public: + CastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + CastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : CastOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + CastOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class CastOpAdaptor : public CastOpGenericAdaptor<::mlir::ValueRange> { +public: + using CastOpGenericAdaptor::CastOpGenericAdaptor; + CastOpAdaptor(CastOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class CastOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::CastOpInterface::Trait, ShapeInference::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultShape> { +public: + using Op::Op; + using Op::print; + using Adaptor = CastOpAdaptor; + template + using GenericAdaptor = CastOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.cast"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getOutput(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static bool areCastCompatible(::mlir::TypeRange inputs, ::mlir::TypeRange outputs); + void inferShapes(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::CastOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ConstantOpGenericAdaptorBase { +public: + struct Properties { + using valueTy = ::mlir::DenseElementsAttr; + valueTy value; + + auto getValue() { + auto &propStorage = this->value; + return ::llvm::cast<::mlir::DenseElementsAttr>(propStorage); + } + void setValue(const ::mlir::DenseElementsAttr &propValue) { + this->value = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.value == this->value && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + ConstantOpGenericAdaptorBase(ConstantOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); +}; +} // namespace detail +template +class ConstantOpGenericAdaptor : public detail::ConstantOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ConstantOpGenericAdaptorBase; +public: + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ConstantOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + ConstantOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ConstantOpAdaptor : public ConstantOpGenericAdaptor<::mlir::ValueRange> { +public: + using ConstantOpGenericAdaptor::ConstantOpGenericAdaptor; + ConstantOpAdaptor(ConstantOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ConstantOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ConstantOpAdaptor; + template + using GenericAdaptor = ConstantOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getValueAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.constant"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); + void setValueAttr(::mlir::DenseElementsAttr attr); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class FuncOpGenericAdaptorBase { +public: + struct Properties { + using arg_attrsTy = ::mlir::ArrayAttr; + arg_attrsTy arg_attrs; + + auto getArgAttrs() { + auto &propStorage = this->arg_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setArgAttrs(const ::mlir::ArrayAttr &propValue) { + this->arg_attrs = propValue; + } + using function_typeTy = ::mlir::TypeAttr; + function_typeTy function_type; + + auto getFunctionType() { + auto &propStorage = this->function_type; + return ::llvm::cast<::mlir::TypeAttr>(propStorage); + } + void setFunctionType(const ::mlir::TypeAttr &propValue) { + this->function_type = propValue; + } + using res_attrsTy = ::mlir::ArrayAttr; + res_attrsTy res_attrs; + + auto getResAttrs() { + auto &propStorage = this->res_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setResAttrs(const ::mlir::ArrayAttr &propValue) { + this->res_attrs = propValue; + } + using sym_nameTy = ::mlir::StringAttr; + sym_nameTy sym_name; + + auto getSymName() { + auto &propStorage = this->sym_name; + return ::llvm::cast<::mlir::StringAttr>(propStorage); + } + void setSymName(const ::mlir::StringAttr &propValue) { + this->sym_name = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.arg_attrs == this->arg_attrs && + rhs.function_type == this->function_type && + rhs.res_attrs == this->res_attrs && + rhs.sym_name == this->sym_name && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + FuncOpGenericAdaptorBase(FuncOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + ::mlir::Region &getBody(); + ::mlir::RegionRange getRegions(); +}; +} // namespace detail +template +class FuncOpGenericAdaptor : public detail::FuncOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::FuncOpGenericAdaptorBase; +public: + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : FuncOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + FuncOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class FuncOpAdaptor : public FuncOpGenericAdaptor<::mlir::ValueRange> { +public: + using FuncOpGenericAdaptor::FuncOpGenericAdaptor; + FuncOpAdaptor(FuncOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class FuncOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = FuncOpAdaptor; + template + using GenericAdaptor = FuncOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("arg_attrs"), ::llvm::StringRef("function_type"), ::llvm::StringRef("res_attrs"), ::llvm::StringRef("sym_name")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getArgAttrsAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getArgAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + ::mlir::StringAttr getFunctionTypeAttrName() { + return getAttributeNameForIndex(1); + } + + static ::mlir::StringAttr getFunctionTypeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 1); + } + + ::mlir::StringAttr getResAttrsAttrName() { + return getAttributeNameForIndex(2); + } + + static ::mlir::StringAttr getResAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 2); + } + + ::mlir::StringAttr getSymNameAttrName() { + return getAttributeNameForIndex(3); + } + + static ::mlir::StringAttr getSymNameAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 3); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.func"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::Region &getBody(); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + void setSymNameAttr(::mlir::StringAttr attr); + void setSymName(::llvm::StringRef attrValue); + void setFunctionTypeAttr(::mlir::TypeAttr attr); + void setFunctionType(::mlir::FunctionType attrValue); + void setArgAttrsAttr(::mlir::ArrayAttr attr); + void setResAttrsAttr(::mlir::ArrayAttr attr); + ::mlir::Attribute removeArgAttrsAttr(); + ::mlir::Attribute removeResAttrsAttr(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef name, FunctionType type, ArrayRef attrs = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 4 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the function operation that is callable. + Region *getCallableRegion() { return &getBody(); } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class GenericCallOpGenericAdaptorBase { +public: + struct Properties { + using calleeTy = ::mlir::FlatSymbolRefAttr; + calleeTy callee; + + auto getCallee() { + auto &propStorage = this->callee; + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(propStorage); + } + void setCallee(const ::mlir::FlatSymbolRefAttr &propValue) { + this->callee = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.callee == this->callee && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + GenericCallOpGenericAdaptorBase(GenericCallOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); +}; +} // namespace detail +template +class GenericCallOpGenericAdaptor : public detail::GenericCallOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::GenericCallOpGenericAdaptorBase; +public: + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : GenericCallOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + GenericCallOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInputs() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class GenericCallOpAdaptor : public GenericCallOpGenericAdaptor<::mlir::ValueRange> { +public: + using GenericCallOpGenericAdaptor::GenericCallOpGenericAdaptor; + GenericCallOpAdaptor(GenericCallOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class GenericCallOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::VariadicOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::CallOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = GenericCallOpAdaptor; + template + using GenericAdaptor = GenericCallOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("callee")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getCalleeAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getCalleeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.generic_call"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInputs(); + ::mlir::MutableOperandRange getInputsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); + void setCalleeAttr(::mlir::FlatSymbolRefAttr attr); + void setCallee(::llvm::StringRef attrValue); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef callee, ArrayRef arguments); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::CallInterfaceCallable getCallableForCallee(); + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee); + ::mlir::Operation::operand_range getArgOperands(); + ::mlir::MutableOperandRange getArgOperandsMutable(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class MulOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + MulOpGenericAdaptorBase(MulOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class MulOpGenericAdaptor : public detail::MulOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::MulOpGenericAdaptorBase; +public: + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : MulOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + MulOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class MulOpAdaptor : public MulOpGenericAdaptor<::mlir::ValueRange> { +public: + using MulOpGenericAdaptor::MulOpGenericAdaptor; + MulOpAdaptor(MulOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class MulOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = MulOpAdaptor; + template + using GenericAdaptor = MulOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.mul"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class PrintOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + PrintOpGenericAdaptorBase(PrintOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class PrintOpGenericAdaptor : public detail::PrintOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::PrintOpGenericAdaptorBase; +public: + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : PrintOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + PrintOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class PrintOpAdaptor : public PrintOpGenericAdaptor<::mlir::ValueRange> { +public: + using PrintOpGenericAdaptor::PrintOpGenericAdaptor; + PrintOpAdaptor(PrintOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class PrintOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = PrintOpAdaptor; + template + using GenericAdaptor = PrintOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.print"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Value getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReshapeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReshapeOpGenericAdaptorBase(ReshapeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReshapeOpGenericAdaptor : public detail::ReshapeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReshapeOpGenericAdaptorBase; +public: + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReshapeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReshapeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReshapeOpAdaptor : public ReshapeOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReshapeOpGenericAdaptor::ReshapeOpGenericAdaptor; + ReshapeOpAdaptor(ReshapeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReshapeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReshapeOpAdaptor; + template + using GenericAdaptor = ReshapeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.reshape"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReturnOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReturnOpGenericAdaptorBase(ReturnOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReturnOpGenericAdaptor : public detail::ReturnOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReturnOpGenericAdaptorBase; +public: + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReturnOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReturnOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInput() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReturnOpAdaptor : public ReturnOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReturnOpGenericAdaptor::ReturnOpGenericAdaptor; + ReturnOpAdaptor(ReturnOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReturnOp : public ::mlir::Op::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::IsTerminator> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReturnOpAdaptor; + template + using GenericAdaptor = ReturnOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.return"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInput(); + ::mlir::MutableOperandRange getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: + bool hasOperand() { return getNumOperands() != 0; } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class TransposeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + TransposeOpGenericAdaptorBase(TransposeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class TransposeOpGenericAdaptor : public detail::TransposeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::TransposeOpGenericAdaptorBase; +public: + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : TransposeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + TransposeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class TransposeOpAdaptor : public TransposeOpGenericAdaptor<::mlir::ValueRange> { +public: + using TransposeOpGenericAdaptor::TransposeOpGenericAdaptor; + TransposeOpAdaptor(TransposeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class TransposeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = TransposeOpAdaptor; + template + using GenericAdaptor = TransposeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.transpose"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + void inferShapes(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch5/include/toy/Ops.td b/Ch5/include/toy/Ops.td new file mode 100644 index 0000000..ec6762f --- /dev/null +++ b/Ch5/include/toy/Ops.td @@ -0,0 +1,372 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "toy/ShapeInferenceInterface.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "::mlir::toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'Pure' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [Pure]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> + : tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<(ins "DenseElementsAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<(ins "double":$value)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : Toy_Op<"add", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Pure, + SameOperandsAndResultShape + ]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types must + both be tensor types with the same element type. If both are ranked, then + shape is required to match. The operation is invalid if converting to a + mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : Toy_Op<"func", [ + FunctionOpInterface, IsolatedFromAbove + ]> { + let summary = "user defined function operation"; + let description = [{ + The "toy.func" operation represents a user defined function. These are + callable SSA-region operations that contain toy computations. + + Example: + + ```mlir + toy.func @main() { + %0 = toy.constant dense<5.500000e+00> : tensor + %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> + toy.print %1 : tensor<2x2xf64> + toy.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the function operation that is callable. + Region *getCallableRegion() { return &getBody(); } + }]; + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = toy.generic_call @my_func(%1, %3) + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Specialize assembly printing and parsing using a declarative format. + let assemblyFormat = [{ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + ]; +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : Toy_Op<"mul", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); + + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def ReshapeOp : Toy_Op<"reshape", [Pure]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, + Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + toy.func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // The return operation only emits the input in the format if it is present. + let assemblyFormat = "($input^ `:` type($input))? attr-dict "; + + // Allow building a ReturnOp with no return operand. + let builders = [ + OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + ]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def TransposeOp : Toy_Op<"transpose", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +#endif // TOY_OPS diff --git a/Ch5/include/toy/Parser.h b/Ch5/include/toy/Parser.h new file mode 100644 index 0000000..1f20616 --- /dev/null +++ b/Ch5/include/toy/Parser.h @@ -0,0 +1,489 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PARSER_H +#define TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + std::optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + + if (lexer.getCurToken() != tok_def) + return parseError("def", "in prototype"); + lexer.consume(tok_def); + + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName(lexer.getId()); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name(lexer.getId()); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError(")", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // TOY_PARSER_H diff --git a/Ch5/include/toy/Passes.h b/Ch5/include/toy/Passes.h new file mode 100644 index 0000000..02a83cf --- /dev/null +++ b/Ch5/include/toy/Passes.h @@ -0,0 +1,31 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PASSES_H +#define TOY_PASSES_H + +#include + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr createShapeInferencePass(); + +/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr createLowerToAffinePass(); + +} // namespace toy +} // namespace mlir + +#endif // TOY_PASSES_H diff --git a/Ch5/include/toy/ShapeInferenceInterface.h b/Ch5/include/toy/ShapeInferenceInterface.h new file mode 100644 index 0000000..cfe5a87 --- /dev/null +++ b/Ch5/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,28 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // namespace toy +} // namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/Ch5/include/toy/ShapeInferenceInterface.td b/Ch5/include/toy/ShapeInferenceInterface.td new file mode 100644 index 0000000..2279015 --- /dev/null +++ b/Ch5/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,30 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/Ch5/include/toy/ShapeInferenceOpInterfaces.cpp.inc b/Ch5/include/toy/ShapeInferenceOpInterfaces.cpp.inc new file mode 100644 index 0000000..a481d2e --- /dev/null +++ b/Ch5/include/toy/ShapeInferenceOpInterfaces.cpp.inc @@ -0,0 +1,12 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Interface Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/// Infer and set the output shape for the current operation. +void ShapeInference::inferShapes() { + return getImpl()->inferShapes(getImpl(), getOperation()); + } diff --git a/Ch5/include/toy/ShapeInferenceOpInterfaces.h.inc b/Ch5/include/toy/ShapeInferenceOpInterfaces.h.inc new file mode 100644 index 0000000..bb24654 --- /dev/null +++ b/Ch5/include/toy/ShapeInferenceOpInterfaces.h.inc @@ -0,0 +1,61 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Interface Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +class ShapeInference; +namespace detail { +struct ShapeInferenceInterfaceTraits { + struct Concept { + /// The methods defined by the interface. + void (*inferShapes)(const Concept *impl, ::mlir::Operation *); + }; + template + class Model : public Concept { + public: + using Interface = ShapeInference; + Model() : Concept{inferShapes} {} + + static inline void inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val); + }; + template + class FallbackModel : public Concept { + public: + using Interface = ShapeInference; + FallbackModel() : Concept{inferShapes} {} + + static inline void inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val); + }; + template + class ExternalModel : public FallbackModel { + public: + using ConcreteEntity = ConcreteOp; + }; +};template +struct ShapeInferenceTrait; + +} // namespace detail +class ShapeInference : public ::mlir::OpInterface { +public: + using ::mlir::OpInterface::OpInterface; + template + struct Trait : public detail::ShapeInferenceTrait {}; + /// Infer and set the output shape for the current operation. + void inferShapes(); +}; +namespace detail { + template + struct ShapeInferenceTrait : public ::mlir::OpInterface::Trait { + }; +}// namespace detail +template +void detail::ShapeInferenceInterfaceTraits::Model::inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val) { + return (llvm::cast(tablegen_opaque_val)).inferShapes(); +} +template +void detail::ShapeInferenceInterfaceTraits::FallbackModel::inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val) { + return static_cast(impl)->inferShapes(tablegen_opaque_val); +} diff --git a/Ch5/mlir/Dialect.cpp b/Ch5/mlir/Dialect.cpp new file mode 100644 index 0000000..c587dd2 --- /dev/null +++ b/Ch5/mlir/Dialect.cpp @@ -0,0 +1,444 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::toy; + +#include "toy/Dialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } + + // All functions within toy can be inlined. + bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void ToyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// A generalized parser for binary operations. This parses the different forms +/// of 'printBinaryOp' below. +static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + SmallVector operands; + SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (FunctionType funcType = llvm::dyn_cast(type)) { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// A generalized printer for binary operations. It prints in two different +/// forms depending on if all of the types match. +static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })) { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder.getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// The 'OpAsmParser' class provides a collection of methods for parsing +/// various punctuation, as well as attributes, operands, types, etc. Each of +/// these methods returns a `ParseResult`. This class is a wrapper around +/// `LogicalResult` that can be converted to a boolean `true` value on failure, +/// or `false` on success. This allows for easily chaining together a set of +/// parser rules. These rules are used to populate an `mlir::OperationState` +/// similarly to the `build` methods described above. +mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::DenseElementsAttr value; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(value, "value", result.attributes)) + return failure(); + + result.addTypes(value.getType()); + return success(); +} + +/// The 'OpAsmPrinter' class is a stream that allows for formatting +/// strings, attributes, operands, types, etc. +void ConstantOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); + printer << getValue(); +} + +/// Verifier for the constant operation. This corresponds to the +/// `let hasVerifier = 1` in the op definition. +mlir::LogicalResult ConstantOp::verify() { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(getResult().getType()); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = llvm::cast(getValue().getType()); + if (attrType.getRank() != resultType.getRank()) { + return emitOpError("return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } + +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = llvm::dyn_cast(inputs.front()); + TensorType output = llvm::dyn_cast(outputs.front()); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs) { + // FunctionOpInterface provides a convenient `build` method that will populate + // the state of our FuncOp, and create an entry block. + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + // Dispatch to the FunctionOpInterface provided utility method that parses the + // function operation. + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(mlir::OpAsmPrinter &p) { + // Dispatch to the FunctionOpInterface provided utility method that prints the + // function operation. + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return (*this)->getAttrOfType("callee"); +} + +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } + +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult ReturnOp::verify() { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast((*this)->getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (getNumOperands() > 1) + return emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError() << "does not return the same number of values (" + << getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!hasOperand()) + return mlir::success(); + + auto inputType = *operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || llvm::isa(inputType) || + llvm::isa(resultType)) + return mlir::success(); + + return emitError() << "type of return operand (" << inputType + << ") doesn't match function result type (" << resultType + << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = llvm::cast(getOperand().getType()); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +mlir::LogicalResult TransposeOp::verify() { + auto inputType = llvm::dyn_cast(getOperand().getType()); + auto resultType = llvm::dyn_cast(getType()); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/Ch5/mlir/LowerToAffineLoops.cpp b/Ch5/mlir/LowerToAffineLoops.cpp new file mode 100644 index 0000000..ae4bd98 --- /dev/null +++ b/Ch5/mlir/LowerToAffineLoops.cpp @@ -0,0 +1,385 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops, memref operations and standard operations. This lowering +// expects that all calls have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc->getBlock(); + alloc->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input an OpBuilder, an range of memRefOperands +/// corresponding to the operands of the input operation, and the range of loop +/// induction variables for the iteration. It returns a value to store at the +/// current index of the iteration. +using LoopIterationFn = function_ref; + +static void lowerOpToLoops(Operation *op, ValueRange operands, + PatternRewriter &rewriter, + LoopIterationFn processIteration) { + auto tensorType = llvm::cast((*op->result_type_begin())); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create a nest of affine loops, with one loop per dimension of the shape. + // The buildAffineLoopNest function takes a callback that is used to construct + // the body of the innermost loop given a builder, a location and a range of + // loop induction variables. + SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); + SmallVector steps(tensorType.getRank(), /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, tensorType.getShape(), steps, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { + // Call the processing function with the rewriter, the memref operands, + // and the loop induction variables. This function will return the value + // to store at the current index. + Value valueToStore = processIteration(nestedBuilder, operands, ivs); + nestedBuilder.create(loc, valueToStore, alloc, + ivs); + }); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace { +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Binary operations +//===----------------------------------------------------------------------===// + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext *ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // BinaryOp. This allows for using the nice named accessors + // that are generated by the ODS. + typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); + + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = builder.create( + loc, binaryAdaptor.getLhs(), loopIvs); + auto loadedRhs = builder.create( + loc, binaryAdaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return builder.create(loc, loadedLhs, + loadedRhs); + }); + return success(); + } +}; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Constant operations +//===----------------------------------------------------------------------===// + +struct ConstantOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { + DenseElementsAttr constantValue = op.getValue(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = llvm::cast(op.getType()); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + + if (!valueShape.empty()) { + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back( + rewriter.create(loc, i)); + } else { + // This is the case of a tensor of rank 0. + constantIndices.push_back( + rewriter.create(loc, 0)); + } + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.value_begin(); + std::function storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::ArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Func operations +//===----------------------------------------------------------------------===// + +struct FuncOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We only lower the main function as we expect that all other functions + // have been inlined. + if (op.getName() != "main") + return failure(); + + // Verify that the given main has no inputs and results. + if (op.getNumArguments() || op.getFunctionType().getNumResults()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "expected 'main' to have 0 inputs and 0 results"; + }); + } + + // Create a new non-toy function, with the same region. + auto func = rewriter.create(op.getLoc(), op.getName(), + op.getFunctionType()); + rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Print operations +//===----------------------------------------------------------------------===// + +struct PrintOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We don't lower "toy.print" in this pass, but we need to update its + // operands. + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Return operations +//===----------------------------------------------------------------------===// + +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { + // During this lowering, we expect that all function calls have been + // inlined. + if (op.hasOperand()) + return failure(); + + // We lower "toy.return" directly to "func.return". + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Transpose operations +//===----------------------------------------------------------------------===// + +struct TransposeOpLowering : public ConversionPattern { + TransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // TransposeOp. This allows for using the nice named + // accessors that are generated by the ODS. + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return builder.create(loc, input, + reverseIvs); + }); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// ToyToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the toy operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Toy dialect. +namespace { +struct ToyToAffineLoweringPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToAffineLoweringPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // namespace + +void ToyToAffineLoweringPass::runOnOperation() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine`, `Arith`, `Func`, and `MemRef` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands + // to be updated though (as we convert from TensorType to MemRefType), so we + // only treat it as `legal` if its operands are legal. + target.addIllegalDialect(); + target.addDynamicallyLegalOp([](toy::PrintOp op) { + return llvm::none_of(op->getOperandTypes(), + [](Type type) { return llvm::isa(type); }); + }); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + RewritePatternSet patterns(&getContext()); + patterns.add( + &getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} + +/// Create a pass for lowering operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr mlir::toy::createLowerToAffinePass() { + return std::make_unique(); +} diff --git a/Ch5/mlir/MLIRGen.cpp b/Ch5/mlir/MLIRGen.cpp new file mode 100644 index 0000000..6c5474a --- /dev/null +++ b/Ch5/mlir/MLIRGen.cpp @@ -0,0 +1,461 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &f : moduleAST) + mlirGen(f); + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(const Location &loc) { + return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::toy::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector argTypes(proto.getArgs().size(), + getType(VarType{})); + auto funcType = builder.getFunctionType(argTypes, std::nullopt); + return builder.create(location, proto.getName(), + funcType); + } + + /// Emit a new function and add it to the MLIR module. + mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope varScope(symbolTable); + + // Create an MLIR function for the given prototype. + builder.setInsertionPointToEnd(theModule.getBody()); + mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); + if (!function) + return nullptr; + + // Let's start the body of the function now! + mlir::Block &entryBlock = function.front(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto nameValue : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(nameValue)->getName(), + std::get<1>(nameValue)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType( + function.getFunctionType().getInputs(), getType(VarType{}))); + } + + // If this function isn't main, then set the visibility to private. + if (funcAST.getProto()->getName() != "main") + function.setPrivate(); + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().has_value()) { + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, + expr ? ArrayRef(expr) : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builtin calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to + // user-defined functions are mapped to a custom call that takes the callee + // name as an attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto *init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope varScope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/Ch5/mlir/ShapeInferencePass.cpp b/Ch5/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000..a9e995e --- /dev/null +++ b/Ch5/mlir/ShapeInferencePass.cpp @@ -0,0 +1,122 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a pass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +struct ShapeInferencePass + : public mlir::PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + + void runOnOperation() override { + auto f = getOperation(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, allOperandsInferred); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has all of its + /// operands inferred. + static bool allOperandsInferred(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return llvm::isa(operandType); + }); + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !llvm::isa(resultType); + }); + } +}; +} // namespace + +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/Ch5/mlir/ToyCombine.cpp b/Ch5/mlir/ToyCombine.cpp new file mode 100644 index 0000000..3ce35c8 --- /dev/null +++ b/Ch5/mlir/ToyCombine.cpp @@ -0,0 +1,69 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Dialect.h" +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // namespace + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> x +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); + + // Input defined by another transpose? If not, no match. + if (!transposeInputOp) + return failure(); + + // Otherwise, we have a redundant transpose. Use the rewriter. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); + return success(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} diff --git a/Ch5/mlir/ToyCombine.inc b/Ch5/mlir/ToyCombine.inc new file mode 100644 index 0000000..61c6203 --- /dev/null +++ b/Ch5/mlir/ToyCombine.inc @@ -0,0 +1,176 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Rewriters *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: ToyCombine.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/* Generated from: + ToyCombine.td:46 +*/ +struct FoldConstantReshapeOptPattern : public ::mlir::RewritePattern { + FoldConstantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.constant"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::DenseElementsAttr arg; + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ConstantOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ConstantOp type"; + }); + } + { + auto tblgen_attr = op1->getAttrOfType<::mlir::DenseElementsAttr>("value");(void)tblgen_attr; + if (!(tblgen_attr)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "expected op 'toy.constant' to have attribute 'value' of type '::mlir::DenseElementsAttr'"; + }); + } + arg = tblgen_attr; + } + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + auto nativeVar_0 = arg.reshape(::llvm::cast((*res.getODSResults(0).begin()).getType())); (void)nativeVar_0; + ::mlir::toy::ConstantOp tblgen_ConstantOp_1; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + if (auto tmpAttr = nativeVar_0) { + tblgen_attrs.emplace_back(rewriter.getStringAttr("value"), tmpAttr); + } + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ConstantOp_1 = rewriter.create<::mlir::toy::ConstantOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ConstantOp_1.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:59 +*/ +struct RedundantReshapeOptPattern : public ::mlir::RewritePattern { + RedundantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 1, context, {}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + arg = castedOp0.getODSOperands(0); + if (!(((*res.getODSResults(0).begin()).getType() == (*arg.begin()).getType()))){ + return rewriter.notifyMatchFailure(op0, [&](::mlir::Diagnostic &diag) { + diag << "entities 'res, arg' failed to satisfy constraint: ''"; + }); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ arg }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:33 +*/ +struct ReshapeReshapeOptPattern : public ::mlir::RewritePattern { + ReshapeReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.reshape"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ReshapeOp type"; + }); + } + arg = castedOp1.getODSOperands(0); + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + ::mlir::toy::ReshapeOp tblgen_ReshapeOp_0; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + tblgen_values.push_back((*arg.begin())); + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ReshapeOp_0 = rewriter.create<::mlir::toy::ReshapeOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ReshapeOp_0.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} diff --git a/Ch5/mlir/ToyCombine.td b/Ch5/mlir/ToyCombine.td new file mode 100644 index 0000000..11d7831 --- /dev/null +++ b/Ch5/mlir/ToyCombine.td @@ -0,0 +1,63 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "mlir/IR/PatternBase.td" +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : Constraint>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/Ch5/mlir/run.sh b/Ch5/mlir/run.sh new file mode 100644 index 0000000..f592fde --- /dev/null +++ b/Ch5/mlir/run.sh @@ -0,0 +1,2 @@ +mlir-tblgen-18 -gen-rewriters -I /usr/lib/llvm-18/include -I ../include ToyCombine.td > ToyCombine.inc + diff --git a/Ch5/parser/AST.cpp b/Ch5/parser/AST.cpp new file mode 100644 index 0000000..2546f2a --- /dev/null +++ b/Ch5/parser/AST.cpp @@ -0,0 +1,237 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template +static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + llvm::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto *num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + llvm::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + llvm::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().has_value()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + llvm::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; + indent(); + llvm::errs() << "Params: ["; + llvm::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/Ch5/toyc.cpp b/Ch5/toyc.cpp new file mode 100644 index 0000000..4eb6fde --- /dev/null +++ b/Ch5/toyc.cpp @@ -0,0 +1,207 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} // namespace +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine }; +} // namespace +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningOpRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).ends_with(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return -1; + } + + // Parse the input mlir. + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int dumpMLIR() { + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + + mlir::MLIRContext context(registry); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + + mlir::OwningOpRef module; + llvm::SourceMgr sourceMgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + if (int error = loadMLIR(sourceMgr, context, module)) + return error; + + mlir::PassManager pm(module.get()->getName()); + // Apply any generic pass manager command line options and run the pipeline. + if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) + return 4; + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + + if (enableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect. + pm.addPass(mlir::toy::createLowerToAffinePass()); + + // Add a few cleanups post lowering. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (enableOpt) { + optPM.addPass(mlir::affine::createLoopFusionPass()); + optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); + } + } + + if (mlir::failed(pm.run(*module))) + return 4; + + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + // Register any command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + case Action::DumpMLIRAffine: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/Ch6/CMakeLists.txt b/Ch6/CMakeLists.txt new file mode 100644 index 0000000..aef5b99 --- /dev/null +++ b/Ch6/CMakeLists.txt @@ -0,0 +1,65 @@ +# This chapter depends on JIT support enabled. +if(NOT MLIR_ENABLE_EXECUTION_ENGINE) + return() +endif() + + +# For a better template to copy, see examples/standalone +include_directories(include) +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Core + Support + nativecodegen + OrcJIT + ) + +# set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +# mlir_tablegen(ToyCombine.inc -gen-rewriters) +# add_public_tablegen_target(ToyCh6CombineIncGen) + +add_executable(toyc-ch6 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/LowerToAffineLoops.cpp + mlir/LowerToLLVM.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + + # DEPENDS + # ToyCh6ShapeInferenceInterfaceIncGen + # ToyCh6OpsIncGen + # ToyCh6CombineIncGen + ) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) +target_link_libraries(toyc-ch6 + PRIVATE + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + MLIRAnalysis + MLIRBuiltinToLLVMIRTranslation + MLIRCallInterfaces + MLIRCastInterfaces + MLIRExecutionEngine + MLIRFunctionInterfaces + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRMemRefDialect + MLIRParser + MLIRPass + MLIRSideEffectInterfaces + MLIRSupport + MLIRTargetLLVMIRExport + MLIRTransforms + ) diff --git a/Ch6/include/CMakeLists.txt b/Ch6/include/CMakeLists.txt new file mode 100644 index 0000000..37c89d0 --- /dev/null +++ b/Ch6/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/Ch6/include/run.sh b/Ch6/include/run.sh new file mode 100644 index 0000000..b9d18af --- /dev/null +++ b/Ch6/include/run.sh @@ -0,0 +1,7 @@ +mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.h.inc +mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.cpp.inc +mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.h.inc +mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.cpp.inc + +mlir-tblgen-18 -gen-op-interface-decls -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.h.inc +mlir-tblgen-18 -gen-op-interface-defs -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.cpp.inc diff --git a/Ch6/include/toy/AST.h b/Ch6/include/toy/AST.h new file mode 100644 index 0000000..d2ba101 --- /dev/null +++ b/Ch6/include/toy/AST.h @@ -0,0 +1,246 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_AST_H +#define TOY_AST_H + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include + +namespace toy { + +/// A variable type with shape information. +struct VarType { + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(std::move(location)) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double val; + +public: + NumberExprAST(Location loc, double val) + : ExprAST(Expr_Num, std::move(loc)), val(val) {} + + double getValue() { return val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, std::move(loc)), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + std::optional> expr; + +public: + ReturnExprAST(Location loc, std::optional> expr) + : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} + + std::optional getExpr() { + if (expr.has_value()) + return expr->get(); + return std::nullopt; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, std::move(loc)), callee(callee), + args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(std::move(location)), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : proto(std::move(proto)), body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() { return functions.begin(); } + auto end() { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // TOY_AST_H diff --git a/Ch6/include/toy/CMakeLists.txt b/Ch6/include/toy/CMakeLists.txt new file mode 100644 index 0000000..2f9729c --- /dev/null +++ b/Ch6/include/toy/CMakeLists.txt @@ -0,0 +1,13 @@ +# # Most dialects should use add_mlir_dialect(). See examples/standalone. +# set(LLVM_TARGET_DEFINITIONS Ops.td) +# mlir_tablegen(Ops.h.inc -gen-op-decls) +# mlir_tablegen(Ops.cpp.inc -gen-op-defs) +# mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +# mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +# add_public_tablegen_target(ToyCh6OpsIncGen) + +# # Most dialects should use add_mlir_interfaces(). +# set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +# mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +# mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +# add_public_tablegen_target(ToyCh6ShapeInferenceInterfaceIncGen) diff --git a/Ch6/include/toy/Dialect.cpp.inc b/Ch6/include/toy/Dialect.cpp.inc new file mode 100644 index 0000000..8cbc772 --- /dev/null +++ b/Ch6/include/toy/Dialect.cpp.inc @@ -0,0 +1,23 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) +namespace mlir { +namespace toy { + +ToyDialect::ToyDialect(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get()) { + + initialize(); +} + +ToyDialect::~ToyDialect() = default; + +} // namespace toy +} // namespace mlir diff --git a/Ch6/include/toy/Dialect.h b/Ch6/include/toy/Dialect.h new file mode 100644 index 0000000..5db325e --- /dev/null +++ b/Ch6/include/toy/Dialect.h @@ -0,0 +1,36 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See docs/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "toy/ShapeInferenceInterface.h" + +/// Include the auto-generated header file containing the declaration of the toy +/// dialect. +#include "toy/Dialect.h.inc" + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/Ch6/include/toy/Dialect.h.inc b/Ch6/include/toy/Dialect.h.inc new file mode 100644 index 0000000..f19d867 --- /dev/null +++ b/Ch6/include/toy/Dialect.h.inc @@ -0,0 +1,26 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +namespace mlir { +namespace toy { + +class ToyDialect : public ::mlir::Dialect { + explicit ToyDialect(::mlir::MLIRContext *context); + + void initialize(); + friend class ::mlir::MLIRContext; +public: + ~ToyDialect() override; + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("toy"); + } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) diff --git a/Ch6/include/toy/Lexer.h b/Ch6/include/toy/Lexer.h new file mode 100644 index 0000000..3c59cd9 --- /dev/null +++ b/Ch6/include/toy/Lexer.h @@ -0,0 +1,232 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_LEXER_H +#define TOY_LEXER_H + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9.]+ + if (isdigit(lastChar) || lastChar == '.') { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // TOY_LEXER_H diff --git a/Ch6/include/toy/MLIRGen.h b/Ch6/include/toy/MLIRGen.h new file mode 100644 index 0000000..fe9dbe5 --- /dev/null +++ b/Ch6/include/toy/MLIRGen.h @@ -0,0 +1,35 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_MLIRGEN_H +#define TOY_MLIRGEN_H + +#include + +namespace mlir { +class MLIRContext; +template +class OwningOpRef; +class ModuleOp; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // TOY_MLIRGEN_H diff --git a/Ch6/include/toy/Ops.cpp.inc b/Ch6/include/toy/Ops.cpp.inc new file mode 100644 index 0000000..6bb98a2 --- /dev/null +++ b/Ch6/include/toy/Ops.cpp.inc @@ -0,0 +1,2252 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifdef GET_OP_LIST +#undef GET_OP_LIST + +::mlir::toy::AddOp, +::mlir::toy::CastOp, +::mlir::toy::ConstantOp, +::mlir::toy::FuncOp, +::mlir::toy::GenericCallOp, +::mlir::toy::MulOp, +::mlir::toy::PrintOp, +::mlir::toy::ReshapeOp, +::mlir::toy::ReturnOp, +::mlir::toy::TransposeOp +#endif // GET_OP_LIST + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be variadic of tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))) || (((::llvm::isa<::mlir::MemRefType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values or memref of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::RankedTensorType>(type))) && ((::llvm::cast<::mlir::ShapedType>(type).hasStaticShape()))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be statically shaped tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::DenseFPElementsAttr>(attr) &&::llvm::cast<::mlir::DenseElementsAttr>(attr).getType().getElementType().isF64()))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: 64-bit float elements attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops0(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::StringAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: string attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops1(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::TypeAttr>(attr))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: type attribute of function type"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops2(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::ArrayAttr>(attr))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(attr), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: Array of dictionary attributes"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops3(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: flat symbol reference attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops4(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_region_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, + unsigned regionIndex) { + if (!((true))) { + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: any region"; + } + return ::mlir::success(); +} +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.add", odsAttrs.getContext()); +} + +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(AddOp op) : AddOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair AddOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr AddOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +AddOpAdaptor::AddOpAdaptor(AddOp op) : AddOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult AddOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair AddOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range AddOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &AddOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &AddOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair AddOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range AddOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void AddOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult AddOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult AddOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void AddOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::CastOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +CastOpGenericAdaptorBase::CastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.cast", odsAttrs.getContext()); +} + +CastOpGenericAdaptorBase::CastOpGenericAdaptorBase(CastOp op) : CastOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair CastOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr CastOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +CastOpAdaptor::CastOpAdaptor(CastOp op) : CastOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult CastOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair CastOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range CastOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> CastOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &CastOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair CastOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range CastOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> CastOp::getOutput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(0).begin()); +} + +void CastOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(output); +} + +void CastOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void CastOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult CastOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult CastOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult CastOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + outputRawTypes[0] = type; + } + result.addTypes(outputTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void CastOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + { + auto type = getOutput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +void CastOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::CastOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.constant", odsAttrs.getContext()); +} + +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(ConstantOp op) : ConstantOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ConstantOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ConstantOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValueAttr() { + auto attr = ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); + return attr; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValue() { + auto attr = getValueAttr(); + return attr; +} + +} // namespace detail +ConstantOpAdaptor::ConstantOpAdaptor(ConstantOp op) : ConstantOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ConstantOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitError(loc, "'toy.constant' op ""requires attribute 'value'"); + + if (tblgen_value && !((::llvm::isa<::mlir::DenseFPElementsAttr>(tblgen_value) &&::llvm::cast<::mlir::DenseElementsAttr>(tblgen_value).getType().getElementType().isF64()))) + return emitError(loc, "'toy.constant' op ""attribute 'value' failed to satisfy constraint: 64-bit float elements attribute"); + return ::mlir::success(); +} + +std::pair ConstantOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ConstantOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair ConstantOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ConstantOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult ConstantOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.value; + auto attr = dict.get("value"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for value in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `value` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute ConstantOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.value; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("value", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code ConstantOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.value.getAsOpaquePointer())); +} + +std::optional ConstantOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "value") + return prop.value; + return std::nullopt; +} + +void ConstantOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "value") { + prop.value = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void ConstantOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.value) attrs.append("value", prop.value); +} + +::mlir::LogicalResult ConstantOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getValueAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(attr, "value", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.value))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ConstantOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.value); +} + +::mlir::DenseElementsAttr ConstantOp::getValueAttr() { + return ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); +} + +::mlir::DenseElementsAttr ConstantOp::getValue() { + auto attr = getValueAttr(); + return attr; +} + +void ConstantOp::setValueAttr(::mlir::DenseElementsAttr attr) { + (*this)->setAttr(getValueAttrName(), attr); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value) { + build(odsBuilder, odsState, value.getType(), value); + +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + odsState.addTypes(resultType0); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ConstantOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 0u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ConstantOp::verifyInvariantsImpl() { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitOpError("requires attribute 'value'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(*this, tblgen_value, "value"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +void ConstantOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.func", odsAttrs.getContext()); +} + +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(FuncOp op) : FuncOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair FuncOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr FuncOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::StringAttr FuncOpGenericAdaptorBase::getSymNameAttr() { + auto attr = ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); + return attr; +} + +::llvm::StringRef FuncOpGenericAdaptorBase::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOpGenericAdaptorBase::getFunctionTypeAttr() { + auto attr = ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); + return attr; +} + +::mlir::FunctionType FuncOpGenericAdaptorBase::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getArgAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getResAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::Region &FuncOpGenericAdaptorBase::getBody() { + return *odsRegions[0]; +} + +::mlir::RegionRange FuncOpGenericAdaptorBase::getRegions() { + return odsRegions; +} + +} // namespace detail +FuncOpAdaptor::FuncOpAdaptor(FuncOp op) : FuncOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult FuncOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitError(loc, "'toy.func' op ""requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitError(loc, "'toy.func' op ""requires attribute 'sym_name'"); + + if (tblgen_sym_name && !((::llvm::isa<::mlir::StringAttr>(tblgen_sym_name)))) + return emitError(loc, "'toy.func' op ""attribute 'sym_name' failed to satisfy constraint: string attribute"); + + if (tblgen_function_type && !(((::llvm::isa<::mlir::TypeAttr>(tblgen_function_type))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))))) + return emitError(loc, "'toy.func' op ""attribute 'function_type' failed to satisfy constraint: type attribute of function type"); + + if (tblgen_arg_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_arg_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_arg_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'arg_attrs' failed to satisfy constraint: Array of dictionary attributes"); + + if (tblgen_res_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_res_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_res_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'res_attrs' failed to satisfy constraint: Array of dictionary attributes"); + return ::mlir::success(); +} + +std::pair FuncOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range FuncOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair FuncOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range FuncOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Region &FuncOp::getBody() { + return (*this)->getRegion(0); +} + +::mlir::LogicalResult FuncOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.arg_attrs; + auto attr = dict.get("arg_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for arg_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `arg_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.function_type; + auto attr = dict.get("function_type"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for function_type in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `function_type` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.res_attrs; + auto attr = dict.get("res_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for res_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `res_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.sym_name; + auto attr = dict.get("sym_name"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for sym_name in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `sym_name` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute FuncOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.arg_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("arg_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.function_type; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("function_type", + propStorage)); + } + + { + const auto &propStorage = prop.res_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("res_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.sym_name; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("sym_name", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code FuncOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.arg_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.function_type.getAsOpaquePointer()), + llvm::hash_value(prop.res_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.sym_name.getAsOpaquePointer())); +} + +std::optional FuncOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "arg_attrs") + return prop.arg_attrs; + + if (name == "function_type") + return prop.function_type; + + if (name == "res_attrs") + return prop.res_attrs; + + if (name == "sym_name") + return prop.sym_name; + return std::nullopt; +} + +void FuncOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "arg_attrs") { + prop.arg_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "function_type") { + prop.function_type = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "res_attrs") { + prop.res_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "sym_name") { + prop.sym_name = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void FuncOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.arg_attrs) attrs.append("arg_attrs", prop.arg_attrs); + + if (prop.function_type) attrs.append("function_type", prop.function_type); + + if (prop.res_attrs) attrs.append("res_attrs", prop.res_attrs); + + if (prop.sym_name) attrs.append("sym_name", prop.sym_name); +} + +::mlir::LogicalResult FuncOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getArgAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "arg_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getFunctionTypeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(attr, "function_type", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getResAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "res_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getSymNameAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(attr, "sym_name", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readOptionalAttribute(prop.arg_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.function_type))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readOptionalAttribute(prop.res_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.sym_name))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void FuncOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + + writer.writeOptionalAttribute(prop.arg_attrs); + writer.writeAttribute(prop.function_type); + + writer.writeOptionalAttribute(prop.res_attrs); + writer.writeAttribute(prop.sym_name); +} + +::mlir::StringAttr FuncOp::getSymNameAttr() { + return ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); +} + +::llvm::StringRef FuncOp::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOp::getFunctionTypeAttr() { + return ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); +} + +::mlir::FunctionType FuncOp::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOp::getArgAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOp::getResAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +void FuncOp::setSymNameAttr(::mlir::StringAttr attr) { + (*this)->setAttr(getSymNameAttrName(), attr); +} + +void FuncOp::setSymName(::llvm::StringRef attrValue) { + (*this)->setAttr(getSymNameAttrName(), ::mlir::Builder((*this)->getContext()).getStringAttr(attrValue)); +} + +void FuncOp::setFunctionTypeAttr(::mlir::TypeAttr attr) { + (*this)->setAttr(getFunctionTypeAttrName(), attr); +} + +void FuncOp::setFunctionType(::mlir::FunctionType attrValue) { + (*this)->setAttr(getFunctionTypeAttrName(), ::mlir::TypeAttr::get(attrValue)); +} + +void FuncOp::setArgAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getArgAttrsAttrName(), attr); +} + +void FuncOp::setResAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getResAttrsAttrName(), attr); +} + +::mlir::Attribute FuncOp::removeArgAttrsAttr() { + auto &attr = getProperties().arg_attrs; + attr = {}; + return attr; +} + +::mlir::Attribute FuncOp::removeResAttrsAttr() { + auto &attr = getProperties().res_attrs; + attr = {}; + return attr; +} + +::mlir::LogicalResult FuncOp::verifyInvariantsImpl() { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitOpError("requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitOpError("requires attribute 'sym_name'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(*this, tblgen_sym_name, "sym_name"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(*this, tblgen_function_type, "function_type"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_arg_attrs, "arg_attrs"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_res_attrs, "res_attrs"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + + for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(0))) + if (::mlir::failed(__mlir_ods_local_region_constraint_Ops0(*this, region, "body", index++))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.generic_call", odsAttrs.getContext()); +} + +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(GenericCallOp op) : GenericCallOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair GenericCallOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr GenericCallOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::FlatSymbolRefAttr GenericCallOpGenericAdaptorBase::getCalleeAttr() { + auto attr = ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); + return attr; +} + +::llvm::StringRef GenericCallOpGenericAdaptorBase::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +} // namespace detail +GenericCallOpAdaptor::GenericCallOpAdaptor(GenericCallOp op) : GenericCallOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult GenericCallOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitError(loc, "'toy.generic_call' op ""requires attribute 'callee'"); + + if (tblgen_callee && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(tblgen_callee)))) + return emitError(loc, "'toy.generic_call' op ""attribute 'callee' failed to satisfy constraint: flat symbol reference attribute"); + return ::mlir::success(); +} + +std::pair GenericCallOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range GenericCallOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range GenericCallOp::getInputs() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange GenericCallOp::getInputsMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair GenericCallOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range GenericCallOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult GenericCallOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.callee; + auto attr = dict.get("callee"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for callee in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `callee` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute GenericCallOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.callee; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("callee", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code GenericCallOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.callee.getAsOpaquePointer())); +} + +std::optional GenericCallOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "callee") + return prop.callee; + return std::nullopt; +} + +void GenericCallOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "callee") { + prop.callee = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void GenericCallOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.callee) attrs.append("callee", prop.callee); +} + +::mlir::LogicalResult GenericCallOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getCalleeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(attr, "callee", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.callee))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.callee); +} + +::mlir::FlatSymbolRefAttr GenericCallOp::getCalleeAttr() { + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); +} + +::llvm::StringRef GenericCallOp::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +void GenericCallOp::setCalleeAttr(::mlir::FlatSymbolRefAttr attr) { + (*this)->setAttr(getCalleeAttrName(), attr); +} + +void GenericCallOp::setCallee(::llvm::StringRef attrValue) { + (*this)->setAttr(getCalleeAttrName(), ::mlir::SymbolRefAttr::get(::mlir::Builder((*this)->getContext()).getContext(), attrValue)); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariantsImpl() { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitOpError("requires attribute 'callee'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(*this, tblgen_callee, "callee"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult GenericCallOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::FlatSymbolRefAttr calleeAttr; + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputsOperands; + ::llvm::SMLoc inputsOperandsLoc; + (void)inputsOperandsLoc; + ::llvm::ArrayRef<::mlir::Type> inputsTypes; + ::llvm::ArrayRef<::mlir::Type> allResultTypes; + + if (parser.parseCustomAttributeWithFallback(calleeAttr, parser.getBuilder().getType<::mlir::NoneType>())) { + return ::mlir::failure(); + } + if (calleeAttr) result.getOrAddProperties().callee = calleeAttr; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands)) + return ::mlir::failure(); + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + ::mlir::FunctionType inputs__allResult_functionType; + if (parser.parseType(inputs__allResult_functionType)) + return ::mlir::failure(); + inputsTypes = inputs__allResult_functionType.getInputs(); + allResultTypes = inputs__allResult_functionType.getResults(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter.printAttributeWithoutType(getCalleeAttr()); + _odsPrinter << "("; + _odsPrinter << getInputs(); + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + elidedAttrs.push_back("callee"); + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter.printFunctionalType(getInputs().getTypes(), getOperation()->getResultTypes()); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.mul", odsAttrs.getContext()); +} + +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(MulOp op) : MulOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair MulOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr MulOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +MulOpAdaptor::MulOpAdaptor(MulOp op) : MulOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult MulOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair MulOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range MulOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &MulOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &MulOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair MulOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range MulOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void MulOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult MulOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult MulOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void MulOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.print", odsAttrs.getContext()); +} + +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(PrintOp op) : PrintOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair PrintOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr PrintOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +PrintOpAdaptor::PrintOpAdaptor(PrintOp op) : PrintOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult PrintOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair PrintOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range PrintOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Value PrintOp::getInput() { + return ::llvm::cast<::mlir::Value>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &PrintOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair PrintOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range PrintOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input) { + odsState.addOperands(input); +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 0u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void PrintOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult PrintOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops2(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult PrintOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult PrintOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::Type type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void PrintOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.reshape", odsAttrs.getContext()); +} + +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(ReshapeOp op) : ReshapeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReshapeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ReshapeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReshapeOpAdaptor::ReshapeOpAdaptor(ReshapeOp op) : ReshapeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReshapeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReshapeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ReshapeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> ReshapeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &ReshapeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair ReshapeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReshapeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ReshapeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops3(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult ReshapeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReshapeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void ReshapeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.return", odsAttrs.getContext()); +} + +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(ReturnOp op) : ReturnOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReturnOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr ReturnOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReturnOpAdaptor::ReturnOpAdaptor(ReturnOp op) : ReturnOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReturnOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReturnOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range ReturnOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range ReturnOp::getInput() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange ReturnOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair ReturnOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReturnOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState) { + build(odsBuilder, odsState, std::nullopt); +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input) { + odsState.addOperands(input); +} + +void ReturnOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReturnOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReturnOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult ReturnOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputOperands; + ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::llvm::SmallVector<::mlir::Type, 1> inputTypes; + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputOperands)) + return ::mlir::failure(); + if (!inputOperands.empty()) { + if (parser.parseColon()) + return ::mlir::failure(); + + if (parser.parseTypeList(inputTypes)) + return ::mlir::failure(); + } + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReturnOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + if (!getInput().empty()) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter << getInput().getTypes(); + } + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); +} + +void ReturnOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.transpose", odsAttrs.getContext()); +} + +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(TransposeOp op) : TransposeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair TransposeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr TransposeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +TransposeOpAdaptor::TransposeOpAdaptor(TransposeOp op) : TransposeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult TransposeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair TransposeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range TransposeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> TransposeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &TransposeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair TransposeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range TransposeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void TransposeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult TransposeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult TransposeOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult TransposeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void TransposeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void TransposeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch6/include/toy/Ops.h.inc b/Ch6/include/toy/Ops.h.inc new file mode 100644 index 0000000..a3011f8 --- /dev/null +++ b/Ch6/include/toy/Ops.h.inc @@ -0,0 +1,1361 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES) +#undef GET_OP_FWD_DEFINES +namespace mlir { +namespace toy { +class AddOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class CastOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ConstantOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class FuncOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class GenericCallOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class MulOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class PrintOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReshapeOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReturnOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class TransposeOp; +} // namespace toy +} // namespace mlir +#endif + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class AddOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + AddOpGenericAdaptorBase(AddOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class AddOpGenericAdaptor : public detail::AddOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::AddOpGenericAdaptorBase; +public: + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + AddOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class AddOpAdaptor : public AddOpGenericAdaptor<::mlir::ValueRange> { +public: + using AddOpGenericAdaptor::AddOpGenericAdaptor; + AddOpAdaptor(AddOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class AddOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = AddOpAdaptor; + template + using GenericAdaptor = AddOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.add"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::CastOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class CastOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + CastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + CastOpGenericAdaptorBase(CastOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class CastOpGenericAdaptor : public detail::CastOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::CastOpGenericAdaptorBase; +public: + CastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + CastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : CastOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + CastOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class CastOpAdaptor : public CastOpGenericAdaptor<::mlir::ValueRange> { +public: + using CastOpGenericAdaptor::CastOpGenericAdaptor; + CastOpAdaptor(CastOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class CastOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::CastOpInterface::Trait, ShapeInference::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultShape> { +public: + using Op::Op; + using Op::print; + using Adaptor = CastOpAdaptor; + template + using GenericAdaptor = CastOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.cast"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getOutput(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static bool areCastCompatible(::mlir::TypeRange inputs, ::mlir::TypeRange outputs); + void inferShapes(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::CastOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ConstantOpGenericAdaptorBase { +public: + struct Properties { + using valueTy = ::mlir::DenseElementsAttr; + valueTy value; + + auto getValue() { + auto &propStorage = this->value; + return ::llvm::cast<::mlir::DenseElementsAttr>(propStorage); + } + void setValue(const ::mlir::DenseElementsAttr &propValue) { + this->value = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.value == this->value && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + ConstantOpGenericAdaptorBase(ConstantOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); +}; +} // namespace detail +template +class ConstantOpGenericAdaptor : public detail::ConstantOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ConstantOpGenericAdaptorBase; +public: + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ConstantOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + ConstantOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ConstantOpAdaptor : public ConstantOpGenericAdaptor<::mlir::ValueRange> { +public: + using ConstantOpGenericAdaptor::ConstantOpGenericAdaptor; + ConstantOpAdaptor(ConstantOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ConstantOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ConstantOpAdaptor; + template + using GenericAdaptor = ConstantOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getValueAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.constant"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); + void setValueAttr(::mlir::DenseElementsAttr attr); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class FuncOpGenericAdaptorBase { +public: + struct Properties { + using arg_attrsTy = ::mlir::ArrayAttr; + arg_attrsTy arg_attrs; + + auto getArgAttrs() { + auto &propStorage = this->arg_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setArgAttrs(const ::mlir::ArrayAttr &propValue) { + this->arg_attrs = propValue; + } + using function_typeTy = ::mlir::TypeAttr; + function_typeTy function_type; + + auto getFunctionType() { + auto &propStorage = this->function_type; + return ::llvm::cast<::mlir::TypeAttr>(propStorage); + } + void setFunctionType(const ::mlir::TypeAttr &propValue) { + this->function_type = propValue; + } + using res_attrsTy = ::mlir::ArrayAttr; + res_attrsTy res_attrs; + + auto getResAttrs() { + auto &propStorage = this->res_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setResAttrs(const ::mlir::ArrayAttr &propValue) { + this->res_attrs = propValue; + } + using sym_nameTy = ::mlir::StringAttr; + sym_nameTy sym_name; + + auto getSymName() { + auto &propStorage = this->sym_name; + return ::llvm::cast<::mlir::StringAttr>(propStorage); + } + void setSymName(const ::mlir::StringAttr &propValue) { + this->sym_name = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.arg_attrs == this->arg_attrs && + rhs.function_type == this->function_type && + rhs.res_attrs == this->res_attrs && + rhs.sym_name == this->sym_name && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + FuncOpGenericAdaptorBase(FuncOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + ::mlir::Region &getBody(); + ::mlir::RegionRange getRegions(); +}; +} // namespace detail +template +class FuncOpGenericAdaptor : public detail::FuncOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::FuncOpGenericAdaptorBase; +public: + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : FuncOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + FuncOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class FuncOpAdaptor : public FuncOpGenericAdaptor<::mlir::ValueRange> { +public: + using FuncOpGenericAdaptor::FuncOpGenericAdaptor; + FuncOpAdaptor(FuncOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class FuncOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = FuncOpAdaptor; + template + using GenericAdaptor = FuncOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("arg_attrs"), ::llvm::StringRef("function_type"), ::llvm::StringRef("res_attrs"), ::llvm::StringRef("sym_name")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getArgAttrsAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getArgAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + ::mlir::StringAttr getFunctionTypeAttrName() { + return getAttributeNameForIndex(1); + } + + static ::mlir::StringAttr getFunctionTypeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 1); + } + + ::mlir::StringAttr getResAttrsAttrName() { + return getAttributeNameForIndex(2); + } + + static ::mlir::StringAttr getResAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 2); + } + + ::mlir::StringAttr getSymNameAttrName() { + return getAttributeNameForIndex(3); + } + + static ::mlir::StringAttr getSymNameAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 3); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.func"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::Region &getBody(); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + void setSymNameAttr(::mlir::StringAttr attr); + void setSymName(::llvm::StringRef attrValue); + void setFunctionTypeAttr(::mlir::TypeAttr attr); + void setFunctionType(::mlir::FunctionType attrValue); + void setArgAttrsAttr(::mlir::ArrayAttr attr); + void setResAttrsAttr(::mlir::ArrayAttr attr); + ::mlir::Attribute removeArgAttrsAttr(); + ::mlir::Attribute removeResAttrsAttr(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef name, FunctionType type, ArrayRef attrs = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 4 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the function operation that is callable. + Region *getCallableRegion() { return &getBody(); } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class GenericCallOpGenericAdaptorBase { +public: + struct Properties { + using calleeTy = ::mlir::FlatSymbolRefAttr; + calleeTy callee; + + auto getCallee() { + auto &propStorage = this->callee; + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(propStorage); + } + void setCallee(const ::mlir::FlatSymbolRefAttr &propValue) { + this->callee = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.callee == this->callee && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + GenericCallOpGenericAdaptorBase(GenericCallOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); +}; +} // namespace detail +template +class GenericCallOpGenericAdaptor : public detail::GenericCallOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::GenericCallOpGenericAdaptorBase; +public: + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : GenericCallOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + GenericCallOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInputs() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class GenericCallOpAdaptor : public GenericCallOpGenericAdaptor<::mlir::ValueRange> { +public: + using GenericCallOpGenericAdaptor::GenericCallOpGenericAdaptor; + GenericCallOpAdaptor(GenericCallOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class GenericCallOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::VariadicOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::CallOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = GenericCallOpAdaptor; + template + using GenericAdaptor = GenericCallOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("callee")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getCalleeAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getCalleeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.generic_call"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInputs(); + ::mlir::MutableOperandRange getInputsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); + void setCalleeAttr(::mlir::FlatSymbolRefAttr attr); + void setCallee(::llvm::StringRef attrValue); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef callee, ArrayRef arguments); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::CallInterfaceCallable getCallableForCallee(); + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee); + ::mlir::Operation::operand_range getArgOperands(); + ::mlir::MutableOperandRange getArgOperandsMutable(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class MulOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + MulOpGenericAdaptorBase(MulOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class MulOpGenericAdaptor : public detail::MulOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::MulOpGenericAdaptorBase; +public: + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : MulOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + MulOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class MulOpAdaptor : public MulOpGenericAdaptor<::mlir::ValueRange> { +public: + using MulOpGenericAdaptor::MulOpGenericAdaptor; + MulOpAdaptor(MulOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class MulOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = MulOpAdaptor; + template + using GenericAdaptor = MulOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.mul"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class PrintOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + PrintOpGenericAdaptorBase(PrintOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class PrintOpGenericAdaptor : public detail::PrintOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::PrintOpGenericAdaptorBase; +public: + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : PrintOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + PrintOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class PrintOpAdaptor : public PrintOpGenericAdaptor<::mlir::ValueRange> { +public: + using PrintOpGenericAdaptor::PrintOpGenericAdaptor; + PrintOpAdaptor(PrintOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class PrintOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = PrintOpAdaptor; + template + using GenericAdaptor = PrintOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.print"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Value getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReshapeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReshapeOpGenericAdaptorBase(ReshapeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReshapeOpGenericAdaptor : public detail::ReshapeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReshapeOpGenericAdaptorBase; +public: + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReshapeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReshapeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReshapeOpAdaptor : public ReshapeOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReshapeOpGenericAdaptor::ReshapeOpGenericAdaptor; + ReshapeOpAdaptor(ReshapeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReshapeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReshapeOpAdaptor; + template + using GenericAdaptor = ReshapeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.reshape"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReturnOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReturnOpGenericAdaptorBase(ReturnOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReturnOpGenericAdaptor : public detail::ReturnOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReturnOpGenericAdaptorBase; +public: + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReturnOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReturnOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInput() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReturnOpAdaptor : public ReturnOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReturnOpGenericAdaptor::ReturnOpGenericAdaptor; + ReturnOpAdaptor(ReturnOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReturnOp : public ::mlir::Op::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::IsTerminator> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReturnOpAdaptor; + template + using GenericAdaptor = ReturnOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.return"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInput(); + ::mlir::MutableOperandRange getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: + bool hasOperand() { return getNumOperands() != 0; } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class TransposeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + TransposeOpGenericAdaptorBase(TransposeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class TransposeOpGenericAdaptor : public detail::TransposeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::TransposeOpGenericAdaptorBase; +public: + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : TransposeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + TransposeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class TransposeOpAdaptor : public TransposeOpGenericAdaptor<::mlir::ValueRange> { +public: + using TransposeOpGenericAdaptor::TransposeOpGenericAdaptor; + TransposeOpAdaptor(TransposeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class TransposeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = TransposeOpAdaptor; + template + using GenericAdaptor = TransposeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.transpose"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + void inferShapes(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch6/include/toy/Ops.td b/Ch6/include/toy/Ops.td new file mode 100644 index 0000000..a52bebc --- /dev/null +++ b/Ch6/include/toy/Ops.td @@ -0,0 +1,372 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "toy/ShapeInferenceInterface.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "::mlir::toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'Pure' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [Pure]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> + : tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<(ins "DenseElementsAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<(ins "double":$value)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : Toy_Op<"add", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Pure, + SameOperandsAndResultShape + ]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types must + both be tensor types with the same element type. If both are ranked, then + shape is required to match. The operation is invalid if converting to a + mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : Toy_Op<"func", [ + FunctionOpInterface, IsolatedFromAbove + ]> { + let summary = "user defined function operation"; + let description = [{ + The "toy.func" operation represents a user defined function. These are + callable SSA-region operations that contain toy computations. + + Example: + + ```mlir + toy.func @main() { + %0 = toy.constant dense<5.500000e+00> : tensor + %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> + toy.print %1 : tensor<2x2xf64> + toy.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the function operation that is callable. + Region *getCallableRegion() { return &getBody(); } + }]; + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = toy.generic_call @my_func(%1, %3) + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Specialize assembly printing and parsing using a declarative format. + let assemblyFormat = [{ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + ]; +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : Toy_Op<"mul", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); + + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def ReshapeOp : Toy_Op<"reshape", [Pure]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, + Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + toy.func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // The return operation only emits the input in the format if it is present. + let assemblyFormat = "($input^ `:` type($input))? attr-dict "; + + // Allow building a ReturnOp with no return operand. + let builders = [ + OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + ]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def TransposeOp : Toy_Op<"transpose", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +#endif // TOY_OPS diff --git a/Ch6/include/toy/Parser.h b/Ch6/include/toy/Parser.h new file mode 100644 index 0000000..1f20616 --- /dev/null +++ b/Ch6/include/toy/Parser.h @@ -0,0 +1,489 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PARSER_H +#define TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto f = parseDefinition()) { + functions.push_back(std::move(*f)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + std::optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(std::move(loc), std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(std::move(loc), name, std::move(args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr parseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + + if (!type) + type = std::make_unique(); + lexer.consume(Token('=')); + auto expr = parseExpression(); + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + + if (lexer.getCurToken() != tok_def) + return parseError("def", "in prototype"); + lexer.consume(tok_def); + + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName(lexer.getId()); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name(lexer.getId()); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError(")", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // TOY_PARSER_H diff --git a/Ch6/include/toy/Passes.h b/Ch6/include/toy/Passes.h new file mode 100644 index 0000000..62471dd --- /dev/null +++ b/Ch6/include/toy/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PASSES_H +#define TOY_PASSES_H + +#include + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr createShapeInferencePass(); + +/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr createLowerToAffinePass(); + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr createLowerToLLVMPass(); + +} // namespace toy +} // namespace mlir + +#endif // TOY_PASSES_H diff --git a/Ch6/include/toy/ShapeInferenceInterface.h b/Ch6/include/toy/ShapeInferenceInterface.h new file mode 100644 index 0000000..cfe5a87 --- /dev/null +++ b/Ch6/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,28 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // namespace toy +} // namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/Ch6/include/toy/ShapeInferenceInterface.td b/Ch6/include/toy/ShapeInferenceInterface.td new file mode 100644 index 0000000..2279015 --- /dev/null +++ b/Ch6/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,30 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/Ch6/include/toy/ShapeInferenceOpInterfaces.cpp.inc b/Ch6/include/toy/ShapeInferenceOpInterfaces.cpp.inc new file mode 100644 index 0000000..a481d2e --- /dev/null +++ b/Ch6/include/toy/ShapeInferenceOpInterfaces.cpp.inc @@ -0,0 +1,12 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Interface Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/// Infer and set the output shape for the current operation. +void ShapeInference::inferShapes() { + return getImpl()->inferShapes(getImpl(), getOperation()); + } diff --git a/Ch6/include/toy/ShapeInferenceOpInterfaces.h.inc b/Ch6/include/toy/ShapeInferenceOpInterfaces.h.inc new file mode 100644 index 0000000..bb24654 --- /dev/null +++ b/Ch6/include/toy/ShapeInferenceOpInterfaces.h.inc @@ -0,0 +1,61 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Interface Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +class ShapeInference; +namespace detail { +struct ShapeInferenceInterfaceTraits { + struct Concept { + /// The methods defined by the interface. + void (*inferShapes)(const Concept *impl, ::mlir::Operation *); + }; + template + class Model : public Concept { + public: + using Interface = ShapeInference; + Model() : Concept{inferShapes} {} + + static inline void inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val); + }; + template + class FallbackModel : public Concept { + public: + using Interface = ShapeInference; + FallbackModel() : Concept{inferShapes} {} + + static inline void inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val); + }; + template + class ExternalModel : public FallbackModel { + public: + using ConcreteEntity = ConcreteOp; + }; +};template +struct ShapeInferenceTrait; + +} // namespace detail +class ShapeInference : public ::mlir::OpInterface { +public: + using ::mlir::OpInterface::OpInterface; + template + struct Trait : public detail::ShapeInferenceTrait {}; + /// Infer and set the output shape for the current operation. + void inferShapes(); +}; +namespace detail { + template + struct ShapeInferenceTrait : public ::mlir::OpInterface::Trait { + }; +}// namespace detail +template +void detail::ShapeInferenceInterfaceTraits::Model::inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val) { + return (llvm::cast(tablegen_opaque_val)).inferShapes(); +} +template +void detail::ShapeInferenceInterfaceTraits::FallbackModel::inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val) { + return static_cast(impl)->inferShapes(tablegen_opaque_val); +} diff --git a/Ch6/mlir/Dialect.cpp b/Ch6/mlir/Dialect.cpp new file mode 100644 index 0000000..c587dd2 --- /dev/null +++ b/Ch6/mlir/Dialect.cpp @@ -0,0 +1,444 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::toy; + +#include "toy/Dialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } + + // All functions within toy can be inlined. + bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void ToyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// A generalized parser for binary operations. This parses the different forms +/// of 'printBinaryOp' below. +static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + SmallVector operands; + SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (FunctionType funcType = llvm::dyn_cast(type)) { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// A generalized printer for binary operations. It prints in two different +/// forms depending on if all of the types match. +static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })) { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder.getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// The 'OpAsmParser' class provides a collection of methods for parsing +/// various punctuation, as well as attributes, operands, types, etc. Each of +/// these methods returns a `ParseResult`. This class is a wrapper around +/// `LogicalResult` that can be converted to a boolean `true` value on failure, +/// or `false` on success. This allows for easily chaining together a set of +/// parser rules. These rules are used to populate an `mlir::OperationState` +/// similarly to the `build` methods described above. +mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::DenseElementsAttr value; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(value, "value", result.attributes)) + return failure(); + + result.addTypes(value.getType()); + return success(); +} + +/// The 'OpAsmPrinter' class is a stream that allows for formatting +/// strings, attributes, operands, types, etc. +void ConstantOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); + printer << getValue(); +} + +/// Verifier for the constant operation. This corresponds to the +/// `let hasVerifier = 1` in the op definition. +mlir::LogicalResult ConstantOp::verify() { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(getResult().getType()); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = llvm::cast(getValue().getType()); + if (attrType.getRank() != resultType.getRank()) { + return emitOpError("return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } + +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = llvm::dyn_cast(inputs.front()); + TensorType output = llvm::dyn_cast(outputs.front()); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs) { + // FunctionOpInterface provides a convenient `build` method that will populate + // the state of our FuncOp, and create an entry block. + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + // Dispatch to the FunctionOpInterface provided utility method that parses the + // function operation. + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(mlir::OpAsmPrinter &p) { + // Dispatch to the FunctionOpInterface provided utility method that prints the + // function operation. + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return (*this)->getAttrOfType("callee"); +} + +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } + +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult ReturnOp::verify() { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast((*this)->getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (getNumOperands() > 1) + return emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError() << "does not return the same number of values (" + << getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!hasOperand()) + return mlir::success(); + + auto inputType = *operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || llvm::isa(inputType) || + llvm::isa(resultType)) + return mlir::success(); + + return emitError() << "type of return operand (" << inputType + << ") doesn't match function result type (" << resultType + << ")"; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = llvm::cast(getOperand().getType()); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +mlir::LogicalResult TransposeOp::verify() { + auto inputType = llvm::dyn_cast(getOperand().getType()); + auto resultType = llvm::dyn_cast(getType()); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/Ch6/mlir/LowerToAffineLoops.cpp b/Ch6/mlir/LowerToAffineLoops.cpp new file mode 100644 index 0000000..ae4bd98 --- /dev/null +++ b/Ch6/mlir/LowerToAffineLoops.cpp @@ -0,0 +1,385 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops, memref operations and standard operations. This lowering +// expects that all calls have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc->getBlock(); + alloc->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input an OpBuilder, an range of memRefOperands +/// corresponding to the operands of the input operation, and the range of loop +/// induction variables for the iteration. It returns a value to store at the +/// current index of the iteration. +using LoopIterationFn = function_ref; + +static void lowerOpToLoops(Operation *op, ValueRange operands, + PatternRewriter &rewriter, + LoopIterationFn processIteration) { + auto tensorType = llvm::cast((*op->result_type_begin())); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create a nest of affine loops, with one loop per dimension of the shape. + // The buildAffineLoopNest function takes a callback that is used to construct + // the body of the innermost loop given a builder, a location and a range of + // loop induction variables. + SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); + SmallVector steps(tensorType.getRank(), /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, tensorType.getShape(), steps, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { + // Call the processing function with the rewriter, the memref operands, + // and the loop induction variables. This function will return the value + // to store at the current index. + Value valueToStore = processIteration(nestedBuilder, operands, ivs); + nestedBuilder.create(loc, valueToStore, alloc, + ivs); + }); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace { +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Binary operations +//===----------------------------------------------------------------------===// + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext *ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // BinaryOp. This allows for using the nice named accessors + // that are generated by the ODS. + typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); + + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = builder.create( + loc, binaryAdaptor.getLhs(), loopIvs); + auto loadedRhs = builder.create( + loc, binaryAdaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return builder.create(loc, loadedLhs, + loadedRhs); + }); + return success(); + } +}; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Constant operations +//===----------------------------------------------------------------------===// + +struct ConstantOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { + DenseElementsAttr constantValue = op.getValue(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = llvm::cast(op.getType()); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + + if (!valueShape.empty()) { + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back( + rewriter.create(loc, i)); + } else { + // This is the case of a tensor of rank 0. + constantIndices.push_back( + rewriter.create(loc, 0)); + } + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.value_begin(); + std::function storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::ArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Func operations +//===----------------------------------------------------------------------===// + +struct FuncOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We only lower the main function as we expect that all other functions + // have been inlined. + if (op.getName() != "main") + return failure(); + + // Verify that the given main has no inputs and results. + if (op.getNumArguments() || op.getFunctionType().getNumResults()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "expected 'main' to have 0 inputs and 0 results"; + }); + } + + // Create a new non-toy function, with the same region. + auto func = rewriter.create(op.getLoc(), op.getName(), + op.getFunctionType()); + rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Print operations +//===----------------------------------------------------------------------===// + +struct PrintOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We don't lower "toy.print" in this pass, but we need to update its + // operands. + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Return operations +//===----------------------------------------------------------------------===// + +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { + // During this lowering, we expect that all function calls have been + // inlined. + if (op.hasOperand()) + return failure(); + + // We lower "toy.return" directly to "func.return". + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Transpose operations +//===----------------------------------------------------------------------===// + +struct TransposeOpLowering : public ConversionPattern { + TransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // TransposeOp. This allows for using the nice named + // accessors that are generated by the ODS. + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return builder.create(loc, input, + reverseIvs); + }); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// ToyToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the toy operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Toy dialect. +namespace { +struct ToyToAffineLoweringPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToAffineLoweringPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // namespace + +void ToyToAffineLoweringPass::runOnOperation() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine`, `Arith`, `Func`, and `MemRef` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands + // to be updated though (as we convert from TensorType to MemRefType), so we + // only treat it as `legal` if its operands are legal. + target.addIllegalDialect(); + target.addDynamicallyLegalOp([](toy::PrintOp op) { + return llvm::none_of(op->getOperandTypes(), + [](Type type) { return llvm::isa(type); }); + }); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + RewritePatternSet patterns(&getContext()); + patterns.add( + &getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} + +/// Create a pass for lowering operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr mlir::toy::createLowerToAffinePass() { + return std::make_unique(); +} diff --git a/Ch6/mlir/LowerToLLVM.cpp b/Ch6/mlir/LowerToLLVM.cpp new file mode 100644 index 0000000..f91d880 --- /dev/null +++ b/Ch6/mlir/LowerToLLVM.cpp @@ -0,0 +1,241 @@ +//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements full lowering of Toy operations to LLVM MLIR dialect. +// 'toy.print' is lowered to a loop nest that calls `printf` on each element of +// the input array. The file also sets up the ToyToLLVMLoweringPass. This pass +// lowers the combination of Arithmetic + Affine + SCF + Func dialects to the +// LLVM one: +// +// Affine -- +// | +// v +// Arithmetic + Func --> LLVM (Dialect) +// ^ +// | +// 'toy.print' --> Loop (SCF) -- +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Casting.h" +#include +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToLLVM RewritePatterns +//===----------------------------------------------------------------------===// + +namespace { +/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual +/// elements of the array. +class PrintOpLowering : public ConversionPattern { +public: + explicit PrintOpLowering(MLIRContext *context) + : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto *context = rewriter.getContext(); + auto memRefType = llvm::cast((*op->operand_type_begin())); + auto memRefShape = memRefType.getShape(); + auto loc = op->getLoc(); + + ModuleOp parentModule = op->getParentOfType(); + + // Get a symbol reference to the printf function, inserting it if necessary. + auto printfRef = getOrInsertPrintf(rewriter, parentModule); + Value formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule); + Value newLineCst = getOrCreateGlobalString( + loc, rewriter, "nl", StringRef("\n\0", 2), parentModule); + + // Create a loop for each of the dimensions within the shape. + SmallVector loopIvs; + for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + for (Operation &nested : *loop.getBody()) + rewriter.eraseOp(&nested); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body. + rewriter.setInsertionPointToEnd(loop.getBody()); + + // Insert a newline after each of the inner dimensions of the shape. + if (i != e - 1) + rewriter.create(loc, getPrintfType(context), printfRef, + newLineCst); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to printf for the current element of the loop. + auto printOp = cast(op); + auto elementLoad = + rewriter.create(loc, printOp.getInput(), loopIvs); + rewriter.create( + loc, getPrintfType(context), printfRef, + ArrayRef({formatSpecifierCst, elementLoad})); + + // Notify the rewriter that this operation has been removed. + rewriter.eraseOp(op); + return success(); + } + +private: + /// Create a function declaration for printf, the signature is: + /// * `i32 (i8*, ...)` + static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { + auto llvmI32Ty = IntegerType::get(context, 32); + auto llvmPtrTy = LLVM::LLVMPointerType::get(context); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, + /*isVarArg=*/true); + return llvmFnType; + } + + /// Return a symbol reference to the printf function, inserting it into the + /// module if necessary. + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module) { + auto *context = module.getContext(); + if (module.lookupSymbol("printf")) + return SymbolRefAttr::get(context, "printf"); + + // Insert the printf function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), "printf", + getPrintfType(context)); + return SymbolRefAttr::get(context, "printf"); + } + + /// Return a value representing an access into a global string with the given + /// name, creating the string if necessary. + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module) { + // Create the global at the entry of the module. + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(name))) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 8), value.size()); + global = builder.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), + /*alignment=*/0); + } + + // Get the pointer to the first character in the global string. + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create(loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), + globalPtr, ArrayRef({cst0, cst0})); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// ToyToLLVMLoweringPass +//===----------------------------------------------------------------------===// + +namespace { +struct ToyToLLVMLoweringPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToLLVMLoweringPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // namespace + +void ToyToLLVMLoweringPass::runOnOperation() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. For this lowering, we are only targeting + // the LLVM dialect. + LLVMConversionTarget target(getContext()); + target.addLegalOp(); + + // During this lowering, we will also be lowering the MemRef types, that are + // currently being operated on, to a representation in LLVM. To perform this + // conversion we use a TypeConverter as part of the lowering. This converter + // details how one type maps to another. This is necessary now that we will be + // doing more complicated lowerings, involving loop region arguments. + LLVMTypeConverter typeConverter(&getContext()); + + // Now that the conversion target has been defined, we need to provide the + // patterns used for lowering. At this point of the compilation process, we + // have a combination of `toy`, `affine`, and `std` operations. Luckily, there + // are already exists a set of patterns to transform `affine` and `std` + // dialects. These patterns lowering in multiple stages, relying on transitive + // lowerings. Transitive lowering, or A->B->C lowering, is when multiple + // patterns must be applied to fully transform an illegal operation into a + // set of legal ones. + RewritePatternSet patterns(&getContext()); + populateAffineToStdConversionPatterns(patterns); + populateSCFToControlFlowConversionPatterns(patterns); + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); + cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + populateFuncToLLVMConversionPatterns(typeConverter, patterns); + + // The only remaining operation to lower from the `toy` dialect, is the + // PrintOp. + patterns.add(&getContext()); + + // We want to completely lower to LLVM, so we use a `FullConversion`. This + // ensures that only legal operations will remain after the conversion. + auto module = getOperation(); + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr mlir::toy::createLowerToLLVMPass() { + return std::make_unique(); +} diff --git a/Ch6/mlir/MLIRGen.cpp b/Ch6/mlir/MLIRGen.cpp new file mode 100644 index 0000000..6c5474a --- /dev/null +++ b/Ch6/mlir/MLIRGen.cpp @@ -0,0 +1,461 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (FunctionAST &f : moduleAST) + mlirGen(f); + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(const Location &loc) { + return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (symbolTable.count(var)) + return mlir::failure(); + symbolTable.insert(var, value); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::toy::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + // Arguments type are uniformly unranked tensors. + llvm::SmallVector argTypes(proto.getArgs().size(), + getType(VarType{})); + auto funcType = builder.getFunctionType(argTypes, std::nullopt); + return builder.create(location, proto.getName(), + funcType); + } + + /// Emit a new function and add it to the MLIR module. + mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + ScopedHashTableScope varScope(symbolTable); + + // Create an MLIR function for the given prototype. + builder.setInsertionPointToEnd(theModule.getBody()); + mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); + if (!function) + return nullptr; + + // Let's start the body of the function now! + mlir::Block &entryBlock = function.front(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto nameValue : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(std::get<0>(nameValue)->getName(), + std::get<1>(nameValue)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType( + function.getFunctionType().getInputs(), getType(VarType{}))); + } + + // If this function isn't main, then set the visibility to private. + if (funcAST.getProto()->getName() != "main") + function.setPrivate(); + + return function; + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().has_value()) { + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, + expr ? ArrayRef(expr) : ArrayRef()); + return mlir::success(); + } + + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::Value mlirGen(LiteralExprAST &lit) { + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builtin calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to + // user-defined functions are mapped to a custom call that takes the callee + // name as an attribute. + return builder.create(location, callee, operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto *init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope varScope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/Ch6/mlir/ShapeInferencePass.cpp b/Ch6/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000..a9e995e --- /dev/null +++ b/Ch6/mlir/ShapeInferencePass.cpp @@ -0,0 +1,122 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a pass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +struct ShapeInferencePass + : public mlir::PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + + void runOnOperation() override { + auto f = getOperation(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, allOperandsInferred); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has all of its + /// operands inferred. + static bool allOperandsInferred(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return llvm::isa(operandType); + }); + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !llvm::isa(resultType); + }); + } +}; +} // namespace + +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/Ch6/mlir/ToyCombine.cpp b/Ch6/mlir/ToyCombine.cpp new file mode 100644 index 0000000..3ce35c8 --- /dev/null +++ b/Ch6/mlir/ToyCombine.cpp @@ -0,0 +1,69 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Dialect.h" +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // namespace + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> x +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); + + // Input defined by another transpose? If not, no match. + if (!transposeInputOp) + return failure(); + + // Otherwise, we have a redundant transpose. Use the rewriter. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); + return success(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} diff --git a/Ch6/mlir/ToyCombine.inc b/Ch6/mlir/ToyCombine.inc new file mode 100644 index 0000000..61c6203 --- /dev/null +++ b/Ch6/mlir/ToyCombine.inc @@ -0,0 +1,176 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Rewriters *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: ToyCombine.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/* Generated from: + ToyCombine.td:46 +*/ +struct FoldConstantReshapeOptPattern : public ::mlir::RewritePattern { + FoldConstantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.constant"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::DenseElementsAttr arg; + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ConstantOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ConstantOp type"; + }); + } + { + auto tblgen_attr = op1->getAttrOfType<::mlir::DenseElementsAttr>("value");(void)tblgen_attr; + if (!(tblgen_attr)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "expected op 'toy.constant' to have attribute 'value' of type '::mlir::DenseElementsAttr'"; + }); + } + arg = tblgen_attr; + } + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + auto nativeVar_0 = arg.reshape(::llvm::cast((*res.getODSResults(0).begin()).getType())); (void)nativeVar_0; + ::mlir::toy::ConstantOp tblgen_ConstantOp_1; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + if (auto tmpAttr = nativeVar_0) { + tblgen_attrs.emplace_back(rewriter.getStringAttr("value"), tmpAttr); + } + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ConstantOp_1 = rewriter.create<::mlir::toy::ConstantOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ConstantOp_1.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:59 +*/ +struct RedundantReshapeOptPattern : public ::mlir::RewritePattern { + RedundantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 1, context, {}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + arg = castedOp0.getODSOperands(0); + if (!(((*res.getODSResults(0).begin()).getType() == (*arg.begin()).getType()))){ + return rewriter.notifyMatchFailure(op0, [&](::mlir::Diagnostic &diag) { + diag << "entities 'res, arg' failed to satisfy constraint: ''"; + }); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ arg }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:33 +*/ +struct ReshapeReshapeOptPattern : public ::mlir::RewritePattern { + ReshapeReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.reshape"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ReshapeOp type"; + }); + } + arg = castedOp1.getODSOperands(0); + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + ::mlir::toy::ReshapeOp tblgen_ReshapeOp_0; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + tblgen_values.push_back((*arg.begin())); + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ReshapeOp_0 = rewriter.create<::mlir::toy::ReshapeOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ReshapeOp_0.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} diff --git a/Ch6/mlir/ToyCombine.td b/Ch6/mlir/ToyCombine.td new file mode 100644 index 0000000..11d7831 --- /dev/null +++ b/Ch6/mlir/ToyCombine.td @@ -0,0 +1,63 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "mlir/IR/PatternBase.td" +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : Constraint>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/Ch6/mlir/run.sh b/Ch6/mlir/run.sh new file mode 100644 index 0000000..f592fde --- /dev/null +++ b/Ch6/mlir/run.sh @@ -0,0 +1,2 @@ +mlir-tblgen-18 -gen-rewriters -I /usr/lib/llvm-18/include -I ../include ToyCombine.td > ToyCombine.inc + diff --git a/Ch6/parser/AST.cpp b/Ch6/parser/AST.cpp new file mode 100644 index 0000000..2546f2a --- /dev/null +++ b/Ch6/parser/AST.cpp @@ -0,0 +1,237 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template +static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + llvm::TypeSwitch(expr) + .Case( + [&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto *num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + llvm::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + llvm::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().has_value()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + llvm::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; + indent(); + llvm::errs() << "Params: ["; + llvm::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &f : *node) + dump(&f); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/Ch6/toyc.cpp b/Ch6/toyc.cpp new file mode 100644 index 0000000..ddc0c25 --- /dev/null +++ b/Ch6/toyc.cpp @@ -0,0 +1,329 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} // namespace +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { + None, + DumpAST, + DumpMLIR, + DumpMLIRAffine, + DumpMLIRLLVM, + DumpLLVMIR, + RunJIT +}; +} // namespace +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering")), + cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", + "output the MLIR dump after llvm lowering")), + cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), + cl::values( + clEnumValN(RunJIT, "jit", + "JIT the code and run it by invoking the main function"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).ends_with(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) { + if (int error = loadMLIR(context, module)) + return error; + + mlir::PassManager pm(module.get()->getName()); + // Apply any generic pass manager command line options and run the pipeline. + if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) + return 4; + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; + + if (enableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect. + pm.addPass(mlir::toy::createLowerToAffinePass()); + + // Add a few cleanups post lowering. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (enableOpt) { + optPM.addPass(mlir::affine::createLoopFusionPass()); + optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); + } + } + + if (isLoweringToLLVM) { + // Finish lowering the toy IR to the LLVM dialect. + pm.addPass(mlir::toy::createLowerToLLVMPass()); + // This is necessary to have line tables emitted and basic + // debugger working. In the future we will add proper debug information + // emission directly from our frontend. + pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); + } + + if (mlir::failed(pm.run(*module))) + return 4; + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int dumpLLVMIR(mlir::ModuleOp module) { + // Register the translation to LLVM IR with the MLIR context. + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // Convert the module to LLVM IR in a new LLVM IR context. + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // Configure the LLVM Module + auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!tmBuilderOrError) { + llvm::errs() << "Could not create JITTargetMachineBuilder\n"; + return -1; + } + + auto tmOrError = tmBuilderOrError->createTargetMachine(); + if (!tmOrError) { + llvm::errs() << "Could not create TargetMachine\n"; + return -1; + } + mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), + tmOrError.get().get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} + +int runJit(mlir::ModuleOp module) { + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // Register the translation from MLIR to LLVM IR, which must happen before we + // can JIT-compile. + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + mlir::ExecutionEngineOptions engineOptions; + engineOptions.transformer = optPipeline; + auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invokePacked("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + // Register any command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + if (emitAction == Action::DumpAST) + return dumpAST(); + + // If we aren't dumping the AST, then we are compiling with/to MLIR. + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + + mlir::MLIRContext context(registry); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + + mlir::OwningOpRef module; + if (int error = loadAndProcessMLIR(context, module)) + return error; + + // If we aren't exporting to non-mlir, then we are done. + bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; + if (isOutputingMLIR) { + module->dump(); + return 0; + } + + // Check to see if we are compiling to LLVM IR. + if (emitAction == Action::DumpLLVMIR) + return dumpLLVMIR(*module); + + // Otherwise, we must be running the jit. + if (emitAction == Action::RunJIT) + return runJit(*module); + + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + return -1; +} diff --git a/Ch7/CMakeLists.txt b/Ch7/CMakeLists.txt new file mode 100644 index 0000000..360b554 --- /dev/null +++ b/Ch7/CMakeLists.txt @@ -0,0 +1,62 @@ +# This chapter depends on JIT support enabled. +if(NOT MLIR_ENABLE_EXECUTION_ENGINE) + return() +endif() + +# For a better template to copy, see examples/standalone +include_directories(include) +add_subdirectory(include) + +set(LLVM_LINK_COMPONENTS + Core + Support + nativecodegen + OrcJIT + ) + +# set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +# mlir_tablegen(ToyCombine.inc -gen-rewriters) +# add_public_tablegen_target(ToyCh7CombineIncGen) + +add_executable(toyc-ch7 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/LowerToAffineLoops.cpp + mlir/LowerToLLVM.cpp + mlir/ShapeInferencePass.cpp + mlir/ToyCombine.cpp + + # DEPENDS + # ToyCh7ShapeInferenceInterfaceIncGen + # ToyCh7OpsIncGen + # ToyCh7CombineIncGen + ) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) +target_link_libraries(toyc-ch7 + PRIVATE + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + MLIRAnalysis + MLIRBuiltinToLLVMIRTranslation + MLIRCallInterfaces + MLIRCastInterfaces + MLIRExecutionEngine + MLIRFunctionInterfaces + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMToLLVMIRTranslation + MLIRMemRefDialect + MLIRParser + MLIRPass + MLIRSideEffectInterfaces + MLIRTargetLLVMIRExport + MLIRTransforms + ) diff --git a/Ch7/include/CMakeLists.txt b/Ch7/include/CMakeLists.txt new file mode 100644 index 0000000..37c89d0 --- /dev/null +++ b/Ch7/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/Ch7/include/run.sh b/Ch7/include/run.sh new file mode 100644 index 0000000..b9d18af --- /dev/null +++ b/Ch7/include/run.sh @@ -0,0 +1,7 @@ +mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.h.inc +mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Ops.cpp.inc +mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.h.inc +mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include toy/Ops.td > toy/Dialect.cpp.inc + +mlir-tblgen-18 -gen-op-interface-decls -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.h.inc +mlir-tblgen-18 -gen-op-interface-defs -I /usr/lib/llvm-18/include toy/ShapeInferenceInterface.td > toy/ShapeInferenceOpInterfaces.cpp.inc diff --git a/Ch7/include/toy/AST.h b/Ch7/include/toy/AST.h new file mode 100644 index 0000000..42d64ed --- /dev/null +++ b/Ch7/include/toy/AST.h @@ -0,0 +1,313 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_AST_H +#define TOY_AST_H + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include + +namespace toy { + +/// A variable type with either name or shape information. +struct VarType { + std::string name; + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_StructLiteral, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(std::move(location)) {} + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double val; + +public: + NumberExprAST(Location loc, double val) + : ExprAST(Expr_Num, std::move(loc)), val(val) {} + + double getValue() { return val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } +}; + +/// Expression class for a literal value. +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), + dims(std::move(dims)) {} + + llvm::ArrayRef> getValues() { return values; } + llvm::ArrayRef getDims() { return dims; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } +}; + +/// Expression class for a literal struct value. +class StructLiteralExprAST : public ExprAST { + std::vector> values; + +public: + StructLiteralExprAST(Location loc, + std::vector> values) + : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) { + } + + llvm::ArrayRef> getValues() { return values; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { + return c->getKind() == Expr_StructLiteral; + } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, llvm::StringRef name) + : ExprAST(Expr_Var, std::move(loc)), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } +}; + +/// Expression class for defining a variable. +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, + std::unique_ptr initVal = nullptr) + : ExprAST(Expr_VarDecl, std::move(loc)), name(name), + type(std::move(type)), initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + const VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } +}; + +/// Expression class for a return operator. +class ReturnExprAST : public ExprAST { + std::optional> expr; + +public: + ReturnExprAST(Location loc, std::optional> expr) + : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} + + std::optional getExpr() { + if (expr.has_value()) + return expr->get(); + return std::nullopt; + } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char op; + std::unique_ptr lhs, rhs; + +public: + char getOp() { return op; } + ExprAST *getLHS() { return lhs.get(); } + ExprAST *getRHS() { return rhs.get(); } + + BinaryExprAST(Location loc, char op, std::unique_ptr lhs, + std::unique_ptr rhs) + : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), + rhs(std::move(rhs)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string callee; + std::vector> args; + +public: + CallExprAST(Location loc, const std::string &callee, + std::vector> args) + : ExprAST(Expr_Call, std::move(loc)), callee(callee), + args(std::move(args)) {} + + llvm::StringRef getCallee() { return callee; } + llvm::ArrayRef> getArgs() { return args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr arg; + +public: + PrintExprAST(Location loc, std::unique_ptr arg) + : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} + + ExprAST *getArg() { return arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(std::move(location)), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getArgs() { return args; } +}; + +/// This class represents a top level record in a module. +class RecordAST { +public: + enum RecordASTKind { + Record_Function, + Record_Struct, + }; + + RecordAST(RecordASTKind kind) : kind(kind) {} + virtual ~RecordAST() = default; + + RecordASTKind getKind() const { return kind; } + +private: + const RecordASTKind kind; +}; + +/// This class represents a function definition itself. +class FunctionAST : public RecordAST { + std::unique_ptr proto; + std::unique_ptr body; + +public: + FunctionAST(std::unique_ptr proto, + std::unique_ptr body) + : RecordAST(Record_Function), proto(std::move(proto)), + body(std::move(body)) {} + PrototypeAST *getProto() { return proto.get(); } + ExprASTList *getBody() { return body.get(); } + + /// LLVM style RTTI + static bool classof(const RecordAST *r) { + return r->getKind() == Record_Function; + } +}; + +/// This class represents a struct definition. +class StructAST : public RecordAST { + Location location; + std::string name; + std::vector> variables; + +public: + StructAST(Location location, const std::string &name, + std::vector> variables) + : RecordAST(Record_Struct), location(std::move(location)), name(name), + variables(std::move(variables)) {} + + const Location &loc() { return location; } + llvm::StringRef getName() const { return name; } + llvm::ArrayRef> getVariables() { + return variables; + } + + /// LLVM style RTTI + static bool classof(const RecordAST *r) { + return r->getKind() == Record_Struct; + } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector> records; + +public: + ModuleAST(std::vector> records) + : records(std::move(records)) {} + + auto begin() { return records.begin(); } + auto end() { return records.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // TOY_AST_H diff --git a/Ch7/include/toy/CMakeLists.txt b/Ch7/include/toy/CMakeLists.txt new file mode 100644 index 0000000..5ff56e8 --- /dev/null +++ b/Ch7/include/toy/CMakeLists.txt @@ -0,0 +1,13 @@ +# # Most dialects should use add_mlir_dialect(). See examples/standalone. +# set(LLVM_TARGET_DEFINITIONS Ops.td) +# mlir_tablegen(Ops.h.inc -gen-op-decls) +# mlir_tablegen(Ops.cpp.inc -gen-op-defs) +# mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +# mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +# add_public_tablegen_target(ToyCh7OpsIncGen) + +# # Most dialects should use add_mlir_interfaces(). +# set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +# mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +# mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +# add_public_tablegen_target(ToyCh7ShapeInferenceInterfaceIncGen) diff --git a/Ch7/include/toy/Dialect.cpp.inc b/Ch7/include/toy/Dialect.cpp.inc new file mode 100644 index 0000000..8cbc772 --- /dev/null +++ b/Ch7/include/toy/Dialect.cpp.inc @@ -0,0 +1,23 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) +namespace mlir { +namespace toy { + +ToyDialect::ToyDialect(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get()) { + + initialize(); +} + +ToyDialect::~ToyDialect() = default; + +} // namespace toy +} // namespace mlir diff --git a/Ch7/include/toy/Dialect.h b/Ch7/include/toy/Dialect.h new file mode 100644 index 0000000..64094c3 --- /dev/null +++ b/Ch7/include/toy/Dialect.h @@ -0,0 +1,82 @@ +//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the IR Dialect for the Toy language. +// See docs/Tutorials/Toy/Ch-2.md for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ +#define MLIR_TUTORIAL_TOY_DIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "toy/ShapeInferenceInterface.h" + +namespace mlir { +namespace toy { +namespace detail { +struct StructTypeStorage; +} // namespace detail +} // namespace toy +} // namespace mlir + +/// Include the auto-generated header file containing the declaration of the toy +/// dialect. +#include "toy/Dialect.h.inc" + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// Toy Types +//===----------------------------------------------------------------------===// + +/// This class defines the Toy struct type. It represents a collection of +/// element types. All derived types in MLIR must inherit from the CRTP class +/// 'Type::TypeBase'. It takes as template parameters the concrete type +/// (StructType), the base class to use (Type), and the storage class +/// (StructTypeStorage). +class StructType : public mlir::Type::TypeBase { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// Create an instance of a `StructType` with the given element types. There + /// *must* be atleast one element type. + static StructType get(llvm::ArrayRef elementTypes); + + /// Returns the element types of this struct type. + llvm::ArrayRef getElementTypes(); + + /// Returns the number of element type held by this struct. + size_t getNumElementTypes() { return getElementTypes().size(); } + + /// The name of this struct type. + static constexpr StringLiteral name = "toy.struct"; +}; +} // namespace toy +} // namespace mlir + +#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/Ch7/include/toy/Dialect.h.inc b/Ch7/include/toy/Dialect.h.inc new file mode 100644 index 0000000..44dfbd5 --- /dev/null +++ b/Ch7/include/toy/Dialect.h.inc @@ -0,0 +1,40 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +namespace mlir { +namespace toy { + +class ToyDialect : public ::mlir::Dialect { + explicit ToyDialect(::mlir::MLIRContext *context); + + void initialize(); + friend class ::mlir::MLIRContext; +public: + ~ToyDialect() override; + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("toy"); + } + + /// Parse a type registered to this dialect. + ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; + + /// Print a type registered to this dialect. + void printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &os) const override; + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder, + ::mlir::Attribute value, + ::mlir::Type type, + ::mlir::Location loc) override; +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ToyDialect) diff --git a/Ch7/include/toy/Lexer.h b/Ch7/include/toy/Lexer.h new file mode 100644 index 0000000..a3fde91 --- /dev/null +++ b/Ch7/include/toy/Lexer.h @@ -0,0 +1,235 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_LEXER_H +#define TOY_LEXER_H + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename. + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + tok_struct = -5, + + // primary + tok_identifier = -6, + tok_number = -7, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return identifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return numVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(lastChar)) + lastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + // Identifier: [a-zA-Z][a-zA-Z0-9_]* + if (isalpha(lastChar)) { + identifierStr = (char)lastChar; + while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') + identifierStr += (char)lastChar; + + if (identifierStr == "return") + return tok_return; + if (identifierStr == "def") + return tok_def; + if (identifierStr == "struct") + return tok_struct; + if (identifierStr == "var") + return tok_var; + return tok_identifier; + } + + // Number: [0-9] ([0-9.])* + if (isdigit(lastChar)) { + std::string numStr; + do { + numStr += lastChar; + lastChar = Token(getNextChar()); + } while (isdigit(lastChar) || lastChar == '.'); + + numVal = strtod(numStr.c_str(), nullptr); + return tok_number; + } + + if (lastChar == '#') { + // Comment until end of line. + do { + lastChar = Token(getNextChar()); + } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); + + if (lastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (lastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token thisChar = Token(lastChar); + lastChar = Token(getNextChar()); + return thisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string identifierStr; + + /// If the current Token is a number, this contains the value. + double numVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token lastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // TOY_LEXER_H diff --git a/Ch7/include/toy/MLIRGen.h b/Ch7/include/toy/MLIRGen.h new file mode 100644 index 0000000..fe9dbe5 --- /dev/null +++ b/Ch7/include/toy/MLIRGen.h @@ -0,0 +1,35 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_MLIRGEN_H +#define TOY_MLIRGEN_H + +#include + +namespace mlir { +class MLIRContext; +template +class OwningOpRef; +class ModuleOp; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // TOY_MLIRGEN_H diff --git a/Ch7/include/toy/Ops.cpp.inc b/Ch7/include/toy/Ops.cpp.inc new file mode 100644 index 0000000..f7af5e9 --- /dev/null +++ b/Ch7/include/toy/Ops.cpp.inc @@ -0,0 +1,2907 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifdef GET_OP_LIST +#undef GET_OP_LIST + +::mlir::toy::AddOp, +::mlir::toy::CastOp, +::mlir::toy::ConstantOp, +::mlir::toy::FuncOp, +::mlir::toy::GenericCallOp, +::mlir::toy::MulOp, +::mlir::toy::PrintOp, +::mlir::toy::ReshapeOp, +::mlir::toy::ReturnOp, +::mlir::toy::StructAccessOp, +::mlir::toy::StructConstantOp, +::mlir::toy::TransposeOp +#endif // GET_OP_LIST + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))) || ((::llvm::isa(type))))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be variadic of tensor of 64-bit float values or Toy struct type, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))) || ((::llvm::isa(type))))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values or Toy struct type, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))) || (((::llvm::isa<::mlir::MemRefType>(type))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be tensor of 64-bit float values or memref of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops4( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((((::llvm::isa<::mlir::RankedTensorType>(type))) && ((::llvm::cast<::mlir::ShapedType>(type).hasStaticShape()))) && ([](::mlir::Type elementType) { return (elementType.isF64()); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be statically shaped tensor of 64-bit float values, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_type_constraint_Ops5( + ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, + unsigned valueIndex) { + if (!((::llvm::isa(type)))) { + return op->emitOpError(valueKind) << " #" << valueIndex + << " must be Toy struct type, but got " << type; + } + return ::mlir::success(); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::DenseFPElementsAttr>(attr) &&::llvm::cast<::mlir::DenseElementsAttr>(attr).getType().getElementType().isF64()))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: 64-bit float elements attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops0(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::StringAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: string attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops1( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops1(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::TypeAttr>(attr))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(attr).getValue()))))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: type attribute of function type"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops2( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops2(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::ArrayAttr>(attr))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(attr), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: Array of dictionary attributes"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops3( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops3(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: flat symbol reference attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops4( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops4(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops5( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !(((::llvm::isa<::mlir::IntegerAttr>(attr))) && ((::llvm::cast<::mlir::IntegerAttr>(attr).getType().isSignlessInteger(64))))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: 64-bit signless integer attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops5( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops5(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops6( + ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + if (attr && !((::llvm::isa<::mlir::ArrayAttr>(attr)))) + return emitError() << "attribute '" << attrName + << "' failed to satisfy constraint: array attribute"; + return ::mlir::success(); +} +static ::mlir::LogicalResult __mlir_ods_local_attr_constraint_Ops6( + ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { + return __mlir_ods_local_attr_constraint_Ops6(attr, attrName, [op]() { + return op->emitOpError(); + }); +} + +static ::mlir::LogicalResult __mlir_ods_local_region_constraint_Ops0( + ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, + unsigned regionIndex) { + if (!((true))) { + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: any region"; + } + return ::mlir::success(); +} +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.add", odsAttrs.getContext()); +} + +AddOpGenericAdaptorBase::AddOpGenericAdaptorBase(AddOp op) : AddOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair AddOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr AddOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +AddOpAdaptor::AddOpAdaptor(AddOp op) : AddOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult AddOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair AddOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range AddOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> AddOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &AddOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &AddOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair AddOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range AddOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void AddOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult AddOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult AddOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void AddOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::CastOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +CastOpGenericAdaptorBase::CastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.cast", odsAttrs.getContext()); +} + +CastOpGenericAdaptorBase::CastOpGenericAdaptorBase(CastOp op) : CastOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair CastOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr CastOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +CastOpAdaptor::CastOpAdaptor(CastOp op) : CastOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult CastOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair CastOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range CastOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> CastOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &CastOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair CastOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range CastOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> CastOp::getOutput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(0).begin()); +} + +void CastOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(output); +} + +void CastOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void CastOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult CastOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult CastOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult CastOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + outputRawTypes[0] = type; + } + result.addTypes(outputTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void CastOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + { + auto type = getOutput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +void CastOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::CastOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.constant", odsAttrs.getContext()); +} + +ConstantOpGenericAdaptorBase::ConstantOpGenericAdaptorBase(ConstantOp op) : ConstantOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ConstantOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ConstantOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValueAttr() { + auto attr = ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); + return attr; +} + +::mlir::DenseElementsAttr ConstantOpGenericAdaptorBase::getValue() { + auto attr = getValueAttr(); + return attr; +} + +} // namespace detail +ConstantOpAdaptor::ConstantOpAdaptor(ConstantOp op) : ConstantOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ConstantOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitError(loc, "'toy.constant' op ""requires attribute 'value'"); + + if (tblgen_value && !((::llvm::isa<::mlir::DenseFPElementsAttr>(tblgen_value) &&::llvm::cast<::mlir::DenseElementsAttr>(tblgen_value).getType().getElementType().isF64()))) + return emitError(loc, "'toy.constant' op ""attribute 'value' failed to satisfy constraint: 64-bit float elements attribute"); + return ::mlir::success(); +} + +std::pair ConstantOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ConstantOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair ConstantOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ConstantOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult ConstantOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.value; + auto attr = dict.get("value"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for value in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `value` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute ConstantOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.value; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("value", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code ConstantOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.value.getAsOpaquePointer())); +} + +std::optional ConstantOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "value") + return prop.value; + return std::nullopt; +} + +void ConstantOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "value") { + prop.value = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void ConstantOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.value) attrs.append("value", prop.value); +} + +::mlir::LogicalResult ConstantOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getValueAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(attr, "value", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.value))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ConstantOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.value); +} + +::mlir::DenseElementsAttr ConstantOp::getValueAttr() { + return ::llvm::cast<::mlir::DenseElementsAttr>(getProperties().value); +} + +::mlir::DenseElementsAttr ConstantOp::getValue() { + auto attr = getValueAttr(); + return attr; +} + +void ConstantOp::setValueAttr(::mlir::DenseElementsAttr attr) { + (*this)->setAttr(getValueAttrName(), attr); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value) { + build(odsBuilder, odsState, value.getType(), value); + +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + odsState.addTypes(resultType0); +} + +void ConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value) { + odsState.getOrAddProperties().value = value; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ConstantOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 0u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ConstantOp::verifyInvariantsImpl() { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitOpError("requires attribute 'value'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops0(*this, tblgen_value, "value"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ConstantOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +void ConstantOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.func", odsAttrs.getContext()); +} + +FuncOpGenericAdaptorBase::FuncOpGenericAdaptorBase(FuncOp op) : FuncOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair FuncOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr FuncOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::StringAttr FuncOpGenericAdaptorBase::getSymNameAttr() { + auto attr = ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); + return attr; +} + +::llvm::StringRef FuncOpGenericAdaptorBase::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOpGenericAdaptorBase::getFunctionTypeAttr() { + auto attr = ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); + return attr; +} + +::mlir::FunctionType FuncOpGenericAdaptorBase::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getArgAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOpGenericAdaptorBase::getResAttrsAttr() { + auto attr = ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); + return attr; +} + +::std::optional< ::mlir::ArrayAttr > FuncOpGenericAdaptorBase::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::Region &FuncOpGenericAdaptorBase::getBody() { + return *odsRegions[0]; +} + +::mlir::RegionRange FuncOpGenericAdaptorBase::getRegions() { + return odsRegions; +} + +} // namespace detail +FuncOpAdaptor::FuncOpAdaptor(FuncOp op) : FuncOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult FuncOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitError(loc, "'toy.func' op ""requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitError(loc, "'toy.func' op ""requires attribute 'sym_name'"); + + if (tblgen_sym_name && !((::llvm::isa<::mlir::StringAttr>(tblgen_sym_name)))) + return emitError(loc, "'toy.func' op ""attribute 'sym_name' failed to satisfy constraint: string attribute"); + + if (tblgen_function_type && !(((::llvm::isa<::mlir::TypeAttr>(tblgen_function_type))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))) && ((::llvm::isa<::mlir::FunctionType>(::llvm::cast<::mlir::TypeAttr>(tblgen_function_type).getValue()))))) + return emitError(loc, "'toy.func' op ""attribute 'function_type' failed to satisfy constraint: type attribute of function type"); + + if (tblgen_arg_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_arg_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_arg_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'arg_attrs' failed to satisfy constraint: Array of dictionary attributes"); + + if (tblgen_res_attrs && !(((::llvm::isa<::mlir::ArrayAttr>(tblgen_res_attrs))) && (::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>(tblgen_res_attrs), [&](::mlir::Attribute attr) { return attr && ((::llvm::isa<::mlir::DictionaryAttr>(attr))); })))) + return emitError(loc, "'toy.func' op ""attribute 'res_attrs' failed to satisfy constraint: Array of dictionary attributes"); + return ::mlir::success(); +} + +std::pair FuncOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range FuncOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair FuncOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range FuncOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Region &FuncOp::getBody() { + return (*this)->getRegion(0); +} + +::mlir::LogicalResult FuncOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.arg_attrs; + auto attr = dict.get("arg_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for arg_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `arg_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.function_type; + auto attr = dict.get("function_type"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for function_type in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `function_type` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.res_attrs; + auto attr = dict.get("res_attrs"); + if (attr || /*isRequired=*/false) { + if (!attr) { + emitError() << "expected key entry for res_attrs in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `res_attrs` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + + { + auto &propStorage = prop.sym_name; + auto attr = dict.get("sym_name"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for sym_name in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `sym_name` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute FuncOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.arg_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("arg_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.function_type; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("function_type", + propStorage)); + } + + { + const auto &propStorage = prop.res_attrs; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("res_attrs", + propStorage)); + } + + { + const auto &propStorage = prop.sym_name; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("sym_name", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code FuncOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.arg_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.function_type.getAsOpaquePointer()), + llvm::hash_value(prop.res_attrs.getAsOpaquePointer()), + llvm::hash_value(prop.sym_name.getAsOpaquePointer())); +} + +std::optional FuncOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "arg_attrs") + return prop.arg_attrs; + + if (name == "function_type") + return prop.function_type; + + if (name == "res_attrs") + return prop.res_attrs; + + if (name == "sym_name") + return prop.sym_name; + return std::nullopt; +} + +void FuncOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "arg_attrs") { + prop.arg_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "function_type") { + prop.function_type = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "res_attrs") { + prop.res_attrs = ::llvm::dyn_cast_or_null>(value); + return; + } + + if (name == "sym_name") { + prop.sym_name = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void FuncOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.arg_attrs) attrs.append("arg_attrs", prop.arg_attrs); + + if (prop.function_type) attrs.append("function_type", prop.function_type); + + if (prop.res_attrs) attrs.append("res_attrs", prop.res_attrs); + + if (prop.sym_name) attrs.append("sym_name", prop.sym_name); +} + +::mlir::LogicalResult FuncOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getArgAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "arg_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getFunctionTypeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(attr, "function_type", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getResAttrsAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(attr, "res_attrs", emitError))) + return ::mlir::failure(); + } + + { + ::mlir::Attribute attr = attrs.get(getSymNameAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(attr, "sym_name", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readOptionalAttribute(prop.arg_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.function_type))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readOptionalAttribute(prop.res_attrs))) + return ::mlir::failure(); + + if (::mlir::failed(reader.readAttribute(prop.sym_name))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void FuncOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + + writer.writeOptionalAttribute(prop.arg_attrs); + writer.writeAttribute(prop.function_type); + + writer.writeOptionalAttribute(prop.res_attrs); + writer.writeAttribute(prop.sym_name); +} + +::mlir::StringAttr FuncOp::getSymNameAttr() { + return ::llvm::cast<::mlir::StringAttr>(getProperties().sym_name); +} + +::llvm::StringRef FuncOp::getSymName() { + auto attr = getSymNameAttr(); + return attr.getValue(); +} + +::mlir::TypeAttr FuncOp::getFunctionTypeAttr() { + return ::llvm::cast<::mlir::TypeAttr>(getProperties().function_type); +} + +::mlir::FunctionType FuncOp::getFunctionType() { + auto attr = getFunctionTypeAttr(); + return ::llvm::cast<::mlir::FunctionType>(attr.getValue()); +} + +::mlir::ArrayAttr FuncOp::getArgAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().arg_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getArgAttrs() { + auto attr = getArgAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +::mlir::ArrayAttr FuncOp::getResAttrsAttr() { + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(getProperties().res_attrs); +} + +::std::optional< ::mlir::ArrayAttr > FuncOp::getResAttrs() { + auto attr = getResAttrsAttr(); + return attr ? ::std::optional< ::mlir::ArrayAttr >(attr) : (::std::nullopt); +} + +void FuncOp::setSymNameAttr(::mlir::StringAttr attr) { + (*this)->setAttr(getSymNameAttrName(), attr); +} + +void FuncOp::setSymName(::llvm::StringRef attrValue) { + (*this)->setAttr(getSymNameAttrName(), ::mlir::Builder((*this)->getContext()).getStringAttr(attrValue)); +} + +void FuncOp::setFunctionTypeAttr(::mlir::TypeAttr attr) { + (*this)->setAttr(getFunctionTypeAttrName(), attr); +} + +void FuncOp::setFunctionType(::mlir::FunctionType attrValue) { + (*this)->setAttr(getFunctionTypeAttrName(), ::mlir::TypeAttr::get(attrValue)); +} + +void FuncOp::setArgAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getArgAttrsAttrName(), attr); +} + +void FuncOp::setResAttrsAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getResAttrsAttrName(), attr); +} + +::mlir::Attribute FuncOp::removeArgAttrsAttr() { + auto &attr = getProperties().arg_attrs; + attr = {}; + return attr; +} + +::mlir::Attribute FuncOp::removeResAttrsAttr() { + auto &attr = getProperties().res_attrs; + attr = {}; + return attr; +} + +::mlir::LogicalResult FuncOp::verifyInvariantsImpl() { + auto tblgen_arg_attrs = getProperties().arg_attrs; (void)tblgen_arg_attrs; + auto tblgen_function_type = getProperties().function_type; (void)tblgen_function_type; + if (!tblgen_function_type) return emitOpError("requires attribute 'function_type'"); + auto tblgen_res_attrs = getProperties().res_attrs; (void)tblgen_res_attrs; + auto tblgen_sym_name = getProperties().sym_name; (void)tblgen_sym_name; + if (!tblgen_sym_name) return emitOpError("requires attribute 'sym_name'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops1(*this, tblgen_sym_name, "sym_name"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops2(*this, tblgen_function_type, "function_type"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_arg_attrs, "arg_attrs"))) + return ::mlir::failure(); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops3(*this, tblgen_res_attrs, "res_attrs"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + + for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(0))) + if (::mlir::failed(__mlir_ods_local_region_constraint_Ops0(*this, region, "body", index++))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult FuncOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.generic_call", odsAttrs.getContext()); +} + +GenericCallOpGenericAdaptorBase::GenericCallOpGenericAdaptorBase(GenericCallOp op) : GenericCallOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair GenericCallOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr GenericCallOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::FlatSymbolRefAttr GenericCallOpGenericAdaptorBase::getCalleeAttr() { + auto attr = ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); + return attr; +} + +::llvm::StringRef GenericCallOpGenericAdaptorBase::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +} // namespace detail +GenericCallOpAdaptor::GenericCallOpAdaptor(GenericCallOp op) : GenericCallOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult GenericCallOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitError(loc, "'toy.generic_call' op ""requires attribute 'callee'"); + + if (tblgen_callee && !((::llvm::isa<::mlir::FlatSymbolRefAttr>(tblgen_callee)))) + return emitError(loc, "'toy.generic_call' op ""attribute 'callee' failed to satisfy constraint: flat symbol reference attribute"); + return ::mlir::success(); +} + +std::pair GenericCallOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range GenericCallOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range GenericCallOp::getInputs() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange GenericCallOp::getInputsMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair GenericCallOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range GenericCallOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::LogicalResult GenericCallOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.callee; + auto attr = dict.get("callee"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for callee in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `callee` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute GenericCallOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.callee; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("callee", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code GenericCallOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.callee.getAsOpaquePointer())); +} + +std::optional GenericCallOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "callee") + return prop.callee; + return std::nullopt; +} + +void GenericCallOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "callee") { + prop.callee = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void GenericCallOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.callee) attrs.append("callee", prop.callee); +} + +::mlir::LogicalResult GenericCallOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getCalleeAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(attr, "callee", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.callee))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.callee); +} + +::mlir::FlatSymbolRefAttr GenericCallOp::getCalleeAttr() { + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(getProperties().callee); +} + +::llvm::StringRef GenericCallOp::getCallee() { + auto attr = getCalleeAttr(); + return attr.getValue(); +} + +void GenericCallOp::setCalleeAttr(::mlir::FlatSymbolRefAttr attr) { + (*this)->setAttr(getCalleeAttrName(), attr); +} + +void GenericCallOp::setCallee(::llvm::StringRef attrValue) { + (*this)->setAttr(getCalleeAttrName(), ::mlir::SymbolRefAttr::get(::mlir::Builder((*this)->getContext()).getContext(), attrValue)); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = callee; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + odsState.addTypes(resultType0); +} + +void GenericCallOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs) { + odsState.addOperands(inputs); + odsState.getOrAddProperties().callee = ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), callee); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void GenericCallOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariantsImpl() { + auto tblgen_callee = getProperties().callee; (void)tblgen_callee; + if (!tblgen_callee) return emitOpError("requires attribute 'callee'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops4(*this, tblgen_callee, "callee"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops2(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult GenericCallOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult GenericCallOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::FlatSymbolRefAttr calleeAttr; + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputsOperands; + ::llvm::SMLoc inputsOperandsLoc; + (void)inputsOperandsLoc; + ::llvm::ArrayRef<::mlir::Type> inputsTypes; + ::llvm::ArrayRef<::mlir::Type> allResultTypes; + + if (parser.parseCustomAttributeWithFallback(calleeAttr, parser.getBuilder().getType<::mlir::NoneType>())) { + return ::mlir::failure(); + } + if (calleeAttr) result.getOrAddProperties().callee = calleeAttr; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands)) + return ::mlir::failure(); + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + ::mlir::FunctionType inputs__allResult_functionType; + if (parser.parseType(inputs__allResult_functionType)) + return ::mlir::failure(); + inputsTypes = inputs__allResult_functionType.getInputs(); + allResultTypes = inputs__allResult_functionType.getResults(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void GenericCallOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter.printAttributeWithoutType(getCalleeAttr()); + _odsPrinter << "("; + _odsPrinter << getInputs(); + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + elidedAttrs.push_back("callee"); + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter.printFunctionalType(getInputs().getTypes(), getOperation()->getResultTypes()); +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.mul", odsAttrs.getContext()); +} + +MulOpGenericAdaptorBase::MulOpGenericAdaptorBase(MulOp op) : MulOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair MulOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr MulOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +MulOpAdaptor::MulOpAdaptor(MulOp op) : MulOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult MulOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair MulOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range MulOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getLhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::TypedValue<::mlir::TensorType> MulOp::getRhs() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin()); +} + +::mlir::OpOperand &MulOp::getLhsMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +::mlir::OpOperand &MulOp::getRhsMutable() { + auto range = getODSOperandIndexAndLength(1); + return getOperation()->getOpOperand(range.first); +} + +std::pair MulOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range MulOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + odsState.addTypes(resultType0); +} + +void MulOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs) { + odsState.addOperands(lhs); + odsState.addOperands(rhs); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void MulOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 2u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult MulOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + auto valueGroup1 = getODSOperands(1); + + for (auto v : valueGroup1) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult MulOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +void MulOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.print", odsAttrs.getContext()); +} + +PrintOpGenericAdaptorBase::PrintOpGenericAdaptorBase(PrintOp op) : PrintOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair PrintOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr PrintOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +PrintOpAdaptor::PrintOpAdaptor(PrintOp op) : PrintOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult PrintOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair PrintOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range PrintOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Value PrintOp::getInput() { + return ::llvm::cast<::mlir::Value>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &PrintOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair PrintOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range PrintOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input) { + odsState.addOperands(input); +} + +void PrintOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 0u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void PrintOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult PrintOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops3(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult PrintOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult PrintOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::Type type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void PrintOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.reshape", odsAttrs.getContext()); +} + +ReshapeOpGenericAdaptorBase::ReshapeOpGenericAdaptorBase(ReshapeOp op) : ReshapeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReshapeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr ReshapeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReshapeOpAdaptor::ReshapeOpAdaptor(ReshapeOp op) : ReshapeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReshapeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReshapeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range ReshapeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> ReshapeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &ReshapeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair ReshapeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReshapeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void ReshapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void ReshapeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops4(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReshapeOp::verifyInvariants() { + return verifyInvariantsImpl(); +} + +::mlir::ParseResult ReshapeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReshapeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void ReshapeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.return", odsAttrs.getContext()); +} + +ReturnOpGenericAdaptorBase::ReturnOpGenericAdaptorBase(ReturnOp op) : ReturnOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair ReturnOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (odsOperandsSize - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::DictionaryAttr ReturnOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +ReturnOpAdaptor::ReturnOpAdaptor(ReturnOp op) : ReturnOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult ReturnOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair ReturnOp::getODSOperandIndexAndLength(unsigned index) { + bool isVariadic[] = {true}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic operand corresponds to. + // This assumes all static variadic operands have the same dynamic value count. + int variadicSize = (getOperation()->getNumOperands() - 0) / 1; + // `index` passed in as the parameter is the static index which counts each + // operand (variadic or not) as size 1. So here for each previous static variadic + // operand, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static operand starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {start, size}; +} + +::mlir::Operation::operand_range ReturnOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Operation::operand_range ReturnOp::getInput() { + return getODSOperands(0); +} + +::mlir::MutableOperandRange ReturnOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + auto mutableRange = ::mlir::MutableOperandRange(getOperation(), range.first, range.second); + return mutableRange; +} + +std::pair ReturnOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range ReturnOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState) { + build(odsBuilder, odsState, std::nullopt); +} + +void ReturnOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input) { + odsState.addOperands(input); +} + +void ReturnOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 0u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult ReturnOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops1(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult ReturnOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult ReturnOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> inputOperands; + ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::llvm::SmallVector<::mlir::Type, 1> inputTypes; + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputOperands)) + return ::mlir::failure(); + if (!inputOperands.empty()) { + if (parser.parseColon()) + return ::mlir::failure(); + + if (parser.parseTypeList(inputTypes)) + return ::mlir::failure(); + } + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void ReturnOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + if (!getInput().empty()) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + _odsPrinter << getInput().getTypes(); + } + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); +} + +void ReturnOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::StructAccessOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +StructAccessOpGenericAdaptorBase::StructAccessOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.struct_access", odsAttrs.getContext()); +} + +StructAccessOpGenericAdaptorBase::StructAccessOpGenericAdaptorBase(StructAccessOp op) : StructAccessOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair StructAccessOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr StructAccessOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::IntegerAttr StructAccessOpGenericAdaptorBase::getIndexAttr() { + auto attr = ::llvm::cast<::mlir::IntegerAttr>(getProperties().index); + return attr; +} + +uint64_t StructAccessOpGenericAdaptorBase::getIndex() { + auto attr = getIndexAttr(); + return attr.getValue().getZExtValue(); +} + +} // namespace detail +StructAccessOpAdaptor::StructAccessOpAdaptor(StructAccessOp op) : StructAccessOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult StructAccessOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_index = getProperties().index; (void)tblgen_index; + if (!tblgen_index) return emitError(loc, "'toy.struct_access' op ""requires attribute 'index'"); + + if (tblgen_index && !(((::llvm::isa<::mlir::IntegerAttr>(tblgen_index))) && ((::llvm::cast<::mlir::IntegerAttr>(tblgen_index).getType().isSignlessInteger(64))))) + return emitError(loc, "'toy.struct_access' op ""attribute 'index' failed to satisfy constraint: 64-bit signless integer attribute"); + return ::mlir::success(); +} + +std::pair StructAccessOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range StructAccessOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Value StructAccessOp::getInput() { + return ::llvm::cast<::mlir::Value>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &StructAccessOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair StructAccessOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range StructAccessOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Value StructAccessOp::getOutput() { + return ::llvm::cast<::mlir::Value>(*getODSResults(0).begin()); +} + +::mlir::LogicalResult StructAccessOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.index; + auto attr = dict.get("index"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for index in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `index` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute StructAccessOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.index; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("index", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code StructAccessOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.index.getAsOpaquePointer())); +} + +std::optional StructAccessOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "index") + return prop.index; + return std::nullopt; +} + +void StructAccessOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "index") { + prop.index = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void StructAccessOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.index) attrs.append("index", prop.index); +} + +::mlir::LogicalResult StructAccessOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getIndexAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops5(attr, "index", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult StructAccessOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.index))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void StructAccessOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.index); +} + +::mlir::IntegerAttr StructAccessOp::getIndexAttr() { + return ::llvm::cast<::mlir::IntegerAttr>(getProperties().index); +} + +uint64_t StructAccessOp::getIndex() { + auto attr = getIndexAttr(); + return attr.getValue().getZExtValue(); +} + +void StructAccessOp::setIndexAttr(::mlir::IntegerAttr attr) { + (*this)->setAttr(getIndexAttrName(), attr); +} + +void StructAccessOp::setIndex(uint64_t attrValue) { + (*this)->setAttr(getIndexAttrName(), ::mlir::Builder((*this)->getContext()).getIntegerAttr(::mlir::Builder((*this)->getContext()).getIntegerType(64), attrValue)); +} + +void StructAccessOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input, ::mlir::IntegerAttr index) { + odsState.addOperands(input); + odsState.getOrAddProperties().index = index; + odsState.addTypes(output); +} + +void StructAccessOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input, ::mlir::IntegerAttr index) { + odsState.addOperands(input); + odsState.getOrAddProperties().index = index; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void StructAccessOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input, uint64_t index) { + odsState.addOperands(input); + odsState.getOrAddProperties().index = odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(64), index); + odsState.addTypes(output); +} + +void StructAccessOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input, uint64_t index) { + odsState.addOperands(input); + odsState.getOrAddProperties().index = odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(64), index); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void StructAccessOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult StructAccessOp::verifyInvariantsImpl() { + auto tblgen_index = getProperties().index; (void)tblgen_index; + if (!tblgen_index) return emitOpError("requires attribute 'index'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops5(*this, tblgen_index, "index"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops5(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops2(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult StructAccessOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult StructAccessOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::IntegerAttr indexAttr; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseLSquare()) + return ::mlir::failure(); + + if (parser.parseCustomAttributeWithFallback(indexAttr, parser.getBuilder().getIntegerType(64))) { + return ::mlir::failure(); + } + if (indexAttr) result.getOrAddProperties().index = indexAttr; + if (parser.parseRSquare()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::Type type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseArrow()) + return ::mlir::failure(); + + { + ::mlir::Type type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + outputRawTypes[0] = type; + } + result.addTypes(outputTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void StructAccessOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter << getInput(); + _odsPrinter << "["; + _odsPrinter.printAttributeWithoutType(getIndexAttr()); + _odsPrinter << "]"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + elidedAttrs.push_back("index"); + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ' ' << "->"; + _odsPrinter << ' '; + { + auto type = getOutput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +void StructAccessOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::StructAccessOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::StructConstantOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +StructConstantOpGenericAdaptorBase::StructConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const Properties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), properties(properties), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.struct_constant", odsAttrs.getContext()); +} + +StructConstantOpGenericAdaptorBase::StructConstantOpGenericAdaptorBase(StructConstantOp op) : StructConstantOpGenericAdaptorBase(op->getDiscardableAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair StructConstantOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr StructConstantOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +::mlir::ArrayAttr StructConstantOpGenericAdaptorBase::getValueAttr() { + auto attr = ::llvm::cast<::mlir::ArrayAttr>(getProperties().value); + return attr; +} + +::mlir::ArrayAttr StructConstantOpGenericAdaptorBase::getValue() { + auto attr = getValueAttr(); + return attr; +} + +} // namespace detail +StructConstantOpAdaptor::StructConstantOpAdaptor(StructConstantOp op) : StructConstantOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult StructConstantOpAdaptor::verify(::mlir::Location loc) { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitError(loc, "'toy.struct_constant' op ""requires attribute 'value'"); + + if (tblgen_value && !((::llvm::isa<::mlir::ArrayAttr>(tblgen_value)))) + return emitError(loc, "'toy.struct_constant' op ""attribute 'value' failed to satisfy constraint: array attribute"); + return ::mlir::success(); +} + +std::pair StructConstantOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range StructConstantOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +std::pair StructConstantOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range StructConstantOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::Value StructConstantOp::getOutput() { + return ::llvm::cast<::mlir::Value>(*getODSResults(0).begin()); +} + +::mlir::LogicalResult StructConstantOp::setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set properties"; + return ::mlir::failure(); + } + + { + auto &propStorage = prop.value; + auto attr = dict.get("value"); + if (attr || /*isRequired=*/true) { + if (!attr) { + emitError() << "expected key entry for value in DictionaryAttr to set " + "Properties."; + return ::mlir::failure(); + } + auto convertedAttr = ::llvm::dyn_cast>(attr); + if (convertedAttr) { + propStorage = convertedAttr; + } else { + emitError() << "Invalid attribute `value` in property conversion: " << attr; + return ::mlir::failure(); + } + } + } + return ::mlir::success(); +} + +::mlir::Attribute StructConstantOp::getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop) { + ::mlir::SmallVector<::mlir::NamedAttribute> attrs; + ::mlir::Builder odsBuilder{ctx}; + + { + const auto &propStorage = prop.value; + if (propStorage) + attrs.push_back(odsBuilder.getNamedAttr("value", + propStorage)); + } + + if (!attrs.empty()) + return odsBuilder.getDictionaryAttr(attrs); + return {}; +} + +llvm::hash_code StructConstantOp::computePropertiesHash(const Properties &prop) { + return llvm::hash_combine( + llvm::hash_value(prop.value.getAsOpaquePointer())); +} + +std::optional StructConstantOp::getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name) { + if (name == "value") + return prop.value; + return std::nullopt; +} + +void StructConstantOp::setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value) { + if (name == "value") { + prop.value = ::llvm::dyn_cast_or_null>(value); + return; + } +} + +void StructConstantOp::populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs) { + if (prop.value) attrs.append("value", prop.value); +} + +::mlir::LogicalResult StructConstantOp::verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + { + ::mlir::Attribute attr = attrs.get(getValueAttrName(opName)); + if (attr && ::mlir::failed(__mlir_ods_local_attr_constraint_Ops6(attr, "value", emitError))) + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult StructConstantOp::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); (void)prop; + if (::mlir::failed(reader.readAttribute(prop.value))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void StructConstantOp::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); (void)prop; + writer.writeAttribute(prop.value); +} + +::mlir::ArrayAttr StructConstantOp::getValueAttr() { + return ::llvm::cast<::mlir::ArrayAttr>(getProperties().value); +} + +::mlir::ArrayAttr StructConstantOp::getValue() { + auto attr = getValueAttr(); + return attr; +} + +void StructConstantOp::setValueAttr(::mlir::ArrayAttr attr) { + (*this)->setAttr(getValueAttrName(), attr); +} + +void StructConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::ArrayAttr value) { + odsState.getOrAddProperties().value = value; + odsState.addTypes(output); +} + +void StructConstantOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ArrayAttr value) { + odsState.getOrAddProperties().value = value; + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void StructConstantOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 0u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult StructConstantOp::verifyInvariantsImpl() { + auto tblgen_value = getProperties().value; (void)tblgen_value; + if (!tblgen_value) return emitOpError("requires attribute 'value'"); + + if (::mlir::failed(__mlir_ods_local_attr_constraint_Ops6(*this, tblgen_value, "value"))) + return ::mlir::failure(); + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops5(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult StructConstantOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult StructConstantOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::ArrayAttr valueAttr; + ::mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + + if (parser.parseCustomAttributeWithFallback(valueAttr, parser.getBuilder().getType<::mlir::NoneType>())) { + return ::mlir::failure(); + } + if (valueAttr) result.getOrAddProperties().value = valueAttr; + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return ::mlir::failure(); + } + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::Type type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + outputRawTypes[0] = type; + } + result.addTypes(outputTypes); + return ::mlir::success(); +} + +void StructConstantOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << ' '; + _odsPrinter.printAttributeWithoutType(getValueAttr()); + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + elidedAttrs.push_back("value"); + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getOutput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } +} + +void StructConstantOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::StructConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp definitions +//===----------------------------------------------------------------------===// + +namespace detail { +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) { if (odsAttrs) + odsOpName.emplace("toy.transpose", odsAttrs.getContext()); +} + +TransposeOpGenericAdaptorBase::TransposeOpGenericAdaptorBase(TransposeOp op) : TransposeOpGenericAdaptorBase(op->getAttrDictionary(), op.getProperties(), op->getRegions()) {} + +std::pair TransposeOpGenericAdaptorBase::getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) { + return {index, 1}; +} + +::mlir::DictionaryAttr TransposeOpGenericAdaptorBase::getAttributes() { + return odsAttrs; +} + +} // namespace detail +TransposeOpAdaptor::TransposeOpAdaptor(TransposeOp op) : TransposeOpGenericAdaptor(op->getOperands(), op) {} + +::mlir::LogicalResult TransposeOpAdaptor::verify(::mlir::Location loc) { + return ::mlir::success(); +} + +std::pair TransposeOp::getODSOperandIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::operand_range TransposeOp::getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; +} + +::mlir::TypedValue<::mlir::TensorType> TransposeOp::getInput() { + return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(0).begin()); +} + +::mlir::OpOperand &TransposeOp::getInputMutable() { + auto range = getODSOperandIndexAndLength(0); + return getOperation()->getOpOperand(range.first); +} + +std::pair TransposeOp::getODSResultIndexAndLength(unsigned index) { + return {index, 1}; +} + +::mlir::Operation::result_range TransposeOp::getODSResults(unsigned index) { + auto valueRange = getODSResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input) { + odsState.addOperands(input); + odsState.addTypes(resultType0); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input) { + odsState.addOperands(input); + assert(resultTypes.size() == 1u && "mismatched number of results"); + odsState.addTypes(resultTypes); +} + +void TransposeOp::build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + assert(operands.size() == 1u && "mismatched number of parameters"); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + assert(resultTypes.size() == 1u && "mismatched number of return types"); + odsState.addTypes(resultTypes); +} + +::mlir::LogicalResult TransposeOp::verifyInvariantsImpl() { + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSOperands(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "operand", index++))) + return ::mlir::failure(); + } + } + { + unsigned index = 0; (void)index; + auto valueGroup0 = getODSResults(0); + + for (auto v : valueGroup0) { + if (::mlir::failed(__mlir_ods_local_type_constraint_Ops0(*this, v.getType(), "result", index++))) + return ::mlir::failure(); + } + } + return ::mlir::success(); +} + +::mlir::LogicalResult TransposeOp::verifyInvariants() { + if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) + return ::mlir::success(); + return ::mlir::failure(); +} + +::mlir::ParseResult TransposeOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) { + ::mlir::OpAsmParser::UnresolvedOperand inputRawOperands[1]; + ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> inputOperands(inputRawOperands); ::llvm::SMLoc inputOperandsLoc; + (void)inputOperandsLoc; + ::mlir::Type inputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); + ::llvm::SmallVector<::mlir::Type, 1> allResultTypes; + if (parser.parseLParen()) + return ::mlir::failure(); + + inputOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(inputRawOperands[0])) + return ::mlir::failure(); + if (parser.parseColon()) + return ::mlir::failure(); + + { + ::mlir::TensorType type; + if (parser.parseCustomTypeWithFallback(type)) + return ::mlir::failure(); + inputRawTypes[0] = type; + } + if (parser.parseRParen()) + return ::mlir::failure(); + { + auto loc = parser.getCurrentLocation();(void)loc; + if (parser.parseOptionalAttrDict(result.attributes)) + return ::mlir::failure(); + } + if (parser.parseKeyword("to")) + return ::mlir::failure(); + + if (parser.parseTypeList(allResultTypes)) + return ::mlir::failure(); + result.addTypes(allResultTypes); + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, result.operands)) + return ::mlir::failure(); + return ::mlir::success(); +} + +void TransposeOp::print(::mlir::OpAsmPrinter &_odsPrinter) { + _odsPrinter << "("; + _odsPrinter << getInput(); + _odsPrinter << ' ' << ":"; + _odsPrinter << ' '; + { + auto type = getInput().getType(); + if (auto validType = ::llvm::dyn_cast<::mlir::TensorType>(type)) + _odsPrinter.printStrippedAttrOrType(validType); + else + _odsPrinter << type; + } + _odsPrinter << ")"; + ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; + _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + _odsPrinter << ' ' << "to"; + _odsPrinter << ' '; + _odsPrinter << getOperation()->getResultTypes(); +} + +void TransposeOp::getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) { +} + +} // namespace toy +} // namespace mlir +MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch7/include/toy/Ops.h.inc b/Ch7/include/toy/Ops.h.inc new file mode 100644 index 0000000..70c176e --- /dev/null +++ b/Ch7/include/toy/Ops.h.inc @@ -0,0 +1,1698 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Op Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: Ops.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES) +#undef GET_OP_FWD_DEFINES +namespace mlir { +namespace toy { +class AddOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class CastOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ConstantOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class FuncOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class GenericCallOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class MulOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class PrintOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReshapeOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class ReturnOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class StructAccessOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class StructConstantOp; +} // namespace toy +} // namespace mlir +namespace mlir { +namespace toy { +class TransposeOp; +} // namespace toy +} // namespace mlir +#endif + +#ifdef GET_OP_CLASSES +#undef GET_OP_CLASSES + + +//===----------------------------------------------------------------------===// +// Local Utility Method Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::AddOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class AddOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + AddOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + AddOpGenericAdaptorBase(AddOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class AddOpGenericAdaptor : public detail::AddOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::AddOpGenericAdaptorBase; +public: + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + AddOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + AddOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class AddOpAdaptor : public AddOpGenericAdaptor<::mlir::ValueRange> { +public: + using AddOpGenericAdaptor::AddOpGenericAdaptor; + AddOpAdaptor(AddOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class AddOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = AddOpAdaptor; + template + using GenericAdaptor = AddOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.add"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::AddOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::CastOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class CastOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + CastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + CastOpGenericAdaptorBase(CastOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class CastOpGenericAdaptor : public detail::CastOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::CastOpGenericAdaptorBase; +public: + CastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + CastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : CastOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + CastOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class CastOpAdaptor : public CastOpGenericAdaptor<::mlir::ValueRange> { +public: + using CastOpGenericAdaptor::CastOpGenericAdaptor; + CastOpAdaptor(CastOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class CastOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::CastOpInterface::Trait, ShapeInference::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultShape> { +public: + using Op::Op; + using Op::print; + using Adaptor = CastOpAdaptor; + template + using GenericAdaptor = CastOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.cast"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getOutput(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static bool areCastCompatible(::mlir::TypeRange inputs, ::mlir::TypeRange outputs); + void inferShapes(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::CastOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ConstantOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ConstantOpGenericAdaptorBase { +public: + struct Properties { + using valueTy = ::mlir::DenseElementsAttr; + valueTy value; + + auto getValue() { + auto &propStorage = this->value; + return ::llvm::cast<::mlir::DenseElementsAttr>(propStorage); + } + void setValue(const ::mlir::DenseElementsAttr &propValue) { + this->value = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.value == this->value && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + ConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + ConstantOpGenericAdaptorBase(ConstantOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); +}; +} // namespace detail +template +class ConstantOpGenericAdaptor : public detail::ConstantOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ConstantOpGenericAdaptorBase; +public: + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ConstantOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + ConstantOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ConstantOpAdaptor : public ConstantOpGenericAdaptor<::mlir::ValueRange> { +public: + using ConstantOpGenericAdaptor::ConstantOpGenericAdaptor; + ConstantOpAdaptor(ConstantOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ConstantOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::OpTrait::ConstantLike, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ConstantOpAdaptor; + template + using GenericAdaptor = ConstantOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getValueAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.constant"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::DenseElementsAttr getValueAttr(); + ::mlir::DenseElementsAttr getValue(); + void setValueAttr(::mlir::DenseElementsAttr attr); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + ::mlir::OpFoldResult fold(FoldAdaptor adaptor); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::FuncOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class FuncOpGenericAdaptorBase { +public: + struct Properties { + using arg_attrsTy = ::mlir::ArrayAttr; + arg_attrsTy arg_attrs; + + auto getArgAttrs() { + auto &propStorage = this->arg_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setArgAttrs(const ::mlir::ArrayAttr &propValue) { + this->arg_attrs = propValue; + } + using function_typeTy = ::mlir::TypeAttr; + function_typeTy function_type; + + auto getFunctionType() { + auto &propStorage = this->function_type; + return ::llvm::cast<::mlir::TypeAttr>(propStorage); + } + void setFunctionType(const ::mlir::TypeAttr &propValue) { + this->function_type = propValue; + } + using res_attrsTy = ::mlir::ArrayAttr; + res_attrsTy res_attrs; + + auto getResAttrs() { + auto &propStorage = this->res_attrs; + return ::llvm::dyn_cast_or_null<::mlir::ArrayAttr>(propStorage); + } + void setResAttrs(const ::mlir::ArrayAttr &propValue) { + this->res_attrs = propValue; + } + using sym_nameTy = ::mlir::StringAttr; + sym_nameTy sym_name; + + auto getSymName() { + auto &propStorage = this->sym_name; + return ::llvm::cast<::mlir::StringAttr>(propStorage); + } + void setSymName(const ::mlir::StringAttr &propValue) { + this->sym_name = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.arg_attrs == this->arg_attrs && + rhs.function_type == this->function_type && + rhs.res_attrs == this->res_attrs && + rhs.sym_name == this->sym_name && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + FuncOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + FuncOpGenericAdaptorBase(FuncOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + ::mlir::Region &getBody(); + ::mlir::RegionRange getRegions(); +}; +} // namespace detail +template +class FuncOpGenericAdaptor : public detail::FuncOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::FuncOpGenericAdaptorBase; +public: + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + FuncOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : FuncOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + FuncOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class FuncOpAdaptor : public FuncOpGenericAdaptor<::mlir::ValueRange> { +public: + using FuncOpGenericAdaptor::FuncOpGenericAdaptor; + FuncOpAdaptor(FuncOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class FuncOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = FuncOpAdaptor; + template + using GenericAdaptor = FuncOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("arg_attrs"), ::llvm::StringRef("function_type"), ::llvm::StringRef("res_attrs"), ::llvm::StringRef("sym_name")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getArgAttrsAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getArgAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + ::mlir::StringAttr getFunctionTypeAttrName() { + return getAttributeNameForIndex(1); + } + + static ::mlir::StringAttr getFunctionTypeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 1); + } + + ::mlir::StringAttr getResAttrsAttrName() { + return getAttributeNameForIndex(2); + } + + static ::mlir::StringAttr getResAttrsAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 2); + } + + ::mlir::StringAttr getSymNameAttrName() { + return getAttributeNameForIndex(3); + } + + static ::mlir::StringAttr getSymNameAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 3); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.func"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::Region &getBody(); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::StringAttr getSymNameAttr(); + ::llvm::StringRef getSymName(); + ::mlir::TypeAttr getFunctionTypeAttr(); + ::mlir::FunctionType getFunctionType(); + ::mlir::ArrayAttr getArgAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getArgAttrs(); + ::mlir::ArrayAttr getResAttrsAttr(); + ::std::optional< ::mlir::ArrayAttr > getResAttrs(); + void setSymNameAttr(::mlir::StringAttr attr); + void setSymName(::llvm::StringRef attrValue); + void setFunctionTypeAttr(::mlir::TypeAttr attr); + void setFunctionType(::mlir::FunctionType attrValue); + void setArgAttrsAttr(::mlir::ArrayAttr attr); + void setResAttrsAttr(::mlir::ArrayAttr attr); + ::mlir::Attribute removeArgAttrsAttr(); + ::mlir::Attribute removeResAttrsAttr(); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef name, FunctionType type, ArrayRef attrs = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 4 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::FuncOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::GenericCallOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class GenericCallOpGenericAdaptorBase { +public: + struct Properties { + using calleeTy = ::mlir::FlatSymbolRefAttr; + calleeTy callee; + + auto getCallee() { + auto &propStorage = this->callee; + return ::llvm::cast<::mlir::FlatSymbolRefAttr>(propStorage); + } + void setCallee(const ::mlir::FlatSymbolRefAttr &propValue) { + this->callee = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.callee == this->callee && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + GenericCallOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + GenericCallOpGenericAdaptorBase(GenericCallOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); +}; +} // namespace detail +template +class GenericCallOpGenericAdaptor : public detail::GenericCallOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::GenericCallOpGenericAdaptorBase; +public: + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + GenericCallOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : GenericCallOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + GenericCallOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInputs() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class GenericCallOpAdaptor : public GenericCallOpGenericAdaptor<::mlir::ValueRange> { +public: + using GenericCallOpGenericAdaptor::GenericCallOpGenericAdaptor; + GenericCallOpAdaptor(GenericCallOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class GenericCallOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::VariadicOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::CallOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = GenericCallOpAdaptor; + template + using GenericAdaptor = GenericCallOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("callee")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getCalleeAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getCalleeAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.generic_call"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInputs(); + ::mlir::MutableOperandRange getInputsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::FlatSymbolRefAttr getCalleeAttr(); + ::llvm::StringRef getCallee(); + void setCalleeAttr(::mlir::FlatSymbolRefAttr attr); + void setCallee(::llvm::StringRef attrValue); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, StringRef callee, ArrayRef arguments); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::FlatSymbolRefAttr callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::llvm::StringRef callee, ::mlir::ValueRange inputs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::CallInterfaceCallable getCallableForCallee(); + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee); + ::mlir::Operation::operand_range getArgOperands(); + ::mlir::MutableOperandRange getArgOperandsMutable(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::GenericCallOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::MulOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class MulOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + MulOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + MulOpGenericAdaptorBase(MulOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class MulOpGenericAdaptor : public detail::MulOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::MulOpGenericAdaptorBase; +public: + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + MulOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : MulOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + MulOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getLhs() { + return (*getODSOperands(0).begin()); + } + + ValueT getRhs() { + return (*getODSOperands(1).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class MulOpAdaptor : public MulOpGenericAdaptor<::mlir::ValueRange> { +public: + using MulOpGenericAdaptor::MulOpGenericAdaptor; + MulOpAdaptor(MulOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class MulOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = MulOpAdaptor; + template + using GenericAdaptor = MulOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.mul"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getLhs(); + ::mlir::TypedValue<::mlir::TensorType> getRhs(); + ::mlir::OpOperand &getLhsMutable(); + ::mlir::OpOperand &getRhsMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &p); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + void inferShapes(); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::MulOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::PrintOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class PrintOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + PrintOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + PrintOpGenericAdaptorBase(PrintOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class PrintOpGenericAdaptor : public detail::PrintOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::PrintOpGenericAdaptorBase; +public: + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + PrintOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : PrintOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + PrintOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class PrintOpAdaptor : public PrintOpGenericAdaptor<::mlir::ValueRange> { +public: + using PrintOpGenericAdaptor::PrintOpGenericAdaptor; + PrintOpAdaptor(PrintOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class PrintOp : public ::mlir::Op { +public: + using Op::Op; + using Op::print; + using Adaptor = PrintOpAdaptor; + template + using GenericAdaptor = PrintOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.print"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Value getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::PrintOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReshapeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReshapeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReshapeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReshapeOpGenericAdaptorBase(ReshapeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReshapeOpGenericAdaptor : public detail::ReshapeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReshapeOpGenericAdaptorBase; +public: + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReshapeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReshapeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReshapeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReshapeOpAdaptor : public ReshapeOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReshapeOpGenericAdaptor::ReshapeOpGenericAdaptor; + ReshapeOpAdaptor(ReshapeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReshapeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReshapeOpAdaptor; + template + using GenericAdaptor = ReshapeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.reshape"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReshapeOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::ReturnOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class ReturnOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + ReturnOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + ReturnOpGenericAdaptorBase(ReturnOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class ReturnOpGenericAdaptor : public detail::ReturnOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::ReturnOpGenericAdaptorBase; +public: + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + ReturnOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : ReturnOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + ReturnOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getInput() { + return getODSOperands(0); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class ReturnOpAdaptor : public ReturnOpGenericAdaptor<::mlir::ValueRange> { +public: + using ReturnOpGenericAdaptor::ReturnOpGenericAdaptor; + ReturnOpAdaptor(ReturnOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class ReturnOp : public ::mlir::Op::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::IsTerminator> { +public: + using Op::Op; + using Op::print; + using Adaptor = ReturnOpAdaptor; + template + using GenericAdaptor = ReturnOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.return"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Operation::operand_range getInput(); + ::mlir::MutableOperandRange getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: + bool hasOperand() { return getNumOperands() != 0; } +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::ReturnOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::StructAccessOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class StructAccessOpGenericAdaptorBase { +public: + struct Properties { + using indexTy = ::mlir::IntegerAttr; + indexTy index; + + auto getIndex() { + auto &propStorage = this->index; + return ::llvm::cast<::mlir::IntegerAttr>(propStorage); + } + void setIndex(const ::mlir::IntegerAttr &propValue) { + this->index = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.index == this->index && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + StructAccessOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + StructAccessOpGenericAdaptorBase(StructAccessOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::IntegerAttr getIndexAttr(); + uint64_t getIndex(); +}; +} // namespace detail +template +class StructAccessOpGenericAdaptor : public detail::StructAccessOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::StructAccessOpGenericAdaptorBase; +public: + StructAccessOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + StructAccessOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : StructAccessOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + StructAccessOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class StructAccessOpAdaptor : public StructAccessOpGenericAdaptor<::mlir::ValueRange> { +public: + using StructAccessOpGenericAdaptor::StructAccessOpGenericAdaptor; + StructAccessOpAdaptor(StructAccessOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class StructAccessOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = StructAccessOpAdaptor; + template + using GenericAdaptor = StructAccessOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("index")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getIndexAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getIndexAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.struct_access"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::Value getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::Value getOutput(); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::IntegerAttr getIndexAttr(); + uint64_t getIndex(); + void setIndexAttr(::mlir::IntegerAttr attr); + void setIndex(uint64_t attrValue); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value input, size_t index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input, ::mlir::IntegerAttr index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input, ::mlir::IntegerAttr index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input, uint64_t index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input, uint64_t index); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + ::mlir::OpFoldResult fold(FoldAdaptor adaptor); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::StructAccessOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::StructConstantOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class StructConstantOpGenericAdaptorBase { +public: + struct Properties { + using valueTy = ::mlir::ArrayAttr; + valueTy value; + + auto getValue() { + auto &propStorage = this->value; + return ::llvm::cast<::mlir::ArrayAttr>(propStorage); + } + void setValue(const ::mlir::ArrayAttr &propValue) { + this->value = propValue; + } + bool operator==(const Properties &rhs) const { + return + rhs.value == this->value && + true; + } + bool operator!=(const Properties &rhs) const { + return !(*this == rhs); + } + }; +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +public: + StructConstantOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); + + StructConstantOpGenericAdaptorBase(StructConstantOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes(); + ::mlir::ArrayAttr getValueAttr(); + ::mlir::ArrayAttr getValue(); +}; +} // namespace detail +template +class StructConstantOpGenericAdaptor : public detail::StructConstantOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::StructConstantOpGenericAdaptorBase; +public: + StructConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + StructConstantOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : StructConstantOpGenericAdaptor(values, attrs, (properties ? *properties.as() : Properties{}), regions) {} + + template >> + StructConstantOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class StructConstantOpAdaptor : public StructConstantOpGenericAdaptor<::mlir::ValueRange> { +public: + using StructConstantOpGenericAdaptor::StructConstantOpGenericAdaptor; + StructConstantOpAdaptor(StructConstantOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class StructConstantOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::OpTrait::ConstantLike, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = StructConstantOpAdaptor; + template + using GenericAdaptor = StructConstantOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + using Properties = FoldAdaptor::Properties; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")}; + return ::llvm::ArrayRef(attrNames); + } + + ::mlir::StringAttr getValueAttrName() { + return getAttributeNameForIndex(0); + } + + static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) { + return getAttributeNameForIndex(name, 0); + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.struct_constant"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + ::mlir::Value getOutput(); + static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); + static llvm::hash_code computePropertiesHash(const Properties &prop); + static std::optional getInherentAttr(::mlir::MLIRContext *ctx, const Properties &prop, llvm::StringRef name); + static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); + static void populateInherentAttrs(::mlir::MLIRContext *ctx, const Properties &prop, ::mlir::NamedAttrList &attrs); + static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + static ::mlir::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state); + void writeProperties(::mlir::DialectBytecodeWriter &writer); + ::mlir::ArrayAttr getValueAttr(); + ::mlir::ArrayAttr getValue(); + void setValueAttr(::mlir::ArrayAttr attr); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::ArrayAttr value); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ArrayAttr value); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + ::mlir::OpFoldResult fold(FoldAdaptor adaptor); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +private: + ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { + return getAttributeNameForIndex((*this)->getName(), index); + } + + static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { + assert(index < 1 && "invalid attribute index"); + assert(name.getStringRef() == getOperationName() && "invalid operation name"); + assert(name.isRegistered() && "Operation isn't registered, missing a " + "dependent dialect loading?"); + return name.getAttributeNames()[index]; + } + +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::StructConstantOp) + +namespace mlir { +namespace toy { + +//===----------------------------------------------------------------------===// +// ::mlir::toy::TransposeOp declarations +//===----------------------------------------------------------------------===// + +namespace detail { +class TransposeOpGenericAdaptorBase { +public: +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + ::mlir::RegionRange odsRegions; +public: + TransposeOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); + + TransposeOpGenericAdaptorBase(TransposeOp op); + + std::pair getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); + ::mlir::DictionaryAttr getAttributes(); +}; +} // namespace detail +template +class TransposeOpGenericAdaptor : public detail::TransposeOpGenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::TransposeOpGenericAdaptorBase; +public: + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} + + TransposeOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : TransposeOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} + + template >> + TransposeOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} + + std::pair getODSOperandIndexAndLength(unsigned index) { + return Base::getODSOperandIndexAndLength(index, odsOperands.size()); + } + + RangeT getODSOperands(unsigned index) { + auto valueRange = getODSOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; + } + + ValueT getInput() { + return (*getODSOperands(0).begin()); + } + + RangeT getOperands() { + return odsOperands; + } + +private: + RangeT odsOperands; +}; +class TransposeOpAdaptor : public TransposeOpGenericAdaptor<::mlir::ValueRange> { +public: + using TransposeOpGenericAdaptor::TransposeOpGenericAdaptor; + TransposeOpAdaptor(TransposeOp op); + + ::mlir::LogicalResult verify(::mlir::Location loc); +}; +class TransposeOp : public ::mlir::Op::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ShapeInference::Trait> { +public: + using Op::Op; + using Op::print; + using Adaptor = TransposeOpAdaptor; + template + using GenericAdaptor = TransposeOpGenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("toy.transpose"); + } + + std::pair getODSOperandIndexAndLength(unsigned index); + ::mlir::Operation::operand_range getODSOperands(unsigned index); + ::mlir::TypedValue<::mlir::TensorType> getInput(); + ::mlir::OpOperand &getInputMutable(); + std::pair getODSResultIndexAndLength(unsigned index); + ::mlir::Operation::result_range getODSResults(unsigned index); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::Value input); + static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input); + static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + ::mlir::LogicalResult verifyInvariantsImpl(); + ::mlir::LogicalResult verifyInvariants(); + ::mlir::LogicalResult verify(); + static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); + void inferShapes(); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + void print(::mlir::OpAsmPrinter &_odsPrinter); + void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); +public: +}; +} // namespace toy +} // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::toy::TransposeOp) + + +#endif // GET_OP_CLASSES + diff --git a/Ch7/include/toy/Ops.td b/Ch7/include/toy/Ops.td new file mode 100644 index 0000000..cfd6859 --- /dev/null +++ b/Ch7/include/toy/Ops.td @@ -0,0 +1,453 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_OPS +#define TOY_OPS + +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "toy/ShapeInferenceInterface.td" + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "::mlir::toy"; + + // We set this bit to generate a declaration of the `materializeConstant` + // method so that we can materialize constants for our toy operations. + let hasConstantMaterializer = 1; + + // We set this bit to generate the declarations for the dialect's type parsing + // and printing hooks. + let useDefaultTypePrinterParser = 1; + +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +// Provide a definition for the Toy StructType for use in ODS. This allows for +// using StructType in a similar way to Tensor or MemRef. We use `DialectType` +// to demarcate the StructType as belonging to the Toy dialect. +def Toy_StructType : + DialectType($_self)">, + "Toy struct type">; + +// Provide a definition of the types that are used within the Toy dialect. +def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +// We define a toy operation by inheriting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'Pure' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", + [ConstantLike, Pure, + DeclareOpInterfaceMethods]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> + : tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<(ins "DenseElementsAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<(ins "double":$value)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; + + // Set the folder bit so that we can implement constant folders. + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : Toy_Op<"add", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +def CastOp : Toy_Op<"cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Pure, + SameOperandsAndResultShape + ]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types must + both be tensor types with the same element type. If both are ranked, then + shape is required to match. The operation is invalid if converting to a + mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : Toy_Op<"func", [ + FunctionOpInterface, IsolatedFromAbove + ]> { + let summary = "user defined function operation"; + let description = [{ + The "toy.func" operation represents a user defined function. These are + callable SSA-region operations that contain toy computations. + + Example: + + ```mlir + toy.func @main() { + %0 = toy.constant dense<5.500000e+00> : tensor + %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> + toy.print %1 : tensor<2x2xf64> + toy.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } + }]; + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = toy.generic_call @my_func(%1, %3) + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType or + // StructType. + let results = (outs Toy_Type); + + // Specialize assembly printing and parsing using a declarative format. + let assemblyFormat = [{ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; + + // Add custom build methods for the generic call operation. + let builders = [ + OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + ]; +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def MulOp : Toy_Op<"mul", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} + +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); + + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def ReshapeOp : Toy_Op<"reshape", [Pure]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, + Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + toy.func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // The return operation only emits the input in the format if it is present. + let assemblyFormat = "($input^ `:` type($input))? attr-dict "; + + // Allow building a ReturnOp with no return operand. + let builders = [ + OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + ]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// StructAccessOp +//===----------------------------------------------------------------------===// + +def StructAccessOp : Toy_Op<"struct_access", [Pure]> { + let summary = "struct access"; + let description = [{ + Access the Nth element of a value returning a struct type. + }]; + + let arguments = (ins Toy_StructType:$input, I64Attr:$index); + let results = (outs Toy_Type:$output); + + let assemblyFormat = [{ + $input `[` $index `]` attr-dict `:` type($input) `->` type($output) + }]; + + // Allow building a StructAccessOp with just a struct value and an index. + let builders = [ + OpBuilder<(ins "Value":$input, "size_t":$index)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; + + // Set the folder bit so that we can fold constant accesses. + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// StructConstantOp +//===----------------------------------------------------------------------===// + +def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, Pure]> { + let summary = "struct constant"; + let description = [{ + Constant operation turns a literal struct value into an SSA value. The data + is attached to the operation as an attribute. The struct constant is encoded + as an array of other constant values. For example: + + ```mlir + %0 = toy.struct_constant [ + dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> + ] : !toy.struct> + ``` + }]; + + let arguments = (ins ArrayAttr:$value); + let results = (outs Toy_StructType:$output); + + let assemblyFormat = "$value attr-dict `:` type($output)"; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def TransposeOp : Toy_Op<"transpose", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + let assemblyFormat = [{ + `(` $input `:` type($input) `)` attr-dict `to` type(results) + }]; + + // Enable registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the input operand. + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; +} + +#endif // TOY_OPS diff --git a/Ch7/include/toy/Parser.h b/Ch7/include/toy/Parser.h new file mode 100644 index 0000000..7ba7b8f --- /dev/null +++ b/Ch7/include/toy/Parser.h @@ -0,0 +1,683 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PARSER_H +#define TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr parseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions and structs one at a time and accumulate in this vector. + std::vector> records; + while (true) { + std::unique_ptr record; + switch (lexer.getCurToken()) { + case tok_eof: + break; + case tok_def: + record = parseDefinition(); + break; + case tok_struct: + record = parseStruct(); + break; + default: + return parseError("'def' or 'struct'", + "when parsing top level module records"); + } + if (!record) + break; + records.push_back(std::move(record)); + } + + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return std::make_unique(std::move(records)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr parseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + std::optional> expr; + if (lexer.getCurToken() != ';') { + expr = parseExpression(); + if (!expr) + return nullptr; + } + return std::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr parseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto result = + std::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr parseTensorLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(parseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + + // Append the nested dimensions to the current level + auto firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expression"); + } + } + return std::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// Parse a literal struct expression. + /// structLiteral ::= { (structLiteral | tensorLiteral)+ } + std::unique_ptr parseStructLiteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('{')); + + // Hold the list of values. + std::vector> values; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(parseTensorLiteralExpr()); + if (!values.back()) + return nullptr; + } else if (lexer.getCurToken() == tok_number) { + values.push_back(parseNumberExpr()); + if (!values.back()) + return nullptr; + } else { + if (lexer.getCurToken() != '{') + return parseError("{, [, or number", + "in struct literal expression"); + values.push_back(parseStructLiteralExpr()); + } + + // End of this list on '}' + if (lexer.getCurToken() == '}') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("} or ,", "in struct literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", + "to fill struct literal expression"); + lexer.getNextToken(); // eat } + + return std::make_unique(std::move(loc), + std::move(values)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr parseParenExpr() { + lexer.getNextToken(); // eat (. + auto v = parseExpression(); + if (!v) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return v; + } + + /// Parse a call expression. + std::unique_ptr parseCallExpr(llvm::StringRef name, + const Location &loc) { + lexer.consume(Token('(')); + std::vector> args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto arg = parseExpression()) + args.push_back(std::move(arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (args.size() != 1) + return parseError("", "as argument to print()"); + + return std::make_unique(loc, std::move(args[0])); + } + + // Call to a user-defined function + return std::make_unique(loc, std::string(name), + std::move(args)); + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr parseIdentifierExpr() { + std::string name(lexer.getId()); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return std::make_unique(std::move(loc), name); + + // This is a function call. + return parseCallExpr(name, loc); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr parsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return parseIdentifierExpr(); + case tok_number: + return parseNumberExpr(); + case '(': + return parseParenExpr(); + case '[': + return parseTensorLiteralExpr(); + case '{': + return parseStructLiteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr parseBinOpRHS(int exprPrec, + std::unique_ptr lhs) { + // If this is a binop, find its precedence. + while (true) { + int tokPrec = getTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (tokPrec < exprPrec) + return lhs; + + // Okay, we know this is a binop. + int binOp = lexer.getCurToken(); + lexer.consume(Token(binOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto rhs = parsePrimary(); + if (!rhs) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with rhs than the operator after rhs, let + // the pending operator take rhs as its lhs. + int nextPrec = getTokPrecedence(); + if (tokPrec < nextPrec) { + rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); + if (!rhs) + return nullptr; + } + + // Merge lhs/RHS. + lhs = std::make_unique(std::move(loc), binOp, + std::move(lhs), std::move(rhs)); + } + } + + /// expression::= primary binop rhs + std::unique_ptr parseExpression() { + auto lhs = parsePrimary(); + if (!lhs) + return nullptr; + + return parseBinOpRHS(0, std::move(lhs)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr parseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = std::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse either a variable declaration or a call expression. + std::unique_ptr parseDeclarationOrCallExpr() { + auto loc = lexer.getLastLocation(); + std::string id(lexer.getId()); + lexer.consume(tok_identifier); + + // Check for a call expression. + if (lexer.getCurToken() == '(') + return parseCallExpr(id, loc); + + // Otherwise, this is a variable declaration. + return parseTypedDeclaration(id, /*requiresInitializer=*/true, loc); + } + + /// Parse a typed variable declaration. + std::unique_ptr + parseTypedDeclaration(llvm::StringRef typeName, bool requiresInitializer, + const Location &loc) { + // Parse the variable name. + if (lexer.getCurToken() != tok_identifier) + return parseError("name", "in variable declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + // Parse the initializer. + std::unique_ptr expr; + if (requiresInitializer) { + if (lexer.getCurToken() != '=') + return parseError("initializer", + "in variable declaration"); + lexer.consume(Token('=')); + expr = parseExpression(); + } + + VarType type; + type.name = std::string(typeName); + return std::make_unique(loc, std::move(id), std::move(type), + std::move(expr)); + } + + /// Parse a variable declaration, for either a tensor value or a struct value, + /// with an optionally required initializer. + /// decl ::= var identifier [ type ] (= expr)? + /// decl ::= identifier identifier (= expr)? + std::unique_ptr parseDeclaration(bool requiresInitializer) { + // Check to see if this is a 'var' declaration. + if (lexer.getCurToken() == tok_var) + return parseVarDeclaration(requiresInitializer); + + // Parse the type name. + if (lexer.getCurToken() != tok_identifier) + return parseError("type name", "in variable declaration"); + auto loc = lexer.getLastLocation(); + std::string typeName(lexer.getId()); + lexer.getNextToken(); // eat id + + // Parse the rest of the declaration. + return parseTypedDeclaration(typeName, requiresInitializer, loc); + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// optionally required initializer. + /// decl ::= var identifier [ type ] (= expr)? + std::unique_ptr + parseVarDeclaration(bool requiresInitializer) { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = parseType(); + if (!type) + return nullptr; + } + if (!type) + type = std::make_unique(); + + std::unique_ptr expr; + if (requiresInitializer) { + lexer.consume(Token('=')); + expr = parseExpression(); + } + return std::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr parseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = std::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_identifier) { + // Variable declaration or call + auto expr = parseDeclarationOrCallExpr(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } else if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = parseDeclaration(/*requiresInitializer=*/true); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = parseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = parseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr parsePrototype() { + auto loc = lexer.getLastLocation(); + + if (lexer.getCurToken() != tok_def) + return parseError("def", "in prototype"); + lexer.consume(tok_def); + + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string fnName(lexer.getId()); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + VarType type; + std::string name; + + // Parse either the name of the variable, or its type. + std::string nameOrType(lexer.getId()); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + + // If the next token is an identifier, we just parsed the type. + if (lexer.getCurToken() == tok_identifier) { + type.name = std::move(nameOrType); + + // Parse the name. + name = std::string(lexer.getId()); + lexer.consume(tok_identifier); + } else { + // Otherwise, we just parsed the name. + name = std::move(nameOrType); + } + + args.push_back( + std::make_unique(std::move(loc), name, type)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError(")", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return std::make_unique(std::move(loc), fnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr parseDefinition() { + auto proto = parsePrototype(); + if (!proto) + return nullptr; + + if (auto block = parseBlock()) + return std::make_unique(std::move(proto), std::move(block)); + return nullptr; + } + + /// Parse a struct definition, we expect a struct initiated with the + /// `struct` keyword, followed by a block containing a list of variable + /// declarations. + /// + /// definition ::= `struct` identifier `{` decl+ `}` + std::unique_ptr parseStruct() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_struct); + if (lexer.getCurToken() != tok_identifier) + return parseError("name", "in struct definition"); + std::string name(lexer.getId()); + lexer.consume(tok_identifier); + + // Parse: '{' + if (lexer.getCurToken() != '{') + return parseError("{", "in struct definition"); + lexer.consume(Token('{')); + + // Parse: decl+ + std::vector> decls; + do { + auto decl = parseDeclaration(/*requiresInitializer=*/false); + if (!decl) + return nullptr; + decls.push_back(std::move(decl)); + + if (lexer.getCurToken() != ';') + return parseError(";", + "after variable in struct definition"); + lexer.consume(Token(';')); + } while (lexer.getCurToken() != '}'); + + // Parse: '}' + lexer.consume(Token('}')); + return std::make_unique(loc, name, std::move(decls)); + } + + /// Get the precedence of the pending binary operator token. + int getTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + case '.': + return 60; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // TOY_PARSER_H diff --git a/Ch7/include/toy/Passes.h b/Ch7/include/toy/Passes.h new file mode 100644 index 0000000..62471dd --- /dev/null +++ b/Ch7/include/toy/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - Toy Passes Definition -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the entry points to create compiler passes for Toy. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_PASSES_H +#define TOY_PASSES_H + +#include + +namespace mlir { +class Pass; + +namespace toy { +std::unique_ptr createShapeInferencePass(); + +/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr createLowerToAffinePass(); + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr createLowerToLLVMPass(); + +} // namespace toy +} // namespace mlir + +#endif // TOY_PASSES_H diff --git a/Ch7/include/toy/ShapeInferenceInterface.h b/Ch7/include/toy/ShapeInferenceInterface.h new file mode 100644 index 0000000..cfe5a87 --- /dev/null +++ b/Ch7/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,28 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // namespace toy +} // namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/Ch7/include/toy/ShapeInferenceInterface.td b/Ch7/include/toy/ShapeInferenceInterface.td new file mode 100644 index 0000000..2279015 --- /dev/null +++ b/Ch7/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,30 @@ +//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_INFERENCE_INTERFACE +#define SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/Ch7/include/toy/ShapeInferenceOpInterfaces.cpp.inc b/Ch7/include/toy/ShapeInferenceOpInterfaces.cpp.inc new file mode 100644 index 0000000..a481d2e --- /dev/null +++ b/Ch7/include/toy/ShapeInferenceOpInterfaces.cpp.inc @@ -0,0 +1,12 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Interface Definitions *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/// Infer and set the output shape for the current operation. +void ShapeInference::inferShapes() { + return getImpl()->inferShapes(getImpl(), getOperation()); + } diff --git a/Ch7/include/toy/ShapeInferenceOpInterfaces.h.inc b/Ch7/include/toy/ShapeInferenceOpInterfaces.h.inc new file mode 100644 index 0000000..bb24654 --- /dev/null +++ b/Ch7/include/toy/ShapeInferenceOpInterfaces.h.inc @@ -0,0 +1,61 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Interface Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +class ShapeInference; +namespace detail { +struct ShapeInferenceInterfaceTraits { + struct Concept { + /// The methods defined by the interface. + void (*inferShapes)(const Concept *impl, ::mlir::Operation *); + }; + template + class Model : public Concept { + public: + using Interface = ShapeInference; + Model() : Concept{inferShapes} {} + + static inline void inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val); + }; + template + class FallbackModel : public Concept { + public: + using Interface = ShapeInference; + FallbackModel() : Concept{inferShapes} {} + + static inline void inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val); + }; + template + class ExternalModel : public FallbackModel { + public: + using ConcreteEntity = ConcreteOp; + }; +};template +struct ShapeInferenceTrait; + +} // namespace detail +class ShapeInference : public ::mlir::OpInterface { +public: + using ::mlir::OpInterface::OpInterface; + template + struct Trait : public detail::ShapeInferenceTrait {}; + /// Infer and set the output shape for the current operation. + void inferShapes(); +}; +namespace detail { + template + struct ShapeInferenceTrait : public ::mlir::OpInterface::Trait { + }; +}// namespace detail +template +void detail::ShapeInferenceInterfaceTraits::Model::inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val) { + return (llvm::cast(tablegen_opaque_val)).inferShapes(); +} +template +void detail::ShapeInferenceInterfaceTraits::FallbackModel::inferShapes(const Concept *impl, ::mlir::Operation *tablegen_opaque_val) { + return static_cast(impl)->inferShapes(tablegen_opaque_val); +} diff --git a/Ch7/mlir/Dialect.cpp b/Ch7/mlir/Dialect.cpp new file mode 100644 index 0000000..b268b1e --- /dev/null +++ b/Ch7/mlir/Dialect.cpp @@ -0,0 +1,665 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::toy; + +#include "toy/Dialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } + + // All functions within toy can be inlined. + bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// A generalized parser for binary operations. This parses the different forms +/// of 'printBinaryOp' below. +static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + SmallVector operands; + SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (FunctionType funcType = llvm::dyn_cast(type)) { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// A generalized printer for binary operations. It prints in two different +/// forms depending on if all of the types match. +static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })) { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder.getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// The 'OpAsmParser' class provides a collection of methods for parsing +/// various punctuation, as well as attributes, operands, types, etc. Each of +/// these methods returns a `ParseResult`. This class is a wrapper around +/// `LogicalResult` that can be converted to a boolean `true` value on failure, +/// or `false` on success. This allows for easily chaining together a set of +/// parser rules. These rules are used to populate an `mlir::OperationState` +/// similarly to the `build` methods described above. +mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::DenseElementsAttr value; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(value, "value", result.attributes)) + return failure(); + + result.addTypes(value.getType()); + return success(); +} + +/// The 'OpAsmPrinter' class is a stream that allows for formatting +/// strings, attributes, operands, types, etc. +void ConstantOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); + printer << getValue(); +} + +/// Verify that the given attribute value is valid for the given type. +static mlir::LogicalResult verifyConstantForType(mlir::Type type, + mlir::Attribute opaqueValue, + mlir::Operation *op) { + if (llvm::isa(type)) { + // Check that the value is an elements attribute. + auto attrValue = llvm::dyn_cast(opaqueValue); + if (!attrValue) + return op->emitError("constant of TensorType must be initialized by " + "a DenseFPElementsAttr, got ") + << opaqueValue; + + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(type); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the + // constant result type. + auto attrType = llvm::cast(attrValue.getType()); + if (attrType.getRank() != resultType.getRank()) { + return op->emitOpError("return type must match the one of the attached " + "value attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op->emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); + } + auto resultType = llvm::cast(type); + llvm::ArrayRef resultElementTypes = resultType.getElementTypes(); + + // Verify that the initializer is an Array. + auto attrValue = llvm::dyn_cast(opaqueValue); + if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) + return op->emitError("constant of StructType must be initialized by an " + "ArrayAttr with the same number of elements, got ") + << opaqueValue; + + // Check that each of the elements are valid. + llvm::ArrayRef attrElementValues = attrValue.getValue(); + for (const auto it : llvm::zip(resultElementTypes, attrElementValues)) + if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) + return mlir::failure(); + return mlir::success(); +} + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +mlir::LogicalResult ConstantOp::verify() { + return verifyConstantForType(getResult().getType(), getValue(), *this); +} + +mlir::LogicalResult StructConstantOp::verify() { + return verifyConstantForType(getResult().getType(), getValue(), *this); +} + +/// Infer the output shape of the ConstantOp, this is required by the shape +/// inference interface. +void ConstantOp::inferShapes() { + getResult().setType(cast(getValue().getType())); +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult().setType(getInput().getType()); } + +/// Returns true if the given set of input and result types are compatible with +/// this cast operation. This is required by the `CastOpInterface` to verify +/// this operation and provide other additional utilities. +bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + // The inputs must be Tensors with the same element type. + TensorType input = llvm::dyn_cast(inputs.front()); + TensorType output = llvm::dyn_cast(outputs.front()); + if (!input || !output || input.getElementType() != output.getElementType()) + return false; + // The shape is required to match if both types are ranked. + return !input.hasRank() || !output.hasRank() || input == output; +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs) { + // FunctionOpInterface provides a convenient `build` method that will populate + // the state of our FuncOp, and create an entry block. + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + // Dispatch to the FunctionOpInterface provided utility method that parses the + // function operation. + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(mlir::OpAsmPrinter &p) { + // Dispatch to the FunctionOpInterface provided utility method that prints the + // function operation. + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// GenericCallOp +//===----------------------------------------------------------------------===// + +void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + StringRef callee, ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return (*this)->getAttrOfType("callee"); +} + +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } + +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseBinaryOp(parser, result); +} + +void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult ReturnOp::verify() { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast((*this)->getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (getNumOperands() > 1) + return emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError() << "does not return the same number of values (" + << getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!hasOperand()) + return mlir::success(); + + auto inputType = *operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || llvm::isa(inputType) || + llvm::isa(resultType)) + return mlir::success(); + + return emitError() << "type of return operand (" << inputType + << ") doesn't match function result type (" << resultType + << ")"; +} + +//===----------------------------------------------------------------------===// +// StructAccessOp +//===----------------------------------------------------------------------===// + +void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state, + mlir::Value input, size_t index) { + // Extract the result type from the input type. + StructType structTy = llvm::cast(input.getType()); + assert(index < structTy.getNumElementTypes()); + mlir::Type resultType = structTy.getElementTypes()[index]; + + // Call into the auto-generated build method. + build(b, state, resultType, input, b.getI64IntegerAttr(index)); +} + +mlir::LogicalResult StructAccessOp::verify() { + StructType structTy = llvm::cast(getInput().getType()); + size_t indexValue = getIndex(); + if (indexValue >= structTy.getNumElementTypes()) + return emitOpError() + << "index should be within the range of the input struct type"; + mlir::Type resultType = getResult().getType(); + if (resultType != structTy.getElementTypes()[indexValue]) + return emitOpError() << "must have the same result type as the struct " + "element referred to by the index"; + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + auto arrayTy = llvm::cast(getOperand().getType()); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +mlir::LogicalResult TransposeOp::verify() { + auto inputType = llvm::dyn_cast(getOperand().getType()); + auto resultType = llvm::dyn_cast(getType()); + if (!inputType || !resultType) + return mlir::success(); + + auto inputShape = inputType.getShape(); + if (!std::equal(inputShape.begin(), inputShape.end(), + resultType.getShape().rbegin())) { + return emitError() + << "expected result shape to be a transpose of the input"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// Toy Types +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace toy { +namespace detail { +/// This class represents the internal storage of the Toy `StructType`. +struct StructTypeStorage : public mlir::TypeStorage { + /// The `KeyTy` is a required type that provides an interface for the storage + /// instance. This type will be used when uniquing an instance of the type + /// storage. For our struct type, we will unique each instance structurally on + /// the elements that it contains. + using KeyTy = llvm::ArrayRef; + + /// A constructor for the type storage instance. + StructTypeStorage(llvm::ArrayRef elementTypes) + : elementTypes(elementTypes) {} + + /// Define the comparison function for the key type with the current storage + /// instance. This is used when constructing a new instance to ensure that we + /// haven't already uniqued an instance of the given key. + bool operator==(const KeyTy &key) const { return key == elementTypes; } + + /// Define a hash function for the key type. This is used when uniquing + /// instances of the storage, see the `StructType::get` method. + /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type + /// have hash functions available, so we could just omit this entirely. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Define a construction function for the key type from a set of parameters. + /// These parameters will be provided when constructing the storage instance + /// itself. + /// Note: This method isn't necessary because KeyTy can be directly + /// constructed with the given parameters. + static KeyTy getKey(llvm::ArrayRef elementTypes) { + return KeyTy(elementTypes); + } + + /// Define a construction method for creating a new instance of this storage. + /// This method takes an instance of a storage allocator, and an instance of a + /// `KeyTy`. The given allocator must be used for *all* necessary dynamic + /// allocations used to create the type storage and its internal. + static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the elements from the provided `KeyTy` into the allocator. + llvm::ArrayRef elementTypes = allocator.copyInto(key); + + // Allocate the storage instance and construct it. + return new (allocator.allocate()) + StructTypeStorage(elementTypes); + } + + /// The following field contains the element types of the struct. + llvm::ArrayRef elementTypes; +}; +} // namespace detail +} // namespace toy +} // namespace mlir + +/// Create an instance of a `StructType` with the given element types. There +/// *must* be at least one element type. +StructType StructType::get(llvm::ArrayRef elementTypes) { + assert(!elementTypes.empty() && "expected at least 1 element type"); + + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. The first parameter is the context to unique in. The + // parameters after the context are forwarded to the storage instance. + mlir::MLIRContext *ctx = elementTypes.front().getContext(); + return Base::get(ctx, elementTypes); +} + +/// Returns the element types of this struct type. +llvm::ArrayRef StructType::getElementTypes() { + // 'getImpl' returns a pointer to the internal storage instance. + return getImpl()->elementTypes; +} + +/// Parse an instance of a type registered to the toy dialect. +mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { + // Parse a struct type in the following form: + // struct-type ::= `struct` `<` type (`,` type)* `>` + + // NOTE: All MLIR parser function return a ParseResult. This is a + // specialization of LogicalResult that auto-converts to a `true` boolean + // value on failure to allow for chaining, but may be used with explicit + // `mlir::failed/mlir::succeeded` as desired. + + // Parse: `struct` `<` + if (parser.parseKeyword("struct") || parser.parseLess()) + return Type(); + + // Parse the element types of the struct. + SmallVector elementTypes; + do { + // Parse the current element type. + SMLoc typeLoc = parser.getCurrentLocation(); + mlir::Type elementType; + if (parser.parseType(elementType)) + return nullptr; + + // Check that the type is either a TensorType or another StructType. + if (!llvm::isa(elementType)) { + parser.emitError(typeLoc, "element type for a struct must either " + "be a TensorType or a StructType, got: ") + << elementType; + return Type(); + } + elementTypes.push_back(elementType); + + // Parse the optional: `,` + } while (succeeded(parser.parseOptionalComma())); + + // Parse: `>` + if (parser.parseGreater()) + return Type(); + return StructType::get(elementTypes); +} + +/// Print an instance of a type registered to the toy dialect. +void ToyDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // Currently the only toy type is a struct type. + StructType structType = llvm::cast(type); + + // Print the struct type according to the parser format. + printer << "struct<"; + llvm::interleaveComma(structType.getElementTypes(), printer); + printer << '>'; +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void ToyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); + addTypes(); +} + +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (llvm::isa(type)) + return builder.create(loc, type, + llvm::cast(value)); + return builder.create(loc, type, + llvm::cast(value)); +} diff --git a/Ch7/mlir/LowerToAffineLoops.cpp b/Ch7/mlir/LowerToAffineLoops.cpp new file mode 100644 index 0000000..ae4bd98 --- /dev/null +++ b/Ch7/mlir/LowerToAffineLoops.cpp @@ -0,0 +1,385 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops, memref operations and standard operations. This lowering +// expects that all calls have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc->getBlock(); + alloc->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input an OpBuilder, an range of memRefOperands +/// corresponding to the operands of the input operation, and the range of loop +/// induction variables for the iteration. It returns a value to store at the +/// current index of the iteration. +using LoopIterationFn = function_ref; + +static void lowerOpToLoops(Operation *op, ValueRange operands, + PatternRewriter &rewriter, + LoopIterationFn processIteration) { + auto tensorType = llvm::cast((*op->result_type_begin())); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create a nest of affine loops, with one loop per dimension of the shape. + // The buildAffineLoopNest function takes a callback that is used to construct + // the body of the innermost loop given a builder, a location and a range of + // loop induction variables. + SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); + SmallVector steps(tensorType.getRank(), /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, tensorType.getShape(), steps, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { + // Call the processing function with the rewriter, the memref operands, + // and the loop induction variables. This function will return the value + // to store at the current index. + Value valueToStore = processIteration(nestedBuilder, operands, ivs); + nestedBuilder.create(loc, valueToStore, alloc, + ivs); + }); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace { +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Binary operations +//===----------------------------------------------------------------------===// + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext *ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // BinaryOp. This allows for using the nice named accessors + // that are generated by the ODS. + typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); + + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = builder.create( + loc, binaryAdaptor.getLhs(), loopIvs); + auto loadedRhs = builder.create( + loc, binaryAdaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return builder.create(loc, loadedLhs, + loadedRhs); + }); + return success(); + } +}; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Constant operations +//===----------------------------------------------------------------------===// + +struct ConstantOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { + DenseElementsAttr constantValue = op.getValue(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = llvm::cast(op.getType()); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + + if (!valueShape.empty()) { + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back( + rewriter.create(loc, i)); + } else { + // This is the case of a tensor of rank 0. + constantIndices.push_back( + rewriter.create(loc, 0)); + } + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.value_begin(); + std::function storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::ArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Func operations +//===----------------------------------------------------------------------===// + +struct FuncOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We only lower the main function as we expect that all other functions + // have been inlined. + if (op.getName() != "main") + return failure(); + + // Verify that the given main has no inputs and results. + if (op.getNumArguments() || op.getFunctionType().getNumResults()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "expected 'main' to have 0 inputs and 0 results"; + }); + } + + // Create a new non-toy function, with the same region. + auto func = rewriter.create(op.getLoc(), op.getName(), + op.getFunctionType()); + rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Print operations +//===----------------------------------------------------------------------===// + +struct PrintOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We don't lower "toy.print" in this pass, but we need to update its + // operands. + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Return operations +//===----------------------------------------------------------------------===// + +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { + // During this lowering, we expect that all function calls have been + // inlined. + if (op.hasOperand()) + return failure(); + + // We lower "toy.return" directly to "func.return". + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Transpose operations +//===----------------------------------------------------------------------===// + +struct TransposeOpLowering : public ConversionPattern { + TransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // TransposeOp. This allows for using the nice named + // accessors that are generated by the ODS. + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return builder.create(loc, input, + reverseIvs); + }); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// ToyToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the toy operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Toy dialect. +namespace { +struct ToyToAffineLoweringPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToAffineLoweringPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // namespace + +void ToyToAffineLoweringPass::runOnOperation() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine`, `Arith`, `Func`, and `MemRef` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands + // to be updated though (as we convert from TensorType to MemRefType), so we + // only treat it as `legal` if its operands are legal. + target.addIllegalDialect(); + target.addDynamicallyLegalOp([](toy::PrintOp op) { + return llvm::none_of(op->getOperandTypes(), + [](Type type) { return llvm::isa(type); }); + }); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + RewritePatternSet patterns(&getContext()); + patterns.add( + &getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} + +/// Create a pass for lowering operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr mlir::toy::createLowerToAffinePass() { + return std::make_unique(); +} diff --git a/Ch7/mlir/LowerToLLVM.cpp b/Ch7/mlir/LowerToLLVM.cpp new file mode 100644 index 0000000..f91d880 --- /dev/null +++ b/Ch7/mlir/LowerToLLVM.cpp @@ -0,0 +1,241 @@ +//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements full lowering of Toy operations to LLVM MLIR dialect. +// 'toy.print' is lowered to a loop nest that calls `printf` on each element of +// the input array. The file also sets up the ToyToLLVMLoweringPass. This pass +// lowers the combination of Arithmetic + Affine + SCF + Func dialects to the +// LLVM one: +// +// Affine -- +// | +// v +// Arithmetic + Func --> LLVM (Dialect) +// ^ +// | +// 'toy.print' --> Loop (SCF) -- +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Casting.h" +#include +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToLLVM RewritePatterns +//===----------------------------------------------------------------------===// + +namespace { +/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual +/// elements of the array. +class PrintOpLowering : public ConversionPattern { +public: + explicit PrintOpLowering(MLIRContext *context) + : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto *context = rewriter.getContext(); + auto memRefType = llvm::cast((*op->operand_type_begin())); + auto memRefShape = memRefType.getShape(); + auto loc = op->getLoc(); + + ModuleOp parentModule = op->getParentOfType(); + + // Get a symbol reference to the printf function, inserting it if necessary. + auto printfRef = getOrInsertPrintf(rewriter, parentModule); + Value formatSpecifierCst = getOrCreateGlobalString( + loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule); + Value newLineCst = getOrCreateGlobalString( + loc, rewriter, "nl", StringRef("\n\0", 2), parentModule); + + // Create a loop for each of the dimensions within the shape. + SmallVector loopIvs; + for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); + auto loop = + rewriter.create(loc, lowerBound, upperBound, step); + for (Operation &nested : *loop.getBody()) + rewriter.eraseOp(&nested); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body. + rewriter.setInsertionPointToEnd(loop.getBody()); + + // Insert a newline after each of the inner dimensions of the shape. + if (i != e - 1) + rewriter.create(loc, getPrintfType(context), printfRef, + newLineCst); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to printf for the current element of the loop. + auto printOp = cast(op); + auto elementLoad = + rewriter.create(loc, printOp.getInput(), loopIvs); + rewriter.create( + loc, getPrintfType(context), printfRef, + ArrayRef({formatSpecifierCst, elementLoad})); + + // Notify the rewriter that this operation has been removed. + rewriter.eraseOp(op); + return success(); + } + +private: + /// Create a function declaration for printf, the signature is: + /// * `i32 (i8*, ...)` + static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { + auto llvmI32Ty = IntegerType::get(context, 32); + auto llvmPtrTy = LLVM::LLVMPointerType::get(context); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, + /*isVarArg=*/true); + return llvmFnType; + } + + /// Return a symbol reference to the printf function, inserting it into the + /// module if necessary. + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module) { + auto *context = module.getContext(); + if (module.lookupSymbol("printf")) + return SymbolRefAttr::get(context, "printf"); + + // Insert the printf function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), "printf", + getPrintfType(context)); + return SymbolRefAttr::get(context, "printf"); + } + + /// Return a value representing an access into a global string with the given + /// name, creating the string if necessary. + static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, + StringRef name, StringRef value, + ModuleOp module) { + // Create the global at the entry of the module. + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(name))) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 8), value.size()); + global = builder.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), + /*alignment=*/0); + } + + // Get the pointer to the first character in the global string. + Value globalPtr = builder.create(loc, global); + Value cst0 = builder.create(loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), + globalPtr, ArrayRef({cst0, cst0})); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// ToyToLLVMLoweringPass +//===----------------------------------------------------------------------===// + +namespace { +struct ToyToLLVMLoweringPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToLLVMLoweringPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; +} // namespace + +void ToyToLLVMLoweringPass::runOnOperation() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. For this lowering, we are only targeting + // the LLVM dialect. + LLVMConversionTarget target(getContext()); + target.addLegalOp(); + + // During this lowering, we will also be lowering the MemRef types, that are + // currently being operated on, to a representation in LLVM. To perform this + // conversion we use a TypeConverter as part of the lowering. This converter + // details how one type maps to another. This is necessary now that we will be + // doing more complicated lowerings, involving loop region arguments. + LLVMTypeConverter typeConverter(&getContext()); + + // Now that the conversion target has been defined, we need to provide the + // patterns used for lowering. At this point of the compilation process, we + // have a combination of `toy`, `affine`, and `std` operations. Luckily, there + // are already exists a set of patterns to transform `affine` and `std` + // dialects. These patterns lowering in multiple stages, relying on transitive + // lowerings. Transitive lowering, or A->B->C lowering, is when multiple + // patterns must be applied to fully transform an illegal operation into a + // set of legal ones. + RewritePatternSet patterns(&getContext()); + populateAffineToStdConversionPatterns(patterns); + populateSCFToControlFlowConversionPatterns(patterns); + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); + cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + populateFuncToLLVMConversionPatterns(typeConverter, patterns); + + // The only remaining operation to lower from the `toy` dialect, is the + // PrintOp. + patterns.add(&getContext()); + + // We want to completely lower to LLVM, so we use a `FullConversion`. This + // ensures that only legal operations will remain after the conversion. + auto module = getOperation(); + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +/// Create a pass for lowering operations the remaining `Toy` operations, as +/// well as `Affine` and `Std`, to the LLVM dialect for codegen. +std::unique_ptr mlir::toy::createLowerToLLVMPass() { + return std::make_unique(); +} diff --git a/Ch7/mlir/MLIRGen.cpp b/Ch7/mlir/MLIRGen.cpp new file mode 100644 index 0000000..0f8e8df --- /dev/null +++ b/Ch7/mlir/MLIRGen.cpp @@ -0,0 +1,692 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir::toy; +using namespace toy; + +using llvm::ArrayRef; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); + + for (auto &record : moduleAST) { + if (FunctionAST *funcAST = llvm::dyn_cast(record.get())) { + mlir::toy::FuncOp func = mlirGen(*funcAST); + if (!func) + return nullptr; + functionMap.insert({func.getName(), func}); + } else if (StructAST *str = llvm::dyn_cast(record.get())) { + if (failed(mlirGen(*str))) + return nullptr; + } else { + llvm_unreachable("unknown record type"); + } + } + + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); + return nullptr; + } + + return theModule; + } + +private: + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; + + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable> + symbolTable; + using SymbolTableScopeT = + llvm::ScopedHashTableScope>; + + /// A mapping for the functions that have been code generated to MLIR. + llvm::StringMap functionMap; + + /// A mapping for named struct types to the underlying MLIR type and the + /// original AST node. + llvm::StringMap> structMap; + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::Location loc(const Location &loc) { + return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, + loc.col); + } + + /// Declare a variable in the current scope, return success if the variable + /// wasn't declared yet. + mlir::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { + if (symbolTable.count(var.getName())) + return mlir::failure(); + symbolTable.insert(var.getName(), {value, &var}); + return mlir::success(); + } + + /// Create an MLIR type for the given struct. + mlir::LogicalResult mlirGen(StructAST &str) { + if (structMap.count(str.getName())) + return emitError(loc(str.loc())) << "error: struct type with name `" + << str.getName() << "' already exists"; + + auto variables = str.getVariables(); + std::vector elementTypes; + elementTypes.reserve(variables.size()); + for (auto &variable : variables) { + if (variable->getInitVal()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + if (!variable->getType().shape.empty()) + return emitError(loc(variable->loc())) + << "error: variables within a struct definition must not have " + "initializers"; + + mlir::Type type = getType(variable->getType(), variable->loc()); + if (!type) + return mlir::failure(); + elementTypes.push_back(type); + } + + structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str); + return mlir::success(); + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::toy::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + + // This is a generic function, the return type will be inferred later. + llvm::SmallVector argTypes; + argTypes.reserve(proto.getArgs().size()); + for (auto &arg : proto.getArgs()) { + mlir::Type type = getType(arg->getType(), arg->loc()); + if (!type) + return nullptr; + argTypes.push_back(type); + } + auto funcType = builder.getFunctionType(argTypes, std::nullopt); + return builder.create(location, proto.getName(), + funcType); + } + + /// Emit a new function and add it to the MLIR module. + mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + SymbolTableScopeT varScope(symbolTable); + + // Create an MLIR function for the given prototype. + builder.setInsertionPointToEnd(theModule.getBody()); + mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); + if (!function) + return nullptr; + + // Let's start the body of the function now! + mlir::Block &entryBlock = function.front(); + auto protoArgs = funcAST.getProto()->getArgs(); + + // Declare all the function arguments in the symbol table. + for (const auto nameValue : + llvm::zip(protoArgs, entryBlock.getArguments())) { + if (failed(declare(*std::get<0>(nameValue), std::get<1>(nameValue)))) + return nullptr; + } + + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); + + // Emit the body of the function. + if (mlir::failed(mlirGen(*funcAST.getBody()))) { + function.erase(); + return nullptr; + } + + // Implicitly return void if no return statement was emitted. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType( + builder.getFunctionType(function.getFunctionType().getInputs(), + *returnOp.operand_type_begin())); + } + + // If this function isn't main, then set the visibility to private. + if (funcAST.getProto()->getName() != "main") + function.setPrivate(); + + return function; + } + + /// Return the struct type that is the result of the given expression, or null + /// if it cannot be inferred. + StructAST *getStructFor(ExprAST *expr) { + llvm::StringRef structName; + if (auto *decl = llvm::dyn_cast(expr)) { + auto varIt = symbolTable.lookup(decl->getName()); + if (!varIt.first) + return nullptr; + structName = varIt.second->getType().name; + } else if (auto *access = llvm::dyn_cast(expr)) { + if (access->getOp() != '.') + return nullptr; + // The name being accessed should be in the RHS. + auto *name = llvm::dyn_cast(access->getRHS()); + if (!name) + return nullptr; + StructAST *parentStruct = getStructFor(access->getLHS()); + if (!parentStruct) + return nullptr; + + // Get the element within the struct corresponding to the name. + VarDeclExprAST *decl = nullptr; + for (auto &var : parentStruct->getVariables()) { + if (var->getName() == name->getName()) { + decl = var.get(); + break; + } + } + if (!decl) + return nullptr; + structName = decl->getType().name; + } + if (structName.empty()) + return nullptr; + + // If the struct name was valid, check for an entry in the struct map. + auto structIt = structMap.find(structName); + if (structIt == structMap.end()) + return nullptr; + return structIt->second.second; + } + + /// Return the numeric member index of the given struct access expression. + std::optional getMemberIndex(BinaryExprAST &accessOp) { + assert(accessOp.getOp() == '.' && "expected access operation"); + + // Lookup the struct node for the LHS. + StructAST *structAST = getStructFor(accessOp.getLHS()); + if (!structAST) + return std::nullopt; + + // Get the name from the RHS. + VariableExprAST *name = llvm::dyn_cast(accessOp.getRHS()); + if (!name) + return std::nullopt; + + auto structVars = structAST->getVariables(); + const auto *it = llvm::find_if(structVars, [&](auto &var) { + return var->getName() == name->getName(); + }); + if (it == structVars.end()) + return std::nullopt; + return it - structVars.begin(); + } + + /// Emit a binary operation + mlir::Value mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value lhs = mlirGen(*binop.getLHS()); + if (!lhs) + return nullptr; + auto location = loc(binop.loc()); + + // If this is an access operation, handle it immediately. + if (binop.getOp() == '.') { + std::optional accessIndex = getMemberIndex(binop); + if (!accessIndex) { + emitError(location, "invalid access into struct expression"); + return nullptr; + } + return builder.create(location, lhs, *accessIndex); + } + + // Otherwise, this is a normal binary op. + mlir::Value rhs = mlirGen(*binop.getRHS()); + if (!rhs) + return nullptr; + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + switch (binop.getOp()) { + case '+': + return builder.create(location, lhs, rhs); + case '*': + return builder.create(location, lhs, rhs); + } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; + } + + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. + mlir::Value mlirGen(VariableExprAST &expr) { + if (auto variable = symbolTable.lookup(expr.getName()).first) + return variable; + + emitError(loc(expr.loc()), "error: unknown variable '") + << expr.getName() << "'"; + return nullptr; + } + + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + + // 'return' takes an optional expression, handle that case here. + mlir::Value expr = nullptr; + if (ret.getExpr().has_value()) { + if (!(expr = mlirGen(**ret.getExpr()))) + return mlir::failure(); + } + + // Otherwise, this return operation has zero operands. + builder.create(location, + expr ? ArrayRef(expr) : ArrayRef()); + return mlir::success(); + } + + /// Emit a constant for a literal/constant array. It will be emitted as a + /// flattened array of data in an Attribute attached to a `toy.constant` + /// operation. See documentation on [Attributes](LangRef.md#attributes) for + /// more details. Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// + mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) { + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); + } + mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) { + // The type of this attribute is tensor of 64-bit floating-point with no + // shape. + mlir::Type elementType = builder.getF64Type(); + auto dataType = mlir::RankedTensorType::get({}, elementType); + + // This is the actual attribute that holds the list of values for this + // tensor literal. + return mlir::DenseElementsAttr::get(dataType, + llvm::ArrayRef(lit.getValue())); + } + /// Emit a constant for a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. This function returns the generated constant, along with the + /// corresponding struct type. + std::pair + getConstantAttr(StructLiteralExprAST &lit) { + std::vector attrElements; + std::vector typeElements; + + for (auto &var : lit.getValues()) { + if (auto *number = llvm::dyn_cast(var.get())) { + attrElements.push_back(getConstantAttr(*number)); + typeElements.push_back(getType(std::nullopt)); + } else if (auto *lit = llvm::dyn_cast(var.get())) { + attrElements.push_back(getConstantAttr(*lit)); + typeElements.push_back(getType(std::nullopt)); + } else { + auto *structLit = llvm::cast(var.get()); + auto attrTypePair = getConstantAttr(*structLit); + attrElements.push_back(attrTypePair.first); + typeElements.push_back(attrTypePair.second); + } + } + mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements); + mlir::Type dataType = StructType::get(typeElements); + return std::make_pair(dataAttr, dataType); + } + + /// Emit an array literal. + mlir::Value mlirGen(LiteralExprAST &lit) { + mlir::Type type = getType(lit.getDims()); + mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); + + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); + } + + /// Emit a struct literal. It will be emitted as an array of + /// other literals in an Attribute attached to a `toy.struct_constant` + /// operation. + mlir::Value mlirGen(StructLiteralExprAST &lit) { + mlir::ArrayAttr dataAttr; + mlir::Type dataType; + std::tie(dataAttr, dataType) = getConstantAttr(lit); + + // Build the MLIR op `toy.struct_constant`. This invokes the + // `StructConstantOp::build` method. + return builder.create(loc(lit.loc()), dataType, dataAttr); + } + + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + + assert(isa(expr) && "expected literal or number expr"); + data.push_back(cast(expr).getValue()); + } + + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); + auto location = loc(call.loc()); + + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + + // Builtin calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to + // user-defined functions are mapped to a custom call that takes the callee + // name as an attribute. + auto calledFuncIt = functionMap.find(callee); + if (calledFuncIt == functionMap.end()) { + emitError(location) << "no defined function found for '" << callee << "'"; + return nullptr; + } + mlir::toy::FuncOp calledFunc = calledFuncIt->second; + return builder.create( + location, calledFunc.getFunctionType().getResult(0), + mlir::SymbolRefAttr::get(builder.getContext(), callee), operands); + } + + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { + auto arg = mlirGen(*call.getArg()); + if (!arg) + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); + } + + /// Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value mlirGen(NumberExprAST &num) { + return builder.create(loc(num.loc()), num.getValue()); + } + + /// Dispatch codegen for the right expression subclass using RTTI. + mlir::Value mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_StructLiteral: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" + << Twine(expr.getKind()) << "'"; + return nullptr; + } + } + + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. + mlir::Value mlirGen(VarDeclExprAST &vardecl) { + auto *init = vardecl.getInitVal(); + if (!init) { + emitError(loc(vardecl.loc()), + "missing initializer in variable declaration"); + return nullptr; + } + + mlir::Value value = mlirGen(*init); + if (!value) + return nullptr; + + // Handle the case where we are initializing a struct value. + VarType varType = vardecl.getType(); + if (!varType.name.empty()) { + // Check that the initializer type is the same as the variable + // declaration. + mlir::Type type = getType(varType, vardecl.loc()); + if (!type) + return nullptr; + if (type != value.getType()) { + emitError(loc(vardecl.loc())) + << "struct type of initializer is different than the variable " + "declaration. Got " + << value.getType() << ", but expected " << type; + return nullptr; + } + + // Otherwise, we have the initializer value, but in case the variable was + // declared with specific shape, we emit a "reshape" operation. It will + // get optimized out later as needed. + } else if (!varType.shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(varType.shape), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl, value))) + return nullptr; + return value; + } + + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + SymbolTableScopeT varScope(symbolTable); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return mlir::failure(); + continue; + } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); + if (auto *print = dyn_cast(expr.get())) { + if (mlir::failed(mlirGen(*print))) + return mlir::success(); + continue; + } + + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return mlir::failure(); + } + return mlir::success(); + } + + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return mlir::UnrankedTensorType::get(builder.getF64Type()); + + // Otherwise, we use the given shape. + return mlir::RankedTensorType::get(shape, builder.getF64Type()); + } + + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above for non-struct types). + mlir::Type getType(const VarType &type, const Location &location) { + if (!type.name.empty()) { + auto it = structMap.find(type.name); + if (it == structMap.end()) { + emitError(loc(location)) + << "error: unknown struct type '" << type.name << "'"; + return nullptr; + } + return it->second.first; + } + + return getType(type.shape); + } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/Ch7/mlir/ShapeInferencePass.cpp b/Ch7/mlir/ShapeInferencePass.cpp new file mode 100644 index 0000000..a9e995e --- /dev/null +++ b/Ch7/mlir/ShapeInferencePass.cpp @@ -0,0 +1,122 @@ +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a Function level pass performing interprocedural +// propagation of array shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "shape-inference" + +using namespace mlir; +using namespace toy; + +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +namespace { +/// The ShapeInferencePass is a pass that performs intra-procedural +/// shape inference. +/// +/// Algorithm: +/// +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape +/// inference. +/// 2) Iterate on the worklist: +/// a) find an operation to process: the next ready operation in the +/// worklist has all of its arguments non-generic, +/// b) if no operation is found, break out of the loop, +/// c) remove the operation from the worklist, +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. +/// +struct ShapeInferencePass + : public mlir::PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + + void runOnOperation() override { + auto f = getOperation(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet opWorklist; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + opWorklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!opWorklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(opWorklist, allOperandsInferred); + if (nextop == opWorklist.end()) + break; + + Operation *op = *nextop; + opWorklist.erase(op); + + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /// A utility method that returns if the given operation has all of its + /// operands inferred. + static bool allOperandsInferred(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return llvm::isa(operandType); + }); + } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !llvm::isa(resultType); + }); + } +}; +} // namespace + +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/Ch7/mlir/ToyCombine.cpp b/Ch7/mlir/ToyCombine.cpp new file mode 100644 index 0000000..72f5e4b --- /dev/null +++ b/Ch7/mlir/ToyCombine.cpp @@ -0,0 +1,90 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Dialect.h" +#include "llvm/Support/Casting.h" +#include +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // namespace + +/// Fold constants. +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } + +/// Fold struct constants. +OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } + +/// Fold simple struct access operations that access into a constant. +OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { + auto structAttr = + llvm::dyn_cast_if_present(adaptor.getInput()); + if (!structAttr) + return nullptr; + + size_t elementIndex = getIndex(); + return structAttr[elementIndex]; +} + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> x +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. + mlir::LogicalResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); + + // Input defined by another transpose? If not, no match. + if (!transposeInputOp) + return failure(); + + // Otherwise, we have a redundant transpose. Use the rewriter. + rewriter.replaceOp(op, {transposeInputOp.getOperand()}); + return success(); + } +}; + +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} diff --git a/Ch7/mlir/ToyCombine.inc b/Ch7/mlir/ToyCombine.inc new file mode 100644 index 0000000..61c6203 --- /dev/null +++ b/Ch7/mlir/ToyCombine.inc @@ -0,0 +1,176 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* Rewriters *| +|* *| +|* Automatically generated file, do not edit! *| +|* From: ToyCombine.td *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/* Generated from: + ToyCombine.td:46 +*/ +struct FoldConstantReshapeOptPattern : public ::mlir::RewritePattern { + FoldConstantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.constant"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::DenseElementsAttr arg; + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ConstantOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ConstantOp type"; + }); + } + { + auto tblgen_attr = op1->getAttrOfType<::mlir::DenseElementsAttr>("value");(void)tblgen_attr; + if (!(tblgen_attr)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "expected op 'toy.constant' to have attribute 'value' of type '::mlir::DenseElementsAttr'"; + }); + } + arg = tblgen_attr; + } + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + auto nativeVar_0 = arg.reshape(::llvm::cast((*res.getODSResults(0).begin()).getType())); (void)nativeVar_0; + ::mlir::toy::ConstantOp tblgen_ConstantOp_1; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + if (auto tmpAttr = nativeVar_0) { + tblgen_attrs.emplace_back(rewriter.getStringAttr("value"), tmpAttr); + } + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ConstantOp_1 = rewriter.create<::mlir::toy::ConstantOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ConstantOp_1.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:59 +*/ +struct RedundantReshapeOptPattern : public ::mlir::RewritePattern { + RedundantReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 1, context, {}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::mlir::toy::ReshapeOp res; + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + res = castedOp0; + arg = castedOp0.getODSOperands(0); + if (!(((*res.getODSResults(0).begin()).getType() == (*arg.begin()).getType()))){ + return rewriter.notifyMatchFailure(op0, [&](::mlir::Diagnostic &diag) { + diag << "entities 'res, arg' failed to satisfy constraint: ''"; + }); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ arg }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +/* Generated from: + ToyCombine.td:33 +*/ +struct ReshapeReshapeOptPattern : public ::mlir::RewritePattern { + ReshapeReshapeOptPattern(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("toy.reshape", 2, context, {"toy.reshape"}) {} + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops + ::mlir::Operation::operand_range arg(op0->getOperands()); + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + // Match + tblgen_ops.push_back(op0); + auto castedOp0 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0); (void)castedOp0; + { + auto *op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); + if (!(op1)){ + return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) { + diag << "There's no operation that defines operand 0 of castedOp0"; + }); + } + auto castedOp1 = ::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op1); (void)castedOp1; + if (!(castedOp1)){ + return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) { + diag << "castedOp1 is not ::mlir::toy::ReshapeOp type"; + }); + } + arg = castedOp1.getODSOperands(0); + tblgen_ops.push_back(op1); + } + + // Rewrite + auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc; + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + ::mlir::toy::ReshapeOp tblgen_ReshapeOp_0; + { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values; + ::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs; + tblgen_values.push_back((*arg.begin())); + ::llvm::SmallVector<::mlir::Type, 4> tblgen_types; (void)tblgen_types; + for (auto v: castedOp0.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + tblgen_ReshapeOp_0 = rewriter.create<::mlir::toy::ReshapeOp>(odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + } + + for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ReshapeOp_0.getODSResults(0) }) { + tblgen_repl_values.push_back(v); + } + + rewriter.replaceOp(op0, tblgen_repl_values); + return ::mlir::success(); + }; +}; + +void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} diff --git a/Ch7/mlir/ToyCombine.td b/Ch7/mlir/ToyCombine.td new file mode 100644 index 0000000..11d7831 --- /dev/null +++ b/Ch7/mlir/ToyCombine.td @@ -0,0 +1,63 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +include "mlir/IR/PatternBase.td" +include "toy/Ops.td" + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : Constraint>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/Ch7/mlir/run.sh b/Ch7/mlir/run.sh new file mode 100644 index 0000000..f592fde --- /dev/null +++ b/Ch7/mlir/run.sh @@ -0,0 +1,2 @@ +mlir-tblgen-18 -gen-rewriters -I /usr/lib/llvm-18/include -I ../include ToyCombine.td > ToyCombine.inc + diff --git a/Ch7/parser/AST.cpp b/Ch7/parser/AST.cpp new file mode 100644 index 0000000..e38a743 --- /dev/null +++ b/Ch7/parser/AST.cpp @@ -0,0 +1,274 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *node); + +private: + void dump(const VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *node); + void dump(StructLiteralExprAST *node); + void dump(VariableExprAST *node); + void dump(ReturnExprAST *node); + void dump(BinaryExprAST *node); + void dump(CallExprAST *node); + void dump(PrintExprAST *node); + void dump(PrototypeAST *node); + void dump(FunctionAST *node); + void dump(StructAST *node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template +static std::string loc(T *node) { + const auto &loc = node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { + llvm::TypeSwitch(expr) + .Case([&](auto *node) { this->dump(node); }) + .Default([&](ExprAST *) { + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; + }); +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + if (auto *initVal = varDecl->getInitVal()) + dump(initVal); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recursively a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *litOrNum) { + // Inside a literal expression we can have either a number or another literal + if (auto *num = llvm::dyn_cast(litOrNum)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(litOrNum); + + // Print the dimension for this literal first + llvm::errs() << "<"; + llvm::interleaveComma(literal->getDims(), llvm::errs()); + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + llvm::interleaveComma(literal->getValues(), llvm::errs(), + [&](auto &elt) { printLitHelper(elt.get()); }); + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(node); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a struct literal. +void ASTDumper::dump(StructLiteralExprAST *node) { + INDENT(); + llvm::errs() << "Struct Literal: "; + for (auto &value : node->getValues()) + dump(value.get()); + indent(); + llvm::errs() << " " << loc(node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *node) { + INDENT(); + llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *node) { + INDENT(); + llvm::errs() << "Return\n"; + if (node->getExpr().has_value()) + return dump(*node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *node) { + INDENT(); + llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; + dump(node->getLHS()); + dump(node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *node) { + INDENT(); + llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; + for (auto &arg : node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *node) { + INDENT(); + llvm::errs() << "Print [ " << loc(node) << "\n"; + dump(node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(const VarType &type) { + llvm::errs() << "<"; + if (!type.name.empty()) + llvm::errs() << type.name; + else + llvm::interleaveComma(type.shape, llvm::errs()); + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *node) { + INDENT(); + llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; + indent(); + llvm::errs() << "Params: ["; + llvm::interleaveComma(node->getArgs(), llvm::errs(), + [](auto &arg) { llvm::errs() << arg->getName(); }); + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(node->getProto()); + dump(node->getBody()); +} + +/// Print a struct. +void ASTDumper::dump(StructAST *node) { + INDENT(); + llvm::errs() << "Struct: " << node->getName() << " " << loc(node) << "\n"; + + { + INDENT(); + llvm::errs() << "Variables: [\n"; + for (auto &variable : node->getVariables()) + dump(variable.get()); + indent(); + llvm::errs() << "]\n"; + } +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &record : *node) { + if (FunctionAST *function = llvm::dyn_cast(record.get())) + dump(function); + else if (StructAST *str = llvm::dyn_cast(record.get())) + dump(str); + else + llvm::errs() << "getKind() << ">\n"; + } +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/Ch7/toyc.cpp b/Ch7/toyc.cpp new file mode 100644 index 0000000..5eb40b7 --- /dev/null +++ b/Ch7/toyc.cpp @@ -0,0 +1,330 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/AST.h" +#include "toy/Dialect.h" +#include "toy/Lexer.h" +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} // namespace +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { + None, + DumpAST, + DumpMLIR, + DumpMLIRAffine, + DumpMLIRLLVM, + DumpLLVMIR, + RunJIT +}; +} // namespace +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering")), + cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", + "output the MLIR dump after llvm lowering")), + cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), + cl::values( + clEnumValN(RunJIT, "jit", + "JIT the code and run it by invoking the main function"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return nullptr; + } + auto buffer = fileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); + Parser parser(lexer); + return parser.parseModule(); +} + +int loadMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).ends_with(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 6; + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; +} + +int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) { + if (int error = loadMLIR(context, module)) + return error; + + mlir::PassManager pm(module.get()->getName()); + // Apply any generic pass manager command line options and run the pipeline. + if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) + return 4; + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; + + if (enableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::toy::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect. + pm.addPass(mlir::toy::createLowerToAffinePass()); + + // Add a few cleanups post lowering. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (enableOpt) { + optPM.addPass(mlir::affine::createLoopFusionPass()); + optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); + } + } + + if (isLoweringToLLVM) { + // Finish lowering the toy IR to the LLVM dialect. + pm.addPass(mlir::toy::createLowerToLLVMPass()); + // This is necessary to have line tables emitted and basic + // debugger working. In the future we will add proper debug information + // emission directly from our frontend. + pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); + } + + if (mlir::failed(pm.run(*module))) + return 4; + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int dumpLLVMIR(mlir::ModuleOp module) { + // Register the translation to LLVM IR with the MLIR context. + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // Convert the module to LLVM IR in a new LLVM IR context. + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // Create target machine and configure the LLVM Module + auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!tmBuilderOrError) { + llvm::errs() << "Could not create JITTargetMachineBuilder\n"; + return -1; + } + + auto tmOrError = tmBuilderOrError->createTargetMachine(); + if (!tmOrError) { + llvm::errs() << "Could not create TargetMachine\n"; + return -1; + } + mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), + tmOrError.get().get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} + +int runJit(mlir::ModuleOp module) { + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // Register the translation from MLIR to LLVM IR, which must happen before we + // can JIT-compile. + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + mlir::ExecutionEngineOptions engineOptions; + engineOptions.transformer = optPipeline; + auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invokePacked("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} + +int main(int argc, char **argv) { + // Register any command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + if (emitAction == Action::DumpAST) + return dumpAST(); + + // If we aren't dumping the AST, then we are compiling with/to MLIR. + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + + mlir::MLIRContext context(registry); + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + + mlir::OwningOpRef module; + if (int error = loadAndProcessMLIR(context, module)) + return error; + + // If we aren't exporting to non-mlir, then we are done. + bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; + if (isOutputingMLIR) { + module->dump(); + return 0; + } + + // Check to see if we are compiling to LLVM IR. + if (emitAction == Action::DumpLLVMIR) + return dumpLLVMIR(*module); + + // Otherwise, we must be running the jit. + if (emitAction == Action::RunJIT) + return runJit(*module); + + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + return -1; +} diff --git a/Examples/Toy/Ch1/ast.toy b/Examples/Toy/Ch1/ast.toy new file mode 100644 index 0000000..4af2d25 --- /dev/null +++ b/Examples/Toy/Ch1/ast.toy @@ -0,0 +1,74 @@ +# RUN: toyc-ch1 %s -emit=ast 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <3, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<3, 2>` + var d = multiply_transpose(b, a); + # A new call with `<3, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(c, d); + # Finally, calling into `multiply_transpose` with incompatible shapes + # (<2, 3> and <3, 2>) will trigger a shape inference error. + var f = multiply_transpose(a, c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1 +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1 +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: d @{{.*}}ast.toy:25:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl f<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:33 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block diff --git a/Examples/Toy/Ch1/empty.toy b/Examples/Toy/Ch1/empty.toy new file mode 100644 index 0000000..1e1e83a --- /dev/null +++ b/Examples/Toy/Ch1/empty.toy @@ -0,0 +1,3 @@ +# RUN: toyc-ch1 %s -emit=ast 2>&1 | FileCheck %s +# CHECK-NOT: Assert +# CHECK: Parse error diff --git a/Examples/Toy/Ch2/ast.toy b/Examples/Toy/Ch2/ast.toy new file mode 100644 index 0000000..48bd443 --- /dev/null +++ b/Examples/Toy/Ch2/ast.toy @@ -0,0 +1,76 @@ +# RUN: toyc-ch2 %s -emit=ast 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var f = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1 +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1 +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl f<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 +# CHECK-NEXT: ] +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 +# CHECK-NEXT: ] + diff --git a/Examples/Toy/Ch2/codegen.toy b/Examples/Toy/Ch2/codegen.toy new file mode 100644 index 0000000..12178d6 --- /dev/null +++ b/Examples/Toy/Ch2/codegen.toy @@ -0,0 +1,31 @@ +# RUN: toyc-ch2 %s -emit=mlir 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} + +# CHECK-LABEL: toy.func @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64> +# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64> +# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64> + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: toy.print [[VAL_10]] : tensor<*xf64> +# CHECK-NEXT: toy.return diff --git a/Examples/Toy/Ch2/empty.toy b/Examples/Toy/Ch2/empty.toy new file mode 100644 index 0000000..36d092e --- /dev/null +++ b/Examples/Toy/Ch2/empty.toy @@ -0,0 +1,3 @@ +# RUN: toyc-ch2 %s -emit=ast 2>&1 | FileCheck %s +# CHECK-NOT: Assert +# CHECK: Parse error diff --git a/Examples/Toy/Ch2/invalid.mlir b/Examples/Toy/Ch2/invalid.mlir new file mode 100644 index 0000000..b3ff353 --- /dev/null +++ b/Examples/Toy/Ch2/invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not toyc-ch2 %s -emit=mlir 2>&1 + +// The following IR is not "valid": +// - toy.print should not return a value. +// - toy.print should take an argument. +// - There should be a block terminator. +toy.func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} diff --git a/Examples/Toy/Ch2/scalar.toy b/Examples/Toy/Ch2/scalar.toy new file mode 100644 index 0000000..b109898 --- /dev/null +++ b/Examples/Toy/Ch2/scalar.toy @@ -0,0 +1,14 @@ +# RUN: toyc-ch2 %s -emit=mlir 2>&1 | FileCheck %s + +def main() { + var a<2, 2> = 5.5; + print(a); +} + +# CHECK-LABEL: toy.func @main() { +# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor +# CHECK-NEXT: %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> +# CHECK-NEXT: toy.print %1 : tensor<2x2xf64> +# CHECK-NEXT: toy.return +# CHECK-NEXT: } + diff --git a/Examples/Toy/Ch3/ast.toy b/Examples/Toy/Ch3/ast.toy new file mode 100644 index 0000000..15ac242 --- /dev/null +++ b/Examples/Toy/Ch3/ast.toy @@ -0,0 +1,76 @@ +# RUN: toyc-ch3 %s -emit=ast 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var f = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1 +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1 +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl f<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 +# CHECK-NEXT: ] +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 +# CHECK-NEXT: ] + diff --git a/Examples/Toy/Ch3/codegen.toy b/Examples/Toy/Ch3/codegen.toy new file mode 100644 index 0000000..b5d4c14 --- /dev/null +++ b/Examples/Toy/Ch3/codegen.toy @@ -0,0 +1,31 @@ +# RUN: toyc-ch3 %s -emit=mlir 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} + +# CHECK-LABEL: toy.func @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64> +# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64> +# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64> + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: toy.print [[VAL_10]] : tensor<*xf64> +# CHECK-NEXT: toy.return diff --git a/Examples/Toy/Ch3/empty.toy b/Examples/Toy/Ch3/empty.toy new file mode 100644 index 0000000..87baab2 --- /dev/null +++ b/Examples/Toy/Ch3/empty.toy @@ -0,0 +1,3 @@ +# RUN: toyc-ch3 %s -emit=ast 2>&1 | FileCheck %s +# CHECK-NOT: Assert +# CHECK: Parse error diff --git a/Examples/Toy/Ch3/invalid.mlir b/Examples/Toy/Ch3/invalid.mlir new file mode 100644 index 0000000..7e251cc --- /dev/null +++ b/Examples/Toy/Ch3/invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not toyc-ch3 %s -emit=mlir 2>&1 + +// The following IR is not "valid": +// - toy.print should not return a value. +// - toy.print should take an argument. +// - There should be a block terminator. +toy.func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} diff --git a/Examples/Toy/Ch3/scalar.toy b/Examples/Toy/Ch3/scalar.toy new file mode 100644 index 0000000..958808a --- /dev/null +++ b/Examples/Toy/Ch3/scalar.toy @@ -0,0 +1,14 @@ +# RUN: toyc-ch3 %s -emit=mlir 2>&1 | FileCheck %s + +def main() { + var a<2, 2> = 5.5; + print(a); +} + +# CHECK-LABEL: toy.func @main() { +# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor +# CHECK-NEXT: %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> +# CHECK-NEXT: toy.print %1 : tensor<2x2xf64> +# CHECK-NEXT: toy.return +# CHECK-NEXT: } + diff --git a/Examples/Toy/Ch3/transpose_transpose.toy b/Examples/Toy/Ch3/transpose_transpose.toy new file mode 100644 index 0000000..2f13a3f --- /dev/null +++ b/Examples/Toy/Ch3/transpose_transpose.toy @@ -0,0 +1,22 @@ +# RUN: toyc-ch3 %s -emit=mlir -opt 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def transpose_transpose(x) { + return transpose(transpose(x)); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b = transpose_transpose(a); + print(b); +} + +# CHECK-LABEL: toy.func @transpose_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: toy.return [[VAL_0]] : tensor<*xf64> + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: [[VAL_2:%.*]] = toy.generic_call @transpose_transpose([[VAL_1]]) : (tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: toy.print [[VAL_2]] : tensor<*xf64> +# CHECK-NEXT: toy.return \ No newline at end of file diff --git a/Examples/Toy/Ch3/trivial_reshape.toy b/Examples/Toy/Ch3/trivial_reshape.toy new file mode 100644 index 0000000..0bdbe22 --- /dev/null +++ b/Examples/Toy/Ch3/trivial_reshape.toy @@ -0,0 +1,16 @@ +# RUN: toyc-ch3 %s -emit=mlir -opt 2>&1 | FileCheck %s + +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant +# CHECK-SAME: dense<[ +# CHECK-SAME: [1.000000e+00], [2.000000e+00] +# CHECK-SAME: ]> : tensor<2x1xf64> +# CHECK-NEXT: toy.print [[VAL_0]] : tensor<2x1xf64> +# CHECK-NEXT: toy.return \ No newline at end of file diff --git a/Examples/Toy/Ch4/ast.toy b/Examples/Toy/Ch4/ast.toy new file mode 100644 index 0000000..d665be5 --- /dev/null +++ b/Examples/Toy/Ch4/ast.toy @@ -0,0 +1,76 @@ +# RUN: toyc-ch4 %s -emit=ast 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var f = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1 +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1 +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl f<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 +# CHECK-NEXT: ] +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 +# CHECK-NEXT: ] + diff --git a/Examples/Toy/Ch4/codegen.toy b/Examples/Toy/Ch4/codegen.toy new file mode 100644 index 0000000..594ddc0 --- /dev/null +++ b/Examples/Toy/Ch4/codegen.toy @@ -0,0 +1,31 @@ +# RUN: toyc-ch4 %s -emit=mlir 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} + +# CHECK-LABEL: toy.func private @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64> +# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64> +# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64> + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: toy.print [[VAL_10]] : tensor<*xf64> +# CHECK-NEXT: toy.return diff --git a/Examples/Toy/Ch4/empty.toy b/Examples/Toy/Ch4/empty.toy new file mode 100644 index 0000000..5ad37f7 --- /dev/null +++ b/Examples/Toy/Ch4/empty.toy @@ -0,0 +1,3 @@ +# RUN: toyc-ch4 %s -emit=ast 2>&1 | FileCheck %s +# CHECK-NOT: Assert +# CHECK: Parse error diff --git a/Examples/Toy/Ch4/invalid.mlir b/Examples/Toy/Ch4/invalid.mlir new file mode 100644 index 0000000..2bdb6ef --- /dev/null +++ b/Examples/Toy/Ch4/invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not toyc-ch4 %s -emit=mlir 2>&1 + +// The following IR is not "valid": +// - toy.print should not return a value. +// - toy.print should take an argument. +// - There should be a block terminator. +toy.func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} diff --git a/Examples/Toy/Ch4/scalar.toy b/Examples/Toy/Ch4/scalar.toy new file mode 100644 index 0000000..39cc7a6 --- /dev/null +++ b/Examples/Toy/Ch4/scalar.toy @@ -0,0 +1,14 @@ +# RUN: toyc-ch4 %s -emit=mlir 2>&1 | FileCheck %s + +def main() { + var a<2, 2> = 5.5; + print(a); +} + +# CHECK-LABEL: toy.func @main() { +# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor +# CHECK-NEXT: %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> +# CHECK-NEXT: toy.print %1 : tensor<2x2xf64> +# CHECK-NEXT: toy.return +# CHECK-NEXT: } + diff --git a/Examples/Toy/Ch4/shape_inference.mlir b/Examples/Toy/Ch4/shape_inference.mlir new file mode 100644 index 0000000..dbe859e --- /dev/null +++ b/Examples/Toy/Ch4/shape_inference.mlir @@ -0,0 +1,30 @@ +// RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s + +// Check the result of inlining+shape inference on an input module. + +toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> + %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> + %2 = toy.mul %0, %1 : tensor<*xf64> + toy.return %2 : tensor<*xf64> +} +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> + %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> + %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64> + %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + toy.print %5 : tensor<*xf64> + toy.return +} + +// CHECK-NOT: toy.func private @multiply_transpose +// CHECK-NOT: tensor<*xf64> + +// CHECK-LABEL: toy.func @main() +// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64> +// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64> +// CHECK: toy.print [[VAL_2]] : tensor<3x2xf64> +// CHECK: toy.return diff --git a/Examples/Toy/Ch4/transpose_transpose.toy b/Examples/Toy/Ch4/transpose_transpose.toy new file mode 100644 index 0000000..e4f08f5 --- /dev/null +++ b/Examples/Toy/Ch4/transpose_transpose.toy @@ -0,0 +1,17 @@ +# RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def transpose_transpose(x) { + return transpose(transpose(x)); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b = transpose_transpose(a); + print(b); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64> +# CHECK-NEXT: toy.return \ No newline at end of file diff --git a/Examples/Toy/Ch4/trivial_reshape.toy b/Examples/Toy/Ch4/trivial_reshape.toy new file mode 100644 index 0000000..d692991 --- /dev/null +++ b/Examples/Toy/Ch4/trivial_reshape.toy @@ -0,0 +1,16 @@ +# RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s + +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant +# CHECK-SAME: dense<[ +# CHECK-SAME: [1.000000e+00], [2.000000e+00] +# CHECK-SAME: ]> : tensor<2x1xf64> +# CHECK-NEXT: toy.print [[VAL_0]] : tensor<2x1xf64> +# CHECK-NEXT: toy.return \ No newline at end of file diff --git a/Examples/Toy/Ch5/affine-lowering.mlir b/Examples/Toy/Ch5/affine-lowering.mlir new file mode 100644 index 0000000..034474d --- /dev/null +++ b/Examples/Toy/Ch5/affine-lowering.mlir @@ -0,0 +1,64 @@ +// RUN: toyc-ch5 %s -emit=mlir-affine 2>&1 | FileCheck %s +// RUN: toyc-ch5 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT + +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> + %3 = toy.mul %2, %2 : tensor<3x2xf64> + toy.print %3 : tensor<3x2xf64> + toy.return +} + +// CHECK-LABEL: func @main() +// CHECK-DAG: [[VAL_0:%.*]] = arith.constant 1.000000e+00 : f64 +// CHECK-DAG: [[VAL_1:%.*]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: [[VAL_2:%.*]] = arith.constant 3.000000e+00 : f64 +// CHECK-DAG: [[VAL_3:%.*]] = arith.constant 4.000000e+00 : f64 +// CHECK-DAG: [[VAL_4:%.*]] = arith.constant 5.000000e+00 : f64 +// CHECK-DAG: [[VAL_5:%.*]] = arith.constant 6.000000e+00 : f64 +// CHECK: [[VAL_6:%.*]] = memref.alloc() : memref<3x2xf64> +// CHECK: [[VAL_7:%.*]] = memref.alloc() : memref<3x2xf64> +// CHECK: [[VAL_8:%.*]] = memref.alloc() : memref<2x3xf64> +// CHECK: affine.store [[VAL_0]], [[VAL_8]][0, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_1]], [[VAL_8]][0, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_2]], [[VAL_8]][0, 2] : memref<2x3xf64> +// CHECK: affine.store [[VAL_3]], [[VAL_8]][1, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_4]], [[VAL_8]][1, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_5]], [[VAL_8]][1, 2] : memref<2x3xf64> +// CHECK: affine.for [[VAL_9:%.*]] = 0 to 3 { +// CHECK: affine.for [[VAL_10:%.*]] = 0 to 2 { +// CHECK: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64> +// CHECK: affine.store [[VAL_11]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<3x2xf64> +// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { +// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { +// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 +// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: toy.print [[VAL_6]] : memref<3x2xf64> +// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> +// CHECK: memref.dealloc [[VAL_7]] : memref<3x2xf64> +// CHECK: memref.dealloc [[VAL_6]] : memref<3x2xf64> + +// OPT-LABEL: func @main() +// OPT-DAG: [[VAL_0:%.*]] = arith.constant 1.000000e+00 : f64 +// OPT-DAG: [[VAL_1:%.*]] = arith.constant 2.000000e+00 : f64 +// OPT-DAG: [[VAL_2:%.*]] = arith.constant 3.000000e+00 : f64 +// OPT-DAG: [[VAL_3:%.*]] = arith.constant 4.000000e+00 : f64 +// OPT-DAG: [[VAL_4:%.*]] = arith.constant 5.000000e+00 : f64 +// OPT-DAG: [[VAL_5:%.*]] = arith.constant 6.000000e+00 : f64 +// OPT: [[VAL_6:%.*]] = memref.alloc() : memref<3x2xf64> +// OPT: [[VAL_7:%.*]] = memref.alloc() : memref<2x3xf64> +// OPT: affine.store [[VAL_0]], [[VAL_7]][0, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_1]], [[VAL_7]][0, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_2]], [[VAL_7]][0, 2] : memref<2x3xf64> +// OPT: affine.store [[VAL_3]], [[VAL_7]][1, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_4]], [[VAL_7]][1, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_5]], [[VAL_7]][1, 2] : memref<2x3xf64> +// OPT: affine.for [[VAL_8:%.*]] = 0 to 3 { +// OPT: affine.for [[VAL_9:%.*]] = 0 to 2 { +// OPT: [[VAL_10:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_8]]] : memref<2x3xf64> +// OPT: [[VAL_11:%.*]] = arith.mulf [[VAL_10]], [[VAL_10]] : f64 +// OPT: affine.store [[VAL_11]], [[VAL_6]]{{\[}}[[VAL_8]], [[VAL_9]]] : memref<3x2xf64> +// OPT: toy.print [[VAL_6]] : memref<3x2xf64> +// OPT: memref.dealloc [[VAL_7]] : memref<2x3xf64> +// OPT: memref.dealloc [[VAL_6]] : memref<3x2xf64> diff --git a/Examples/Toy/Ch5/ast.toy b/Examples/Toy/Ch5/ast.toy new file mode 100644 index 0000000..9840ea2 --- /dev/null +++ b/Examples/Toy/Ch5/ast.toy @@ -0,0 +1,76 @@ +# RUN: toyc-ch5 %s -emit=ast 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var f = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1 +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1 +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl f<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 +# CHECK-NEXT: ] +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 +# CHECK-NEXT: ] + diff --git a/Examples/Toy/Ch5/codegen.toy b/Examples/Toy/Ch5/codegen.toy new file mode 100644 index 0000000..1010502 --- /dev/null +++ b/Examples/Toy/Ch5/codegen.toy @@ -0,0 +1,31 @@ +# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} + +# CHECK-LABEL: toy.func private @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64> +# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64> +# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64> + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: toy.print [[VAL_10]] : tensor<*xf64> +# CHECK-NEXT: toy.return diff --git a/Examples/Toy/Ch5/empty.toy b/Examples/Toy/Ch5/empty.toy new file mode 100644 index 0000000..d43f34e --- /dev/null +++ b/Examples/Toy/Ch5/empty.toy @@ -0,0 +1,3 @@ +# RUN: toyc-ch5 %s -emit=ast 2>&1 | FileCheck %s +# CHECK-NOT: Assert +# CHECK: Parse error diff --git a/Examples/Toy/Ch5/invalid.mlir b/Examples/Toy/Ch5/invalid.mlir new file mode 100644 index 0000000..05f818c --- /dev/null +++ b/Examples/Toy/Ch5/invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not toyc-ch5 %s -emit=mlir 2>&1 + +// The following IR is not "valid": +// - toy.print should not return a value. +// - toy.print should take an argument. +// - There should be a block terminator. +toy.func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} diff --git a/Examples/Toy/Ch5/scalar.toy b/Examples/Toy/Ch5/scalar.toy new file mode 100644 index 0000000..b8f5384 --- /dev/null +++ b/Examples/Toy/Ch5/scalar.toy @@ -0,0 +1,14 @@ +# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s + +def main() { + var a<2, 2> = 5.5; + print(a); +} + +# CHECK-LABEL: func @main() { +# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor +# CHECK-NEXT: %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> +# CHECK-NEXT: toy.print %1 : tensor<2x2xf64> +# CHECK-NEXT: toy.return +# CHECK-NEXT: } + diff --git a/Examples/Toy/Ch5/shape_inference.mlir b/Examples/Toy/Ch5/shape_inference.mlir new file mode 100644 index 0000000..50cc492 --- /dev/null +++ b/Examples/Toy/Ch5/shape_inference.mlir @@ -0,0 +1,30 @@ +// RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s + +// Check the result of inlining+shape inference on an input module. + +toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> + %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> + %2 = toy.mul %0, %1 : tensor<*xf64> + toy.return %2 : tensor<*xf64> +} +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> + %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> + %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64> + %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + toy.print %5 : tensor<*xf64> + toy.return +} + +// CHECK-NOT: toy.func @multiply_transpose +// CHECK-NOT: tensor<*xf64> + +// CHECK-LABEL: toy.func @main() +// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64> +// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64> +// CHECK: toy.print [[VAL_2]] : tensor<3x2xf64> +// CHECK: toy.return diff --git a/Examples/Toy/Ch5/transpose_transpose.toy b/Examples/Toy/Ch5/transpose_transpose.toy new file mode 100644 index 0000000..df74fd4 --- /dev/null +++ b/Examples/Toy/Ch5/transpose_transpose.toy @@ -0,0 +1,17 @@ +# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def transpose_transpose(x) { + return transpose(transpose(x)); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b = transpose_transpose(a); + print(b); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64> +# CHECK-NEXT: toy.return \ No newline at end of file diff --git a/Examples/Toy/Ch5/trivial_reshape.toy b/Examples/Toy/Ch5/trivial_reshape.toy new file mode 100644 index 0000000..9fee8c6 --- /dev/null +++ b/Examples/Toy/Ch5/trivial_reshape.toy @@ -0,0 +1,16 @@ +# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s + +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant +# CHECK-SAME: dense<[ +# CHECK-SAME: [1.000000e+00], [2.000000e+00] +# CHECK-SAME: ]> : tensor<2x1xf64> +# CHECK-NEXT: toy.print [[VAL_0]] : tensor<2x1xf64> +# CHECK-NEXT: toy.return diff --git a/Examples/Toy/Ch6/affine-lowering.mlir b/Examples/Toy/Ch6/affine-lowering.mlir new file mode 100644 index 0000000..51dedaf --- /dev/null +++ b/Examples/Toy/Ch6/affine-lowering.mlir @@ -0,0 +1,64 @@ +// RUN: toyc-ch6 %s -emit=mlir-affine 2>&1 | FileCheck %s +// RUN: toyc-ch6 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT + +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> + %3 = toy.mul %2, %2 : tensor<3x2xf64> + toy.print %3 : tensor<3x2xf64> + toy.return +} + +// CHECK-LABEL: func @main() +// CHECK-DAG: [[VAL_0:%.*]] = arith.constant 1.000000e+00 : f64 +// CHECK-DAG: [[VAL_1:%.*]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: [[VAL_2:%.*]] = arith.constant 3.000000e+00 : f64 +// CHECK-DAG: [[VAL_3:%.*]] = arith.constant 4.000000e+00 : f64 +// CHECK-DAG: [[VAL_4:%.*]] = arith.constant 5.000000e+00 : f64 +// CHECK-DAG: [[VAL_5:%.*]] = arith.constant 6.000000e+00 : f64 +// CHECK: [[VAL_6:%.*]] = memref.alloc() : memref<3x2xf64> +// CHECK: [[VAL_7:%.*]] = memref.alloc() : memref<3x2xf64> +// CHECK: [[VAL_8:%.*]] = memref.alloc() : memref<2x3xf64> +// CHECK: affine.store [[VAL_0]], [[VAL_8]][0, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_1]], [[VAL_8]][0, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_2]], [[VAL_8]][0, 2] : memref<2x3xf64> +// CHECK: affine.store [[VAL_3]], [[VAL_8]][1, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_4]], [[VAL_8]][1, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_5]], [[VAL_8]][1, 2] : memref<2x3xf64> +// CHECK: affine.for [[VAL_9:%.*]] = 0 to 3 { +// CHECK: affine.for [[VAL_10:%.*]] = 0 to 2 { +// CHECK: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64> +// CHECK: affine.store [[VAL_11]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<3x2xf64> +// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { +// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { +// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 +// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: toy.print [[VAL_6]] : memref<3x2xf64> +// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> +// CHECK: memref.dealloc [[VAL_7]] : memref<3x2xf64> +// CHECK: memref.dealloc [[VAL_6]] : memref<3x2xf64> + +// OPT-LABEL: func @main() +// OPT-DAG: [[VAL_0:%.*]] = arith.constant 1.000000e+00 : f64 +// OPT-DAG: [[VAL_1:%.*]] = arith.constant 2.000000e+00 : f64 +// OPT-DAG: [[VAL_2:%.*]] = arith.constant 3.000000e+00 : f64 +// OPT-DAG: [[VAL_3:%.*]] = arith.constant 4.000000e+00 : f64 +// OPT-DAG: [[VAL_4:%.*]] = arith.constant 5.000000e+00 : f64 +// OPT-DAG: [[VAL_5:%.*]] = arith.constant 6.000000e+00 : f64 +// OPT: [[VAL_6:%.*]] = memref.alloc() : memref<3x2xf64> +// OPT: [[VAL_7:%.*]] = memref.alloc() : memref<2x3xf64> +// OPT: affine.store [[VAL_0]], [[VAL_7]][0, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_1]], [[VAL_7]][0, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_2]], [[VAL_7]][0, 2] : memref<2x3xf64> +// OPT: affine.store [[VAL_3]], [[VAL_7]][1, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_4]], [[VAL_7]][1, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_5]], [[VAL_7]][1, 2] : memref<2x3xf64> +// OPT: affine.for [[VAL_8:%.*]] = 0 to 3 { +// OPT: affine.for [[VAL_9:%.*]] = 0 to 2 { +// OPT: [[VAL_10:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_8]]] : memref<2x3xf64> +// OPT: [[VAL_11:%.*]] = arith.mulf [[VAL_10]], [[VAL_10]] : f64 +// OPT: affine.store [[VAL_11]], [[VAL_6]]{{\[}}[[VAL_8]], [[VAL_9]]] : memref<3x2xf64> +// OPT: toy.print [[VAL_6]] : memref<3x2xf64> +// OPT: memref.dealloc [[VAL_7]] : memref<2x3xf64> +// OPT: memref.dealloc [[VAL_6]] : memref<3x2xf64> diff --git a/Examples/Toy/Ch6/ast.toy b/Examples/Toy/Ch6/ast.toy new file mode 100644 index 0000000..f5fc278 --- /dev/null +++ b/Examples/Toy/Ch6/ast.toy @@ -0,0 +1,76 @@ +# RUN: toyc-ch6 %s -emit=ast 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var f = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1 +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1 +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl f<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 +# CHECK-NEXT: ] +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 +# CHECK-NEXT: ] + diff --git a/Examples/Toy/Ch6/codegen.toy b/Examples/Toy/Ch6/codegen.toy new file mode 100644 index 0000000..1d121cd --- /dev/null +++ b/Examples/Toy/Ch6/codegen.toy @@ -0,0 +1,31 @@ +# RUN: toyc-ch6 %s -emit=mlir 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} + +# CHECK-LABEL: toy.func private @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64> +# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64> +# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64> + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: toy.print [[VAL_10]] : tensor<*xf64> +# CHECK-NEXT: toy.return diff --git a/Examples/Toy/Ch6/empty.toy b/Examples/Toy/Ch6/empty.toy new file mode 100644 index 0000000..f221f38 --- /dev/null +++ b/Examples/Toy/Ch6/empty.toy @@ -0,0 +1,3 @@ +# RUN: toyc-ch6 %s -emit=ast 2>&1 | FileCheck %s +# CHECK-NOT: Assert +# CHECK: Parse error diff --git a/Examples/Toy/Ch6/invalid.mlir b/Examples/Toy/Ch6/invalid.mlir new file mode 100644 index 0000000..8d6d497 --- /dev/null +++ b/Examples/Toy/Ch6/invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not toyc-ch6 %s -emit=mlir 2>&1 + +// The following IR is not "valid": +// - toy.print should not return a value. +// - toy.print should take an argument. +// - There should be a block terminator. +toy.func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} diff --git a/Examples/Toy/Ch6/jit.toy b/Examples/Toy/Ch6/jit.toy new file mode 100644 index 0000000..c5be603 --- /dev/null +++ b/Examples/Toy/Ch6/jit.toy @@ -0,0 +1,6 @@ +# RUN: toyc-ch6 -emit=jit %s +# UNSUPPORTED: target={{.*windows.*}} + +def main() { + print([[1, 2], [3, 4]]); +} diff --git a/Examples/Toy/Ch6/lit.local.cfg b/Examples/Toy/Ch6/lit.local.cfg new file mode 100644 index 0000000..0d9aa10 --- /dev/null +++ b/Examples/Toy/Ch6/lit.local.cfg @@ -0,0 +1,3 @@ +# Requires native execution. +if "host-supports-jit" not in config.available_features: + config.unsupported = True diff --git a/Examples/Toy/Ch6/llvm-lowering.mlir b/Examples/Toy/Ch6/llvm-lowering.mlir new file mode 100644 index 0000000..37ad2bd --- /dev/null +++ b/Examples/Toy/Ch6/llvm-lowering.mlir @@ -0,0 +1,23 @@ +// RUN: toyc-ch6 %s -emit=llvm -opt + +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> + %3 = toy.mul %2, %2 : tensor<3x2xf64> + toy.print %3 : tensor<3x2xf64> + toy.return +} + +// CHECK-LABEL: define void @main() +// CHECK: @printf +// CHECK-SAME: 1.000000e+00 +// CHECK: @printf +// CHECK-SAME: 1.600000e+01 +// CHECK: @printf +// CHECK-SAME: 4.000000e+00 +// CHECK: @printf +// CHECK-SAME: 2.500000e+01 +// CHECK: @printf +// CHECK-SAME: 9.000000e+00 +// CHECK: @printf +// CHECK-SAME: 3.000000e+01 diff --git a/Examples/Toy/Ch6/scalar.toy b/Examples/Toy/Ch6/scalar.toy new file mode 100644 index 0000000..351d1e7 --- /dev/null +++ b/Examples/Toy/Ch6/scalar.toy @@ -0,0 +1,14 @@ +# RUN: toyc-ch6 %s -emit=mlir 2>&1 | FileCheck %s + +def main() { + var a<2, 2> = 5.5; + print(a); +} + +# CHECK-LABEL: toy.func @main() { +# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor +# CHECK-NEXT: %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> +# CHECK-NEXT: toy.print %1 : tensor<2x2xf64> +# CHECK-NEXT: toy.return +# CHECK-NEXT: } + diff --git a/Examples/Toy/Ch6/shape_inference.mlir b/Examples/Toy/Ch6/shape_inference.mlir new file mode 100644 index 0000000..7d23f88 --- /dev/null +++ b/Examples/Toy/Ch6/shape_inference.mlir @@ -0,0 +1,30 @@ +// RUN: toyc-ch6 %s -emit=mlir -opt 2>&1 | FileCheck %s + +// Check the result of inlining+shape inference on an input module. + +toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> + %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> + %2 = toy.mul %0, %1 : tensor<*xf64> + toy.return %2 : tensor<*xf64> +} +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> + %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> + %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64> + %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + toy.print %5 : tensor<*xf64> + toy.return +} + +// CHECK-NOT: toy.func @multiply_transpose +// CHECK-NOT: tensor<*xf64> + +// CHECK-LABEL: toy.func @main() +// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64> +// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64> +// CHECK: toy.print [[VAL_2]] : tensor<3x2xf64> +// CHECK: toy.return diff --git a/Examples/Toy/Ch6/transpose_transpose.toy b/Examples/Toy/Ch6/transpose_transpose.toy new file mode 100644 index 0000000..c17f2f4 --- /dev/null +++ b/Examples/Toy/Ch6/transpose_transpose.toy @@ -0,0 +1,17 @@ +# RUN: toyc-ch6 %s -emit=mlir -opt 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def transpose_transpose(x) { + return transpose(transpose(x)); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b = transpose_transpose(a); + print(b); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64> +# CHECK-NEXT: toy.return \ No newline at end of file diff --git a/Examples/Toy/Ch6/trivial_reshape.toy b/Examples/Toy/Ch6/trivial_reshape.toy new file mode 100644 index 0000000..2beb9e7 --- /dev/null +++ b/Examples/Toy/Ch6/trivial_reshape.toy @@ -0,0 +1,16 @@ +# RUN: toyc-ch6 %s -emit=mlir -opt 2>&1 | FileCheck %s + +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant +# CHECK-SAME: dense<[ +# CHECK-SAME: [1.000000e+00], [2.000000e+00] +# CHECK-SAME: ]> : tensor<2x1xf64> +# CHECK-NEXT: toy.print [[VAL_0]] : tensor<2x1xf64> +# CHECK-NEXT: toy.return diff --git a/Examples/Toy/Ch7/affine-lowering.mlir b/Examples/Toy/Ch7/affine-lowering.mlir new file mode 100644 index 0000000..3cefd0e --- /dev/null +++ b/Examples/Toy/Ch7/affine-lowering.mlir @@ -0,0 +1,64 @@ +// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s +// RUN: toyc-ch7 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT + +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> + %3 = toy.mul %2, %2 : tensor<3x2xf64> + toy.print %3 : tensor<3x2xf64> + toy.return +} + +// CHECK-LABEL: func @main() +// CHECK-DAG: [[VAL_0:%.*]] = arith.constant 1.000000e+00 : f64 +// CHECK-DAG: [[VAL_1:%.*]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: [[VAL_2:%.*]] = arith.constant 3.000000e+00 : f64 +// CHECK-DAG: [[VAL_3:%.*]] = arith.constant 4.000000e+00 : f64 +// CHECK-DAG: [[VAL_4:%.*]] = arith.constant 5.000000e+00 : f64 +// CHECK-DAG: [[VAL_5:%.*]] = arith.constant 6.000000e+00 : f64 +// CHECK: [[VAL_6:%.*]] = memref.alloc() : memref<3x2xf64> +// CHECK: [[VAL_7:%.*]] = memref.alloc() : memref<3x2xf64> +// CHECK: [[VAL_8:%.*]] = memref.alloc() : memref<2x3xf64> +// CHECK: affine.store [[VAL_0]], [[VAL_8]][0, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_1]], [[VAL_8]][0, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_2]], [[VAL_8]][0, 2] : memref<2x3xf64> +// CHECK: affine.store [[VAL_3]], [[VAL_8]][1, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_4]], [[VAL_8]][1, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_5]], [[VAL_8]][1, 2] : memref<2x3xf64> +// CHECK: affine.for [[VAL_9:%.*]] = 0 to 3 { +// CHECK: affine.for [[VAL_10:%.*]] = 0 to 2 { +// CHECK: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64> +// CHECK: affine.store [[VAL_11]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<3x2xf64> +// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { +// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { +// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 +// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> +// CHECK: toy.print [[VAL_6]] : memref<3x2xf64> +// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> +// CHECK: memref.dealloc [[VAL_7]] : memref<3x2xf64> +// CHECK: memref.dealloc [[VAL_6]] : memref<3x2xf64> + +// OPT-LABEL: func @main() +// OPT-DAG: [[VAL_0:%.*]] = arith.constant 1.000000e+00 : f64 +// OPT-DAG: [[VAL_1:%.*]] = arith.constant 2.000000e+00 : f64 +// OPT-DAG: [[VAL_2:%.*]] = arith.constant 3.000000e+00 : f64 +// OPT-DAG: [[VAL_3:%.*]] = arith.constant 4.000000e+00 : f64 +// OPT-DAG: [[VAL_4:%.*]] = arith.constant 5.000000e+00 : f64 +// OPT-DAG: [[VAL_5:%.*]] = arith.constant 6.000000e+00 : f64 +// OPT: [[VAL_6:%.*]] = memref.alloc() : memref<3x2xf64> +// OPT: [[VAL_7:%.*]] = memref.alloc() : memref<2x3xf64> +// OPT: affine.store [[VAL_0]], [[VAL_7]][0, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_1]], [[VAL_7]][0, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_2]], [[VAL_7]][0, 2] : memref<2x3xf64> +// OPT: affine.store [[VAL_3]], [[VAL_7]][1, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_4]], [[VAL_7]][1, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_5]], [[VAL_7]][1, 2] : memref<2x3xf64> +// OPT: affine.for [[VAL_8:%.*]] = 0 to 3 { +// OPT: affine.for [[VAL_9:%.*]] = 0 to 2 { +// OPT: [[VAL_10:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_8]]] : memref<2x3xf64> +// OPT: [[VAL_11:%.*]] = arith.mulf [[VAL_10]], [[VAL_10]] : f64 +// OPT: affine.store [[VAL_11]], [[VAL_6]]{{\[}}[[VAL_8]], [[VAL_9]]] : memref<3x2xf64> +// OPT: toy.print [[VAL_6]] : memref<3x2xf64> +// OPT: memref.dealloc [[VAL_7]] : memref<2x3xf64> +// OPT: memref.dealloc [[VAL_6]] : memref<3x2xf64> diff --git a/Examples/Toy/Ch7/ast.toy b/Examples/Toy/Ch7/ast.toy new file mode 100644 index 0000000..878450a --- /dev/null +++ b/Examples/Toy/Ch7/ast.toy @@ -0,0 +1,76 @@ +# RUN: toyc-ch7 %s -emit=ast 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var f = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1 +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:25 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:10 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:20 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:25 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:35 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1 +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl f<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 +# CHECK-NEXT: ] +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 +# CHECK-NEXT: ] + diff --git a/Examples/Toy/Ch7/codegen.toy b/Examples/Toy/Ch7/codegen.toy new file mode 100644 index 0000000..af6a3bd --- /dev/null +++ b/Examples/Toy/Ch7/codegen.toy @@ -0,0 +1,31 @@ +# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} + +# CHECK-LABEL: toy.func private @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64> +# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64> +# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64> + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = toy.reshape([[VAL_7]] : tensor<6xf64>) to tensor<2x3xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]], [[VAL_8]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_10:%.*]] = toy.generic_call @multiply_transpose([[VAL_8]], [[VAL_6]]) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: toy.print [[VAL_10]] : tensor<*xf64> +# CHECK-NEXT: toy.return diff --git a/Examples/Toy/Ch7/empty.toy b/Examples/Toy/Ch7/empty.toy new file mode 100644 index 0000000..d9d4b8e --- /dev/null +++ b/Examples/Toy/Ch7/empty.toy @@ -0,0 +1,4 @@ +# RUN: toyc-ch7 %s -emit=ast 2>&1 | FileCheck %s +# CHECK-NOT: Assert +# CHECK-NOT: Parse error +# CHECK: Module diff --git a/Examples/Toy/Ch7/invalid.mlir b/Examples/Toy/Ch7/invalid.mlir new file mode 100644 index 0000000..c0aa707 --- /dev/null +++ b/Examples/Toy/Ch7/invalid.mlir @@ -0,0 +1,9 @@ +// RUN: not toyc-ch7 %s -emit=mlir 2>&1 + +// The following IR is not "valid": +// - toy.print should not return a value. +// - toy.print should take an argument. +// - There should be a block terminator. +toy.func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} diff --git a/Examples/Toy/Ch7/jit.toy b/Examples/Toy/Ch7/jit.toy new file mode 100644 index 0000000..82469c7 --- /dev/null +++ b/Examples/Toy/Ch7/jit.toy @@ -0,0 +1,6 @@ +# RUN: toyc-ch7 -emit=jit %s +# UNSUPPORTED: target={{.*windows.*}} + +def main() { + print([[1, 2], [3, 4]]); +} diff --git a/Examples/Toy/Ch7/lit.local.cfg b/Examples/Toy/Ch7/lit.local.cfg new file mode 100644 index 0000000..0d9aa10 --- /dev/null +++ b/Examples/Toy/Ch7/lit.local.cfg @@ -0,0 +1,3 @@ +# Requires native execution. +if "host-supports-jit" not in config.available_features: + config.unsupported = True diff --git a/Examples/Toy/Ch7/llvm-lowering.mlir b/Examples/Toy/Ch7/llvm-lowering.mlir new file mode 100644 index 0000000..fc4c8d5 --- /dev/null +++ b/Examples/Toy/Ch7/llvm-lowering.mlir @@ -0,0 +1,23 @@ +// RUN: toyc-ch7 %s -emit=llvm -opt + +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> + %3 = toy.mul %2, %2 : tensor<3x2xf64> + toy.print %3 : tensor<3x2xf64> + toy.return +} + +// CHECK-LABEL: define void @main() +// CHECK: @printf +// CHECK-SAME: 1.000000e+00 +// CHECK: @printf +// CHECK-SAME: 1.600000e+01 +// CHECK: @printf +// CHECK-SAME: 4.000000e+00 +// CHECK: @printf +// CHECK-SAME: 2.500000e+01 +// CHECK: @printf +// CHECK-SAME: 9.000000e+00 +// CHECK: @printf +// CHECK-SAME: 3.000000e+01 diff --git a/Examples/Toy/Ch7/scalar.toy b/Examples/Toy/Ch7/scalar.toy new file mode 100644 index 0000000..f449028 --- /dev/null +++ b/Examples/Toy/Ch7/scalar.toy @@ -0,0 +1,14 @@ +# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s + +def main() { + var a<2, 2> = 5.5; + print(a); +} + +# CHECK-LABEL: toy.func @main() { +# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor +# CHECK-NEXT: %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> +# CHECK-NEXT: toy.print %1 : tensor<2x2xf64> +# CHECK-NEXT: toy.return +# CHECK-NEXT: } + diff --git a/Examples/Toy/Ch7/shape_inference.mlir b/Examples/Toy/Ch7/shape_inference.mlir new file mode 100644 index 0000000..083305f --- /dev/null +++ b/Examples/Toy/Ch7/shape_inference.mlir @@ -0,0 +1,30 @@ +// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s + +// Check the result of inlining+shape inference on an input module. + +toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> + %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> + %2 = toy.mul %0, %1 : tensor<*xf64> + toy.return %2 : tensor<*xf64> +} +toy.func @main() { + %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> + %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> + %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> + %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64> + %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + toy.print %5 : tensor<*xf64> + toy.return +} + +// CHECK-NOT: func @multiply_transpose +// CHECK-NOT: tensor<*xf64> + +// CHECK-LABEL: func @main() +// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64> +// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64> +// CHECK: toy.print [[VAL_2]] : tensor<3x2xf64> +// CHECK: toy.return diff --git a/Examples/Toy/Ch7/struct-ast.toy b/Examples/Toy/Ch7/struct-ast.toy new file mode 100644 index 0000000..d2ccc5a --- /dev/null +++ b/Examples/Toy/Ch7/struct-ast.toy @@ -0,0 +1,61 @@ +# RUN: toyc-ch7 %s -emit=ast 2>&1 | FileCheck %s + +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} + +# CHECK: Module: +# CHECK-NEXT: Struct: Struct @{{.*}}struct-ast.toy:3:1 +# CHECK-NEXT: Variables: [ +# CHECK-NEXT: VarDecl a<> @{{.*}}struct-ast.toy:4:3 +# CHECK-NEXT: VarDecl b<> @{{.*}}struct-ast.toy:5:3 +# CHECK-NEXT: ] +# CHECK-NEXT:Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}struct-ast.toy:9:1 +# CHECK-NEXT: Params: [value] +# CHECK-NEXT: Block { +# CHECK-NEXT: Return +# CHECK-NEXT: BinOp: * @{{.*}}struct-ast.toy:11:31 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}struct-ast.toy:11:10 +# CHECK-NEXT: BinOp: . @{{.*}}struct-ast.toy:11:26 +# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:11:20 +# CHECK-NEXT: var: a @{{.*}}struct-ast.toy:11:26 +# CHECK-NEXT: ] +# CHECK-NEXT: Call 'transpose' [ @{{.*}}struct-ast.toy:11:31 +# CHECK-NEXT: BinOp: . @{{.*}}struct-ast.toy:11:47 +# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:11:41 +# CHECK-NEXT: var: b @{{.*}}struct-ast.toy:11:47 +# CHECK-NEXT: ] +# CHECK-NEXT: } +# CHECK-NEXT:Function +# CHECK-NEXT: Proto 'main' @{{.*}}struct-ast.toy:14:1 +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl value @{{.*}}struct-ast.toy:16:3 +# CHECK-NEXT: Struct Literal: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}struct-ast.toy:16:19 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}struct-ast.toy:16:43 +# CHECK-NEXT: @{{.*}}struct-ast.toy:16:18 +# CHECK-NEXT: VarDecl c<> @{{.*}}struct-ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}struct-ast.toy:19:11 +# CHECK-NEXT: var: value @{{.*}}struct-ast.toy:19:30 +# CHECK-NEXT: ] +# CHECK-NEXT: Print [ @{{.*}}struct-ast.toy:20:3 +# CHECK-NEXT: var: c @{{.*}}struct-ast.toy:20:9 +# CHECK-NEXT: ] +# CHECK-NEXT: } \ No newline at end of file diff --git a/Examples/Toy/Ch7/struct-codegen.toy b/Examples/Toy/Ch7/struct-codegen.toy new file mode 100644 index 0000000..74dcad0 --- /dev/null +++ b/Examples/Toy/Ch7/struct-codegen.toy @@ -0,0 +1,44 @@ +# RUN: toyc-ch7 %s -emit=mlir 2>&1 | FileCheck %s +# RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT + +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} + +# CHECK-LABEL: toy.func private @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: !toy.struct, tensor<*xf64>>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_1:%.*]] = toy.struct_access [[VAL_0]][0] : !toy.struct, tensor<*xf64>> -> tensor<*xf64> +# CHECK-NEXT: [[VAL_2:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = toy.struct_access [[VAL_0]][1] : !toy.struct, tensor<*xf64>> -> tensor<*xf64> +# CHECK-NEXT: [[VAL_4:%.*]] = toy.transpose([[VAL_3]] : tensor<*xf64>) to tensor<*xf64> +# CHECK-NEXT: [[VAL_5:%.*]] = toy.mul [[VAL_2]], [[VAL_4]] : tensor<*xf64> +# CHECK-NEXT: toy.return [[VAL_5]] : tensor<*xf64> + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_6:%.*]] = toy.struct_constant [dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>] : !toy.struct, tensor<*xf64>> +# CHECK-NEXT: [[VAL_7:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]]) : (!toy.struct, tensor<*xf64>>) -> tensor<*xf64> +# CHECK-NEXT: toy.print [[VAL_7]] : tensor<*xf64> +# CHECK-NEXT: toy.return + +# OPT-LABEL: toy.func @main() +# OPT-NEXT: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# OPT-NEXT: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64> +# OPT-NEXT: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64> +# OPT-NEXT: toy.print [[VAL_2]] : tensor<3x2xf64> +# OPT-NEXT: toy.return diff --git a/Examples/Toy/Ch7/struct-opt.mlir b/Examples/Toy/Ch7/struct-opt.mlir new file mode 100644 index 0000000..3faae88 --- /dev/null +++ b/Examples/Toy/Ch7/struct-opt.mlir @@ -0,0 +1,15 @@ +// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s + +toy.func @main() { + %0 = toy.struct_constant [ + [dense<4.000000e+00> : tensor<2x2xf64>], dense<4.000000e+00> : tensor<2x2xf64> + ] : !toy.struct>, tensor<*xf64>> + %1 = toy.struct_access %0[0] : !toy.struct>, tensor<*xf64>> -> !toy.struct> + %2 = toy.struct_access %1[0] : !toy.struct> -> tensor<*xf64> + toy.print %2 : tensor<*xf64> + toy.return +} + +// CHECK-LABEL: toy.func @main +// CHECK-NEXT: %[[CST:.*]] = toy.constant dense<4.0 +// CHECK-NEXT: toy.print %[[CST]] diff --git a/Examples/Toy/Ch7/transpose_transpose.toy b/Examples/Toy/Ch7/transpose_transpose.toy new file mode 100644 index 0000000..571d7b1 --- /dev/null +++ b/Examples/Toy/Ch7/transpose_transpose.toy @@ -0,0 +1,17 @@ +# RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def transpose_transpose(x) { + return transpose(transpose(x)); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b = transpose_transpose(a); + print(b); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> +# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64> +# CHECK-NEXT: toy.return \ No newline at end of file diff --git a/Examples/Toy/Ch7/trivial_reshape.toy b/Examples/Toy/Ch7/trivial_reshape.toy new file mode 100644 index 0000000..06df1c1 --- /dev/null +++ b/Examples/Toy/Ch7/trivial_reshape.toy @@ -0,0 +1,16 @@ +# RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s + +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} + +# CHECK-LABEL: toy.func @main() +# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant +# CHECK-SAME: dense<[ +# CHECK-SAME: [1.000000e+00], [2.000000e+00] +# CHECK-SAME: ]> : tensor<2x1xf64> +# CHECK-NEXT: toy.print [[VAL_0]] : tensor<2x1xf64> +# CHECK-NEXT: toy.return diff --git a/README.md b/README.md new file mode 100644 index 0000000..4384deb --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +**Note: This repository is only for reference purposes, if you want to use it, you need to find your own solutions to configure** + +# CMake MLIR Toy Tutorial + +This contains sample code to support the tutorial on using MLIR for building a compiler for a simple Toy language. + +See [docs/Tutorials/Toy](../../docs/Tutorials/Toy) at the root of the project for more informations. + +In this repository, you can **run the MLIR toy tutorial (on Debian) without compile the LLVM project** 😋 + +## Environment + +- Debian +- CMake +- Ninja-Build +- LLVM18 +- Clang18 +- MLIR18 + +``` +apt install llvm-18 clang-18 cmake ninja-build mlir-18-tools libmlir-18-dev +``` + +## Note + +The `.td` file need run shell script to generate `.h` and `.cpp` , The reference shell script are on the folder + +``` +mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include Ops.td > Ops.h.inc +mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include Ops.td > Ops.cpp.inc +mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include Ops.td > Dialect.h.inc +mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include Ops.td > Dialect.cpp.inc +``` +