diff --git a/mlir/astnodes.py b/mlir/astnodes.py index 3dbb1ff..c5b3cc8 100644 --- a/mlir/astnodes.py +++ b/mlir/astnodes.py @@ -720,6 +720,16 @@ def dump(self, indent: int = 0) -> str: return result +@dataclass +class ArgumentAssignment(Node): + name: SsaId + value: SsaId + + def dump(self, indent: int = 0) -> str: + return '%s = %s' % (dump_or_value(self.name, indent), + dump_or_value(self.value, indent)) + + @dataclass class MLIRFile(Node): definitions: List["Definition"] diff --git a/mlir/dialects/func.py b/mlir/dialects/func.py index d94cefd..c27d86a 100644 --- a/mlir/dialects/func.py +++ b/mlir/dialects/func.py @@ -15,7 +15,6 @@ class CallIndirectOperation(DialectOp): func: mast.SymbolRefId type: mast.FunctionType args: Optional[List[SsaUse]] = None - argtypes: Optional[List[mast.Type]] = None _syntax_ = ['func.call_indirect {func.symbol_ref_id} () : {type.function_type}', 'func.call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}'] @@ -25,9 +24,8 @@ class CallOperation(DialectOp): func: mast.SymbolRefId type: mast.FunctionType args: Optional[List[SsaUse]] = None - argtypes: Optional[List[mast.Type]] = None _syntax_ = ['func.call {func.symbol_ref_id} () : {type.function_type}', - 'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {argtypes.function_type}'] + 'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}'] @dataclass class ConstantOperation(DialectOp): diff --git a/mlir/dialects/scf.py b/mlir/dialects/scf.py index 1c54cc7..8c71848 100644 --- a/mlir/dialects/scf.py +++ b/mlir/dialects/scf.py @@ -8,6 +8,14 @@ from typing import Optional, List, Tuple +@dataclass +class SCFConditionOp(DialectOp): + condition: mast.SsaId + args: List[mast.SsaId] + out_types: List[mast.Type] + _syntax_ = ['scf.condition ( {condition.ssa_id} ) {args.ssa_id_list} : {out_types.type_list_no_parens}'] + + @dataclass class SCFForOp(DialectOp): index: mast.SsaId @@ -20,8 +28,10 @@ class SCFForOp(DialectOp): out_type: Optional[mast.Type] = None _syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}', 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} : {out_type.type} {body.region}', - 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} {body.region}', - 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} : {out_type.type} {body.region}'] + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> {iter_args_types.type_list_no_parens} {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> {iter_args_types.type_list_no_parens} : {out_type.type} {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) : {out_type.type} {body.region}'] @dataclass @@ -29,11 +39,28 @@ class SCFIfOp(DialectOp): cond: mast.SsaId body: mast.Region elsebody: Optional[mast.Region] = None + out_types: Optional[List[mast.Type]] = None _syntax_ = ['scf.if {cond.ssa_id} {body.region}', - 'scf.if {cond.ssa_id} {body.region} else {elsebody.region}'] + 'scf.if {cond.ssa_id} {body.region} else {elsebody.region}', + 'scf.if {cond.ssa_id} -> ( {out_types.type_list_no_parens} ) {body.region}', + 'scf.if {cond.ssa_id} -> ( {out_types.type_list_no_parens} ) {body.region} else {elsebody.region}'] -class SCFYield(UnaryOperation): _opname_ = 'scf.yield' +@dataclass +class SCFWhileOp(DialectOp): + assignments: List[Tuple[mast.SsaId, mast.Type]] + out_type: mast.FunctionType + while_body: mast.Region + do_body: mast.Region + _syntax_ = ['scf.while ( {assignments.argument_assignment_list_no_parens} ) : {out_type.function_type} {while_body.region} do {do_body.region}'] + + +@dataclass +class SCFYield(DialectOp): + results: Optional[List[mast.SsaId]] = None + result_types: Optional[List[mast.Type]] = None + _syntax_ = ['scf.yield', + 'scf.yield {results.ssa_id_list} : {result_types.type_list_no_parens}'] # Inspect current module to get all classes defined above diff --git a/mlir/parser_transformer.py b/mlir/parser_transformer.py index d552c1e..c856572 100644 --- a/mlir/parser_transformer.py +++ b/mlir/parser_transformer.py @@ -128,6 +128,7 @@ def block_label(self, value): function = astnodes.Function.from_lark generic_module = astnodes.GenericModule.from_lark named_argument = astnodes.NamedArgument.from_lark + argument_assignment = astnodes.ArgumentAssignment.from_lark ############################################################### # (semi-)Affine expressions, maps, and integer sets diff --git a/tests/test_builder.py b/tests/test_builder.py index 840cfe6..f97eb43 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -1,4 +1,3 @@ -import sys from mlir import parse_string from mlir.builder import IRBuilder from mlir.builder import Reads, Writes, Isa @@ -103,9 +102,7 @@ def index(expr): assert index(Reads(a0)) == 2 -if __name__ == "__main__": - if len(sys.argv) > 1: - exec(sys.argv[1]) - else: - from pytest import main - main([__file__]) +if __name__ == '__main__': + test_saxpy_builder() + test_query() + test_build_with_queries() diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 1558d92..1d783ca 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,6 +1,4 @@ -import sys import mlir -from mlir.dialects.func import func # All source strings have been taken from MLIR's codebase. # See llvm-project/mlir/test/Dialect/Linalg @@ -167,9 +165,15 @@ def test_matvec(): }""") -if __name__ == "__main__": - if len(sys.argv) > 1: - exec(sys.argv[1]) - else: - from pytest import main - main([__file__]) +if __name__ == '__main__': + test_batch_matmul() + test_conv() + test_copy() + test_dot() + test_fill() + test_generic() + test_indexed_generic() + test_reduce() + test_view() + test_matmul() + test_matvec() diff --git a/tests/test_scf.py b/tests/test_scf.py new file mode 100644 index 0000000..7f267ac --- /dev/null +++ b/tests/test_scf.py @@ -0,0 +1,54 @@ +import mlir + +# All source strings taken from examples in https://mlir.llvm.org/docs/Dialects/SCFDialect/ + + +def assert_roundtrip_equivalence(source): + assert source == mlir.parse_string(source).dump() + + +def test_scf_for(): + assert_roundtrip_equivalence("""module { + func.func @reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index, %sum_0: f32) -> (f32) { + %sum = scf.for %iv = %lb to %ub step %step iter_args ( %sum_iter = %sum_0 ) -> ( f32 ) { + %t = load %buffer [ %iv ] : memref<1024xf32> + %sum_next = arith.addf %sum_iter, %t : f32 + scf.yield %sum_next : f32 + } + return %sum : f32 + } +}""") + + +def test_scf_if(): + assert_roundtrip_equivalence("""module { + func.func @example(%A: f32, %B: f32, %C: f32, %D: f32) { + %x, %y = scf.if %b -> ( f32, f32 ) { + scf.yield %A, %B : f32, f32 + } else { + scf.yield %C, %D : f32, f32 + } + return + } +}""") + + +def test_scf_while(): + assert_roundtrip_equivalence("""module { + func.func @example(%A: f32, %B: f32, %C: f32, %D: f32) { + %res = scf.while ( %arg1 = %init1 ) : (f32) -> f32 { + %condition = func.call @evaluate_condition ( %arg1 ) : (f32) -> i1 + scf.condition ( %condition ) %arg1 : f32 + } do { + ^bb0 (%arg2: f32): + %next = func.call @payload ( %arg2 ) : (f32) -> f32 + scf.yield %next : f32 + } + } +}""") + + +if __name__ == '__main__': + test_scf_for() + test_scf_if() + test_scf_while() diff --git a/tests/test_syntax.py b/tests/test_syntax.py index 325aaba..087c662 100644 --- a/tests/test_syntax.py +++ b/tests/test_syntax.py @@ -158,26 +158,6 @@ def test_affine(parser: Optional[Parser] = None): module = parser.parse(code) print(module.pretty()) -def test_scf_for(parser: Optional[Parser] = None): - code = """ -module { - func.func @reduce(%buffer: memref<1024xf32>, %lb: index, - %ub: index, %step: index) -> (f32) { - %sum_0 = arith.constant 0.0 : f32 - %sum = scf.for %iv = %lb to %ub step %step - iter_args(%sum_iter = %sum_0) -> (f32) { - %t = load %buffer[%iv] : memref<1024xf32> - %sum_next = arith.addf %sum_iter, %t : f32 - scf.yield %sum_next : f32 - } - return %sum : f32 - } -} - """ - parser = parser or Parser() - module = parser.parse(code) - print(module.pretty()) - def test_definitions(parser: Optional[Parser] = None): code = ''' #map0 = affine_map<(d0, d1) -> (d0, d1)>