From 9a4b28c6e31cc5e18dc36b632b0a00eff6fe15d7 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Thu, 1 Feb 2024 11:03:57 +0000 Subject: [PATCH] core: Fix and test empty variadic directives in declarative assembly format. (#2069) --- tests/test_declarative_assembly_format.py | 10 ++++++++++ xdsl/irdl/declarative_assembly_format.py | 18 ++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/tests/test_declarative_assembly_format.py b/tests/test_declarative_assembly_format.py index 3fa0f5ed60..8f28661425 100644 --- a/tests/test_declarative_assembly_format.py +++ b/tests/test_declarative_assembly_format.py @@ -527,6 +527,11 @@ class TwoOperandsOp(IRDLOperation): @pytest.mark.parametrize( "format, program, generic_program", [ + ( + "$args type($args) attr-dict", + '%0 = "test.op"() : () -> i32\n' "test.variadic_operand ", + '%0 = "test.op"() : () -> i32\n' '"test.variadic_operand"() : () -> ()', + ), ( "$args type($args) attr-dict", '%0 = "test.op"() : () -> i32\n' "test.variadic_operand %0 i32", @@ -718,6 +723,11 @@ class TwoResultOp(IRDLOperation): @pytest.mark.parametrize( "format, program, generic_program", [ + ( + "`:` type($res) attr-dict", + "test.variadic_result : ", + '"test.variadic_result"() : () -> ()', + ), ( "`:` type($res) attr-dict", "%0 = test.variadic_result : i32", diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 2a150e1e5a..bdb646a83d 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -364,9 +364,11 @@ class VariadicOperandVariable(OperandVariable): """ def parse(self, parser: Parser, state: ParsingState) -> None: - operands = parser.parse_comma_separated_list( - parser.Delimiter.NONE, parser.parse_unresolved_operand + operands = parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_unresolved_operand, parser.parse_unresolved_operand ) + if operands is None: + operands = [] state.operands[self.index] = cast(list[UnresolvedOperand | None], operands) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: @@ -411,9 +413,11 @@ class VariadicOperandTypeDirective(OperandTypeDirective): """ def parse(self, parser: Parser, state: ParsingState) -> None: - operand_types = parser.parse_comma_separated_list( - parser.Delimiter.NONE, parser.parse_type + operand_types = parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_type, parser.parse_type ) + if operand_types is None: + operand_types = [] state.operand_types[self.index] = cast(list[Attribute | None], operand_types) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: @@ -540,9 +544,11 @@ class VariadicResultTypeDirective(ResultTypeDirective): """ def parse(self, parser: Parser, state: ParsingState) -> None: - result_types = parser.parse_comma_separated_list( - parser.Delimiter.NONE, parser.parse_type + result_types = parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_type, parser.parse_type ) + if result_types is None: + result_types = [] state.result_types[self.index] = cast(list[Attribute | None], result_types) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: