Skip to content

Commit

Permalink
Revert "dialects: (llvm) Add support for overflow flags" (#3290)
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot authored Oct 11, 2024
1 parent 51bc3dc commit c33f7b3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 154 deletions.
26 changes: 5 additions & 21 deletions tests/dialects/test_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
@pytest.mark.parametrize(
"op_type, attributes",
[
(llvm.AddOp, {}),
(llvm.AddOp, {"attr1": UnitAttr()}),
(llvm.SubOp, {}),
(llvm.MulOp, {}),
(llvm.UDivOp, {}),
(llvm.SDivOp, {}),
(llvm.URemOp, {}),
(llvm.SRemOp, {}),
(llvm.AndOp, {}),
(llvm.OrOp, {}),
(llvm.XOrOp, {}),
(llvm.ShlOp, {}),
(llvm.LShrOp, {}),
(llvm.AShrOp, {}),
],
Expand All @@ -36,27 +41,6 @@ def test_llvm_arithmetic_ops(
)


@pytest.mark.parametrize(
"op_type, attributes, overflow",
[
(llvm.AddOp, {}, llvm.OverflowAttr(None)),
(llvm.AddOp, {"attr1": UnitAttr()}, llvm.OverflowAttr(None)),
(llvm.SubOp, {}, llvm.OverflowAttr(None)),
(llvm.MulOp, {}, llvm.OverflowAttr(None)),
(llvm.ShlOp, {}, llvm.OverflowAttr(None)),
],
)
def test_llvm_overflow_arithmetic_ops(
op_type: type[llvm.ArithmeticBinOpOverflow],
attributes: dict[str, Attribute],
overflow: llvm.OverflowAttr,
):
op1, op2 = test.TestOp(result_types=[i32, i32]).results
assert op_type(op1, op2, attributes).is_structurally_equivalent(
op_type(lhs=op1, rhs=op2, attributes=attributes, overflow=overflow)
)


def test_llvm_pointer_ops():
module = builtin.ModuleOp(
[
Expand Down
81 changes: 31 additions & 50 deletions tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
@@ -1,66 +1,47 @@
// RUN: XDSL_ROUNDTRIP

%arg0, %arg1 = "test.op"() : () -> (i32, i32)
builtin.module {
%arg0, %arg1 = "test.op"() : () -> (i32, i32)

%add_nsw = llvm.add %arg0, %arg1 overflow<nsw> : i32
// CHECK: %add_nsw = llvm.add %arg0, %arg1 overflow<nsw> : i32
%add = llvm.add %arg0, %arg1 : i32
// CHECK: %add = llvm.add %arg0, %arg1 : i32

%add_nuw = llvm.add %arg0, %arg1 {"overflowFlags" = #llvm.overflow<nuw>} : i32
// CHECK: %add_nuw = llvm.add %arg0, %arg1 overflow<nuw> : i32
%add2 = llvm.add %arg0, %arg1 {"nsw"} : i32
// CHECK: %add2 = llvm.add %arg0, %arg1 {"nsw"} : i32

%add_none = llvm.add %arg0, %arg1 {"overflowFlags" = #llvm.overflow<none>} : i32
// CHECK: %add_none = llvm.add %arg0, %arg1 : i32
%sub = llvm.sub %arg0, %arg1 : i32
// CHECK: %sub = llvm.sub %arg0, %arg1 : i32

%add_both = llvm.add %arg0, %arg1 {"overflowFlags" = #llvm.overflow<nsw, nuw>} : i32
// CHECK: %add_both = llvm.add %arg0, %arg1 overflow<nsw,nuw> : i32
%mul = llvm.mul %arg0, %arg1 : i32
// CHECK: %mul = llvm.mul %arg0, %arg1 : i32

%add_both_reverse = llvm.add %arg0, %arg1 {"overflowFlags" = #llvm.overflow<nuw, nsw>} : i32
// CHECK: %add_both_reverse = llvm.add %arg0, %arg1 overflow<nsw,nuw> : i32
%udiv = llvm.udiv %arg0, %arg1 : i32
// CHECK: %udiv = llvm.udiv %arg0, %arg1 : i32

%add_both_pretty = llvm.add %arg0, %arg1 overflow<nsw, nuw> : i32
// CHECK: %add_both_pretty = llvm.add %arg0, %arg1 overflow<nsw,nuw> : i32
%sdiv = llvm.sdiv %arg0, %arg1 : i32
// CHECK: %sdiv = llvm.sdiv %arg0, %arg1 : i32

%sub = llvm.sub %arg0, %arg1 : i32
// CHECK: %sub = llvm.sub %arg0, %arg1 : i32
%urem = llvm.urem %arg0, %arg1 : i32
// CHECK: %urem = llvm.urem %arg0, %arg1 : i32

%sub_overflow = llvm.sub %arg0, %arg1 overflow<nsw> : i32
// CHECK: %sub_overflow = llvm.sub %arg0, %arg1 overflow<nsw> : i32
%srem = llvm.srem %arg0, %arg1 : i32
// CHECK: %srem = llvm.srem %arg0, %arg1 : i32

%mul = llvm.mul %arg0, %arg1 : i32
// CHECK: %mul = llvm.mul %arg0, %arg1 : i32
%and = llvm.and %arg0, %arg1 : i32
// CHECK: %and = llvm.and %arg0, %arg1 : i32

%mul_overflow = llvm.mul %arg0, %arg1 overflow<nsw> : i32
// CHECK: %mul_overflow = llvm.mul %arg0, %arg1 overflow<nsw> : i32
%or = llvm.or %arg0, %arg1 : i32
// CHECK: %or = llvm.or %arg0, %arg1 : i32

%udiv = llvm.udiv %arg0, %arg1 : i32
// CHECK: %udiv = llvm.udiv %arg0, %arg1 : i32
%xor = llvm.xor %arg0, %arg1 : i32
// CHECK: %xor = llvm.xor %arg0, %arg1 : i32

%sdiv = llvm.sdiv %arg0, %arg1 : i32
// CHECK: %sdiv = llvm.sdiv %arg0, %arg1 : i32
%shl = llvm.shl %arg0, %arg1 : i32
// CHECK: %shl = llvm.shl %arg0, %arg1 : i32

%urem = llvm.urem %arg0, %arg1 : i32
// CHECK: %urem = llvm.urem %arg0, %arg1 : i32
%lshr = llvm.lshr %arg0, %arg1 : i32
// CHECK: %lshr = llvm.lshr %arg0, %arg1 : i32

%srem = llvm.srem %arg0, %arg1 : i32
// CHECK: %srem = llvm.srem %arg0, %arg1 : i32

%and = llvm.and %arg0, %arg1 : i32
// CHECK: %and = llvm.and %arg0, %arg1 : i32

%or = llvm.or %arg0, %arg1 : i32
// CHECK: %or = llvm.or %arg0, %arg1 : i32

%xor = llvm.xor %arg0, %arg1 : i32
// CHECK: %xor = llvm.xor %arg0, %arg1 : i32

%shl = llvm.shl %arg0, %arg1 : i32
// CHECK: %shl = llvm.shl %arg0, %arg1 : i32

%shl_overflow = llvm.shl %arg0, %arg1 overflow<nsw> : i32
// CHECK: %shl_overflow = llvm.shl %arg0, %arg1 overflow<nsw> : i32

%lshr = llvm.lshr %arg0, %arg1 : i32
// CHECK: %lshr = llvm.lshr %arg0, %arg1 : i32

%ashr = llvm.ashr %arg0, %arg1 : i32
// CHECK: %ashr = llvm.ashr %arg0, %arg1 : i32
%ashr = llvm.ashr %arg0, %arg1 : i32
// CHECK: %ashr = llvm.ashr %arg0, %arg1 : i32
}
87 changes: 4 additions & 83 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,96 +394,18 @@ def print(self, printer: Printer) -> None:
printer.print(self.lhs.type)


class OverflowFlag(StrEnum):
NO_SIGNED_WRAP = "nsw"
NO_UNSIGNED_WRAP = "nuw"


@dataclass(frozen=True, init=False)
class OverflowAttrBase(BitEnumAttribute[OverflowFlag]):
none_value = "none"


@irdl_attr_definition
class OverflowAttr(OverflowAttrBase):
name = "llvm.overflow"


class ArithmeticBinOpOverflow(IRDLOperation, ABC):
"""Class for arithmetic binary operations that use overflow flags."""

T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", BaseAttr(IntegerType))

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)
overflowFlags = opt_prop_def(OverflowAttr)

traits = frozenset([NoMemoryEffect()])

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
attributes: dict[str, Attribute] = {},
overflow: OverflowAttr = OverflowAttr(None),
):
super().__init__(
operands=[lhs, rhs],
attributes=attributes,
result_types=[lhs.type],
properties={
"overflowFlags": overflow,
},
)

@classmethod
def parse_overflow(cls, parser: Parser) -> OverflowAttr:
if parser.parse_optional_keyword("overflow") is not None:
return OverflowAttr(OverflowAttr.parse_parameter(parser))
return OverflowAttr("none")

def print_overflow(self, printer: Printer) -> None:
if self.overflowFlags and self.overflowFlags.flags:
printer.print(" overflow")
self.overflowFlags.print_parameter(printer)

@classmethod
def parse(cls, parser: Parser):
lhs = parser.parse_unresolved_operand()
parser.parse_characters(",")
rhs = parser.parse_unresolved_operand()
overflowFlags = cls.parse_overflow(parser)
attributes = parser.parse_optional_attr_dict()
if "overflowFlags" in attributes:
flags = attributes.pop("overflowFlags")
if isinstance(flags, OverflowAttr):
overflowFlags = flags
parser.parse_characters(":")
type = parser.parse_type()
operands = parser.resolve_operands([lhs, rhs], [type, type], parser.pos)
return cls(operands[0], operands[1], attributes, overflowFlags)

def print(self, printer: Printer) -> None:
printer.print(" ", self.lhs, ", ", self.rhs)
self.print_overflow(printer)
printer.print_op_attributes(self.attributes)
printer.print(" : ")
printer.print(self.lhs.type)


@irdl_op_definition
class AddOp(ArithmeticBinOpOverflow):
class AddOp(ArithmeticBinOperation):
name = "llvm.add"


@irdl_op_definition
class SubOp(ArithmeticBinOpOverflow):
class SubOp(ArithmeticBinOperation):
name = "llvm.sub"


@irdl_op_definition
class MulOp(ArithmeticBinOpOverflow):
class MulOp(ArithmeticBinOperation):
name = "llvm.mul"


Expand Down Expand Up @@ -523,7 +445,7 @@ class XOrOp(ArithmeticBinOperation):


@irdl_op_definition
class ShlOp(ArithmeticBinOpOverflow):
class ShlOp(ArithmeticBinOperation):
name = "llvm.shl"


Expand Down Expand Up @@ -1400,6 +1322,5 @@ class ZeroOp(IRDLOperation):
LinkageAttr,
CallingConventionAttr,
FastMathAttr,
OverflowAttr,
],
)

0 comments on commit c33f7b3

Please sign in to comment.