Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand scf dialect implementation #38

Merged
merged 7 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/astnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,16 @@ def dump(self, indent: int = 0) -> str:
return result


@dataclass
class ArgumentAssignment(Node):
name: SsaId
value: SsaId

def dump(self, indent: int = 0) -> str:
return '%s = %s' % (dump_or_value(self.name, indent),
dump_or_value(self.value, indent))


@dataclass
class MLIRFile(Node):
definitions: List["Definition"]
Expand Down
4 changes: 1 addition & 3 deletions mlir/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class CallIndirectOperation(DialectOp):
func: mast.SymbolRefId
type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['func.call_indirect {func.symbol_ref_id} () : {type.function_type}',
'func.call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}']

Expand All @@ -25,9 +24,8 @@ class CallOperation(DialectOp):
func: mast.SymbolRefId
type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['func.call {func.symbol_ref_id} () : {type.function_type}',
'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {argtypes.function_type}']
'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}']

@dataclass
class ConstantOperation(DialectOp):
Expand Down
35 changes: 31 additions & 4 deletions mlir/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from typing import Optional, List, Tuple


@dataclass
class SCFConditionOp(DialectOp):
condition: mast.SsaId
args: List[mast.SsaId]
out_types: List[mast.Type]
_syntax_ = ['scf.condition ( {condition.ssa_id} ) {args.ssa_id_list} : {out_types.type_list_no_parens}']


@dataclass
class SCFForOp(DialectOp):
index: mast.SsaId
Expand All @@ -20,20 +28,39 @@ class SCFForOp(DialectOp):
out_type: Optional[mast.Type] = None
_syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} : {out_type.type} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} : {out_type.type} {body.region}']
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> {iter_args_types.type_list_no_parens} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> {iter_args_types.type_list_no_parens} : {out_type.type} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args ( {iter_args.argument_assignment_list_no_parens} ) -> ( {iter_args_types.type_list_no_parens} ) : {out_type.type} {body.region}']


@dataclass
class SCFIfOp(DialectOp):
cond: mast.SsaId
body: mast.Region
elsebody: Optional[mast.Region] = None
out_types: Optional[List[mast.Type]] = None
_syntax_ = ['scf.if {cond.ssa_id} {body.region}',
'scf.if {cond.ssa_id} {body.region} else {elsebody.region}']
'scf.if {cond.ssa_id} {body.region} else {elsebody.region}',
'scf.if {cond.ssa_id} -> ( {out_types.type_list_no_parens} ) {body.region}',
'scf.if {cond.ssa_id} -> ( {out_types.type_list_no_parens} ) {body.region} else {elsebody.region}']


class SCFYield(UnaryOperation): _opname_ = 'scf.yield'
@dataclass
class SCFWhileOp(DialectOp):
assignments: List[Tuple[mast.SsaId, mast.Type]]
out_type: mast.FunctionType
while_body: mast.Region
do_body: mast.Region
_syntax_ = ['scf.while ( {assignments.argument_assignment_list_no_parens} ) : {out_type.function_type} {while_body.region} do {do_body.region}']


@dataclass
class SCFYield(DialectOp):
results: Optional[List[mast.SsaId]] = None
result_types: Optional[List[mast.Type]] = None
_syntax_ = ['scf.yield',
'scf.yield {results.ssa_id_list} : {result_types.type_list_no_parens}']


# Inspect current module to get all classes defined above
Expand Down
1 change: 1 addition & 0 deletions mlir/parser_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def block_label(self, value):
function = astnodes.Function.from_lark
generic_module = astnodes.GenericModule.from_lark
named_argument = astnodes.NamedArgument.from_lark
argument_assignment = astnodes.ArgumentAssignment.from_lark

###############################################################
# (semi-)Affine expressions, maps, and integer sets
Expand Down
11 changes: 4 additions & 7 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
from mlir import parse_string
from mlir.builder import IRBuilder
from mlir.builder import Reads, Writes, Isa
Expand Down Expand Up @@ -103,9 +102,7 @@ def index(expr):
assert index(Reads(a0)) == 2


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])
if __name__ == '__main__':
test_saxpy_builder()
test_query()
test_build_with_queries()
20 changes: 12 additions & 8 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import sys
import mlir
from mlir.dialects.func import func

# All source strings have been taken from MLIR's codebase.
# See llvm-project/mlir/test/Dialect/Linalg
Expand Down Expand Up @@ -167,9 +165,15 @@ def test_matvec():
}""")


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])
if __name__ == '__main__':
test_batch_matmul()
test_conv()
test_copy()
test_dot()
test_fill()
test_generic()
test_indexed_generic()
test_reduce()
test_view()
test_matmul()
test_matvec()
54 changes: 54 additions & 0 deletions tests/test_scf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import mlir

# All source strings taken from examples in https://mlir.llvm.org/docs/Dialects/SCFDialect/


def assert_roundtrip_equivalence(source):
assert source == mlir.parse_string(source).dump()


def test_scf_for():
assert_roundtrip_equivalence("""module {
func.func @reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index, %sum_0: f32) -> (f32) {
%sum = scf.for %iv = %lb to %ub step %step iter_args ( %sum_iter = %sum_0 ) -> ( f32 ) {
%t = load %buffer [ %iv ] : memref<1024xf32>
%sum_next = arith.addf %sum_iter, %t : f32
scf.yield %sum_next : f32
}
return %sum : f32
}
}""")


def test_scf_if():
assert_roundtrip_equivalence("""module {
func.func @example(%A: f32, %B: f32, %C: f32, %D: f32) {
%x, %y = scf.if %b -> ( f32, f32 ) {
scf.yield %A, %B : f32, f32
} else {
scf.yield %C, %D : f32, f32
}
return
}
}""")


def test_scf_while():
assert_roundtrip_equivalence("""module {
func.func @example(%A: f32, %B: f32, %C: f32, %D: f32) {
%res = scf.while ( %arg1 = %init1 ) : (f32) -> f32 {
%condition = func.call @evaluate_condition ( %arg1 ) : (f32) -> i1
scf.condition ( %condition ) %arg1 : f32
} do {
^bb0 (%arg2: f32):
%next = func.call @payload ( %arg2 ) : (f32) -> f32
scf.yield %next : f32
}
}
}""")


if __name__ == '__main__':
test_scf_for()
test_scf_if()
test_scf_while()
20 changes: 0 additions & 20 deletions tests/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,6 @@ def test_affine(parser: Optional[Parser] = None):
module = parser.parse(code)
print(module.pretty())

def test_scf_for(parser: Optional[Parser] = None):
code = """
module {
func.func @reduce(%buffer: memref<1024xf32>, %lb: index,
%ub: index, %step: index) -> (f32) {
%sum_0 = arith.constant 0.0 : f32
%sum = scf.for %iv = %lb to %ub step %step
iter_args(%sum_iter = %sum_0) -> (f32) {
%t = load %buffer[%iv] : memref<1024xf32>
%sum_next = arith.addf %sum_iter, %t : f32
scf.yield %sum_next : f32
}
return %sum : f32
}
}
"""
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())

def test_definitions(parser: Optional[Parser] = None):
code = '''
#map0 = affine_map<(d0, d1) -> (d0, d1)>
Expand Down
Loading