From 1dc15d4c91b5e987feb893350b9d0d55638d054c Mon Sep 17 00:00:00 2001 From: Rihi <19492038+rihi@users.noreply.github.com> Date: Wed, 25 Oct 2023 17:22:48 +0200 Subject: [PATCH] [Expression simplifcation] Improve error handling (#352) * Fix incorrect error type in term_order.py * Improve error handling of constant folding * Catch exception caused by simplification rules if in debug * Catch empty operands operation in collapse_constants.py * Update documentation of constant_folding.py * Create MalformedInput exception * Handle malformed input in collapse_constants.py * Handle errors in collapse_nested_constants.py * Add clarifying comment in collapse_nested_constants.py * Catch exception in expression simplification earlier * Fix documentation in constant_folding.py * Change how exceptions are propagated * Update outdated comment * Remove unnecessary check in collapse_constants.py * fix broken merge --------- Co-authored-by: Manuel Blatt --- .../constant_folding.py | 105 ++++++--- .../rules/collapse_constants.py | 19 +- .../rules/collapse_nested_constants.py | 29 ++- .../expression_simplification/rules/rule.py | 8 + .../rules/term_order.py | 2 +- .../expression_simplification/stages.py | 26 ++- .../test_constant_folding.py | 209 ++++++++++-------- .../expression_simplification/test_stage.py | 4 +- 8 files changed, 257 insertions(+), 145 deletions(-) diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py index 07827f126..2706987b0 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py @@ -2,46 +2,91 @@ from functools import partial from typing import Callable, Optional -from decompiler.structures.pseudo import Constant, Integer, OperationType +from decompiler.structures.pseudo import Constant, Integer, OperationType, Type from decompiler.util.integer_util import normalize_int +# The first three exception types indicate that an operation is not suitable for constant folding. +# They do NOT indicate that the input was malformed in any way. +# The idea is that the caller of constant_fold does not need to verify that folding is possible. -def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant: + +class UnsupportedOperationType(Exception): + """Indicates that the specified Operation is not supported""" + + pass + + +class UnsupportedValueType(Exception): + + """Indicates that the value type of one constant is not supported.""" + + pass + + +class UnsupportedMismatchedSizes(Exception): + """Indicates that folding of different sized constants is not supported for the specified operation.""" + + pass + + +class IncompatibleOperandCount(Exception): + """Indicates that the specified operation type is not defined for the number of constants specified""" + + pass + + +def constant_fold(operation: OperationType, constants: list[Constant], result_type: Type) -> Constant: """ Fold operation with constants as operands. :param operation: The operation. :param constants: All constant operands of the operation. + Count of operands must be compatible with the specified operation type. + :param result_type: What type the folded constant should have. :return: A constant representing the result of the operation. + :raises: + UnsupportedOperationType: Thrown if the specified operation is not supported. + UnsupportedValueType: Thrown if constants contain value of types not supported. Currently only ints are supported. + UnsupportedMismatchedSizes: Thrown if constants types have different sizes and folding of different sized + constants is not supported for the specified operation. + IncompatibleOperandCount: Thrown if the specified operation type is not defined for the number of constants in `constants`. """ if operation not in _OPERATION_TO_FOLD_FUNCTION: - raise ValueError(f"Constant folding not implemented for operation '{operation}'.") + raise UnsupportedOperationType(f"Constant folding not implemented for operation '{operation}'.") + + if not all(isinstance(v, int) for v in [c.value for c in constants]): # For now we only support integer value folding + raise UnsupportedValueType(f"Constant folding is not implemented for non int constant values: {[c.value for c in constants]}") - return _OPERATION_TO_FOLD_FUNCTION[operation](constants) + return Constant( + normalize_int( + _OPERATION_TO_FOLD_FUNCTION[operation](constants), result_type.size, isinstance(result_type, Integer) and result_type.signed + ), + result_type, + ) -def _constant_fold_arithmetic_binary( - constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None -) -> Constant: +def _constant_fold_arithmetic_binary(constants: list[Constant], fun: Callable[[int, int], int], norm_sign: Optional[bool] = None) -> int: """ Fold an arithmetic binary operation with constants as operands. - :param constants: A list of exactly 2 constant operands. + :param constants: A list of exactly 2 constant values. :param fun: The binary function to perform on the constants. :param norm_sign: Optional boolean flag to indicate if/how to normalize the input constants to 'fun': - None (default): no normalization - - True: normalize inputs, interpreted as signed values + - True: normalize inputs, interpreted as signed values - False: normalize inputs, interpreted as unsigned values - :return: A constant representing the result of the operation. + :return: The result of the operation. + :raises: + UnsupportedMismatchedSizes: Thrown if constants types have different sizes and folding of different sized + constants is not supported for the specified operation. + IncompatibleOperandCount: Thrown if the number of constants is not equal to 2. """ if len(constants) != 2: - raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.") - if not all(constant.type == constants[0].type for constant in constants): - raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}") - if not all(isinstance(constant.type, Integer) for constant in constants): - raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.") + raise IncompatibleOperandCount(f"Expected exactly 2 constants to fold, got {len(constants)}.") + if not all(constant.type.size == constants[0].type.size for constant in constants): + raise UnsupportedMismatchedSizes(f"Can not fold constants with different sizes: {[constant.type for constant in constants]}") left, right = constants @@ -51,27 +96,27 @@ def _constant_fold_arithmetic_binary( left_value = normalize_int(left_value, left.type.size, norm_sign) right_value = normalize_int(right_value, right.type.size, norm_sign) - return Constant(normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), left.type) + return fun(left_value, right_value) -def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant: +def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> int: """ Fold an arithmetic unary operation with a constant operand. :param constants: A list containing a single constant operand. :param fun: The unary function to perform on the constant. - :return: A constant representing the result of the operation. + :return: The result of the operation. + :raises: + IncompatibleOperandCount: Thrown if the number of constants is not equal to 1. """ if len(constants) != 1: - raise ValueError("Expected exactly 1 constant to fold") - if not isinstance(constants[0].type, Integer): - raise ValueError(f"Constant must be of type integer: {constants[0].type}") + raise IncompatibleOperandCount("Expected exactly 1 constant to fold") - return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type) + return fun(constants[0].value) -def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> Constant: +def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> int: """ Fold a shift operation with constants as operands. @@ -79,21 +124,20 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in :param fun: The shift function to perform on the constants. :param signed: Boolean flag indicating whether the shift is signed. This is used to normalize the sign of the input constant to simulate unsigned shifts. - :return: A constant representing the result of the operation. + :return: The result of the operation. + :raises: + IncompatibleOperandCount: Thrown if the number of constants is not equal to 2. """ if len(constants) != 2: - raise ValueError("Expected exactly 2 constants to fold") - if not all(isinstance(constant.type, Integer) for constant in constants): - raise ValueError("All constants must be integers") + raise IncompatibleOperandCount("Expected exactly 2 constants to fold") left, right = constants - shifted_value = fun(normalize_int(left.value, left.type.size, left.type.signed and signed), right.value) - return Constant(normalize_int(shifted_value, left.type.size, left.type.signed), left.type) + return fun(normalize_int(left.value, left.type.size, left.type.signed and signed), right.value) -_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = { +_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], int]] = { OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub), OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add), OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True), @@ -110,5 +154,4 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in OperationType.bitwise_not: partial(_constant_fold_arithmetic_unary, fun=operator.inv), } - FOLDABLE_OPERATIONS = _OPERATION_TO_FOLD_FUNCTION.keys() diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py index 295a0bb03..7e5734686 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py @@ -1,5 +1,11 @@ -from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold -from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import ( + IncompatibleOperandCount, + UnsupportedMismatchedSizes, + UnsupportedOperationType, + UnsupportedValueType, + constant_fold, +) +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import MalformedData, SimplificationRule from decompiler.structures.pseudo import Constant, Expression, Operation @@ -11,7 +17,12 @@ class CollapseConstants(SimplificationRule): def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: if not all(isinstance(o, Constant) for o in operation.operands): return [] - if operation.operation not in FOLDABLE_OPERATIONS: + + try: + folded_constant = constant_fold(operation.operation, operation.operands, operation.type) + except (UnsupportedOperationType, UnsupportedValueType, UnsupportedMismatchedSizes): return [] + except IncompatibleOperandCount as e: + raise MalformedData() from e - return [(operation, constant_fold(operation.operation, operation.operands))] + return [(operation, folded_constant)] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py index cf36ce84b..62426e6dd 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py @@ -1,8 +1,13 @@ from functools import reduce from typing import Iterator -from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold -from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import ( + FOLDABLE_OPERATIONS, + IncompatibleOperandCount, + UnsupportedValueType, + constant_fold, +) +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import MalformedData, SimplificationRule from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS @@ -19,8 +24,6 @@ class CollapseNestedConstants(SimplificationRule): def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: if operation.operation not in _COLLAPSIBLE_OPERATIONS: return [] - if not isinstance(operation, Operation): - raise TypeError(f"Expected Operation, got {type(operation)}") constants = list(_collect_constants(operation)) if len(constants) <= 1: @@ -28,7 +31,14 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: first, *rest = constants - folded_constant = reduce(lambda c0, c1: constant_fold(operation.operation, [c0, c1]), rest, first) + # We don't need to catch UnsupportedOperationType, because check that operation is in _COLLAPSIBLE_OPERATIONS + # We don't need to catch UnsupportedMismatchedSizes, because '_collect_constants' only returns constants of the same type + try: + folded_constant = reduce(lambda c0, c1: constant_fold(operation.operation, [c0, c1], operation.type), rest, first) + except UnsupportedValueType: + return [] + except IncompatibleOperandCount as e: + raise MalformedData() from e identity_constant = _identity_constant(operation.operation, operation.type) return [(first, folded_constant), *((constant, identity_constant) for constant in rest)] @@ -51,7 +61,7 @@ def _collect_constants(operation: Operation) -> Iterator[Constant]: current_operation = context_stack.pop() for i, operand in enumerate(current_operation.operands): - if operand.type != operand_type: + if operand.type != operand_type: # This check could potentially be relaxed to only check for equal size continue if isinstance(operand, Operation): @@ -72,6 +82,11 @@ def _identity_constant(operation: OperationType, var_type: Type) -> Constant: case OperationType.multiply | OperationType.multiply_us: return Constant(1, var_type) case OperationType.bitwise_and: - return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)]) + # Should not throw any exception because: + # - OperationType.bitwise_not is foldable (UnsupportedOperationType) + # - constant has integer value, which is supported (UnsupportedValueType) + # - with only 1 constant there cant be mismatched sizes (UnsupportedMismatchedSizes) + # - bitwise_not has exactly one operand (IncompatibleOperandCount) + return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)], var_type) case _: raise NotImplementedError() diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py index a4f70334a..ef46c9a3d 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py @@ -16,5 +16,13 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: :param operation: The operation to which the simplification rule should be applied. :return: A list of tuples, each containing a pair of expressions representing the original and simplified versions resulting from applying the simplification rule to the given operation. + :raises: + MalformedData: Thrown inf malformed data, like a dereference operation with two operands, is encountered. """ pass + + +class MalformedData(Exception): + """Used to indicate that malformed data was encountered""" + + pass diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py index 638f4d08a..5af7dd9fe 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py @@ -20,7 +20,7 @@ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: if operation.operation not in COMMUTATIVE_OPERATIONS: return [] if not isinstance(operation, BinaryOperation): - raise ValueError(f"Expected BinaryOperation, got {operation}") + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") if isinstance(operation.left, Constant) and not isinstance(operation.right, Constant): return [(operation, BinaryOperation(operation.operation, [operation.right, operation.left], operation.type, operation.tags))] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py index bdad4333f..708049020 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py @@ -28,18 +28,20 @@ class _ExpressionSimplificationBase(PipelineStage, ABC): def run(self, task: DecompilerTask): max_iterations = task.options.getint("expression-simplification.max_iterations") - self._simplify_instructions(self._get_instructions(task), max_iterations) + debug = task.options.getboolean("pipeline.debug", fallback=False) + + self._simplify_instructions(self._get_instructions(task), max_iterations, debug) @abstractmethod def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: pass @classmethod - def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int): + def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int, debug: bool): rule_sets = [("pre-rules", _pre_rules), ("rules", _rules), ("post-rules", _post_rules)] for rule_name, rule_set in rule_sets: # max_iterations is counted per rule_set - iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations) + iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations, debug) if iteration_count <= max_iterations: logging.info(f"Expression simplification took {iteration_count} iterations for {rule_name}") else: @@ -47,7 +49,7 @@ def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: @classmethod def _simplify_instructions_with_rule_set( - cls, instructions: list[Instruction], rule_set: list[SimplificationRule], max_iterations: int + cls, instructions: list[Instruction], rule_set: list[SimplificationRule], max_iterations: int, debug: bool ) -> int: iteration_count = 0 @@ -57,7 +59,7 @@ def _simplify_instructions_with_rule_set( for rule in rule_set: for instruction in instructions: - additional_iterations = cls._simplify_instruction_with_rule(instruction, rule, max_iterations - iteration_count) + additional_iterations = cls._simplify_instruction_with_rule(instruction, rule, max_iterations - iteration_count, debug) if additional_iterations > 0: changes = True @@ -68,7 +70,7 @@ def _simplify_instructions_with_rule_set( return iteration_count @classmethod - def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: SimplificationRule, max_iterations: int) -> int: + def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: SimplificationRule, max_iterations: int, debug: bool) -> int: iteration_count = 0 for expression in instruction.subexpressions(): while True: @@ -78,9 +80,17 @@ def _simplify_instruction_with_rule(cls, instruction: Instruction, rule: Simplif if not isinstance(expression, Operation): break - substitutions = rule.apply(expression) + try: + substitutions = rule.apply(expression) + except Exception as e: + if debug: + raise # re-raise the exception + else: + logging.exception(f"An unexpected error occurred while simplifying: {e}") + break # continue with next subexpression + if not substitutions: - break + break # continue with next subexpression iteration_count += 1 diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py index 40efe858b..44ff99ea7 100644 --- a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py @@ -1,8 +1,16 @@ from contextlib import nullcontext +from typing import Optional import pytest -from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold -from decompiler.structures.pseudo import Constant, Float, Integer, OperationType +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import ( + FOLDABLE_OPERATIONS, + IncompatibleOperandCount, + UnsupportedMismatchedSizes, + UnsupportedOperationType, + UnsupportedValueType, + constant_fold, +) +from decompiler.structures.pseudo import Constant, Float, Integer, OperationType, Type def _c_i32(value: int) -> Constant: @@ -23,100 +31,117 @@ def _c_float(value: float) -> Constant: @pytest.mark.parametrize(["operation"], [(operation,) for operation in OperationType if operation not in FOLDABLE_OPERATIONS]) def test_constant_fold_invalid_operations(operation: OperationType): - with pytest.raises(ValueError): - constant_fold(operation, []) + with pytest.raises(UnsupportedOperationType): + constant_fold(operation, [], Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "constants", "result_type", "expected_result", "context"], + [ + (OperationType.plus, [_c_i32(0), _c_i32(0)], Integer.int32_t(), _c_i32(0), nullcontext()), + (OperationType.plus, [_c_float(0.0), _c_float(0.0)], Float.float(), _c_float(0.0), pytest.raises(UnsupportedValueType)), + (OperationType.plus, [_c_i32(0), _c_float(0.0)], Integer.int32_t(), _c_i32(0), pytest.raises(UnsupportedValueType)), + ], +) +def test_constant_fold_invalid_value_type( + operation: OperationType, constants: list[Constant], result_type: Type, expected_result: Optional[Constant], context +): + with context: + assert constant_fold(operation, constants, result_type) == expected_result @pytest.mark.parametrize( - ["operation", "constants", "result", "context"], + ["operation", "constants", "result_type", "expected_result", "context"], [ - (OperationType.plus, [_c_i32(3), _c_i32(4)], _c_i32(7), nullcontext()), - (OperationType.plus, [_c_i32(2147483647), _c_i32(1)], _c_i32(-2147483648), nullcontext()), - (OperationType.plus, [_c_u32(2147483658), _c_u32(2147483652)], _c_u32(14), nullcontext()), - (OperationType.plus, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.plus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.plus, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.plus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.minus, [_c_i32(3), _c_i32(4)], _c_i32(-1), nullcontext()), - (OperationType.minus, [_c_i32(-2147483648), _c_i32(1)], _c_i32(2147483647), nullcontext()), - (OperationType.minus, [_c_u32(3), _c_u32(4)], _c_u32(4294967295), nullcontext()), - (OperationType.minus, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.minus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.minus, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.minus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.multiply, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), - (OperationType.multiply, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), - (OperationType.multiply, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), - (OperationType.multiply, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.multiply, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.multiply, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.multiply, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.multiply_us, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), - (OperationType.multiply_us, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), - (OperationType.multiply_us, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), - (OperationType.multiply_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.multiply_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.multiply_us, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.multiply_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.divide, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), - (OperationType.divide, [_c_i32(-2147483648), _c_i32(2)], _c_i32(-1073741824), nullcontext()), - (OperationType.divide, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.divide, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.divide, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.divide, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.divide_us, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), - (OperationType.divide_us, [_c_i32(-2147483648), _c_i32(2)], _c_i32(1073741824), nullcontext()), - (OperationType.divide_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.divide_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.divide_us, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.negate, [_c_i32(3)], _c_i32(-3), nullcontext()), - (OperationType.negate, [_c_i32(-2147483648)], _c_i32(-2147483648), nullcontext()), - (OperationType.negate, [], None, pytest.raises(ValueError)), - (OperationType.negate, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.left_shift, [_c_i32(3), _c_i32(4)], _c_i32(48), nullcontext()), - (OperationType.left_shift, [_c_i32(1073741824), _c_i32(1)], _c_i32(-2147483648), nullcontext()), - (OperationType.left_shift, [_c_u32(1073741824), _c_u32(1)], _c_u32(2147483648), nullcontext()), - (OperationType.left_shift, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.left_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.right_shift, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), - (OperationType.right_shift, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-1073741824), nullcontext()), - (OperationType.right_shift, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), - (OperationType.right_shift, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.right_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.right_shift_us, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), - (OperationType.right_shift_us, [_c_i32(-2147483648), _c_i32(1)], _c_i32(1073741824), nullcontext()), - (OperationType.right_shift_us, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), - (OperationType.right_shift_us, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.right_shift_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_or, [_c_i32(85), _c_i32(34)], _c_i32(119), nullcontext()), - (OperationType.bitwise_or, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-2147483647), nullcontext()), - (OperationType.bitwise_or, [_c_u32(2147483648), _c_u32(1)], _c_u32(2147483649), nullcontext()), - (OperationType.bitwise_or, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.bitwise_or, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.bitwise_or, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_or, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_and, [_c_i32(85), _c_i32(51)], _c_i32(17), nullcontext()), - (OperationType.bitwise_and, [_c_i32(-2147483647), _c_i32(3)], _c_i32(1), nullcontext()), - (OperationType.bitwise_and, [_c_u32(2147483649), _c_u32(3)], _c_u32(1), nullcontext()), - (OperationType.bitwise_and, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.bitwise_and, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.bitwise_and, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_and, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_xor, [_c_i32(85), _c_i32(51)], _c_i32(102), nullcontext()), - (OperationType.bitwise_xor, [_c_i32(-2147483647), _c_i32(-2147483646)], _c_i32(3), nullcontext()), - (OperationType.bitwise_xor, [_c_u32(2147483649), _c_u32(2147483650)], _c_u32(3), nullcontext()), - (OperationType.bitwise_xor, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), - (OperationType.bitwise_xor, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), - (OperationType.bitwise_xor, [_c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_xor, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), - (OperationType.bitwise_not, [_c_i32(6)], _c_i32(-7), nullcontext()), - (OperationType.bitwise_not, [_c_i32(-2147483648)], _c_i32(2147483647), nullcontext()), - (OperationType.bitwise_not, [_c_u32(2147483648)], _c_u32(2147483647), nullcontext()), - (OperationType.bitwise_not, [], None, pytest.raises(ValueError)), - (OperationType.bitwise_not, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(7), nullcontext()), + (OperationType.plus, [_c_i32(2147483647), _c_i32(1)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()), + (OperationType.plus, [_c_u32(2147483658), _c_u32(2147483652)], Integer.uint32_t(), _c_u32(14), nullcontext()), + (OperationType.plus, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(7), nullcontext()), + (OperationType.plus, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.plus, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.plus, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.minus, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(-1), nullcontext()), + (OperationType.minus, [_c_i32(-2147483648), _c_i32(1)], Integer.int32_t(), _c_i32(2147483647), nullcontext()), + (OperationType.minus, [_c_u32(3), _c_u32(4)], Integer.uint32_t(), _c_u32(4294967295), nullcontext()), + (OperationType.minus, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(-1), nullcontext()), + (OperationType.minus, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.minus, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.minus, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.multiply, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(12), nullcontext()), + (OperationType.multiply, [_c_i32(-1073741824), _c_i32(2)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()), + (OperationType.multiply, [_c_u32(3221225472), _c_u32(2)], Integer.uint32_t(), _c_u32(2147483648), nullcontext()), + (OperationType.multiply, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(12), nullcontext()), + (OperationType.multiply, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.multiply, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.multiply, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.multiply_us, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(12), nullcontext()), + (OperationType.multiply_us, [_c_i32(-1073741824), _c_i32(2)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()), + (OperationType.multiply_us, [_c_u32(3221225472), _c_u32(2)], Integer.uint32_t(), _c_u32(2147483648), nullcontext()), + (OperationType.multiply_us, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(12), nullcontext()), + (OperationType.multiply_us, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.multiply_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.multiply_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.divide, [_c_i32(12), _c_i32(4)], Integer.int32_t(), _c_i32(3), nullcontext()), + (OperationType.divide, [_c_i32(-2147483648), _c_i32(2)], Integer.int32_t(), _c_i32(-1073741824), nullcontext()), + (OperationType.divide, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(0), nullcontext()), + (OperationType.divide, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.divide, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.divide, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.divide_us, [_c_i32(12), _c_i32(4)], Integer.int32_t(), _c_i32(3), nullcontext()), + (OperationType.divide_us, [_c_i32(-2147483648), _c_i32(2)], Integer.int32_t(), _c_i32(1073741824), nullcontext()), + (OperationType.divide_us, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(0), nullcontext()), + (OperationType.divide_us, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.divide_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.negate, [_c_i32(3)], Integer.int32_t(), _c_i32(-3), nullcontext()), + (OperationType.negate, [_c_i32(-2147483648)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()), + (OperationType.negate, [], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.negate, [_c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.left_shift, [_c_i32(3), _c_i32(4)], Integer.int32_t(), _c_i32(48), nullcontext()), + (OperationType.left_shift, [_c_i32(1073741824), _c_i32(1)], Integer.int32_t(), _c_i32(-2147483648), nullcontext()), + (OperationType.left_shift, [_c_u32(1073741824), _c_u32(1)], Integer.uint32_t(), _c_u32(2147483648), nullcontext()), + (OperationType.left_shift, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.left_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.right_shift, [_c_i32(32), _c_i32(4)], Integer.int32_t(), _c_i32(2), nullcontext()), + (OperationType.right_shift, [_c_i32(-2147483648), _c_i32(1)], Integer.int32_t(), _c_i32(-1073741824), nullcontext()), + (OperationType.right_shift, [_c_u32(2147483648), _c_u32(1)], Integer.uint32_t(), _c_u32(1073741824), nullcontext()), + (OperationType.right_shift, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.right_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.right_shift_us, [_c_i32(32), _c_i32(4)], Integer.int32_t(), _c_i32(2), nullcontext()), + (OperationType.right_shift_us, [_c_i32(-2147483648), _c_i32(1)], Integer.int32_t(), _c_i32(1073741824), nullcontext()), + (OperationType.right_shift_us, [_c_u32(2147483648), _c_u32(1)], Integer.uint32_t(), _c_u32(1073741824), nullcontext()), + (OperationType.right_shift_us, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.right_shift_us, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.bitwise_or, [_c_i32(85), _c_i32(34)], Integer.int32_t(), _c_i32(119), nullcontext()), + (OperationType.bitwise_or, [_c_i32(-2147483648), _c_i32(1)], Integer.int32_t(), _c_i32(-2147483647), nullcontext()), + (OperationType.bitwise_or, [_c_u32(2147483648), _c_u32(1)], Integer.uint32_t(), _c_u32(2147483649), nullcontext()), + (OperationType.bitwise_or, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(7), nullcontext()), + (OperationType.bitwise_or, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.bitwise_or, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.bitwise_or, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.bitwise_and, [_c_i32(85), _c_i32(51)], Integer.int32_t(), _c_i32(17), nullcontext()), + (OperationType.bitwise_and, [_c_i32(-2147483647), _c_i32(3)], Integer.int32_t(), _c_i32(1), nullcontext()), + (OperationType.bitwise_and, [_c_u32(2147483649), _c_u32(3)], Integer.uint32_t(), _c_u32(1), nullcontext()), + (OperationType.bitwise_and, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(0), nullcontext()), + (OperationType.bitwise_and, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.bitwise_and, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.bitwise_and, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.bitwise_xor, [_c_i32(85), _c_i32(51)], Integer.int32_t(), _c_i32(102), nullcontext()), + (OperationType.bitwise_xor, [_c_i32(-2147483647), _c_i32(-2147483646)], Integer.int32_t(), _c_i32(3), nullcontext()), + (OperationType.bitwise_xor, [_c_u32(2147483649), _c_u32(2147483650)], Integer.uint32_t(), _c_u32(3), nullcontext()), + (OperationType.bitwise_xor, [_c_u32(3), _c_i32(4)], Integer.int32_t(), _c_i32(7), nullcontext()), + (OperationType.bitwise_xor, [_c_i32(3), _c_i16(4)], Integer.int32_t(), None, pytest.raises(UnsupportedMismatchedSizes)), + (OperationType.bitwise_xor, [_c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.bitwise_xor, [_c_i32(3), _c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.bitwise_not, [_c_i32(6)], Integer.int32_t(), _c_i32(-7), nullcontext()), + (OperationType.bitwise_not, [_c_i32(-2147483648)], Integer.int32_t(), _c_i32(2147483647), nullcontext()), + (OperationType.bitwise_not, [_c_u32(2147483648)], Integer.uint32_t(), _c_u32(2147483647), nullcontext()), + (OperationType.bitwise_not, [], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), + (OperationType.bitwise_not, [_c_i32(3), _c_i32(3)], Integer.int32_t(), None, pytest.raises(IncompatibleOperandCount)), ], ) -def test_constant_fold(operation: OperationType, constants: list[Constant], result: Constant, context): +def test_constant_fold( + operation: OperationType, constants: list[Constant], result_type: Type, expected_result: Optional[Constant], context +): with context: - assert constant_fold(operation, constants) == result + assert constant_fold(operation, constants, result_type) == expected_result diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py index 449f1428a..d2b811abe 100644 --- a/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py @@ -58,7 +58,7 @@ def _v_i32(name: str) -> Variable: ], ) def test_simplify_instructions_with_rule_set(rule_set: list[SimplificationRule], instruction: Instruction, expected_result: Instruction): - _ExpressionSimplificationBase._simplify_instructions_with_rule_set([instruction], rule_set, 100) + _ExpressionSimplificationBase._simplify_instructions_with_rule_set([instruction], rule_set, 100, True) assert instruction == expected_result @@ -72,5 +72,5 @@ def test_simplify_instructions_with_rule_set(rule_set: list[SimplificationRule], def test_simplify_instructions_with_rule_set_max_iterations( rule_set: list[SimplificationRule], instruction: Instruction, max_iterations: int, expect_exceed_max_iterations: bool ): - iterations = _ExpressionSimplificationBase._simplify_instructions_with_rule_set([instruction], rule_set, max_iterations) + iterations = _ExpressionSimplificationBase._simplify_instructions_with_rule_set([instruction], rule_set, max_iterations, True) assert (iterations > max_iterations) == expect_exceed_max_iterations