Skip to content

Commit

Permalink
Expand scf dialect implementation (#38)
Browse files Browse the repository at this point in the history
* Add and update scf condition, if, while, yield

* Test cases for scf dialect

* Additional scf.for syntax
  • Loading branch information
amanda849 authored Oct 16, 2024
1 parent 28b8ada commit 4f0986a
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 42 deletions.
10 changes: 10 additions & 0 deletions mlir/astnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,16 @@ def dump(self, indent: int = 0) -> str:
return result


@dataclass
class ArgumentAssignment(Node):
name: SsaId
value: SsaId

def dump(self, indent: int = 0) -> str:
return '%s = %s' % (dump_or_value(self.name, indent),
dump_or_value(self.value, indent))


@dataclass
class MLIRFile(Node):
definitions: List["Definition"]
Expand Down
4 changes: 1 addition & 3 deletions mlir/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class CallIndirectOperation(DialectOp):
func: mast.SymbolRefId
type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['func.call_indirect {func.symbol_ref_id} () : {type.function_type}',
'func.call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}']

Expand All @@ -25,9 +24,8 @@ class CallOperation(DialectOp):
func: mast.SymbolRefId
type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['func.call {func.symbol_ref_id} () : {type.function_type}',
'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {argtypes.function_type}']
'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}']

@dataclass
class ConstantOperation(DialectOp):
Expand Down
35 changes: 31 additions & 4 deletions mlir/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from typing import Optional, List, Tuple


@dataclass
class SCFConditionOp(DialectOp):
condition: mast.SsaId
args: List[mast.SsaId]
out_types: List[mast.Type]
_syntax_ = ['scf.condition ( {condition.ssa_id} ) {args.ssa_id_list} : {out_types.type_list_no_parens}']


@dataclass
class SCFForOp(DialectOp):
index: mast.SsaId
Expand All @@ -20,20 +28,39 @@ class SCFForOp(DialectOp):
out_type: Optional[mast.Type] = None
_syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} : {out_type.type} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} : {out_type.type} {body.region}']
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> {iter_args_types.type_list_no_parens} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> {iter_args_types.type_list_no_parens} : {out_type.type} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) : {out_type.type} {body.region}']


@dataclass
class SCFIfOp(DialectOp):
cond: mast.SsaId
body: mast.Region
elsebody: Optional[mast.Region] = None
out_types: Optional[List[mast.Type]] = None
_syntax_ = ['scf.if {cond.ssa_id} {body.region}',
'scf.if {cond.ssa_id} {body.region} else {elsebody.region}']
'scf.if {cond.ssa_id} {body.region} else {elsebody.region}',
'scf.if {cond.ssa_id} -> ( {out_types.type_list_no_parens} ) {body.region}',
'scf.if {cond.ssa_id} -> ( {out_types.type_list_no_parens} ) {body.region} else {elsebody.region}']


class SCFYield(UnaryOperation): _opname_ = 'scf.yield'
@dataclass
class SCFWhileOp(DialectOp):
assignments: List[Tuple[mast.SsaId, mast.Type]]
out_type: mast.FunctionType
while_body: mast.Region
do_body: mast.Region
_syntax_ = ['scf.while ( {assignments.argument_assignment_list_no_parens} ) : {out_type.function_type} {while_body.region} do {do_body.region}']


@dataclass
class SCFYield(DialectOp):
results: Optional[List[mast.SsaId]] = None
result_types: Optional[List[mast.Type]] = None
_syntax_ = ['scf.yield',
'scf.yield {results.ssa_id_list} : {result_types.type_list_no_parens}']


# Inspect current module to get all classes defined above
Expand Down
1 change: 1 addition & 0 deletions mlir/parser_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def block_label(self, value):
function = astnodes.Function.from_lark
generic_module = astnodes.GenericModule.from_lark
named_argument = astnodes.NamedArgument.from_lark
argument_assignment = astnodes.ArgumentAssignment.from_lark

###############################################################
# (semi-)Affine expressions, maps, and integer sets
Expand Down
11 changes: 4 additions & 7 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
from mlir import parse_string
from mlir.builder import IRBuilder
from mlir.builder import Reads, Writes, Isa
Expand Down Expand Up @@ -103,9 +102,7 @@ def index(expr):
assert index(Reads(a0)) == 2


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])
if __name__ == '__main__':
test_saxpy_builder()
test_query()
test_build_with_queries()
20 changes: 12 additions & 8 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import sys
import mlir
from mlir.dialects.func import func

# All source strings have been taken from MLIR's codebase.
# See llvm-project/mlir/test/Dialect/Linalg
Expand Down Expand Up @@ -186,9 +184,15 @@ def test_transpose():
}""")


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])
if __name__ == '__main__':
test_batch_matmul()
test_conv()
test_copy()
test_dot()
test_fill()
test_generic()
test_indexed_generic()
test_reduce()
test_view()
test_matmul()
test_matvec()
54 changes: 54 additions & 0 deletions tests/test_scf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import mlir

# All source strings taken from examples in https://mlir.llvm.org/docs/Dialects/SCFDialect/


def assert_roundtrip_equivalence(source):
assert source == mlir.parse_string(source).dump()


def test_scf_for():
assert_roundtrip_equivalence("""module {
func.func @reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index, %sum_0: f32) -> (f32) {
%sum = scf.for %iv = %lb to %ub step %step iter_args ( %sum_iter = %sum_0 ) -> ( f32 ) {
%t = load %buffer [ %iv ] : memref<1024xf32>
%sum_next = arith.addf %sum_iter, %t : f32
scf.yield %sum_next : f32
}
return %sum : f32
}
}""")


def test_scf_if():
assert_roundtrip_equivalence("""module {
func.func @example(%A: f32, %B: f32, %C: f32, %D: f32) {
%x, %y = scf.if %b -> ( f32, f32 ) {
scf.yield %A, %B : f32, f32
} else {
scf.yield %C, %D : f32, f32
}
return
}
}""")


def test_scf_while():
assert_roundtrip_equivalence("""module {
func.func @example(%A: f32, %B: f32, %C: f32, %D: f32) {
%res = scf.while ( %arg1 = %init1 ) : (f32) -> f32 {
%condition = func.call @evaluate_condition ( %arg1 ) : (f32) -> i1
scf.condition ( %condition ) %arg1 : f32
} do {
^bb0 (%arg2: f32):
%next = func.call @payload ( %arg2 ) : (f32) -> f32
scf.yield %next : f32
}
}
}""")


if __name__ == '__main__':
test_scf_for()
test_scf_if()
test_scf_while()
20 changes: 0 additions & 20 deletions tests/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,6 @@ def test_affine(parser: Optional[Parser] = None):
module = parser.parse(code)
print(module.pretty())

def test_scf_for(parser: Optional[Parser] = None):
code = """
module {
func.func @reduce(%buffer: memref<1024xf32>, %lb: index,
%ub: index, %step: index) -> (f32) {
%sum_0 = arith.constant 0.0 : f32
%sum = scf.for %iv = %lb to %ub step %step
iter_args(%sum_iter = %sum_0) -> (f32) {
%t = load %buffer[%iv] : memref<1024xf32>
%sum_next = arith.addf %sum_iter, %t : f32
scf.yield %sum_next : f32
}
return %sum : f32
}
}
"""
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())

def test_definitions(parser: Optional[Parser] = None):
code = '''
#map0 = affine_map<(d0, d1) -> (d0, d1)>
Expand Down

0 comments on commit 4f0986a

Please sign in to comment.