diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/modification.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/modification.py new file mode 100644 index 000000000..99eca34d5 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/modification.py @@ -0,0 +1,117 @@ +from typing import Callable, Optional + +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, OperationType + + +def multiply_int_with_constant(expression: Expression, constant: Constant) -> Expression: + """ + Multiply an integer expression with an integer constant. + + :param expression: The integer expression to be multiplied. + :param constant: The constant value to multiply the expression by. + :return: A simplified expression representing the multiplication result. + """ + + if not isinstance(expression.type, Integer): + raise ValueError(f"Expression must have integer type, got {expression.type}.") + if not isinstance(constant.type, Integer): + raise ValueError(f"Constant must have integer type, got {constant.type}.") + if expression.type != constant.type: + raise ValueError(f"Expression and constant type must equal. {expression.type} != {constant.type}") + + if isinstance(expression, Constant): + return constant_fold(OperationType.multiply, [expression, constant]) + else: + return BinaryOperation(OperationType.multiply, [expression, constant]) + + +_FOLD_HANDLER: dict[OperationType, Callable[[list[Constant]], Constant]] = { + OperationType.minus: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x - y), + OperationType.plus: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x + y), + OperationType.multiply: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x * y, True), + OperationType.multiply_us: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x * y, False), + OperationType.divide: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x // y, True), + OperationType.divide_us: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x // y, False), + OperationType.negate: lambda constants: _constant_fold_arithmetic_unary(constants, lambda value: -value), + OperationType.left_shift: lambda constants: _constant_fold_shift(constants, lambda value, shift, size: value << shift), + OperationType.right_shift: lambda constants: _constant_fold_shift(constants, lambda value, shift, size: value >> shift), + OperationType.right_shift_us: lambda constants: _constant_fold_shift( + constants, lambda value, shift, size: normalize_int(value >> shift, size - shift, False) + ), + OperationType.bitwise_or: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x | y), + OperationType.bitwise_and: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x & y), + OperationType.bitwise_xor: lambda constants: _constant_fold_arithmetic_binary(constants, lambda x, y: x ^ y), + OperationType.bitwise_not: lambda constants: _constant_fold_arithmetic_unary(constants, lambda x: ~x), +} + + +FOLDABLE_CONSTANTS = _FOLD_HANDLER.keys() + + +def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant: + """ + Fold operation with constants as operands. + + :param operation: The operation. + :param constants: All constant operands of the operation. + :return: A constant representing the result of the operation. + """ + + if operation not in _FOLD_HANDLER: + raise ValueError(f"Constant folding not implemented for operation '{operation}'.") + + return _FOLD_HANDLER[operation](constants) + + +def _constant_fold_arithmetic_binary( + constants: list[Constant], + fun: Callable[[int, int], int], + norm_sign: Optional[bool] = None +) -> Constant: + if len(constants) != 2: + raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.") + if not all(constant.type == constants[0].type for constant in constants): + raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}") + if not all(isinstance(constant.type, Integer) for constant in constants): + raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.") + + left, right = constants + + left_value = left.value + right_value = right.value + if norm_sign is not None: + left_value = normalize_int(left_value, left.type.size, norm_sign) + right_value = normalize_int(right_value, right.type.size, norm_sign) + + return Constant( + normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), + left.type + ) + + +def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant: + if len(constants) != 1: + raise ValueError("Expected exactly 1 constant to fold") + if not isinstance(constants[0].type, Integer): + raise ValueError(f"Constant must be of type integer: {constants[0].type}") + + return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type) + + +def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int, int], int]) -> Constant: + if len(constants) != 2: + raise ValueError("Expected exactly 2 constants to fold") + if not all(isinstance(constant.type, Integer) for constant in constants): + raise ValueError("All constants must be integers") + + left, right = constants + + return Constant(normalize_int(fun(left.value, right.value, left.type.size), left.type.size, left.type.signed), left.type) + + +def normalize_int(v: int, size: int, signed: bool) -> int: + value = v & ((1 << size) - 1) + if signed and value & (1 << (size - 1)): + return value - (1 << size) + else: + return value diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py new file mode 100644 index 000000000..c978991fc --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py @@ -0,0 +1,30 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation + + +class CollapseAddNeg(SimplificationRule): + """ + Simplifies additions/subtraction with negated expression. + + - `e0 + -(e1) -> e0 - e1` + - `e0 - -(e1) -> e0 + e1` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in [OperationType.plus, OperationType.minus]: + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + right = operation.right + if not isinstance(right, UnaryOperation) or right.operation != OperationType.negate: + return [] + + return [( + operation, + BinaryOperation( + OperationType.minus if operation.operation == OperationType.plus else OperationType.plus, + [operation.left, right.operand], + operation.type + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py new file mode 100644 index 000000000..18709913f --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py @@ -0,0 +1,20 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import FOLDABLE_CONSTANTS, constant_fold +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Constant, Expression, Operation + + +class CollapseConstants(SimplificationRule): + """ + Fold operations with only constants as operands: + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if not all(isinstance(o, Constant) for o in operation.operands): + return [] + if operation.operation not in FOLDABLE_CONSTANTS: + return [] + + return [( + operation, + constant_fold(operation.operation, operation.operands) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_mult_neg_one.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_mult_neg_one.py new file mode 100644 index 000000000..2b3cfc52f --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_mult_neg_one.py @@ -0,0 +1,29 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation + + +class CollapseMultNegOne(SimplificationRule): + """ + Simplifies expressions multiplied with -1. + + `e0 * -1 -> -(e0)` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation != OperationType.multiply: + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + right = operation.right + if not isinstance(right, Constant) or right.value != -1: + return [] + + return [( + operation, + UnaryOperation( + OperationType.negate, + [operation.left], + operation.type + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collect_terms.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collect_terms.py new file mode 100644 index 000000000..532ff93fb --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collect_terms.py @@ -0,0 +1,65 @@ +from functools import reduce +from typing import Iterator + +from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import constant_fold +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + + +class CollectTerms(SimplificationRule): + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in COMMUTATIVE_OPERATIONS: + return [] + if not isinstance(operation, Operation): + raise TypeError(f"Expected Operation, got {type(operation)}") + + operands = list(_collect_constants(operation)) + if len(operands) <= 1: + return [] + + first, *rest = operands + + folded_constant = reduce( + lambda c0, c1: constant_fold(operation.operation, [c0, c1]), + rest, + first + ) + + identity_constant = _identity_constant(operation.operation, operation.type) + return [ + (first, folded_constant), + *((o, identity_constant) for o in rest) + ] + + +def _collect_constants(operation: Operation) -> Iterator[Constant]: + operation_type = operation.operation + operand_type = operation.type + + context_stack: list[Operation] = [operation] + while context_stack: + current_operation = context_stack.pop() + + for i, operand in enumerate(current_operation.operands): + if operand.type != operand_type: + continue + + if isinstance(operand, Operation): + if operand.operation == operation_type: + context_stack.append(operand) + continue + elif isinstance(operand, Constant) and _identity_constant(operation_type, operand_type).value != operand.value: + yield operand + + +def _identity_constant(operation: OperationType, var_type: Type) -> Constant: + match operation: + case OperationType.plus | OperationType.bitwise_xor | OperationType.bitwise_or: + return Constant(0, var_type) + case OperationType.multiply | OperationType.multiply_us: + return Constant(1, var_type) + case OperationType.bitwise_and: + return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)]) + case _: + raise NotImplementedError() diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/fix_add_sub_sign.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/fix_add_sub_sign.py new file mode 100644 index 000000000..2e60c74dd --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/fix_add_sub_sign.py @@ -0,0 +1,42 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import normalize_int +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType + + +class FixAddSubSign(SimplificationRule): + """ + Changes add/sub when variable type is signed. + + - `V - a -> E + (-a)` when signed(a) < 0 + - `V + a -> E - (-a)` when signed(a) < 0 + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in (OperationType.plus, OperationType.minus): + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + right = operation.right + if not isinstance(right, Constant): + return [] + + con_type = right.type + if not isinstance(con_type, Integer): + return [] + + a = normalize_int(right.value, con_type.size, True) + if a >= 0: + return [] + + neg_a = Constant( + normalize_int(-a, con_type.size, con_type.signed), + con_type + ) + return [( + operation, + BinaryOperation( + OperationType.plus if operation.operation == OperationType.minus else OperationType.minus, + [operation.left, neg_a] + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py new file mode 100644 index 000000000..70780a062 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py @@ -0,0 +1,20 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Expression, Operation, OperationType, UnaryOperation + + +class SimplifyRedundantReference(SimplificationRule): + """ + Removes redundant nesting of referencing, immediately followed by referencing. + + `*(&(e0)) -> e0` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case UnaryOperation( + operation=OperationType.dereference, + operand=UnaryOperation(operation=OperationType.address, operand=operand) + ): + return [(operation, operand)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py new file mode 100644 index 000000000..1c4e58ee2 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py @@ -0,0 +1,37 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation + + +class SimplifyTrivialArithmetic(SimplificationRule): + """ + Simplifies trivial arithmetic: + + - `e + 0 -> e` + - `e - 0 -> e` + - `e * 1 -> e` + - `e u* 1 -> e` + - `e * -1 -> -e` + - `e u* -1 -> -e` + - `e / 1 -> e` + - `e u/ 1 -> e` + - `e / -1 -> -e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.plus | OperationType.minus, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation( + operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide | OperationType.divide_us, + right=Constant(value=1), + ): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.multiply, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation( + operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide, + right=Constant(value=-1) + ): + return [(operation, UnaryOperation(OperationType.negate, [operation.left]))] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py new file mode 100644 index 000000000..5a6a62c58 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py @@ -0,0 +1,28 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialBitArithmetic(SimplificationRule): + """ + Simplifies trivial bit arithmetic: + + - `e | 0 -> e` + - `e | e -> e` + - `e & 0 -> 0` + - `e & e -> e` + - `e ^ 0 -> e` + - `e ^ e -> 0` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_xor, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.bitwise_and, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_and, left=left, right=right) if left == right: + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.bitwise_xor, left=left, right=right) if left == right: + return [(operation, Constant(0, operation.type))] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py new file mode 100644 index 000000000..08994b1cc --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py @@ -0,0 +1,26 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialLogicArithmetic(SimplificationRule): + """ + Simplifies trivial logic arithmetic. + + - `e || false -> e` + - `e || true -> true` + - `e && false -> false` + - `e && true -> e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.logical_or, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.logical_and, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation(operation=OperationType.logical_or, right=Constant(value=value)) if value != 0: + return [(operation, Constant(1, operation.type))] + case BinaryOperation(operation=OperationType.logical_and, right=Constant(value=value)) if value != 0: + return [(operation, operation.left)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py new file mode 100644 index 000000000..fcf1fcfa9 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py @@ -0,0 +1,28 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialShift(SimplificationRule): + """ + Simplifies trivial shift/rotate arithmetic: + + - `e << 0 -> e` + - `e >> 0 -> e` + - `e u>> 0 -> e` + - `e lrot 0 -> e` + - `e rrot 0 -> e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation( + operation=OperationType.left_shift + | OperationType.right_shift + | OperationType.right_shift_us + | OperationType.left_rotate + | OperationType.right_rotate, + right=Constant(value=0), + ): + return [(operation, operation.left)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py new file mode 100644 index 000000000..7ba105fba --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py @@ -0,0 +1,27 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SubToAdd(SimplificationRule): + """ + Replace subtractions with additions. + + `e0 - e1 -> e0 + (e1 * -1)` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation != OperationType.minus: + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + mul_op = BinaryOperation(OperationType.multiply, [operation.right, Constant(-1, operation.type)]) + + return [( + operation, + BinaryOperation( + OperationType.plus, + [operation.left, mul_op], + operation.type + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py new file mode 100644 index 000000000..0c5721a28 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py @@ -0,0 +1,26 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + + +class TermOrder(SimplificationRule): + """ + Swap constants of commutative operations to the right. + + - `c + e -> e + c` + - `c * e -> e * c` + - `c & e -> e & c` + - `c | e -> e | c` + - `c ^ e -> e ^ c` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in COMMUTATIVE_OPERATIONS: + return [] + if not isinstance(operation, BinaryOperation): + raise ValueError(f"Expected BinaryOperation, got {operation}") + + if isinstance(operation.left, Constant) and not isinstance(operation.right, Constant): + return [(operation, BinaryOperation(operation.operation, [operation.right, operation.left], operation.type, operation.tags))] + else: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification_rules.py b/decompiler/pipeline/controlflowanalysis/expression_simplification_rules.py index 997711a34..30acad13d 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification_rules.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification_rules.py @@ -2,7 +2,22 @@ from abc import ABC, abstractmethod from decompiler.backend.cexpressiongenerator import CExpressionGenerator +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_add_neg import CollapseAddNeg +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_constants import CollapseConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collect_terms import CollectTerms +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.fix_add_sub_sign import FixAddSubSign from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_redundant_reference import SimplifyRedundantReference +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_arithmetic import SimplifyTrivialArithmetic +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_bit_arithmetic import ( + SimplifyTrivialBitArithmetic, +) +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_logic_arithmetic import ( + SimplifyTrivialLogicArithmetic, +) +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_shift import SimplifyTrivialShift +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.sub_to_add import SubToAdd +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.term_order import TermOrder from decompiler.pipeline.stage import PipelineStage from decompiler.structures.ast.ast_nodes import CodeNode from decompiler.structures.pseudo import Instruction, Operation @@ -49,8 +64,21 @@ def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: _pre_rules: list[SimplificationRule] = [] -_rules: list[SimplificationRule] = [] -_post_rules: list[SimplificationRule] = [] +_rules: list[SimplificationRule] = [ + TermOrder(), + SubToAdd(), + SimplifyRedundantReference(), + SimplifyTrivialArithmetic(), + SimplifyTrivialBitArithmetic(), + SimplifyTrivialLogicArithmetic(), + SimplifyTrivialShift(), + CollapseConstants(), + CollectTerms(), +] +_post_rules: list[SimplificationRule] = [ + CollapseAddNeg(), + FixAddSubSign() +] def simplify_instructions(instructions: list[Instruction], max_iterations: int): diff --git a/decompiler/structures/pseudo/operations.py b/decompiler/structures/pseudo/operations.py index 43cf989e9..aeaa1b514 100644 --- a/decompiler/structures/pseudo/operations.py +++ b/decompiler/structures/pseudo/operations.py @@ -152,9 +152,14 @@ class OperationType(Enum): COMMUTATIVE_OPERATIONS = { OperationType.plus, OperationType.multiply, + OperationType.multiply_us, OperationType.bitwise_and, OperationType.bitwise_xor, OperationType.bitwise_or, + OperationType.logical_or, + OperationType.logical_and, + OperationType.equal, + OperationType.not_equal } NON_COMPOUNDABLE_OPERATIONS = { @@ -163,6 +168,10 @@ class OperationType(Enum): OperationType.left_rotate, OperationType.left_rotate_carry, OperationType.power, + OperationType.logical_or, + OperationType.logical_and, + OperationType.equal, + OperationType.not_equal } diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py new file mode 100644 index 000000000..d7300107f --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_add_neg import CollapseAddNeg +from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation, Variable + +var_x = Variable("x") +var_y = Variable("y") + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.plus, [var_x, UnaryOperation(OperationType.negate, [var_y])]), + [BinaryOperation(OperationType.minus, [var_x, var_y])], + ), + ( + BinaryOperation(OperationType.minus, [var_x, UnaryOperation(OperationType.negate, [var_y])]), + [BinaryOperation(OperationType.plus, [var_x, var_y])], + ), + ], +) +def test_collapse_add_neg(operation: Operation, result: list[Expression]): + assert CollapseAddNeg().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py new file mode 100644 index 000000000..1a4a5febf --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_constants import CollapseConstants +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Float, Integer, Operation, OperationType, Variable + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _c_float(value: float) -> Constant: + return Constant(value, Float.float()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.plus, [_c_i32(3), _c_i32(4)]), [_c_i32(7)]), + (BinaryOperation(OperationType.plus, [_c_i32(3), Variable("x")]), []), + (BinaryOperation(OperationType.plus_float, [_c_float(3.0), _c_float(4.0)]), []), + ], +) +def test_collapse_constant(operation: Operation, result: list[Expression]): + assert CollapseConstants().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_mult_neg_one.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_mult_neg_one.py new file mode 100644 index 000000000..88b945f91 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_mult_neg_one.py @@ -0,0 +1,16 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_mult_neg_one import CollapseMultNegOne +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation, Variable + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.multiply, [var := Variable("x"), Constant(-1)]), + [UnaryOperation(OperationType.negate, [var])], + ) + ], +) +def test_mult_neg_one(operation: Operation, result: list[Expression]): + assert CollapseMultNegOne().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collect_terms.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collect_terms.py new file mode 100644 index 000000000..f81304f8e --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collect_terms.py @@ -0,0 +1,149 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collect_terms import CollectTerms +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable +from decompiler.structures.visitors.substitute_visitor import SubstituteVisitor + + +def _var_i32(name: str) -> Variable: + return Variable(name, Integer.int32_t()) + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _plus(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.plus, [e0, e1]) + + +def _mul(e: Expression, factor: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.multiply, [e, factor]) + + +def _mul_us(e: Expression, factor: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.multiply_us, [e, factor]) + + +def _bit_and(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_and, [e0, e1]) + + +def _bit_xor(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_xor, [e0, e1]) + + +def _bit_or(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_or, [e0, e1]) + + +@pytest.mark.parametrize( + ["operation", "possible_results"], + [ + ( # plus + _plus(_plus(_c_i32(7), _c_i32(11)), _c_i32(42)), + { + _plus(_plus(_c_i32(0), _c_i32(0)), _c_i32(60)), + _plus(_plus(_c_i32(0), _c_i32(60)), _c_i32(0)), + _plus(_plus(_c_i32(60), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _plus(_plus(_var_i32("a"), _c_i32(2)), _plus(_var_i32("b"), _c_i32(3))), + { + _plus(_plus(_var_i32("a"), _c_i32(5)), _plus(_var_i32("b"), _c_i32(0))), + _plus(_plus(_var_i32("a"), _c_i32(0)), _plus(_var_i32("b"), _c_i32(5))), + }, + ), + ( # multiply + _mul(_mul(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _mul(_mul(_c_i32(1), _c_i32(1)), _c_i32(154)), + _mul(_mul(_c_i32(1), _c_i32(154)), _c_i32(1)), + _mul(_mul(_c_i32(154), _c_i32(1)), _c_i32(1)), + }, + ), + ( + _mul(_mul(_var_i32("a"), _c_i32(2)), _mul(_var_i32("b"), _c_i32(3))), + { + _mul(_mul(_var_i32("a"), _c_i32(6)), _mul(_var_i32("b"), _c_i32(1))), + _mul(_mul(_var_i32("a"), _c_i32(1)), _mul(_var_i32("b"), _c_i32(6))), + }, + ), + ( # multiply_us + _mul_us(_mul_us(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _mul_us(_mul_us(_c_i32(1), _c_i32(1)), _c_i32(154)), + _mul_us(_mul_us(_c_i32(1), _c_i32(154)), _c_i32(1)), + _mul_us(_mul_us(_c_i32(154), _c_i32(1)), _c_i32(1)), + }, + ), + ( + _mul_us(_mul_us(_var_i32("a"), _c_i32(2)), _mul_us(_var_i32("b"), _c_i32(3))), + { + _mul_us(_mul_us(_var_i32("a"), _c_i32(6)), _mul_us(_var_i32("b"), _c_i32(1))), + _mul_us(_mul_us(_var_i32("a"), _c_i32(1)), _mul_us(_var_i32("b"), _c_i32(6))), + }, + ), + ( # bitwise_and + _bit_and(_bit_and(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_and(_bit_and(_c_i32(-1), _c_i32(-1)), _c_i32(2)), + _bit_and(_bit_and(_c_i32(-1), _c_i32(2)), _c_i32(-1)), + _bit_and(_bit_and(_c_i32(2), _c_i32(-1)), _c_i32(-1)), + }, + ), + ( + _bit_and(_bit_and(_var_i32("a"), _c_i32(2)), _bit_and(_var_i32("b"), _c_i32(3))), + { + _bit_and(_bit_and(_var_i32("a"), _c_i32(2)), _bit_and(_var_i32("b"), _c_i32(-1))), + _bit_and(_bit_and(_var_i32("a"), _c_i32(-1)), _bit_and(_var_i32("b"), _c_i32(2))), + }, + ), + ( # bitwise_xor + _bit_xor(_bit_xor(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_xor(_bit_xor(_c_i32(0), _c_i32(0)), _c_i32(14)), + _bit_xor(_bit_xor(_c_i32(0), _c_i32(14)), _c_i32(0)), + _bit_xor(_bit_xor(_c_i32(14), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(2)), _bit_xor(_var_i32("b"), _c_i32(3))), + { + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(1)), _bit_xor(_var_i32("b"), _c_i32(0))), + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(0)), _bit_xor(_var_i32("b"), _c_i32(1))), + }, + ), + ( # bitwise_or + _bit_or(_bit_or(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_or(_bit_or(_c_i32(0), _c_i32(0)), _c_i32(15)), + _bit_or(_bit_or(_c_i32(0), _c_i32(15)), _c_i32(0)), + _bit_or(_bit_or(_c_i32(15), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _bit_or(_bit_or(_var_i32("a"), _c_i32(2)), _bit_or(_var_i32("b"), _c_i32(3))), + { + _bit_or(_bit_or(_var_i32("a"), _c_i32(3)), _bit_or(_var_i32("b"), _c_i32(0))), + _bit_or(_bit_or(_var_i32("a"), _c_i32(0)), _bit_or(_var_i32("b"), _c_i32(3))), + }, + ), + ], +) +def test_collect_terms(operation: Operation, possible_results: set[Expression]): + collect_terms = CollectTerms() + + for i in range(100): + substitutions = collect_terms.apply(operation) + if not substitutions: + break + + for replacee, replacement in substitutions: + new_operation = operation.accept(SubstituteVisitor.identity(replacee, replacement)) + if new_operation is not None: + operation = new_operation + else: + raise RuntimeError("Max iterations exceeded") + + assert operation in possible_results diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_fix_add_sub_sign.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_fix_add_sub_sign.py new file mode 100644 index 000000000..47f787d3a --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_fix_add_sub_sign.py @@ -0,0 +1,36 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.fix_add_sub_sign import FixAddSubSign +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var_x_i = Variable("x", Integer.int32_t()) +var_x_u = Variable("x", Integer.uint32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.minus, [var_x_i, (Constant(-3, Integer.int32_t()))]), + [BinaryOperation(OperationType.plus, [var_x_i, Constant(3, Integer.int32_t())])], + ), + ( + BinaryOperation(OperationType.plus, [var_x_i, (Constant(-3, Integer.int32_t()))]), + [BinaryOperation(OperationType.minus, [var_x_i, Constant(3, Integer.int32_t())])], + ), + (BinaryOperation(OperationType.plus, [var_x_i, (Constant(3, Integer.int32_t()))]), []), + (BinaryOperation(OperationType.minus, [var_x_i, (Constant(3, Integer.int32_t()))]), []), + + ( + BinaryOperation(OperationType.minus, [var_x_u, (Constant(4294967293, Integer.uint32_t()))]), + [BinaryOperation(OperationType.plus, [var_x_u, Constant(3, Integer.uint32_t())])], + ), + ( + BinaryOperation(OperationType.plus, [var_x_u, (Constant(4294967293, Integer.uint32_t()))]), + [BinaryOperation(OperationType.minus, [var_x_u, Constant(3, Integer.uint32_t())])], + ), + (BinaryOperation(OperationType.plus, [var_x_u, (Constant(3, Integer.uint32_t()))]), []), + (BinaryOperation(OperationType.minus, [var_x_u, (Constant(3, Integer.uint32_t()))]), []), + ], +) +def test_fix_add_sub_sign(operation: Operation, result: list[Expression]): + assert FixAddSubSign().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py new file mode 100644 index 000000000..f269d5ec9 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py @@ -0,0 +1,15 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_redundant_reference import SimplifyRedundantReference +from decompiler.structures.pseudo import Expression, Operation, OperationType, UnaryOperation, Variable + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [var := Variable("x")])]), [var]), + (UnaryOperation(OperationType.address, [Variable("x")]), []), + (UnaryOperation(OperationType.dereference, [Variable("x")]), []), + ], +) +def test_simplify_redundant_reference(operation: Operation, result: list[Expression]): + assert SimplifyRedundantReference().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py new file mode 100644 index 000000000..f1f9818bc --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py @@ -0,0 +1,27 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_arithmetic import SimplifyTrivialArithmetic +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, UnaryOperation, Variable + +var = Variable("x") +con_0 = Constant(0, Integer.int32_t()) +con_1 = Constant(1, Integer.int32_t()) +con_neg1 = Constant(-1, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.plus, [var, con_0]), [var]), + (BinaryOperation(OperationType.minus, [var, con_0]), [var]), + (BinaryOperation(OperationType.multiply, [var, con_1]), [var]), + (BinaryOperation(OperationType.multiply_us, [var, con_1]), [var]), + (BinaryOperation(OperationType.multiply, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.multiply_us, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.divide, [var, con_1]), [var]), + (BinaryOperation(OperationType.divide_us, [var, con_1]), [var]), + (BinaryOperation(OperationType.divide, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.divide_us, [var, con_neg1]), []), + ], +) +def test_simplify_trivial_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py new file mode 100644 index 000000000..e42771f66 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_bit_arithmetic import ( + SimplifyTrivialBitArithmetic, +) +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var = Variable("x", Integer.int32_t()) +con_0 = Constant(0, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.bitwise_or, [var, con_0]), [var]), + (BinaryOperation(OperationType.bitwise_or, [var, var]), [var]), + (BinaryOperation(OperationType.bitwise_and, [var, con_0]), [con_0]), + (BinaryOperation(OperationType.bitwise_and, [var, var]), [var]), + (BinaryOperation(OperationType.bitwise_xor, [var, con_0]), [var]), + (BinaryOperation(OperationType.bitwise_xor, [var, var]), [con_0]), + ], +) +def test_simplify_trivial_bit_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialBitArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py new file mode 100644 index 000000000..8d74ff8b1 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py @@ -0,0 +1,22 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_logic_arithmetic import ( + SimplifyTrivialLogicArithmetic, +) +from decompiler.structures.pseudo import BinaryOperation, Constant, CustomType, Expression, Operation, OperationType, Variable + +var = Variable("x", CustomType.bool()) +con_false = Constant(0, CustomType.bool()) +con_true = Constant(1, CustomType.bool()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.logical_or, [var, con_false]), [var]), + (BinaryOperation(OperationType.logical_or, [var, con_true]), [con_true]), + (BinaryOperation(OperationType.logical_and, [var, con_false]), [con_false]), + (BinaryOperation(OperationType.logical_and, [var, con_true]), [var]), + ], +) +def test_simplify_trivial_logic_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialLogicArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py new file mode 100644 index 000000000..a281caa33 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py @@ -0,0 +1,20 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_shift import SimplifyTrivialShift +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var = Variable("x") +con_0 = Constant(0, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.left_shift, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_shift, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_shift_us, [var, con_0]), [var]), + (BinaryOperation(OperationType.left_rotate, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_rotate, [var, con_0]), [var]), + ], +) +def test_simplify_trivial_shift(operation: Operation, result: list[Expression]): + assert SimplifyTrivialShift().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py new file mode 100644 index 000000000..565b9067e --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py @@ -0,0 +1,20 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.sub_to_add import SubToAdd +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var_x = Variable("x", Integer.int32_t()) +var_y = Variable("y", Integer.int32_t()) +con_neg1 = Constant(-1, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.minus, [var_x, var_y]), + [BinaryOperation(OperationType.plus, [var_x, BinaryOperation(OperationType.multiply, [var_y, con_neg1])])], + ), + ], +) +def test_sub_to_add(operation: Operation, result: list[Expression]): + assert SubToAdd().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py new file mode 100644 index 000000000..5846004e9 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py @@ -0,0 +1,15 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.term_order import TermOrder +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, Variable +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + +var = Variable("x") +con = Constant(42, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [(BinaryOperation(operation, [con, var]), [BinaryOperation(operation, [var, con])]) for operation in COMMUTATIVE_OPERATIONS], +) +def test_term_order(operation: Operation, result: list[Expression]): + assert TermOrder().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_modification.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_modification.py new file mode 100644 index 000000000..32e9c1bbe --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_modification.py @@ -0,0 +1,162 @@ +from contextlib import nullcontext + +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.modification import ( + FOLDABLE_CONSTANTS, + constant_fold, + multiply_int_with_constant, +) +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Float, Integer, OperationType, Variable + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _c_u32(value: int) -> Constant: + return Constant(value, Integer.uint32_t()) + + +def _c_i16(value: int) -> Constant: + return Constant(value, Integer.int16_t()) + + +def _c_float(value: float) -> Constant: + return Constant(value, Float.float()) + + +@pytest.mark.parametrize( + ["expression", "constant", "result", "context"], + [ + ( + var := Variable("x", Integer.int32_t()), + con := _c_i32(4), + BinaryOperation(OperationType.multiply, [var, con], Integer.int32_t()), + nullcontext(), + ), + (_c_i32(3), _c_i32(4), _c_i32(12), nullcontext()), + (_c_float(3.0), _c_i32(4), None, pytest.raises(ValueError)), + (_c_i32(4), _c_float(3.0), None, pytest.raises(ValueError)), + (_c_i32(3), _c_i16(4), None, pytest.raises(ValueError)), + ], +) +def test_multiply_int_with_constant(expression: Expression, constant: Constant, result: Expression, context): + with context: + assert multiply_int_with_constant(expression, constant) == result + + +@pytest.mark.parametrize( + ["operation"], + [(operation,) for operation in OperationType if operation not in FOLDABLE_CONSTANTS] +) +def test_constant_fold_invalid_operations(operation: OperationType): + with pytest.raises(ValueError): + constant_fold(operation, []) + + +@pytest.mark.parametrize( + ["operation", "constants", "result", "context"], + [ + (OperationType.plus, [_c_i32(3), _c_i32(4)], _c_i32(7), nullcontext()), + (OperationType.plus, [_c_i32(2147483647), _c_i32(1)], _c_i32(-2147483648), nullcontext()), + (OperationType.plus, [_c_u32(2147483658), _c_u32(2147483652)], _c_u32(14), nullcontext()), + (OperationType.plus, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.minus, [_c_i32(3), _c_i32(4)], _c_i32(-1), nullcontext()), + (OperationType.minus, [_c_i32(-2147483648), _c_i32(1)], _c_i32(2147483647), nullcontext()), + (OperationType.minus, [_c_u32(3), _c_u32(4)], _c_u32(4294967295), nullcontext()), + (OperationType.minus, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.multiply, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), + (OperationType.multiply, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), + (OperationType.multiply, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), + (OperationType.multiply, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.multiply_us, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), + (OperationType.multiply_us, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), + (OperationType.multiply_us, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), + (OperationType.multiply_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.divide, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), + (OperationType.divide, [_c_i32(-2147483648), _c_i32(2)], _c_i32(-1073741824), nullcontext()), + (OperationType.divide, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.divide_us, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), + (OperationType.divide_us, [_c_i32(-2147483648), _c_i32(2)], _c_i32(1073741824), nullcontext()), + (OperationType.divide_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.negate, [_c_i32(3)], _c_i32(-3), nullcontext()), + (OperationType.negate, [_c_i32(-2147483648)], _c_i32(-2147483648), nullcontext()), + (OperationType.negate, [], None, pytest.raises(ValueError)), + (OperationType.negate, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.left_shift, [_c_i32(3), _c_i32(4)], _c_i32(48), nullcontext()), + (OperationType.left_shift, [_c_i32(1073741824), _c_i32(1)], _c_i32(-2147483648), nullcontext()), + (OperationType.left_shift, [_c_u32(1073741824), _c_u32(1)], _c_u32(2147483648), nullcontext()), + (OperationType.left_shift, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.left_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.right_shift, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), + (OperationType.right_shift, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-1073741824), nullcontext()), + (OperationType.right_shift, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), + (OperationType.right_shift, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.right_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.right_shift_us, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), + (OperationType.right_shift_us, [_c_i32(-2147483648), _c_i32(1)], _c_i32(1073741824), nullcontext()), + (OperationType.right_shift_us, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), + (OperationType.right_shift_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.right_shift_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_or, [_c_i32(85), _c_i32(34)], _c_i32(119), nullcontext()), + (OperationType.bitwise_or, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-2147483647), nullcontext()), + (OperationType.bitwise_or, [_c_u32(2147483648), _c_u32(1)], _c_u32(2147483649), nullcontext()), + (OperationType.bitwise_or, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_and, [_c_i32(85), _c_i32(51)], _c_i32(17), nullcontext()), + (OperationType.bitwise_and, [_c_i32(-2147483647), _c_i32(3)], _c_i32(1), nullcontext()), + (OperationType.bitwise_and, [_c_u32(2147483649), _c_u32(3)], _c_u32(1), nullcontext()), + (OperationType.bitwise_and, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_xor, [_c_i32(85), _c_i32(51)], _c_i32(102), nullcontext()), + (OperationType.bitwise_xor, [_c_i32(-2147483647), _c_i32(-2147483646)], _c_i32(3), nullcontext()), + (OperationType.bitwise_xor, [_c_u32(2147483649), _c_u32(2147483650)], _c_u32(3), nullcontext()), + (OperationType.bitwise_xor, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_not, [_c_i32(6)], _c_i32(-7), nullcontext()), + (OperationType.bitwise_not, [_c_i32(-2147483648)], _c_i32(2147483647), nullcontext()), + (OperationType.bitwise_not, [_c_u32(2147483648)], _c_u32(2147483647), nullcontext()), + (OperationType.bitwise_not, [], None, pytest.raises(ValueError)), + (OperationType.bitwise_not, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + ] +) +def test_constant_fold(operation: OperationType, constants: list[Constant], result: Constant, context): + with context: + assert constant_fold(operation, constants) == result