Skip to content

Commit

Permalink
core: Fix and test empty variadic directives in declarative assembly …
Browse files Browse the repository at this point in the history
…format. (#2069)
  • Loading branch information
PapyChacal authored Feb 1, 2024
1 parent 9864f8d commit 9a4b28c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
10 changes: 10 additions & 0 deletions tests/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,11 @@ class TwoOperandsOp(IRDLOperation):
@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"$args type($args) attr-dict",
'%0 = "test.op"() : () -> i32\n' "test.variadic_operand ",
'%0 = "test.op"() : () -> i32\n' '"test.variadic_operand"() : () -> ()',
),
(
"$args type($args) attr-dict",
'%0 = "test.op"() : () -> i32\n' "test.variadic_operand %0 i32",
Expand Down Expand Up @@ -718,6 +723,11 @@ class TwoResultOp(IRDLOperation):
@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"`:` type($res) attr-dict",
"test.variadic_result : ",
'"test.variadic_result"() : () -> ()',
),
(
"`:` type($res) attr-dict",
"%0 = test.variadic_result : i32",
Expand Down
18 changes: 12 additions & 6 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,11 @@ class VariadicOperandVariable(OperandVariable):
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
operands = parser.parse_comma_separated_list(
parser.Delimiter.NONE, parser.parse_unresolved_operand
operands = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_unresolved_operand, parser.parse_unresolved_operand
)
if operands is None:
operands = []
state.operands[self.index] = cast(list[UnresolvedOperand | None], operands)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
Expand Down Expand Up @@ -411,9 +413,11 @@ class VariadicOperandTypeDirective(OperandTypeDirective):
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
operand_types = parser.parse_comma_separated_list(
parser.Delimiter.NONE, parser.parse_type
operand_types = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
if operand_types is None:
operand_types = []
state.operand_types[self.index] = cast(list[Attribute | None], operand_types)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
Expand Down Expand Up @@ -540,9 +544,11 @@ class VariadicResultTypeDirective(ResultTypeDirective):
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
result_types = parser.parse_comma_separated_list(
parser.Delimiter.NONE, parser.parse_type
result_types = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
if result_types is None:
result_types = []
state.result_types[self.index] = cast(list[Attribute | None], result_types)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
Expand Down

0 comments on commit 9a4b28c

Please sign in to comment.