Skip to content

Commit

Permalink
Update linalg fill syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
amanda849 committed Nov 10, 2023
1 parent bddedc2 commit ab4239a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
21 changes: 13 additions & 8 deletions mlir/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,21 @@ class LinalgDot(DialectOp):

@dataclass
class LinalgFill(DialectOp):
output_id: mast.SsaId
value_id: mast.SsaId
output_type: mast.Type
value_type: mast.Type
in_id: mast.SsaId
in_type: mast.Type
out_id: mast.SsaId
out_type: mast.Type
res_type: mast.Type
attr: Optional[mast.Attribute] = None

_syntax_ = [("linalg.fill( {output_id.ssa_id} , {value_id.ssa_id} ) "
"{attr.attribute_value} : {output_type.type} , {value_type.type}"),
("linalg.fill( {output_id.ssa_id} , {value_id.ssa_id} ) "
" : {output_type.type} , {value_type.type}")]
_syntax_ = [("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"
" {attr.attribute_value} -> {res_type.type}"),
("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"
" -> {res_type.type}")]


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_dot():

def test_fill():
assert_roundtrip_equivalence("""module {
func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32) {
linalg.fill( %arg0 , %arg1 ) : memref<?xf32, strided<[1], offset: ?>> , f32
func.func @fill_view(%arg0: f32, %arg1: tensor<?x?xf32>) {
linalg.fill ins( %arg0 : f32 ) outs( %arg1 : tensor<?x?xf32> ) -> tensor<?x?xf32>
return
}
}""")
Expand Down

0 comments on commit ab4239a

Please sign in to comment.