Skip to content

Commit

Permalink
Merge pull request #1 from xdslproject/areid-int-ops
Browse files Browse the repository at this point in the history
Extend ASL support: strings, comparisons, etc.
  • Loading branch information
alastairreid authored Jan 15, 2025
2 parents 53b30ee + 2804869 commit 011134f
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 20 deletions.
70 changes: 64 additions & 6 deletions asl_xdsl/dialects/asl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,14 @@ def parse_parameter(cls, parser: AttrParser) -> bool:

def print_parameter(self, printer: Printer) -> None:
"""Print the attribute parameter."""
printer.print("true" if self.data else "false")
printer.print("<true>" if self.data else "<false>")


@irdl_attr_definition
class StringType(ParametrizedAttribute, TypeAttribute):
"""A string type."""

name = "asl.string"


@irdl_attr_definition
Expand Down Expand Up @@ -323,7 +330,7 @@ def __init__(
@classmethod
def parse(cls, parser: Parser) -> ConstantIntOp:
"""Parse the operation."""
value = parser.parse_integer(allow_boolean=False, allow_negative=False)
value = parser.parse_integer(allow_boolean=False, allow_negative=True)
attr_dict = parser.parse_optional_attr_dict()
return ConstantIntOp(value, attr_dict)

Expand Down Expand Up @@ -377,6 +384,29 @@ def print(self, printer: Printer) -> None:
printer.print_attr_dict(self.attributes)


@irdl_op_definition
class ConstantStringOp(IRDLOperation):
"""A constant string operation."""

name = "asl.constant_string"

value = prop_def(builtin.StringAttr)
res = result_def(StringType)

assembly_format = "$value attr-dict"

def __init__(
self, value: str | builtin.StringAttr, attr_dict: Mapping[str, Attribute] = {}
):
if isinstance(value, str):
value = builtin.StringAttr(value)
super().__init__(
result_types=[StringType()],
properties={"value": value},
attributes=attr_dict,
)


@irdl_op_definition
class NotOp(IRDLOperation):
"""A bitwise NOT operation."""
Expand All @@ -396,6 +426,25 @@ def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
)


@irdl_op_definition
class BoolToI1Op(IRDLOperation):
"""A hack to convert !asl.bool to i1 so that we can use scf.if."""

name = "asl.bool_to_i1"

arg = operand_def(BoolType())
res = result_def(builtin.IntegerType(1))

assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"

def __init__(self, arg: SSAValue, attr_dict: Mapping[str, Attribute] = {}):
super().__init__(
operands=[arg],
result_types=[builtin.IntegerType(1)],
attributes=attr_dict,
)


class BinaryBoolOp(IRDLOperation):
"""A binary boolean operation."""

Expand Down Expand Up @@ -464,7 +513,7 @@ class EquivBoolOp(BinaryBoolOp):
class NegateIntOp(IRDLOperation):
"""An integer negation operation."""

name = "asl.negate_int"
name = "asl.neg_int"

arg = operand_def(IntegerType)
res = result_def(IntegerType)
Expand Down Expand Up @@ -535,14 +584,14 @@ class ExpIntOp(BinaryIntOp):
class ShiftLeftIntOp(BinaryIntOp):
"""An integer left shift operation."""

name = "asl.shiftleft_int"
name = "asl.shl_int"


@irdl_op_definition
class ShiftRightIntOp(BinaryIntOp):
"""An integer right shift operation."""

name = "asl.shiftright_int"
name = "asl.shr_int"


@irdl_op_definition
Expand Down Expand Up @@ -1052,7 +1101,9 @@ def __init__(
ConstantBoolOp,
ConstantIntOp,
ConstantBitVectorOp,
ConstantStringOp,
# Boolean operations
BoolToI1Op,
NotOp,
AndBoolOp,
OrBoolOp,
Expand Down Expand Up @@ -1095,5 +1146,12 @@ def __init__(
# Slices
SliceSingleOp,
],
[BoolType, BoolAttr, IntegerType, BitVectorType, BitVectorAttr],
[
BoolType,
BoolAttr,
IntegerType,
BitVectorType,
BitVectorAttr,
StringType,
],
)
148 changes: 145 additions & 3 deletions asl_xdsl/interpreters/asl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def run_call(
) -> tuple[Any, ...]:
return interpreter.call_op(op.callee.string_value(), args)

@impl(asl.NegateIntOp)
def run_neg_int(
self, interpreter: Interpreter, op: asl.NegateIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
arg: int
[arg] = args
return (0 - arg,)

@impl(asl.AddIntOp)
def run_add_int(
self, interpreter: Interpreter, op: asl.AddIntOp, args: tuple[Any, ...]
Expand All @@ -49,16 +57,132 @@ def run_add_int(
(lhs, rhs) = args
return (lhs + rhs,)

@impl(asl.SubIntOp)
def run_sub_int(
self, interpreter: Interpreter, op: asl.SubIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs - rhs,)

@impl(asl.MulIntOp)
def run_mul_int(
self, interpreter: Interpreter, op: asl.MulIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs * rhs,)

@impl(asl.ShiftLeftIntOp)
def run_shl_int(
self, interpreter: Interpreter, op: asl.ShiftLeftIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
assert rhs >= 0
return (lhs << rhs,)

@impl(asl.ShiftRightIntOp)
def run_shr_int(
self, interpreter: Interpreter, op: asl.ShiftRightIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
assert rhs >= 0
return (lhs >> rhs,)

@impl(asl.EqIntOp)
def run_eq_int(
self, interpreter: Interpreter, op: asl.EqIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs == rhs,)

@impl(asl.NeIntOp)
def run_ne_int(
self, interpreter: Interpreter, op: asl.NeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs != rhs,)

@impl(asl.LeIntOp)
def run_le_int(
self, interpreter: Interpreter, op: asl.LeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs <= rhs,)

@impl(asl.LtIntOp)
def run_lt_int(
self, interpreter: Interpreter, op: asl.LtIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs < rhs,)

@impl(asl.GeIntOp)
def run_ge_int(
self, interpreter: Interpreter, op: asl.GeIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs >= rhs,)

@impl(asl.GtIntOp)
def run_gt_int(
self, interpreter: Interpreter, op: asl.GtIntOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
lhs: int
rhs: int
(lhs, rhs) = args
return (lhs > rhs,)

@impl(asl.ConstantIntOp)
def run_constant(
def run_constant_int(
self, interpreter: Interpreter, op: asl.ConstantIntOp, args: PythonValues
) -> PythonValues:
value = op.value
return (value.data,)

@impl(asl.ConstantStringOp)
def run_constant_string(
self, interpreter: Interpreter, op: asl.ConstantStringOp, args: PythonValues
) -> PythonValues:
value = op.value
return (value.data,)

@impl(asl.BoolToI1Op)
def run_bool_to_i1(
self, interpreter: Interpreter, op: asl.BoolToI1Op, args: PythonValues
) -> PythonValues:
arg: int
[arg] = args
return (arg,)

# region built-in function implementations

@impl_external("asl_print_int_dec")
@impl_external("print_bits_hex.0")
def asl_print_bits_hex(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: int
(arg,) = args
interpreter.print(hex(arg))
return ()

@impl_external("print_int_dec.0")
def asl_print_int_dec(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
Expand All @@ -67,7 +191,16 @@ def asl_print_int_dec(
interpreter.print(arg)
return ()

@impl_external("asl_print_char")
@impl_external("print_int_hex.0")
def asl_print_int_hex(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: int
(arg,) = args
interpreter.print(hex(arg))
return ()

@impl_external("print_char.0")
def asl_print_char(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
Expand All @@ -76,4 +209,13 @@ def asl_print_char(
interpreter.print(chr(arg))
return ()

@impl_external("print_str.0")
def asl_print_string(
self, interpreter: Interpreter, op: Operation, args: PythonValues
) -> PythonValues:
arg: str
(arg,) = args
interpreter.print(arg)
return ()

# endregion
25 changes: 25 additions & 0 deletions tests/filecheck/dialects/asl/cf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: asl-opt %s | asl-opt %s | filecheck %s

builtin.module {
asl.func @print_str.0(%x : !asl.string) -> ()
%c = asl.constant_bool true {attr_dict}
%0 = asl.bool_to_i1 %c : !asl.bool -> i1
scf.if %0 {
%1 = asl.constant_string "TRUE" {attr_dict}
asl.call @print_str.0(%1) : (!asl.string) -> ()
} else {
%2 = asl.constant_string "FALSE" {attr_dict}
asl.call @print_str.0(%2) : (!asl.string) -> ()
}

// CHECK: %c = asl.constant_bool true {attr_dict}
// CHECK-NEXT: %0 = asl.bool_to_i1 %c : !asl.bool -> i1
// CHECK-NEXT: scf.if %0 {
// CHECK-NEXT: %1 = asl.constant_string "TRUE" {attr_dict}
// CHECK-NEXT: asl.call @print_str.0(%1) : (!asl.string) -> ()
// CHECK-NEXT: } else {
// CHECK-NEXT: %2 = asl.constant_string "FALSE" {attr_dict}
// CHECK-NEXT: asl.call @print_str.0(%2) : (!asl.string) -> ()
// CHECK-NEXT: }

}
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/asl/constant_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ builtin.module {
%fourty_two = asl.constant_int 42 {attr_dict}

%fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {attr_dict}

%fourty_two_string = asl.constant_string "Forty Two" {attr_dict}
}

// CHECK: builtin.module {
// CHECK-NEXT: %true = asl.constant_bool true {attr_dict}
// CHECK-NEXT: %false = asl.constant_bool true {attr_dict}
// CHECK-NEXT: %fourty_two = asl.constant_int 42 {attr_dict}
// CHECK-NEXT: %fourty_two_bits = asl.constant_bits 42 : !asl.bits<32> {attr_dict}
// CHECK-NEXT: %fourty_two_string = asl.constant_string "Forty Two" {attr_dict}
// CHECK-NEXT: }
12 changes: 6 additions & 6 deletions tests/filecheck/dialects/asl/primitives.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@ builtin.module {
%int1, %int2 = "test.op"() : () -> (!asl.int, !asl.int)
// CHECK-NEXT: %int1, %int2 = "test.op"() : () -> (!asl.int, !asl.int)

%negate_int = asl.negate_int %int1 : !asl.int -> !asl.int
%neg_int = asl.neg_int %int1 : !asl.int -> !asl.int
%add_int = asl.add_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%sub_int = asl.sub_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%mul_int = asl.mul_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%exp_int = asl.exp_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%shiftleft_int = asl.shiftleft_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%shiftright_int = asl.shiftright_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%shiftleft_int = asl.shl_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%shiftright_int = asl.shr_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%div_int = asl.div_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%fdiv_int = asl.fdiv_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int
%frem_int = asl.frem_int %int1, %int2 : (!asl.int, !asl.int) -> !asl.int

// CHECK-NEXT: %negate_int = asl.negate_int %int1
// CHECK-NEXT: %neg_int = asl.neg_int %int1
// CHECK-NEXT: %add_int = asl.add_int %int1, %int2
// CHECK-NEXT: %sub_int = asl.sub_int %int1, %int2
// CHECK-NEXT: %mul_int = asl.mul_int %int1, %int2
// CHECK-NEXT: %exp_int = asl.exp_int %int1, %int2
// CHECK-NEXT: %shiftleft_int = asl.shiftleft_int %int1, %int2
// CHECK-NEXT: %shiftright_int = asl.shiftright_int %int1, %int2
// CHECK-NEXT: %shiftleft_int = asl.shl_int %int1, %int2
// CHECK-NEXT: %shiftright_int = asl.shr_int %int1, %int2
// CHECK-NEXT: %div_int = asl.div_int %int1, %int2
// CHECK-NEXT: %fdiv_int = asl.fdiv_int %int1, %int2
// CHECK-NEXT: %frem_int = asl.frem_int %int1, %int2
Expand Down
Loading

0 comments on commit 011134f

Please sign in to comment.