diff --git a/mlir/dialects/linalg.py b/mlir/dialects/linalg.py index 569bc27..bbd6685 100644 --- a/mlir/dialects/linalg.py +++ b/mlir/dialects/linalg.py @@ -21,11 +21,11 @@ class LinalgBatchMatmul(DialectOp): out_type: Optional[mast.Type] = None _syntax_ = [("linalg.batch_matmul" - " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" - " outs( {c_id.ssa_id} : {c_type.type} )"), + " ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" + " outs ( {c_id.ssa_id} : {c_type.type} )"), ("linalg.batch_matmul" - " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" - " init( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")] + " ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" + " init ( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")] @dataclass @@ -38,8 +38,8 @@ class LinalgConvW(DialectOp): out_type: mast.Type _syntax_ = [("linalg.conv_1d" - " ins( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )" - " outs( {out_id.ssa_id} : {out_type.type} )")] + " ins ( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )" + " outs ( {out_id.ssa_id} : {out_type.type} )")] @dataclass @@ -52,8 +52,8 @@ class LinalgConvHW(DialectOp): out_type: mast.Type _syntax_ = [("linalg.conv_2d" - " ins( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )" - " outs( {out_id.ssa_id} : {out_type.type} )")] + " ins ( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )" + " outs ( {out_id.ssa_id} : {out_type.type} )")] @dataclass @@ -66,8 +66,8 @@ class LinalgConvDHW(DialectOp): out_type: mast.Type _syntax_ = [("linalg.conv_3d" - " ins( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )" - " outs( {out_id.ssa_id} : {out_type.type} )")] + " ins ( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )" + " outs ( {out_id.ssa_id} : {out_type.type} )")] @dataclass @@ -110,8 +110,8 @@ class LinalgDot(DialectOp): out_type: mast.Type _syntax_ = [("linalg.dot" - " ins( {in_a_id.ssa_id} , {in_b_id.ssa_id} : {in_a_type.type} , {in_b_type.type} )" - " outs( {out_id.ssa_id} : {out_type.type} )")] + " ins ( {in_a_id.ssa_id} , {in_b_id.ssa_id} : {in_a_type.type} , {in_b_type.type} )" + " outs ( {out_id.ssa_id} : {out_type.type} )")] @dataclass @@ -124,19 +124,19 @@ class LinalgFill(DialectOp): attr: Optional[mast.Attribute] = None _syntax_ = [("linalg.fill" - " ins( {in_id.ssa_id} : {in_type.type} )" - " outs( {out_id.ssa_id} : {out_type.type} )" + " ins ( {in_id.ssa_id} : {in_type.type} )" + " outs ( {out_id.ssa_id} : {out_type.type} )" " {attr.attribute_value}"), ("linalg.fill" - " ins( {in_id.ssa_id} : {in_type.type} )" - " outs( {out_id.ssa_id} : {out_type.type} )"), + " ins ( {in_id.ssa_id} : {in_type.type} )" + " outs ( {out_id.ssa_id} : {out_type.type} )"), ("linalg.fill" - " ins( {in_id.ssa_id} : {in_type.type} )" - " outs( {out_id.ssa_id} : {out_type.type} )" + " ins ( {in_id.ssa_id} : {in_type.type} )" + " outs ( {out_id.ssa_id} : {out_type.type} )" " {attr.attribute_value} -> {res_type.type}"), ("linalg.fill" - " ins( {in_id.ssa_id} : {in_type.type} )" - " outs( {out_id.ssa_id} : {out_type.type} )" + " ins ( {in_id.ssa_id} : {in_type.type} )" + " outs ( {out_id.ssa_id} : {out_type.type} )" " -> {res_type.type}")] @@ -153,16 +153,16 @@ class LinalgGeneric(DialectOp): attr: Optional[mast.Attribute] = None _syntax_ = [("linalg.generic {attr.attribute_value} " - " ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" - " outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )" + " ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" + " outs ( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )" " {region.region}"), ("linalg.generic {attr.attribute_value} " - " ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" - " outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )" + " ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" + " outs ( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )" " {region.region} -> {out_type.type}"), ("linalg.generic {attr.attribute_value} " - " ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" - " init( {init_args.ssa_id_list} : {init_types.type_list_no_parens} )" + " ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" + " init ( {init_args.ssa_id_list} : {init_types.type_list_no_parens} )" " {region.region} -> {out_type.type}")] @@ -179,12 +179,12 @@ class LinalgIndexedGeneric(DialectOp): attr: Optional[mast.Attribute] = None _syntax_ = [("linalg.indexed_generic {attr.attribute_value} " - " ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" - " outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )" + " ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" + " outs ( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )" " {region.region}"), ("linalg.indexed_generic {attr.attribute_value} " - " ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" - " init( {init_args.ssa_id_list} : {init_types.type_list_no_parens} )" + " ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" + " init ( {init_args.ssa_id_list} : {init_types.type_list_no_parens} )" " {region.region} -> {out_type.type}")] @@ -213,8 +213,8 @@ class LinalgReduce(DialectOp): args: List[Tuple[mast.SsaId, mast.Type]] _syntax_ = [("linalg.reduce" - " ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" - " outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )" + " ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )" + " outs ( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )" " dimensions = [ {dimensions.ssa_use_list} ]" " ( {args.argument_list} ) {region.region}")] @@ -299,14 +299,14 @@ class LinalgMatmul(DialectOp): out_type: Optional[mast.Type] = None _syntax_ = [("linalg.matmul" - " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" - " outs( {c_id.ssa_id} : {c_type.type} )"), + " ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" + " outs ( {c_id.ssa_id} : {c_type.type} )"), ("linalg.matmul" - " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" - " outs( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}"), + " ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" + " outs ( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}"), ("linalg.matmul" - " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" - " init( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")] + " ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" + " init ( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")] @dataclass @@ -319,8 +319,22 @@ class LinalgMatvec(DialectOp): c_type: mast.Type _syntax_ = [("linalg.matvec" - " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" - " outs( {c_id.ssa_id} : {c_type.type} )")] + " ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" + " outs ( {c_id.ssa_id} : {c_type.type} )")] + + +@dataclass +class LinalgTranspose(DialectOp): + inarg: List[mast.SsaId] + in_type: List[mast.Type] + init: List[mast.SsaId] + init_type: List[mast.Type] + permutation: List[int] + + _syntax_ = [("linalg.transpose" + " ins ( {inarg.ssa_id_list} : {in_type.type_list_no_parens} )" + " outs ( {init.ssa_id_list} : {init_type.type_list_no_parens} )" + " permutation = [ {permutation.ssa_use_list} ]")] # Inspect current module to get all classes defined above diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 1558d92..60614d1 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -13,10 +13,10 @@ def assert_roundtrip_equivalence(source): def test_batch_matmul(): assert_roundtrip_equivalence("""module { func.func @named_ops(%a3: memref, %b3: memref, %c3: memref, %ta3: tensor, %tb3: tensor, %tc3: tensor) -> (tensor, tensor) { - linalg.batch_matmul ins( %a3 , %b3 : memref , memref ) outs( %c3 : memref ) - linalg.batch_matmul ins( %ta3 , %tb3 : tensor , tensor ) outs( %c3 : memref ) - %res1 = linalg.batch_matmul ins( %ta3 , %tb3 : tensor , tensor ) init( %tc3 : tensor ) -> tensor - %res2 = linalg.batch_matmul ins( %ta3 , %b3 : tensor , memref ) init( %tc3 : tensor ) -> tensor + linalg.batch_matmul ins ( %a3 , %b3 : memref , memref ) outs ( %c3 : memref ) + linalg.batch_matmul ins ( %ta3 , %tb3 : tensor , tensor ) outs ( %c3 : memref ) + %res1 = linalg.batch_matmul ins ( %ta3 , %tb3 : tensor , tensor ) init ( %tc3 : tensor ) -> tensor + %res2 = linalg.batch_matmul ins ( %ta3 , %b3 : tensor , memref ) init ( %tc3 : tensor ) -> tensor return %res1, %res2 : tensor, tensor } }""") @@ -25,15 +25,15 @@ def test_batch_matmul(): def test_conv(): assert_roundtrip_equivalence("""module { func.func @conv1d_no_symbols(%in: memref, %filter: memref, %out: memref) { - linalg.conv_1d ins( %in , %filter : memref , memref ) outs( %out : memref ) + linalg.conv_1d ins ( %in , %filter : memref , memref ) outs ( %out : memref ) return } func.func @conv2d_no_symbols(%in: memref, %filter: memref, %out: memref) { - linalg.conv_2d ins( %in , %filter : memref , memref ) outs( %out : memref ) + linalg.conv_2d ins ( %in , %filter : memref , memref ) outs ( %out : memref ) return } func.func @conv3d_no_symbols(%in: memref, %filter: memref, %out: memref) { - linalg.conv_3d ins( %in , %filter : memref , memref ) outs( %out : memref ) + linalg.conv_3d ins ( %in , %filter : memref , memref ) outs ( %out : memref ) return } }""") @@ -60,7 +60,7 @@ def test_dot(): %1 = view %arg0 [ %c0 ] [ %M ] : memref to memref %2 = view %arg0 [ %c0 ] [ %M ] : memref to memref %3 = view %arg0 [ %c0 ] [ ] : memref to memref - linalg.dot ins( %1 , %2 : memref , memref ) outs( %3 : memref ) + linalg.dot ins ( %1 , %2 : memref , memref ) outs ( %3 : memref ) return } }""") @@ -69,8 +69,8 @@ def test_dot(): def test_fill(): assert_roundtrip_equivalence("""module { func.func @fill_view(%arg0: f32, %arg1: tensor) { - linalg.fill ins( %arg0 : f32 ) outs( %arg1 : tensor ) -> tensor - linalg.fill ins( %arg0 : f32 ) outs( %arg1 : tensor ) + linalg.fill ins ( %arg0 : f32 ) outs ( %arg1 : tensor ) -> tensor + linalg.fill ins ( %arg0 : f32 ) outs ( %arg1 : tensor ) return } }""") @@ -79,7 +79,7 @@ def test_fill(): def test_generic(): assert_roundtrip_equivalence("""module { func.func @example(%A: memref, %B: memref, %C: memref) { - linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} ins( %A, %B : memref, memref ) outs( %C : memref ) { + linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} ins ( %A, %B : memref, memref ) outs ( %C : memref ) { ^bb0 (%a: f64, %b: f64, %c: f64): %c0 = constant 3.14 : f64 %d = addf %a , %b : f64 @@ -93,7 +93,7 @@ def test_generic(): def test_indexed_generic(): assert_roundtrip_equivalence("""module { func.func @indexed_generic_region(%arg0: memref>, %arg1: memref>, %arg2: memref>) { - linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins( %arg0 : memref> ) outs( %arg1, %arg2 : memref>, memref> ) { + linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins ( %arg0 : memref> ) outs ( %arg1, %arg2 : memref>, memref> ) { ^bb0 (%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32): %result_1 = mulf %a , %b : f32 %ij = addi %i , %j : index @@ -110,7 +110,7 @@ def test_indexed_generic(): def test_reduce(): assert_roundtrip_equivalence("""module { func.func @reduce(%arg0: tensor<16x32x64xf32>, %arg1: tensor<16x64xf32>) { - %reduce = linalg.reduce ins( %arg0 : tensor<16x32x64xf32> ) outs( %arg1 : tensor<16x64xf32> ) dimensions = [ 1 ] ( %in: f32, %out: f32 ) { + %reduce = linalg.reduce ins ( %arg0 : tensor<16x32x64xf32> ) outs ( %arg1 : tensor<16x64xf32> ) dimensions = [ 1 ] ( %in: f32, %out: f32 ) { %0 = arith.addf %out, %in : f32 linalg.yield %0 : f32 } @@ -146,8 +146,8 @@ def test_matmul(): %A = view %arg0 [ %c0 ] [ %M, %K ] : memref to memref %B = view %arg0 [ %c0 ] [ %K, %N ] : memref to memref %C = view %arg0 [ %c0 ] [ %M, %N ] : memref to memref - linalg.matmul ins( %A , %B : memref , memref ) outs( %C : memref ) - linalg.matmul ins( %A , %B : memref , memref ) outs( %C : memref ) -> memref + linalg.matmul ins ( %A , %B : memref , memref ) outs ( %C : memref ) + linalg.matmul ins ( %A , %B : memref , memref ) outs ( %C : memref ) -> memref return } }""") @@ -161,7 +161,16 @@ def test_matvec(): %2 = view %arg0 [ %c0 ] [ %M, %N ] : memref to memref %3 = view %arg0 [ %c0 ] [ %M ] : memref to memref %4 = view %arg0 [ %c0 ] [ %N ] : memref to memref - linalg.matvec ins( %2 , %3 : memref , memref ) outs( %4 : memref ) + linalg.matvec ins ( %2 , %3 : memref , memref ) outs ( %4 : memref ) + return + } +}""") + + +def test_transpose(): + assert_roundtrip_equivalence("""module { + func.func @transpose(%arg0: memref, %arg1: memref) { + %transpose = linalg.transpose ins ( %arg0 : memref ) outs ( %arg1 : memref ) permutation = [ 1, 0 ] return } }""")