Skip to content

Commit

Permalink
Consistent whitespace in linalg dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
Amanda Tang committed Oct 7, 2024
1 parent ea148bc commit 12bb9dd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 56 deletions.
80 changes: 40 additions & 40 deletions mlir/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class LinalgBatchMatmul(DialectOp):
out_type: Optional[mast.Type] = None

_syntax_ = [("linalg.batch_matmul"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs( {c_id.ssa_id} : {c_type.type} )"),
" ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs ( {c_id.ssa_id} : {c_type.type} )"),
("linalg.batch_matmul"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" init( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")]
" ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" init ( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")]


@dataclass
Expand All @@ -38,8 +38,8 @@ class LinalgConvW(DialectOp):
out_type: mast.Type

_syntax_ = [("linalg.conv_1d"
" ins( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )")]
" ins ( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )")]


@dataclass
Expand All @@ -52,8 +52,8 @@ class LinalgConvHW(DialectOp):
out_type: mast.Type

_syntax_ = [("linalg.conv_2d"
" ins( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )")]
" ins ( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )")]


@dataclass
Expand All @@ -66,8 +66,8 @@ class LinalgConvDHW(DialectOp):
out_type: mast.Type

_syntax_ = [("linalg.conv_3d"
" ins( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )")]
" ins ( {in_id.ssa_id} , {filter_id.ssa_id} : {in_type.type} , {filter_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )")]


@dataclass
Expand Down Expand Up @@ -110,8 +110,8 @@ class LinalgDot(DialectOp):
out_type: mast.Type

_syntax_ = [("linalg.dot"
" ins( {in_a_id.ssa_id} , {in_b_id.ssa_id} : {in_a_type.type} , {in_b_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )")]
" ins ( {in_a_id.ssa_id} , {in_b_id.ssa_id} : {in_a_type.type} , {in_b_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )")]


@dataclass
Expand All @@ -124,19 +124,19 @@ class LinalgFill(DialectOp):
attr: Optional[mast.Attribute] = None

_syntax_ = [("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"
" ins ( {in_id.ssa_id} : {in_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )"
" {attr.attribute_value}"),
("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"),
" ins ( {in_id.ssa_id} : {in_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )"),
("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"
" ins ( {in_id.ssa_id} : {in_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )"
" {attr.attribute_value} -> {res_type.type}"),
("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"
" ins ( {in_id.ssa_id} : {in_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )"
" -> {res_type.type}")]


Expand All @@ -153,16 +153,16 @@ class LinalgGeneric(DialectOp):
attr: Optional[mast.Attribute] = None

_syntax_ = [("linalg.generic {attr.attribute_value} "
" ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs ( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" {region.region}"),
("linalg.generic {attr.attribute_value} "
" ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs ( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" {region.region} -> {out_type.type}"),
("linalg.generic {attr.attribute_value} "
" ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" init( {init_args.ssa_id_list} : {init_types.type_list_no_parens} )"
" ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" init ( {init_args.ssa_id_list} : {init_types.type_list_no_parens} )"
" {region.region} -> {out_type.type}")]


Expand All @@ -179,12 +179,12 @@ class LinalgIndexedGeneric(DialectOp):
attr: Optional[mast.Attribute] = None

_syntax_ = [("linalg.indexed_generic {attr.attribute_value} "
" ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs ( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" {region.region}"),
("linalg.indexed_generic {attr.attribute_value} "
" ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" init( {init_args.ssa_id_list} : {init_types.type_list_no_parens} )"
" ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" init ( {init_args.ssa_id_list} : {init_types.type_list_no_parens} )"
" {region.region} -> {out_type.type}")]


Expand Down Expand Up @@ -213,8 +213,8 @@ class LinalgReduce(DialectOp):
args: List[Tuple[mast.SsaId, mast.Type]]

_syntax_ = [("linalg.reduce"
" ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" ins ( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs ( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" dimensions = [ {dimensions.ssa_use_list} ]"
" ( {args.argument_list} ) {region.region}")]

Expand Down Expand Up @@ -299,14 +299,14 @@ class LinalgMatmul(DialectOp):
out_type: Optional[mast.Type] = None

_syntax_ = [("linalg.matmul"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs( {c_id.ssa_id} : {c_type.type} )"),
" ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs ( {c_id.ssa_id} : {c_type.type} )"),
("linalg.matmul"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}"),
" ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs ( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}"),
("linalg.matmul"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" init( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")]
" ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" init ( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")]


@dataclass
Expand All @@ -319,8 +319,8 @@ class LinalgMatvec(DialectOp):
c_type: mast.Type

_syntax_ = [("linalg.matvec"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs( {c_id.ssa_id} : {c_type.type} )")]
" ins ( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs ( {c_id.ssa_id} : {c_type.type} )")]


@dataclass
Expand Down
32 changes: 16 additions & 16 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def assert_roundtrip_equivalence(source):
def test_batch_matmul():
assert_roundtrip_equivalence("""module {
func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>, %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
linalg.batch_matmul ins( %a3 , %b3 : memref<?x?x?xf32> , memref<?x?x?xf32> ) outs( %c3 : memref<?x?x?xf32> )
linalg.batch_matmul ins( %ta3 , %tb3 : tensor<?x?x?xf32> , tensor<?x?x?xf32> ) outs( %c3 : memref<?x?x?xf32> )
%res1 = linalg.batch_matmul ins( %ta3 , %tb3 : tensor<?x?x?xf32> , tensor<?x?x?xf32> ) init( %tc3 : tensor<?x?x?xf32> ) -> tensor<?x?x?xf32>
%res2 = linalg.batch_matmul ins( %ta3 , %b3 : tensor<?x?x?xf32> , memref<?x?x?xf32> ) init( %tc3 : tensor<?x?x?xf32> ) -> tensor<?x?x?xf32>
linalg.batch_matmul ins ( %a3 , %b3 : memref<?x?x?xf32> , memref<?x?x?xf32> ) outs ( %c3 : memref<?x?x?xf32> )
linalg.batch_matmul ins ( %ta3 , %tb3 : tensor<?x?x?xf32> , tensor<?x?x?xf32> ) outs ( %c3 : memref<?x?x?xf32> )
%res1 = linalg.batch_matmul ins ( %ta3 , %tb3 : tensor<?x?x?xf32> , tensor<?x?x?xf32> ) init ( %tc3 : tensor<?x?x?xf32> ) -> tensor<?x?x?xf32>
%res2 = linalg.batch_matmul ins ( %ta3 , %b3 : tensor<?x?x?xf32> , memref<?x?x?xf32> ) init ( %tc3 : tensor<?x?x?xf32> ) -> tensor<?x?x?xf32>
return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}
}""")
Expand All @@ -25,15 +25,15 @@ def test_batch_matmul():
def test_conv():
assert_roundtrip_equivalence("""module {
func.func @conv1d_no_symbols(%in: memref<?xf32>, %filter: memref<?xf32>, %out: memref<?xf32>) {
linalg.conv_1d ins( %in , %filter : memref<?xf32> , memref<?xf32> ) outs( %out : memref<?xf32> )
linalg.conv_1d ins ( %in , %filter : memref<?xf32> , memref<?xf32> ) outs ( %out : memref<?xf32> )
return
}
func.func @conv2d_no_symbols(%in: memref<?x?xf32>, %filter: memref<?x?xf32>, %out: memref<?x?xf32>) {
linalg.conv_2d ins( %in , %filter : memref<?x?xf32> , memref<?x?xf32> ) outs( %out : memref<?x?xf32> )
linalg.conv_2d ins ( %in , %filter : memref<?x?xf32> , memref<?x?xf32> ) outs ( %out : memref<?x?xf32> )
return
}
func.func @conv3d_no_symbols(%in: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %out: memref<?x?x?xf32>) {
linalg.conv_3d ins( %in , %filter : memref<?x?x?xf32> , memref<?x?x?xf32> ) outs( %out : memref<?x?x?xf32> )
linalg.conv_3d ins ( %in , %filter : memref<?x?x?xf32> , memref<?x?x?xf32> ) outs ( %out : memref<?x?x?xf32> )
return
}
}""")
Expand All @@ -60,7 +60,7 @@ def test_dot():
%1 = view %arg0 [ %c0 ] [ %M ] : memref<?xi8> to memref<?xf32>
%2 = view %arg0 [ %c0 ] [ %M ] : memref<?xi8> to memref<?xf32>
%3 = view %arg0 [ %c0 ] [ ] : memref<?xi8> to memref<f32>
linalg.dot ins( %1 , %2 : memref<?xf32> , memref<?xf32> ) outs( %3 : memref<f32> )
linalg.dot ins ( %1 , %2 : memref<?xf32> , memref<?xf32> ) outs ( %3 : memref<f32> )
return
}
}""")
Expand All @@ -69,8 +69,8 @@ def test_dot():
def test_fill():
assert_roundtrip_equivalence("""module {
func.func @fill_view(%arg0: f32, %arg1: tensor<?x?xf32>) {
linalg.fill ins( %arg0 : f32 ) outs( %arg1 : tensor<?x?xf32> ) -> tensor<?x?xf32>
linalg.fill ins( %arg0 : f32 ) outs( %arg1 : tensor<?x?xf32> )
linalg.fill ins ( %arg0 : f32 ) outs ( %arg1 : tensor<?x?xf32> ) -> tensor<?x?xf32>
linalg.fill ins ( %arg0 : f32 ) outs ( %arg1 : tensor<?x?xf32> )
return
}
}""")
Expand All @@ -79,7 +79,7 @@ def test_fill():
def test_generic():
assert_roundtrip_equivalence("""module {
func.func @example(%A: memref<?x?xf64>, %B: memref<?x?xf64>, %C: memref<?x?xf64>) {
linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} ins( %A, %B : memref<?x?xf64>, memref<?x?xf64> ) outs( %C : memref<?x?xf64> ) {
linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} ins ( %A, %B : memref<?x?xf64>, memref<?x?xf64> ) outs ( %C : memref<?x?xf64> ) {
^bb0 (%a: f64, %b: f64, %c: f64):
%c0 = constant 3.14 : f64
%d = addf %a , %b : f64
Expand All @@ -93,7 +93,7 @@ def test_generic():
def test_indexed_generic():
assert_roundtrip_equivalence("""module {
func.func @indexed_generic_region(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>, %arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>, %arg2: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins( %arg0 : memref<?x?xf32, strided<[?, 1], offset: ?>> ) outs( %arg1, %arg2 : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>, memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> ) {
linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins ( %arg0 : memref<?x?xf32, strided<[?, 1], offset: ?>> ) outs ( %arg1, %arg2 : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>, memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> ) {
^bb0 (%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
%result_1 = mulf %a , %b : f32
%ij = addi %i , %j : index
Expand All @@ -110,7 +110,7 @@ def test_indexed_generic():
def test_reduce():
assert_roundtrip_equivalence("""module {
func.func @reduce(%arg0: tensor<16x32x64xf32>, %arg1: tensor<16x64xf32>) {
%reduce = linalg.reduce ins( %arg0 : tensor<16x32x64xf32> ) outs( %arg1 : tensor<16x64xf32> ) dimensions = [ 1 ] ( %in: f32, %out: f32 ) {
%reduce = linalg.reduce ins ( %arg0 : tensor<16x32x64xf32> ) outs ( %arg1 : tensor<16x64xf32> ) dimensions = [ 1 ] ( %in: f32, %out: f32 ) {
%0 = arith.addf %out, %in : f32
linalg.yield %0 : f32
}
Expand Down Expand Up @@ -146,8 +146,8 @@ def test_matmul():
%A = view %arg0 [ %c0 ] [ %M, %K ] : memref<?xi8> to memref<?x?xf32>
%B = view %arg0 [ %c0 ] [ %K, %N ] : memref<?xi8> to memref<?x?xf32>
%C = view %arg0 [ %c0 ] [ %M, %N ] : memref<?xi8> to memref<?x?xf32>
linalg.matmul ins( %A , %B : memref<?x?xf32> , memref<?x?xf32> ) outs( %C : memref<?x?xf32> )
linalg.matmul ins( %A , %B : memref<?x?xf32> , memref<?x?xf32> ) outs( %C : memref<?x?xf32> ) -> memref<?x?xf32>
linalg.matmul ins ( %A , %B : memref<?x?xf32> , memref<?x?xf32> ) outs ( %C : memref<?x?xf32> )
linalg.matmul ins ( %A , %B : memref<?x?xf32> , memref<?x?xf32> ) outs ( %C : memref<?x?xf32> ) -> memref<?x?xf32>
return
}
}""")
Expand All @@ -161,7 +161,7 @@ def test_matvec():
%2 = view %arg0 [ %c0 ] [ %M, %N ] : memref<?xi8> to memref<?x?xf32>
%3 = view %arg0 [ %c0 ] [ %M ] : memref<?xi8> to memref<?xf32>
%4 = view %arg0 [ %c0 ] [ %N ] : memref<?xi8> to memref<?xf32>
linalg.matvec ins( %2 , %3 : memref<?x?xf32> , memref<?xf32> ) outs( %4 : memref<?xf32> )
linalg.matvec ins ( %2 , %3 : memref<?x?xf32> , memref<?xf32> ) outs ( %4 : memref<?xf32> )
return
}
}""")
Expand Down

0 comments on commit 12bb9dd

Please sign in to comment.