diff --git a/mlir/dialects/linalg.py b/mlir/dialects/linalg.py index e348694..8b5c938 100644 --- a/mlir/dialects/linalg.py +++ b/mlir/dialects/linalg.py @@ -114,16 +114,21 @@ class LinalgDot(DialectOp): @dataclass class LinalgFill(DialectOp): - output_id: mast.SsaId - value_id: mast.SsaId - output_type: mast.Type - value_type: mast.Type + in_id: mast.SsaId + in_type: mast.Type + out_id: mast.SsaId + out_type: mast.Type + res_type: mast.Type attr: Optional[mast.Attribute] = None - _syntax_ = [("linalg.fill( {output_id.ssa_id} , {value_id.ssa_id} ) " - "{attr.attribute_value} : {output_type.type} , {value_type.type}"), - ("linalg.fill( {output_id.ssa_id} , {value_id.ssa_id} ) " - " : {output_type.type} , {value_type.type}")] + _syntax_ = [("linalg.fill" + " ins( {in_id.ssa_id} : {in_type.type} )" + " outs( {out_id.ssa_id} : {out_type.type} )" + " {attr.attribute_value} -> {res_type.type}"), + ("linalg.fill" + " ins( {in_id.ssa_id} : {in_type.type} )" + " outs( {out_id.ssa_id} : {out_type.type} )" + " -> {res_type.type}")] @dataclass diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 1b4f61b..ef5694f 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -68,8 +68,8 @@ def test_dot(): def test_fill(): assert_roundtrip_equivalence("""module { - func.func @fill_view(%arg0: memref>, %arg1: f32) { - linalg.fill( %arg0 , %arg1 ) : memref> , f32 + func.func @fill_view(%arg0: f32, %arg1: tensor) { + linalg.fill ins( %arg0 : f32 ) outs( %arg1 : tensor ) -> tensor return } }""")