From f908f47ba09eba5d3e112d7620fcd6aca8f306e8 Mon Sep 17 00:00:00 2001 From: Amanda Tang Date: Tue, 17 Sep 2024 14:15:51 -0700 Subject: [PATCH 1/6] Add and update scf condition, if, while, yield --- mlir/dialects/scf.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/mlir/dialects/scf.py b/mlir/dialects/scf.py index 1c54cc7..09d60ae 100644 --- a/mlir/dialects/scf.py +++ b/mlir/dialects/scf.py @@ -8,6 +8,14 @@ from typing import Optional, List, Tuple +@dataclass +class SCFConditionOp(DialectOp): + condition: mast.SsaId + args: List[mast.SsaId] + out_types: List[mast.Type] + _syntax_ = ['scf.condition ( {condition.ssa_id} ) {args.ssa_id_list} : {out_types.type_list_no_parens}'] + + @dataclass class SCFForOp(DialectOp): index: mast.SsaId @@ -29,11 +37,28 @@ class SCFIfOp(DialectOp): cond: mast.SsaId body: mast.Region elsebody: Optional[mast.Region] = None + out_types: Optional[List[mast.Type]] = None _syntax_ = ['scf.if {cond.ssa_id} {body.region}', - 'scf.if {cond.ssa_id} {body.region} else {elsebody.region}'] + 'scf.if {cond.ssa_id} {body.region} else {elsebody.region}', + 'scf.if {cond.ssa_id} -> {out_types.type_list_parens} {body.region}', + 'scf.if {cond.ssa_id} -> {out_types.type_list_parens} {body.region} else {elsebody.region}'] -class SCFYield(UnaryOperation): _opname_ = 'scf.yield' +@dataclass +class SCFWhileOp(DialectOp): + assignments: List[Tuple[mast.SsaId, mast.Type]] + out_type: mast.FunctionType + while_body: mast.Region + do_body: mast.Region + _syntax_ = ['scf.while {assignments.argument_assignment_list_parens} : {out_type.function_type} {while_body.region} do {do_body.region}'] + + +@dataclass +class SCFYield(DialectOp): + results: Optional[List[mast.SsaId]] = None + result_types: Optional[List[mast.Type]] = None + _syntax_ = ['scf.yield', + 'scf.yield {results.ssa_id_list} : {result_types.type_list_no_parens}'] # Inspect current module to get all classes defined above From 8af8a7a7dd2c5822291f268ae34df09b34ce9f52 Mon Sep 17 00:00:00 2001 From: Amanda Tang Date: Wed, 25 Sep 2024 13:18:35 -0700 Subject: [PATCH 2/6] Test cases for scf dialect --- mlir/astnodes.py | 10 +++++++ mlir/dialects/func.py | 2 +- mlir/dialects/scf.py | 10 +++---- mlir/parser_transformer.py | 1 + tests/test_scf.py | 58 ++++++++++++++++++++++++++++++++++++++ tests/test_syntax.py | 20 ------------- 6 files changed, 75 insertions(+), 26 deletions(-) create mode 100644 tests/test_scf.py diff --git a/mlir/astnodes.py b/mlir/astnodes.py index 3dbb1ff..c5b3cc8 100644 --- a/mlir/astnodes.py +++ b/mlir/astnodes.py @@ -720,6 +720,16 @@ def dump(self, indent: int = 0) -> str: return result +@dataclass +class ArgumentAssignment(Node): + name: SsaId + value: SsaId + + def dump(self, indent: int = 0) -> str: + return '%s = %s' % (dump_or_value(self.name, indent), + dump_or_value(self.value, indent)) + + @dataclass class MLIRFile(Node): definitions: List["Definition"] diff --git a/mlir/dialects/func.py b/mlir/dialects/func.py index d94cefd..f1add08 100644 --- a/mlir/dialects/func.py +++ b/mlir/dialects/func.py @@ -23,7 +23,7 @@ class CallIndirectOperation(DialectOp): @dataclass class CallOperation(DialectOp): func: mast.SymbolRefId - type: mast.FunctionType + type: Optional[mast.FunctionType] = None args: Optional[List[SsaUse]] = None argtypes: Optional[List[mast.Type]] = None _syntax_ = ['func.call {func.symbol_ref_id} () : {type.function_type}', diff --git a/mlir/dialects/scf.py b/mlir/dialects/scf.py index 09d60ae..a4c227d 100644 --- a/mlir/dialects/scf.py +++ b/mlir/dialects/scf.py @@ -28,8 +28,8 @@ class SCFForOp(DialectOp): out_type: Optional[mast.Type] = None _syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}', 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} : {out_type.type} {body.region}', - 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} {body.region}', - 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} : {out_type.type} {body.region}'] + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) : {out_type.type} {body.region}'] @dataclass @@ -40,8 +40,8 @@ class SCFIfOp(DialectOp): out_types: Optional[List[mast.Type]] = None _syntax_ = ['scf.if {cond.ssa_id} {body.region}', 'scf.if {cond.ssa_id} {body.region} else {elsebody.region}', - 'scf.if {cond.ssa_id} -> {out_types.type_list_parens} {body.region}', - 'scf.if {cond.ssa_id} -> {out_types.type_list_parens} {body.region} else {elsebody.region}'] + 'scf.if {cond.ssa_id} -> ( {out_types.type_list_no_parens} ) {body.region}', + 'scf.if {cond.ssa_id} -> ( {out_types.type_list_no_parens} ) {body.region} else {elsebody.region}'] @dataclass @@ -50,7 +50,7 @@ class SCFWhileOp(DialectOp): out_type: mast.FunctionType while_body: mast.Region do_body: mast.Region - _syntax_ = ['scf.while {assignments.argument_assignment_list_parens} : {out_type.function_type} {while_body.region} do {do_body.region}'] + _syntax_ = ['scf.while ( {assignments.argument_assignment_list_no_parens} ) : {out_type.function_type} {while_body.region} do {do_body.region}'] @dataclass diff --git a/mlir/parser_transformer.py b/mlir/parser_transformer.py index d552c1e..c856572 100644 --- a/mlir/parser_transformer.py +++ b/mlir/parser_transformer.py @@ -128,6 +128,7 @@ def block_label(self, value): function = astnodes.Function.from_lark generic_module = astnodes.GenericModule.from_lark named_argument = astnodes.NamedArgument.from_lark + argument_assignment = astnodes.ArgumentAssignment.from_lark ############################################################### # (semi-)Affine expressions, maps, and integer sets diff --git a/tests/test_scf.py b/tests/test_scf.py new file mode 100644 index 0000000..286ef60 --- /dev/null +++ b/tests/test_scf.py @@ -0,0 +1,58 @@ +import sys +import mlir +from mlir.dialects.func import func + +# All source strings taken from examples in https://mlir.llvm.org/docs/Dialects/SCFDialect/ + + +def assert_roundtrip_equivalence(source): + assert source == mlir.parse_string(source).dump() + + +def test_scf_for(): + assert_roundtrip_equivalence("""module { + func.func @reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index, %sum_0: f32) -> (f32) { + %sum = scf.for %iv = %lb to %ub step %step iter_args ( %sum_iter = %sum_0 ) -> ( f32 ) { + %t = load %buffer [ %iv ] : memref<1024xf32> + %sum_next = arith.addf %sum_iter, %t : f32 + scf.yield %sum_next : f32 + } + return %sum : f32 + } +}""") + + +def test_scf_if(): + assert_roundtrip_equivalence("""module { + func.func @example(%A: f32, %B: f32, %C: f32, %D: f32) { + %x, %y = scf.if %b -> ( f32, f32 ) { + scf.yield %A, %B : f32, f32 + } else { + scf.yield %C, %D : f32, f32 + } + return + } +}""") + + +def test_scf_while(): + assert_roundtrip_equivalence("""module { + func.func @example(%A: f32, %B: f32, %C: f32, %D: f32) { + %res = scf.while ( %arg1 = %init1 ) : (f32) -> f32 { + %condition = func.call @evaluate_condition ( %arg1 ) : (f32) -> i1 + scf.condition ( %condition ) %arg1 : f32 + } do { + ^bb0 (%arg2: f32): + %next = func.call @payload ( %arg2 ) : (f32) -> f32 + scf.yield %next : f32 + } + } +}""") + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) diff --git a/tests/test_syntax.py b/tests/test_syntax.py index 325aaba..087c662 100644 --- a/tests/test_syntax.py +++ b/tests/test_syntax.py @@ -158,26 +158,6 @@ def test_affine(parser: Optional[Parser] = None): module = parser.parse(code) print(module.pretty()) -def test_scf_for(parser: Optional[Parser] = None): - code = """ -module { - func.func @reduce(%buffer: memref<1024xf32>, %lb: index, - %ub: index, %step: index) -> (f32) { - %sum_0 = arith.constant 0.0 : f32 - %sum = scf.for %iv = %lb to %ub step %step - iter_args(%sum_iter = %sum_0) -> (f32) { - %t = load %buffer[%iv] : memref<1024xf32> - %sum_next = arith.addf %sum_iter, %t : f32 - scf.yield %sum_next : f32 - } - return %sum : f32 - } -} - """ - parser = parser or Parser() - module = parser.parse(code) - print(module.pretty()) - def test_definitions(parser: Optional[Parser] = None): code = ''' #map0 = affine_map<(d0, d1) -> (d0, d1)> From 296ca2154a208cb9520083aa644f0f0f1451e62e Mon Sep 17 00:00:00 2001 From: Amanda Tang Date: Mon, 30 Sep 2024 11:02:38 -0700 Subject: [PATCH 3/6] additional scf for syntaxes --- mlir/dialects/scf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/dialects/scf.py b/mlir/dialects/scf.py index a4c227d..8c71848 100644 --- a/mlir/dialects/scf.py +++ b/mlir/dialects/scf.py @@ -28,6 +28,8 @@ class SCFForOp(DialectOp): out_type: Optional[mast.Type] = None _syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}', 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} : {out_type.type} {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> {iter_args_types.type_list_no_parens} {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> {iter_args_types.type_list_no_parens} : {out_type.type} {body.region}', 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) {body.region}', 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) : {out_type.type} {body.region}'] From e8f89031cf32daca29029f0e7dec0ea9b6699a72 Mon Sep 17 00:00:00 2001 From: Amanda Tang Date: Tue, 1 Oct 2024 09:03:31 -0700 Subject: [PATCH 4/6] Address comments --- mlir/dialects/func.py | 6 ++---- tests/test_linalg.py | 8 ++------ tests/test_scf.py | 8 ++------ 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/mlir/dialects/func.py b/mlir/dialects/func.py index f1add08..c27d86a 100644 --- a/mlir/dialects/func.py +++ b/mlir/dialects/func.py @@ -15,7 +15,6 @@ class CallIndirectOperation(DialectOp): func: mast.SymbolRefId type: mast.FunctionType args: Optional[List[SsaUse]] = None - argtypes: Optional[List[mast.Type]] = None _syntax_ = ['func.call_indirect {func.symbol_ref_id} () : {type.function_type}', 'func.call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}'] @@ -23,11 +22,10 @@ class CallIndirectOperation(DialectOp): @dataclass class CallOperation(DialectOp): func: mast.SymbolRefId - type: Optional[mast.FunctionType] = None + type: mast.FunctionType args: Optional[List[SsaUse]] = None - argtypes: Optional[List[mast.Type]] = None _syntax_ = ['func.call {func.symbol_ref_id} () : {type.function_type}', - 'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {argtypes.function_type}'] + 'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}'] @dataclass class ConstantOperation(DialectOp): diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 1558d92..b032b4d 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -167,9 +167,5 @@ def test_matvec(): }""") -if __name__ == "__main__": - if len(sys.argv) > 1: - exec(sys.argv[1]) - else: - from pytest import main - main([__file__]) +from pytest import main + main([__file__]) diff --git a/tests/test_scf.py b/tests/test_scf.py index 286ef60..4fb070e 100644 --- a/tests/test_scf.py +++ b/tests/test_scf.py @@ -50,9 +50,5 @@ def test_scf_while(): }""") -if __name__ == "__main__": - if len(sys.argv) > 1: - exec(sys.argv[1]) - else: - from pytest import main - main([__file__]) +from pytest import main + main([__file__]) From 56825c23824b9f1d1fde123d650770c3fe9c3385 Mon Sep 17 00:00:00 2001 From: Amanda Tang Date: Tue, 1 Oct 2024 09:12:38 -0700 Subject: [PATCH 5/6] Invoke tests without exec --- tests/test_builder.py | 10 ++++------ tests/test_linalg.py | 14 ++++++++++++-- tests/test_scf.py | 6 ++++-- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/test_builder.py b/tests/test_builder.py index 840cfe6..5dafd5f 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -103,9 +103,7 @@ def index(expr): assert index(Reads(a0)) == 2 -if __name__ == "__main__": - if len(sys.argv) > 1: - exec(sys.argv[1]) - else: - from pytest import main - main([__file__]) +if __name__ == '__main__': + test_saxpy_builder() + test_query() + test_build_with_queries() diff --git a/tests/test_linalg.py b/tests/test_linalg.py index b032b4d..ef93f36 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -167,5 +167,15 @@ def test_matvec(): }""") -from pytest import main - main([__file__]) +if __name__ == '__main__': + test_batch_matmul() + test_conv() + test_copy() + test_dot() + test_fill() + test_generic() + test_indexed_generic() + test_reduce() + test_view() + test_matmul() + test_matvec() diff --git a/tests/test_scf.py b/tests/test_scf.py index 4fb070e..31b9454 100644 --- a/tests/test_scf.py +++ b/tests/test_scf.py @@ -50,5 +50,7 @@ def test_scf_while(): }""") -from pytest import main - main([__file__]) +if __name__ == '__main__': + test_scf_for() + test_scf_if() + test_scf_while() From 4d7ed6f7cbc00e517c036ea27599ac93af3c6ab1 Mon Sep 17 00:00:00 2001 From: Amanda Tang Date: Tue, 1 Oct 2024 09:27:48 -0700 Subject: [PATCH 6/6] Clean up imports --- tests/test_builder.py | 1 - tests/test_linalg.py | 2 -- tests/test_scf.py | 2 -- 3 files changed, 5 deletions(-) diff --git a/tests/test_builder.py b/tests/test_builder.py index 5dafd5f..f97eb43 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -1,4 +1,3 @@ -import sys from mlir import parse_string from mlir.builder import IRBuilder from mlir.builder import Reads, Writes, Isa diff --git a/tests/test_linalg.py b/tests/test_linalg.py index ef93f36..1d783ca 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,6 +1,4 @@ -import sys import mlir -from mlir.dialects.func import func # All source strings have been taken from MLIR's codebase. # See llvm-project/mlir/test/Dialect/Linalg diff --git a/tests/test_scf.py b/tests/test_scf.py index 31b9454..7f267ac 100644 --- a/tests/test_scf.py +++ b/tests/test_scf.py @@ -1,6 +1,4 @@ -import sys import mlir -from mlir.dialects.func import func # All source strings taken from examples in https://mlir.llvm.org/docs/Dialects/SCFDialect/