diff --git a/mlir/dialects/linalg.py b/mlir/dialects/linalg.py index 92e2d3b..41cd566 100644 --- a/mlir/dialects/linalg.py +++ b/mlir/dialects/linalg.py @@ -283,6 +283,9 @@ class LinalgMatmul(DialectOp): _syntax_ = [("linalg.matmul" " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" " outs( {c_id.ssa_id} : {c_type.type} )"), + ("linalg.matmul" + " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" + " outs( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}"), ("linalg.matmul" " ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )" " init( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")] diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 4a7364f..213853a 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -136,6 +136,7 @@ def test_matmul(): %B = view %arg0 [ %c0 ] [ %K, %N ] : memref to memref %C = view %arg0 [ %c0 ] [ %M, %N ] : memref to memref linalg.matmul ins( %A , %B : memref , memref ) outs( %C : memref ) + linalg.matmul ins( %A , %B : memref , memref ) outs( %C : memref ) -> memref return } }""")