Skip to content

Commit

Permalink
Bit operations: change types and implement
Browse files Browse the repository at this point in the history
This changes the way that types are printed to be consistent
with the integer operations:

    asl.foo %0, %1 : (!asl.bits<4>, !asl.bits<4>) -> !asl.bits<4>

And it adds a bunch more operations and implements them
  • Loading branch information
alastairreid committed Jan 15, 2025
1 parent 9b06be6 commit ba5eef3
Show file tree
Hide file tree
Showing 3 changed files with 384 additions and 22 deletions.
194 changes: 187 additions & 7 deletions asl_xdsl/dialects/asl.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __init__(self, value: int, type: BitVectorType):
super().__init__([builtin.IntAttr(value), type])

def _verify(self) -> None:
if self.value.data < 0 or self.value.data >= self.maximum_value():
if self.value.data < 0 or self.value.data > self.maximum_value():
raise VerifyException(
f"Value {self.value.data} is out of range "
f"for width {self.type.width.data}"
Expand Down Expand Up @@ -623,7 +623,9 @@ class BinaryBitsOp(IRDLOperation, ABC):
rhs = operand_def(T)
res = result_def(T)

assembly_format = "$lhs `,` $rhs `:` type($res) attr-dict"
assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
Expand Down Expand Up @@ -652,6 +654,13 @@ class SubBitsOp(BinaryBitsOp):
name = "asl.sub_bits"


@irdl_op_definition
class MulBitsOp(BinaryBitsOp):
"""A bit vector multiplication operation."""

name = "asl.mul_bits"


@irdl_op_definition
class AndBitsOp(BinaryBitsOp):
"""A bit vector AND operation."""
Expand Down Expand Up @@ -685,7 +694,9 @@ class AddBitsIntOp(IRDLOperation):
rhs = operand_def(IntegerType())
res = result_def(T)

assembly_format = "$lhs `,` $rhs `:` type($res) attr-dict"
assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
Expand All @@ -712,7 +723,96 @@ class SubBitsIntOp(IRDLOperation):
rhs = operand_def(IntegerType())
res = result_def(T)

assembly_format = "$lhs `,` $rhs `:` type($res) attr-dict"
assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
result_types=[lhs.type],
attributes=attr_dict,
)


@irdl_op_definition
class LslBitsOp(IRDLOperation):
"""A bit vector logical left shift operation."""

name = "asl.lsl_bits"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

lhs = operand_def(T)
rhs = operand_def(IntegerType())
res = result_def(T)

assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
result_types=[lhs.type],
attributes=attr_dict,
)


@irdl_op_definition
class LsrBitsOp(IRDLOperation):
"""A bit vector logical shift right operation."""

name = "asl.lsr_bits"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

lhs = operand_def(T)
rhs = operand_def(IntegerType())
res = result_def(T)

assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
result_types=[lhs.type],
attributes=attr_dict,
)


@irdl_op_definition
class AsrBitsOp(IRDLOperation):
"""A bit vector arithmetic shift right operation."""

name = "asl.asr_bits"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

lhs = operand_def(T)
rhs = operand_def(IntegerType())
res = result_def(T)

assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
Expand All @@ -738,7 +838,7 @@ class NotBitsOp(IRDLOperation):
arg = operand_def(T)
res = result_def(T)

assembly_format = "$arg `:` type($res) attr-dict"
assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"

def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
super().__init__(
Expand All @@ -748,6 +848,48 @@ def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
)


@irdl_op_definition
class CvtBitsSIntOp(IRDLOperation):
"""A conversion operation from bitvectors to signed integers."""

name = "asl.cvt_bits_sint"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

arg = operand_def(T)
res = result_def(IntegerType)

assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"

def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
super().__init__(
operands=[arg],
result_types=[IntegerType()],
attributes=attr_dict,
)


@irdl_op_definition
class CvtBitsUIntOp(IRDLOperation):
"""A conversion operation from bitvectors to unsigned integers."""

name = "asl.cvt_bits_uint"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

arg = operand_def(T)
res = result_def(IntegerType)

assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"

def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
super().__init__(
operands=[arg],
result_types=[IntegerType()],
attributes=attr_dict,
)


@irdl_op_definition
class EqBitsOp(IRDLOperation):
"""A bit vector EQ operation."""
Expand All @@ -760,7 +902,9 @@ class EqBitsOp(IRDLOperation):
rhs = operand_def(T)
res = result_def(builtin.IntegerType(1))

assembly_format = "$lhs `,` $rhs `:` type($lhs) attr-dict"
assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
Expand All @@ -787,7 +931,9 @@ class NeBitsOp(IRDLOperation):
rhs = operand_def(T)
res = result_def(builtin.IntegerType(1))

assembly_format = "$lhs `,` $rhs `:` type($lhs) attr-dict"
assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
Expand All @@ -802,6 +948,33 @@ def __init__(
)


@irdl_op_definition
class PrintBitsHexOp(IRDLOperation):
"""A bit vector print function."""

# Eventually, this should be an external function
# This is just a workaround until we can cope with
# bitwidth polymorphism.

name = "asl.print_bits_hex"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))
arg = operand_def(T)

assembly_format = "$arg `:` type($arg) `->` `(` `)` attr-dict"

def __init__(
self,
arg: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[arg],
result_types=[],
attributes=attr_dict,
)


class FuncOpCallableInterface(CallableOpInterface):
@classmethod
def get_callable_region(cls, op: Operation) -> Region:
Expand Down Expand Up @@ -1051,14 +1224,21 @@ def __init__(
# Bits operations
AddBitsOp,
SubBitsOp,
MulBitsOp,
AndBitsOp,
OrBitsOp,
XorBitsOp,
LslBitsOp,
LsrBitsOp,
AsrBitsOp,
AddBitsIntOp,
SubBitsIntOp,
NotBitsOp,
CvtBitsSIntOp,
CvtBitsUIntOp,
EqBitsOp,
NeBitsOp,
PrintBitsHexOp,
# Functions
ReturnOp,
FuncOp,
Expand Down
Loading

0 comments on commit ba5eef3

Please sign in to comment.