Skip to content

Commit

Permalink
Update scf.for syntax (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
amanda849 authored Nov 27, 2023
1 parent 5a9a019 commit 9d0d8f8
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
13 changes: 9 additions & 4 deletions mlir/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@
from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation
import mlir.astnodes as mast
from dataclasses import dataclass
from typing import Optional
from typing import Optional, List, Tuple


@dataclass
class SCFForOp(DialectOp):
index: mast.SsaId
begin: mast.SsaId
end: mast.SsaId
step: mast.SsaId
body: mast.Region
step: Optional[mast.SsaId] = None
_syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}']
iter_args: Optional[List[Tuple[mast.SsaId, mast.SsaId]]] = None
iter_args_types: Optional[List[mast.Type]] = None
out_type: Optional[mast.Type] = None
_syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} : {out_type.type} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} {body.region}',
'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} : {out_type.type} {body.region}']


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions mlir/lark/mlir.lark
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ region_list : "(" region? ("," region)* ")"
// Arguments
named_argument : ssa_id ":" type optional_attr_dict
argument_list : (named_argument ("," named_argument)*) | (type optional_attr_dict ("," type optional_attr_dict)*)
argument_assignment : ssa_id "=" ssa_id
argument_assignment_list_no_parens : argument_assignment ("," argument_assignment)*
argument_assignment_list_parens : ("(" ")") | ("(" argument_assignment_list_no_parens ")")

// Return values
function_result : type optional_attr_dict
Expand Down
1 change: 1 addition & 0 deletions mlir/parser_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def block_label(self, value):
symbol_use_list = list
operation_list = list
argument_list = list
argument_assignment_list_no_parens = list
definition_list = list
function_list = list
module_list = list
Expand Down
19 changes: 19 additions & 0 deletions tests/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,25 @@ def test_affine(parser: Optional[Parser] = None):
module = parser.parse(code)
print(module.pretty())

def test_scf_for(parser: Optional[Parser] = None):
code = """
module {
func.func @reduce(%buffer: memref<1024xf32>, %lb: index,
%ub: index, %step: index) -> (f32) {
%sum_0 = arith.constant 0.0 : f32
%sum = scf.for %iv = %lb to %ub step %step
iter_args(%sum_iter = %sum_0) -> (f32) {
%t = load %buffer[%iv] : memref<1024xf32>
%sum_next = arith.addf %sum_iter, %t : f32
scf.yield %sum_next : f32
}
return %sum : f32
}
}
"""
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())

def test_definitions(parser: Optional[Parser] = None):
code = '''
Expand Down

0 comments on commit 9d0d8f8

Please sign in to comment.