diff --git a/decompiler/pipeline/controlflowanalysis/__init__.py b/decompiler/pipeline/controlflowanalysis/__init__.py index ebbc4b673..a0cd40244 100644 --- a/decompiler/pipeline/controlflowanalysis/__init__.py +++ b/decompiler/pipeline/controlflowanalysis/__init__.py @@ -1,4 +1,4 @@ -from .expression_simplification import ExpressionSimplification +from .expression_simplification.stages import ExpressionSimplificationAst, ExpressionSimplificationCfg from .instruction_length_handler import InstructionLengthHandler from .readability_based_refinement import ReadabilityBasedRefinement from .variable_name_generation import VariableNameGeneration diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification.py b/decompiler/pipeline/controlflowanalysis/expression_simplification.py deleted file mode 100644 index 5064e58a3..000000000 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Module implementing basic simplifications for expressions.""" -from typing import Optional - -from decompiler.pipeline.stage import PipelineStage -from decompiler.structures.ast.ast_nodes import CodeNode -from decompiler.structures.pseudo.expressions import Constant, Expression -from decompiler.structures.pseudo.instructions import Instruction -from decompiler.structures.pseudo.operations import BinaryOperation, Operation, OperationType, UnaryOperation -from decompiler.structures.pseudo.typing import Float, Integer -from decompiler.task import DecompilerTask - - -class ExpressionSimplification(PipelineStage): - """The ExpressionSimplification makes various simplifications to expressions on the AST, like a + 0 = a.""" - - name = "expression-simplification" - - def __init__(self): - self.HANDLERS = { - OperationType.plus: self._simplify_addition, - OperationType.minus: self._simplify_subtraction, - OperationType.multiply: self._simplify_multiplication, - OperationType.divide: self._simplify_division, - OperationType.divide_us: self._simplify_division, - OperationType.divide_float: self._simplify_division, - OperationType.dereference: self._simplify_dereference, - } - - def run(self, task: DecompilerTask): - """Run the task expression simplification on each instruction of the AST.""" - if task.syntax_tree is None: - for instruction in task.graph.instructions: - self.simplify(instruction) - else: - for node in task.syntax_tree.topological_order(): - if not isinstance(node, CodeNode): - continue - for instruction in node.instructions: - self.simplify(instruction) - - def simplify(self, instruction: Instruction): - """Simplify all subexpressions of the given instruction recursively.""" - todo = list(instruction) - while todo and (expression := todo.pop()): - if self.simplify_expression(expression, instruction): - todo = list(instruction) - else: - todo.extend(expression) - - def simplify_expression(self, expression: Expression, parent: Instruction) -> Optional[Expression]: - """Simplify the given instruction utilizing the registered OperationType handlers.""" - if isinstance(expression, Operation) and expression.operation in self.HANDLERS: - if simplified := self.HANDLERS[expression.operation](expression): - parent.substitute(expression, simplified) - return simplified - - def _simplify_addition(self, expression: BinaryOperation) -> Optional[Expression]: - """ - Simplifies the given addition in the given instruction. - - -> Simplifies a+0, 0+a, a-0 and -0 + a to a - """ - if any(self.is_zero_constant(zero := op) for op in expression.operands): - return self.get_other_operand(expression, zero).copy() - - def _simplify_subtraction(self, expression: BinaryOperation) -> Optional[Expression]: - """ - Simplifies the given subtraction in the given instruction. - - -> Simplifies a-0, a-(-0) to a - -> Simplifies 0-a, -0-a to -a - """ - if self.is_zero_constant(expression.operands[1]): - return expression.operands[0].copy() - if self.is_zero_constant(expression.operands[0]): - return self.negate_expression(expression.operands[1]) - - def _simplify_multiplication(self, expression: BinaryOperation) -> Optional[Expression]: - """ - Simplifies the given multiplication in the given instruction. - - -> Simplifies a*0, 0*a, a*(-0) and (-0) * a to 0 - -> Simplifies a*1, 1*a to a - -> Simplifies a*(-1), (-1)*a to -a - """ - if any(self.is_zero_constant(zero := op) for op in expression.operands): - return zero.copy() - if any(self.is_one_constant(one := op) for op in expression.operands): - return self.get_other_operand(expression, one).copy() - if any(self.is_minus_one_constant(minus_one := op) for op in expression.operands): - return self.negate_expression(self.get_other_operand(expression, minus_one)) - - def _simplify_division(self, expression: BinaryOperation) -> Optional[Expression]: - """ - Simplifies the given division in the given instruction. - - -> Simplifies a/1 to a and a/(-1) to -a - """ - if self.is_one_constant(expression.operands[1]): - return expression.operands[0].copy() - if self.is_minus_one_constant(expression.operands[1]): - return self.negate_expression(expression.operands[0]) - - def _simplify_dereference(self, expression: UnaryOperation) -> Optional[Expression]: - """ - Simplifies dereference expression with nested address-of expressions. - - -> Simplifies *(&(x)) to x - """ - if isinstance(expression.operand, UnaryOperation) and expression.operand.operation == OperationType.address: - return expression.operand.operand.copy() - - @staticmethod - def is_zero_constant(expression: Expression) -> bool: - """Checks whether the given expression is 0.""" - return isinstance(expression, Constant) and expression.value == 0 - - @staticmethod - def is_one_constant(expression: Expression) -> bool: - """Checks whether the given expression is 1.""" - return isinstance(expression, Constant) and expression.value == 1 - - @staticmethod - def is_minus_one_constant(expression: Expression) -> bool: - """Checks whether the given expression is -1.""" - return isinstance(expression, Constant) and expression.value == -1 - - @staticmethod - def negate_expression(expression: Expression) -> Expression: - """Negate the given expression and return it.""" - match expression: - case Constant(value=0): return expression - case UnaryOperation(operation=OperationType.negate): return expression.operand - case Constant(type=Integer(is_signed=True) | Float()): return Constant(-expression.value, expression.type) - case _: return UnaryOperation(OperationType.negate, [expression]) - - @staticmethod - def get_other_operand(binary_operation: BinaryOperation, expression: Expression) -> Expression: - """Returns the operand that is not equal to expression.""" - if binary_operation.operands[0] == expression: - return binary_operation.operands[1] - return binary_operation.operands[0] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py new file mode 100644 index 000000000..81c87498c --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py @@ -0,0 +1,145 @@ +import operator +from functools import partial +from typing import Callable, Optional + +from decompiler.structures.pseudo import Constant, Integer, OperationType + + +def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant: + """ + Fold operation with constants as operands. + + :param operation: The operation. + :param constants: All constant operands of the operation. + :return: A constant representing the result of the operation. + """ + + if operation not in _OPERATION_TO_FOLD_FUNCTION: + raise ValueError(f"Constant folding not implemented for operation '{operation}'.") + + return _OPERATION_TO_FOLD_FUNCTION[operation](constants) + + +def _constant_fold_arithmetic_binary( + constants: list[Constant], + fun: Callable[[int, int], int], + norm_sign: Optional[bool] = None +) -> Constant: + """ + Fold an arithmetic binary operation with constants as operands. + + :param constants: A list of exactly 2 constant operands. + :param fun: The binary function to perform on the constants. + :param norm_sign: Optional boolean flag to indicate if/how to normalize the input constants to 'fun': + - None (default): no normalization + - True: normalize inputs, interpreted as signed values + - False: normalize inputs, interpreted as unsigned values + :return: A constant representing the result of the operation. + """ + + if len(constants) != 2: + raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.") + if not all(constant.type == constants[0].type for constant in constants): + raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}") + if not all(isinstance(constant.type, Integer) for constant in constants): + raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.") + + left, right = constants + + left_value = left.value + right_value = right.value + if norm_sign is not None: + left_value = normalize_int(left_value, left.type.size, norm_sign) + right_value = normalize_int(right_value, right.type.size, norm_sign) + + return Constant( + normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), + left.type + ) + + +def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant: + """ + Fold an arithmetic unary operation with a constant operand. + + :param constants: A list containing a single constant operand. + :param fun: The unary function to perform on the constant. + :return: A constant representing the result of the operation. + """ + + if len(constants) != 1: + raise ValueError("Expected exactly 1 constant to fold") + if not isinstance(constants[0].type, Integer): + raise ValueError(f"Constant must be of type integer: {constants[0].type}") + + return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type) + + +def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> Constant: + """ + Fold a shift operation with constants as operands. + + :param constants: A list of exactly 2 constant operands. + :param fun: The shift function to perform on the constants. + :param signed: Boolean flag indicating whether the shift is signed. + This is used to normalize the sign of the input constant to simulate unsigned shifts. + :return: A constant representing the result of the operation. + """ + + if len(constants) != 2: + raise ValueError("Expected exactly 2 constants to fold") + if not all(isinstance(constant.type, Integer) for constant in constants): + raise ValueError("All constants must be integers") + + left, right = constants + + shifted_value = fun( + normalize_int(left.value, left.type.size, left.type.signed and signed), + right.value + ) + return Constant( + normalize_int(shifted_value, left.type.size, left.type.signed), + left.type + ) + + +def normalize_int(v: int, size: int, signed: bool) -> int: + """ + Normalizes an integer value to a specific size and signedness. + + This function takes an integer value 'v' and normalizes it to fit within + the specified 'size' in bits by discarding overflowing bits. If 'signed' is + true, the value is treated as a signed integer, i.e. interpreted as a two's complement. + Therefore the return value will be negative iff 'signed' is true and the most-significant bit is set. + + :param v: The value to be normalized. + :param size: The desired bit size for the normalized integer. + :param signed: True if the integer should be treated as signed. + :return: The normalized integer value. + """ + value = v & ((1 << size) - 1) + if signed and value & (1 << (size - 1)): + return value - (1 << size) + else: + return value + + +_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = { + OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub), + OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add), + OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True), + OperationType.multiply_us: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=False), + OperationType.divide: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=True), + OperationType.divide_us: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=False), + OperationType.negate: partial(_constant_fold_arithmetic_unary, fun=operator.neg), + OperationType.left_shift: partial(_constant_fold_shift, fun=operator.lshift, signed=True), + OperationType.right_shift: partial(_constant_fold_shift, fun=operator.rshift, signed=True), + OperationType.right_shift_us: partial(_constant_fold_shift, fun=operator.rshift, signed=False), + OperationType.bitwise_or: partial(_constant_fold_arithmetic_binary, fun=operator.or_), + OperationType.bitwise_and: partial(_constant_fold_arithmetic_binary, fun=operator.and_), + OperationType.bitwise_xor: partial(_constant_fold_arithmetic_binary, fun=operator.xor), + OperationType.bitwise_not: partial(_constant_fold_arithmetic_unary, fun=operator.inv), +} + + +FOLDABLE_OPERATIONS = _OPERATION_TO_FOLD_FUNCTION.keys() diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py new file mode 100644 index 000000000..c978991fc --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py @@ -0,0 +1,30 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation + + +class CollapseAddNeg(SimplificationRule): + """ + Simplifies additions/subtraction with negated expression. + + - `e0 + -(e1) -> e0 - e1` + - `e0 - -(e1) -> e0 + e1` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in [OperationType.plus, OperationType.minus]: + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + right = operation.right + if not isinstance(right, UnaryOperation) or right.operation != OperationType.negate: + return [] + + return [( + operation, + BinaryOperation( + OperationType.minus if operation.operation == OperationType.plus else OperationType.plus, + [operation.left, right.operand], + operation.type + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py new file mode 100644 index 000000000..03db3c00d --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py @@ -0,0 +1,20 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Constant, Expression, Operation + + +class CollapseConstants(SimplificationRule): + """ + Fold operations with only constants as operands: + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if not all(isinstance(o, Constant) for o in operation.operands): + return [] + if operation.operation not in FOLDABLE_OPERATIONS: + return [] + + return [( + operation, + constant_fold(operation.operation, operation.operands) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py new file mode 100644 index 000000000..f3559daae --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py @@ -0,0 +1,81 @@ +from functools import reduce +from typing import Iterator + +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import constant_fold +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + + +class CollapseNestedConstants(SimplificationRule): + """ + This rule walks the dafaflow tree and collects and folds constants in commutative operations. + The first constant of the tree is replaced with the folded result and all remaining constants are replaced with the identity. + This stage exploits associativity and is the only stage doing so. Therefore, it cannot be replaced by a combination of `TermOrder` and `CollapseConstants`. + """ + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in COMMUTATIVE_OPERATIONS: + return [] + if not isinstance(operation, Operation): + raise TypeError(f"Expected Operation, got {type(operation)}") + + constants = list(_collect_constants(operation)) + if len(constants) <= 1: + return [] + + first, *rest = constants + + folded_constant = reduce( + lambda c0, c1: constant_fold(operation.operation, [c0, c1]), + rest, + first + ) + + identity_constant = _identity_constant(operation.operation, operation.type) + return [ + (first, folded_constant), + *((constant, identity_constant) for constant in rest) + ] + + +def _collect_constants(operation: Operation) -> Iterator[Constant]: + """ + Collects constants of potentially multiple nested commutative operations of the same type. + + This function traverses the subtree rooted at the provided operation and collects + all constants that belong to operations with the same operation type as the root operation. + The subtree includes only operations that have matching operation types. + """ + + operation_type = operation.operation + operand_type = operation.type + + context_stack: list[Operation] = [operation] + while context_stack: + current_operation = context_stack.pop() + + for i, operand in enumerate(current_operation.operands): + if operand.type != operand_type: + continue + + if isinstance(operand, Operation): + if operand.operation == operation_type: + context_stack.append(operand) + continue + elif isinstance(operand, Constant) and _identity_constant(operation_type, operand_type).value != operand.value: + yield operand + + +def _identity_constant(operation: OperationType, var_type: Type) -> Constant: + """ + Return a const containing the identity element for the specified operation and variable type. + """ + match operation: + case OperationType.plus | OperationType.bitwise_xor | OperationType.bitwise_or: + return Constant(0, var_type) + case OperationType.multiply | OperationType.multiply_us: + return Constant(1, var_type) + case OperationType.bitwise_and: + return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)]) + case _: + raise NotImplementedError() diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py new file mode 100644 index 000000000..6b358dd81 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py @@ -0,0 +1,43 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import normalize_int +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType + + +class PositiveConstants(SimplificationRule): + """ + Changes add/sub so that the right operand constant is always positive. + For unsigned arithmetic, choose the operation with the lesser constant (e.g.: V - 4294967293 -> V + 3 for 32 bit ints). + + - `V - a -> E + (-a)` when signed(a) < 0 + - `V + a -> E - (-a)` when signed(a) < 0 + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in (OperationType.plus, OperationType.minus): + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + right = operation.right + if not isinstance(right, Constant): + return [] + + constant_type = right.type + if not isinstance(constant_type, Integer): + return [] + + signed_normalized_constant = normalize_int(right.value, constant_type.size, True) + if signed_normalized_constant >= 0: + return [] + + neg_constant = Constant( + normalize_int(-signed_normalized_constant, constant_type.size, constant_type.signed), + constant_type + ) + return [( + operation, + BinaryOperation( + OperationType.plus if operation.operation == OperationType.minus else OperationType.minus, + [operation.left, neg_constant] + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py new file mode 100644 index 000000000..a4f70334a --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + +from decompiler.structures.pseudo import Expression, Operation + + +class SimplificationRule(ABC): + """ + This class defines the interface for simplification rules that can be applied to expressions. + """ + + @abstractmethod + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + """ + Apply the simplification rule to the given operation. + + :param operation: The operation to which the simplification rule should be applied. + :return: A list of tuples, each containing a pair of expressions representing the original + and simplified versions resulting from applying the simplification rule to the given operation. + """ + pass diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py new file mode 100644 index 000000000..9337be30b --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py @@ -0,0 +1,20 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Expression, Operation, OperationType, UnaryOperation + + +class SimplifyRedundantReference(SimplificationRule): + """ + Removes redundant nesting of referencing, immediately followed by referencing. + + `*(&(e0)) -> e0` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case UnaryOperation( + operation=OperationType.dereference, + operand=UnaryOperation(operation=OperationType.address, operand=inner_operand) + ): + return [(operation, inner_operand)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py new file mode 100644 index 000000000..07ebac6c2 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py @@ -0,0 +1,39 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation + + +class SimplifyTrivialArithmetic(SimplificationRule): + """ + Simplifies trivial arithmetic: + + - `e + 0 -> e` + - `e - 0 -> e` + - `e * 0 -> 0` + - `e u* 0 -> 0` + - `e * 1 -> e` + - `e u* 1 -> e` + - `e * -1 -> -e` + - `e u* -1 -> -e` + - `e / 1 -> e` + - `e u/ 1 -> e` + - `e / -1 -> -e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.plus | OperationType.minus, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation( + operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide | OperationType.divide_us, + right=Constant(value=1), + ): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.multiply | OperationType.multiply_us, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation( + operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide, + right=Constant(value=-1) + ): + return [(operation, UnaryOperation(OperationType.negate, [operation.left]))] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py new file mode 100644 index 000000000..5a6a62c58 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py @@ -0,0 +1,28 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialBitArithmetic(SimplificationRule): + """ + Simplifies trivial bit arithmetic: + + - `e | 0 -> e` + - `e | e -> e` + - `e & 0 -> 0` + - `e & e -> e` + - `e ^ 0 -> e` + - `e ^ e -> 0` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_xor, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.bitwise_and, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_and, left=left, right=right) if left == right: + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.bitwise_xor, left=left, right=right) if left == right: + return [(operation, Constant(0, operation.type))] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py new file mode 100644 index 000000000..08994b1cc --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py @@ -0,0 +1,26 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialLogicArithmetic(SimplificationRule): + """ + Simplifies trivial logic arithmetic. + + - `e || false -> e` + - `e || true -> true` + - `e && false -> false` + - `e && true -> e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.logical_or, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.logical_and, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation(operation=OperationType.logical_or, right=Constant(value=value)) if value != 0: + return [(operation, Constant(1, operation.type))] + case BinaryOperation(operation=OperationType.logical_and, right=Constant(value=value)) if value != 0: + return [(operation, operation.left)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py new file mode 100644 index 000000000..fcf1fcfa9 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py @@ -0,0 +1,28 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialShift(SimplificationRule): + """ + Simplifies trivial shift/rotate arithmetic: + + - `e << 0 -> e` + - `e >> 0 -> e` + - `e u>> 0 -> e` + - `e lrot 0 -> e` + - `e rrot 0 -> e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation( + operation=OperationType.left_shift + | OperationType.right_shift + | OperationType.right_shift_us + | OperationType.left_rotate + | OperationType.right_rotate, + right=Constant(value=0), + ): + return [(operation, operation.left)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py new file mode 100644 index 000000000..9943d365d --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py @@ -0,0 +1,27 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation + + +class SubToAdd(SimplificationRule): + """ + Replace subtractions with additions. + + `e0 - e1 -> e0 + (-e1)` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation != OperationType.minus: + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + neg_op = UnaryOperation(OperationType.negate, [operation.right]) + + return [( + operation, + BinaryOperation( + OperationType.plus, + [operation.left, neg_op], + operation.type + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py new file mode 100644 index 000000000..638f4d08a --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py @@ -0,0 +1,28 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + + +class TermOrder(SimplificationRule): + """ + Swap constants of commutative operations to the right. + This stage is important because other stages expect constants to be on the right side. + Associativity is not exploited, i.e. nested operations of the same type are not considered. + + - `c + e -> e + c` + - `c * e -> e * c` + - `c & e -> e & c` + - `c | e -> e | c` + - `c ^ e -> e ^ c` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in COMMUTATIVE_OPERATIONS: + return [] + if not isinstance(operation, BinaryOperation): + raise ValueError(f"Expected BinaryOperation, got {operation}") + + if isinstance(operation.left, Constant) and not isinstance(operation.right, Constant): + return [(operation, BinaryOperation(operation.operation, [operation.right, operation.left], operation.type, operation.tags))] + else: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py new file mode 100644 index 000000000..cc05a59ba --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py @@ -0,0 +1,169 @@ +import logging +from abc import ABC, abstractmethod + +from decompiler.backend.cexpressiongenerator import CExpressionGenerator +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_add_neg import CollapseAddNeg +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_constants import CollapseConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_nested_constants import CollapseNestedConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.positive_constants import PositiveConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_redundant_reference import SimplifyRedundantReference +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_arithmetic import SimplifyTrivialArithmetic +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_bit_arithmetic import ( + SimplifyTrivialBitArithmetic, +) +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_logic_arithmetic import ( + SimplifyTrivialLogicArithmetic, +) +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_shift import SimplifyTrivialShift +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.sub_to_add import SubToAdd +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.term_order import TermOrder +from decompiler.pipeline.stage import PipelineStage +from decompiler.structures.ast.ast_nodes import CodeNode +from decompiler.structures.pseudo import Instruction, Operation +from decompiler.structures.visitors.substitute_visitor import SubstituteVisitor +from decompiler.task import DecompilerTask + + +class _ExpressionSimplificationBase(PipelineStage, ABC): + + def run(self, task: DecompilerTask): + max_iterations = task.options.getint("expression-simplification.max_iterations") + self._simplify_instructions(self._get_instructions(task), max_iterations) + + @abstractmethod + def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: + pass + + @classmethod + def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int): + rule_sets = [ + ("pre-rules", _pre_rules), + ("rules", _rules), + ("post-rules", _post_rules) + ] + for rule_name, rule_set in rule_sets: + # max_iterations is counted per rule_set + iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations) + if iteration_count <= max_iterations: + logging.info(f"Expression simplification took {iteration_count} iterations for {rule_name}") + else: + logging.warning(f"Exceeded max iteration count for {rule_name}") + + @classmethod + def _simplify_instructions_with_rule_set( + cls, + instructions: list[Instruction], + rule_set: list[SimplificationRule], + max_iterations: int + ) -> int: + iteration_count = 0 + + changes = True + while changes: + changes = False + + for rule in rule_set: + for instruction in instructions: + additional_iterations = cls._simplify_instruction_with_rule(instruction, rule, max_iterations - iteration_count) + if additional_iterations > 0: + changes = True + + iteration_count += additional_iterations + if iteration_count > max_iterations: + return iteration_count + + return iteration_count + + @classmethod + def _simplify_instruction_with_rule( + cls, + instruction: Instruction, + rule: SimplificationRule, + max_iterations: int + ) -> int: + iteration_count = 0 + for expression in instruction.subexpressions(): + while True: + # NOTE: By breaking out of the inner endless loop, we just continue with the next subexpression. + if expression is None: + break + if not isinstance(expression, Operation): + break + + substitutions = rule.apply(expression) + if not substitutions: + break + + iteration_count += 1 + + if iteration_count > max_iterations: + logging.warning("Took to many iterations for rule set to finish") + return iteration_count + + for i, (replacee, replacement) in enumerate(substitutions): + expression_gen = CExpressionGenerator() + logging.debug( + f"[{rule.__class__.__name__}] {i}. Substituting: '{replacee.accept(expression_gen)}'" + f" with '{replacement.accept(expression_gen)}' in '{expression.accept(expression_gen)}'" + ) + instruction_repl = instruction.accept(SubstituteVisitor.identity(replacee, replacement)) + + # instruction.accept should never replace itself and consequently return none, because replacee is + # of type Expression and doesn't permit Instruction + assert instruction_repl is None + + # This is modifying the expression tree, while we are iterating over it. + # This works because we are iterating depth first and only + # modifying already visited nodes. + + # if expression got replaced, we need to update the reference + if replacee == expression: + expression = replacement + + return iteration_count + + +class ExpressionSimplificationCfg(_ExpressionSimplificationBase): + """ + Pipeline stage that simplifies cfg expressions by applying a set of simplification rules. + """ + + name = "expression-simplification-cfg" + + def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: + return list(task.graph.instructions) + + +class ExpressionSimplificationAst(_ExpressionSimplificationBase): + """ + Pipeline stage that simplifies ast expressions by applying a set of simplification rules. + """ + + name = "expression-simplification-ast" + + def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: + instructions = [] + for node in task.syntax_tree.topological_order(): + if isinstance(node, CodeNode): + instructions.extend(node.instructions) + + return instructions + + +_pre_rules: list[SimplificationRule] = [] +_rules: list[SimplificationRule] = [ + TermOrder(), + SubToAdd(), + SimplifyRedundantReference(), + SimplifyTrivialArithmetic(), + SimplifyTrivialBitArithmetic(), + SimplifyTrivialLogicArithmetic(), + SimplifyTrivialShift(), + CollapseConstants(), + CollapseNestedConstants(), +] +_post_rules: list[SimplificationRule] = [ + CollapseAddNeg(), + PositiveConstants() +] diff --git a/decompiler/pipeline/default.py b/decompiler/pipeline/default.py index 85b0e6389..7266e0596 100644 --- a/decompiler/pipeline/default.py +++ b/decompiler/pipeline/default.py @@ -1,7 +1,8 @@ """Module defining the available pipelines.""" from decompiler.pipeline.controlflowanalysis import ( - ExpressionSimplification, + ExpressionSimplificationAst, + ExpressionSimplificationCfg, InstructionLengthHandler, ReadabilityBasedRefinement, VariableNameGeneration, @@ -36,10 +37,15 @@ IdentityElimination, CommonSubexpressionElimination, ArrayAccessDetection, - ExpressionSimplification, + ExpressionSimplificationCfg, DeadComponentPruner, GraphExpressionFolding, EdgePruner, ] -AST_STAGES = [ReadabilityBasedRefinement, ExpressionSimplification, InstructionLengthHandler, VariableNameGeneration] +AST_STAGES = [ + ReadabilityBasedRefinement, + ExpressionSimplificationAst, + InstructionLengthHandler, + VariableNameGeneration +] diff --git a/decompiler/structures/pseudo/operations.py b/decompiler/structures/pseudo/operations.py index 28d6ab9a6..c21002068 100644 --- a/decompiler/structures/pseudo/operations.py +++ b/decompiler/structures/pseudo/operations.py @@ -153,9 +153,14 @@ class OperationType(Enum): COMMUTATIVE_OPERATIONS = { OperationType.plus, OperationType.multiply, + OperationType.multiply_us, OperationType.bitwise_and, OperationType.bitwise_xor, OperationType.bitwise_or, + OperationType.logical_or, + OperationType.logical_and, + OperationType.equal, + OperationType.not_equal } NON_COMPOUNDABLE_OPERATIONS = { @@ -164,6 +169,10 @@ class OperationType(Enum): OperationType.left_rotate, OperationType.left_rotate_carry, OperationType.power, + OperationType.logical_or, + OperationType.logical_and, + OperationType.equal, + OperationType.not_equal } diff --git a/decompiler/structures/visitors/substitute_visitor.py b/decompiler/structures/visitors/substitute_visitor.py new file mode 100644 index 000000000..b4d7646e2 --- /dev/null +++ b/decompiler/structures/visitors/substitute_visitor.py @@ -0,0 +1,220 @@ +from typing import Callable, Optional, TypeVar, Union + +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Break, + Call, + Comment, + Condition, + Constant, + Continue, + DataflowObject, + Expression, + FunctionSymbol, + GenericBranch, + ImportedFunctionSymbol, + IntrinsicSymbol, + ListOperation, + MemPhi, + Operation, + Phi, + RegisterPair, + Return, + TernaryExpression, + UnaryOperation, + UnknownExpression, + Variable, +) +from decompiler.structures.pseudo.operations import ArrayInfo +from decompiler.structures.visitors.interfaces import DataflowObjectVisitorInterface + +T = TypeVar("T", bound=DataflowObject) + + +def _assert_type(obj: DataflowObject, t: type[T]) -> T: + if not isinstance(obj, t): + raise TypeError() + else: + return obj + + +class SubstituteVisitor(DataflowObjectVisitorInterface[Optional[DataflowObject]]): + """ + A visitor class for performing substitutions in a dataflow tree. + + This class allows you to create instances that can traverse a dataflow graph and perform substitutions + based on a provided mapping function. The mapping function is applied to each visited node in the graph, + and if the mapping function returns a non-None value, the node is replaced with the returned value. + + Note: + - Modifications to the dataflow tree happen in place. Only if the whole node that is being visited is replaced, + the visit method returns the replacement and not none. + - Even if a visit method returns a replacement, modifications could have happened to the original dataflow tree. + - Care should be taken when using this visitor, as substitution can leave the dataflow tree in an invalid state. + For example a dereference UnaryOperation could be updated without the changes being reflected in its ArrayInfo. + Same with changes to Phi and its origin_block + """ + + @classmethod + def identity(cls, replacee: DataflowObject, replacement: DataflowObject) -> "SubstituteVisitor": + """ + Create a SubstituteVisitor instance for identity-based substitution. + + This class method creates a SubstituteVisitor instance that replaces nodes equal to the 'replacee' + parameter with the 'replacement' parameter based on identity comparison (is). + + Note: + While SubstituteVisitor.equality() creates copies of the specified replacement, this one does not! + Be careful as to not introduce the same dataflow object twice into the dataflow tree. + + :param replacee: The object to be replaced based on identity. + :param replacement: The object to replace 'replacee' with. + :return: A SubstituteVisitor instance for identity-based substitution. + """ + + return SubstituteVisitor(lambda o: replacement if o is replacee else None) + + @classmethod + def equality(cls, replacee: DataflowObject, replacement: DataflowObject) -> "SubstituteVisitor": + """ + Create a SubstituteVisitor instance for equality-based substitution. + + This class method creates a SubstituteVisitor instance that replaces nodes equal to the 'replacee' + parameter with the 'replacement' parameter based on equality comparison (==). + + Note: + This visitor creates copies of the specified replacement when substituting. + + :param replacee: The object to be replaced based on equality. + :param replacement: The object to replace 'replacee' with. + :return: A SubstituteVisitor instance for equality-based substitution. + """ + + return SubstituteVisitor(lambda o: replacement.copy() if o == replacee else None) + + def __init__(self, mapper: Callable[[DataflowObject], Optional[DataflowObject]]) -> None: + """ + Initialize a SubstituteVisitor instance. + + :param mapper: A callable object that takes a DataflowObject as input and returns an Optional[DataflowObject]. + This function is used to determine replacements for visited nodes. + """ + + self._mapper = mapper + + def visit_unknown_expression(self, expr: UnknownExpression) -> Optional[DataflowObject]: + return self._mapper(expr) + + def visit_constant(self, expr: Constant) -> Optional[DataflowObject]: + return self._mapper(expr) + + def visit_variable(self, expr: Variable) -> Optional[DataflowObject]: + return self._mapper(expr) + + def visit_register_pair(self, expr: RegisterPair) -> Optional[DataflowObject]: + if (low_replacement := expr.low.accept(self)) is not None: + expr._low = _assert_type(low_replacement, Variable) + + if (high_replacement := expr.high.accept(self)) is not None: + expr._high = _assert_type(high_replacement, Variable) + + return self._mapper(expr) + + def _visit_operation(self, op: Operation) -> Optional[DataflowObject]: + """Base visit function used for all operation related visit functions""" + for index, operand in enumerate(op.operands): + if (repl := operand.accept(self)) is not None: + op.operands[index] = _assert_type(repl, Expression) + + return self._mapper(op) + + def visit_list_operation(self, op: ListOperation) -> Optional[DataflowObject]: + return self._visit_operation(op) + + def _substitute_array_info(self, array_info: ArrayInfo): + if (base_replacement := array_info.base.accept(self)) is not None: + array_info.base = _assert_type(base_replacement, Variable) + + # array_info.index can either be Variable or int. Only try substituting if not int + if isinstance(array_info.index, Variable): + if (index_replacement := array_info.index.accept(self)) is not None: + array_info.index = _assert_type(index_replacement, Variable) + + def visit_unary_operation(self, op: UnaryOperation) -> Optional[DataflowObject]: + if op.array_info is not None: + self._substitute_array_info(op.array_info) + + return self._visit_operation(op) + + def visit_binary_operation(self, op: BinaryOperation) -> Optional[DataflowObject]: + return self._visit_operation(op) + + def visit_call(self, op: Call) -> Optional[DataflowObject]: + if (function_replacement := op.function.accept(self)) is not None: + op._function = _assert_type( + function_replacement, + Union[FunctionSymbol, ImportedFunctionSymbol, IntrinsicSymbol, Variable] + ) + + return self._visit_operation(op) + + def visit_condition(self, op: Condition) -> Optional[DataflowObject]: + return self._visit_operation(op) + + def visit_ternary_expression(self, op: TernaryExpression) -> Optional[DataflowObject]: + return self._visit_operation(op) + + def visit_comment(self, instr: Comment) -> Optional[DataflowObject]: + return self._mapper(instr) + + def visit_assignment(self, instr: Assignment) -> Optional[DataflowObject]: + if (value_replacement := instr.value.accept(self)) is not None: + instr._value = _assert_type(value_replacement, Expression) + if (destination_replacement := instr.destination.accept(self)) is not None: + instr._destination = _assert_type(destination_replacement, Expression) + + return self._mapper(instr) + + def visit_generic_branch(self, instr: GenericBranch) -> Optional[DataflowObject]: + if (condition_replacement := instr.condition.accept(self)) is not None: + instr._condition = _assert_type(condition_replacement, Expression) + + return self._mapper(instr) + + def visit_return(self, instr: Return) -> Optional[DataflowObject]: + if (values_replacement := instr.values.accept(self)) is not None: + instr._values = _assert_type(values_replacement, ListOperation) + + return self._mapper(instr) + + def visit_break(self, instr: Break) -> Optional[DataflowObject]: + return self._mapper(instr) + + def visit_continue(self, instr: Continue) -> Optional[DataflowObject]: + return self._mapper(instr) + + def _visit_phi_base(self, instr: Phi, value_type: type[DataflowObject]): + if (repl := instr.value.accept(self)) is not None: + # Phi only accepts ListOperation with 'value_type' as valid values + for operand in _assert_type(repl, ListOperation).operands: + _assert_type(operand, value_type) + + instr._value = repl + + for node, expression in instr.origin_block.items(): + if (replacement := expression.accept(self)) is not None: + instr.origin_block[node] = _assert_type(replacement, Union[Variable, Constant]) + + if (destination_replacement := instr.destination.accept(self)) is not None: + instr._destination = _assert_type(destination_replacement, Variable) + + return self._mapper(instr) + + def visit_phi(self, instr: Phi) -> Optional[DataflowObject]: + return self._visit_phi_base(instr, Union[Variable, Constant]) + + def visit_mem_phi(self, instr: MemPhi) -> Optional[DataflowObject]: + """We do not want substitute capabilities for MemPhi, since we remove it while preprocessing.""" + # return self._visit_phi_base(instr, Union[Variable]) + pass diff --git a/decompiler/util/default.json b/decompiler/util/default.json index ed48b1984..5ae9ba0b5 100644 --- a/decompiler/util/default.json +++ b/decompiler/util/default.json @@ -581,6 +581,16 @@ "is_hidden_from_gui": true, "is_hidden_from_cli": false, "argument_name": "--loop_break_in_cases" + }, + { + "dest": "expression-simplification.max_iterations", + "default": 10000, + "type": "number", + "title": "The maximum number of iterations any rule set in the expression simplification stage is allowed to take", + "description": "Stop simplifying with a rule set after this number of iterations is exceeded, even if more possible simplifications are possible", + "is_hidden_from_gui": false, + "is_hidden_from_cli": false, + "argument_name": "--max_expression_simplification_iterations" } ] }, @@ -599,7 +609,7 @@ "dead-code-elimination", "expression-propagation-memory", "expression-propagation-function-call", - "expression-simplification", + "expression-simplification-cfg", "dead-code-elimination", "redundant-casts-elimination", "identity-elimination", @@ -619,7 +629,7 @@ "dest": "pipeline.ast_stages", "default": [ "readability-based-refinement", - "expression-simplification", + "expression-simplification-ast", "instruction-length-handler", "variable-name-generation" ], diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py new file mode 100644 index 000000000..d7300107f --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_add_neg import CollapseAddNeg +from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation, Variable + +var_x = Variable("x") +var_y = Variable("y") + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.plus, [var_x, UnaryOperation(OperationType.negate, [var_y])]), + [BinaryOperation(OperationType.minus, [var_x, var_y])], + ), + ( + BinaryOperation(OperationType.minus, [var_x, UnaryOperation(OperationType.negate, [var_y])]), + [BinaryOperation(OperationType.plus, [var_x, var_y])], + ), + ], +) +def test_collapse_add_neg(operation: Operation, result: list[Expression]): + assert CollapseAddNeg().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py new file mode 100644 index 000000000..1a4a5febf --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_constants import CollapseConstants +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Float, Integer, Operation, OperationType, Variable + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _c_float(value: float) -> Constant: + return Constant(value, Float.float()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.plus, [_c_i32(3), _c_i32(4)]), [_c_i32(7)]), + (BinaryOperation(OperationType.plus, [_c_i32(3), Variable("x")]), []), + (BinaryOperation(OperationType.plus_float, [_c_float(3.0), _c_float(4.0)]), []), + ], +) +def test_collapse_constant(operation: Operation, result: list[Expression]): + assert CollapseConstants().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_nested_constants.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_nested_constants.py new file mode 100644 index 000000000..897c58472 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_nested_constants.py @@ -0,0 +1,149 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_nested_constants import CollapseNestedConstants +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable +from decompiler.structures.visitors.substitute_visitor import SubstituteVisitor + + +def _var_i32(name: str) -> Variable: + return Variable(name, Integer.int32_t()) + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _plus(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.plus, [e0, e1]) + + +def _mul(e: Expression, factor: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.multiply, [e, factor]) + + +def _mul_us(e: Expression, factor: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.multiply_us, [e, factor]) + + +def _bit_and(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_and, [e0, e1]) + + +def _bit_xor(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_xor, [e0, e1]) + + +def _bit_or(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_or, [e0, e1]) + + +@pytest.mark.parametrize( + ["operation", "possible_results"], + [ + ( # plus + _plus(_plus(_c_i32(7), _c_i32(11)), _c_i32(42)), + { + _plus(_plus(_c_i32(0), _c_i32(0)), _c_i32(60)), + _plus(_plus(_c_i32(0), _c_i32(60)), _c_i32(0)), + _plus(_plus(_c_i32(60), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _plus(_plus(_var_i32("a"), _c_i32(2)), _plus(_var_i32("b"), _c_i32(3))), + { + _plus(_plus(_var_i32("a"), _c_i32(5)), _plus(_var_i32("b"), _c_i32(0))), + _plus(_plus(_var_i32("a"), _c_i32(0)), _plus(_var_i32("b"), _c_i32(5))), + }, + ), + ( # multiply + _mul(_mul(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _mul(_mul(_c_i32(1), _c_i32(1)), _c_i32(154)), + _mul(_mul(_c_i32(1), _c_i32(154)), _c_i32(1)), + _mul(_mul(_c_i32(154), _c_i32(1)), _c_i32(1)), + }, + ), + ( + _mul(_mul(_var_i32("a"), _c_i32(2)), _mul(_var_i32("b"), _c_i32(3))), + { + _mul(_mul(_var_i32("a"), _c_i32(6)), _mul(_var_i32("b"), _c_i32(1))), + _mul(_mul(_var_i32("a"), _c_i32(1)), _mul(_var_i32("b"), _c_i32(6))), + }, + ), + ( # multiply_us + _mul_us(_mul_us(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _mul_us(_mul_us(_c_i32(1), _c_i32(1)), _c_i32(154)), + _mul_us(_mul_us(_c_i32(1), _c_i32(154)), _c_i32(1)), + _mul_us(_mul_us(_c_i32(154), _c_i32(1)), _c_i32(1)), + }, + ), + ( + _mul_us(_mul_us(_var_i32("a"), _c_i32(2)), _mul_us(_var_i32("b"), _c_i32(3))), + { + _mul_us(_mul_us(_var_i32("a"), _c_i32(6)), _mul_us(_var_i32("b"), _c_i32(1))), + _mul_us(_mul_us(_var_i32("a"), _c_i32(1)), _mul_us(_var_i32("b"), _c_i32(6))), + }, + ), + ( # bitwise_and + _bit_and(_bit_and(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_and(_bit_and(_c_i32(-1), _c_i32(-1)), _c_i32(2)), + _bit_and(_bit_and(_c_i32(-1), _c_i32(2)), _c_i32(-1)), + _bit_and(_bit_and(_c_i32(2), _c_i32(-1)), _c_i32(-1)), + }, + ), + ( + _bit_and(_bit_and(_var_i32("a"), _c_i32(2)), _bit_and(_var_i32("b"), _c_i32(3))), + { + _bit_and(_bit_and(_var_i32("a"), _c_i32(2)), _bit_and(_var_i32("b"), _c_i32(-1))), + _bit_and(_bit_and(_var_i32("a"), _c_i32(-1)), _bit_and(_var_i32("b"), _c_i32(2))), + }, + ), + ( # bitwise_xor + _bit_xor(_bit_xor(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_xor(_bit_xor(_c_i32(0), _c_i32(0)), _c_i32(14)), + _bit_xor(_bit_xor(_c_i32(0), _c_i32(14)), _c_i32(0)), + _bit_xor(_bit_xor(_c_i32(14), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(2)), _bit_xor(_var_i32("b"), _c_i32(3))), + { + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(1)), _bit_xor(_var_i32("b"), _c_i32(0))), + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(0)), _bit_xor(_var_i32("b"), _c_i32(1))), + }, + ), + ( # bitwise_or + _bit_or(_bit_or(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_or(_bit_or(_c_i32(0), _c_i32(0)), _c_i32(15)), + _bit_or(_bit_or(_c_i32(0), _c_i32(15)), _c_i32(0)), + _bit_or(_bit_or(_c_i32(15), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _bit_or(_bit_or(_var_i32("a"), _c_i32(2)), _bit_or(_var_i32("b"), _c_i32(3))), + { + _bit_or(_bit_or(_var_i32("a"), _c_i32(3)), _bit_or(_var_i32("b"), _c_i32(0))), + _bit_or(_bit_or(_var_i32("a"), _c_i32(0)), _bit_or(_var_i32("b"), _c_i32(3))), + }, + ), + ], +) +def test_collect_terms(operation: Operation, possible_results: set[Expression]): + collect_terms = CollapseNestedConstants() + + for i in range(100): + substitutions = collect_terms.apply(operation) + if not substitutions: + break + + for replacee, replacement in substitutions: + new_operation = operation.accept(SubstituteVisitor.identity(replacee, replacement)) + if new_operation is not None: + operation = new_operation + else: + raise RuntimeError("Max iterations exceeded") + + assert operation in possible_results diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py new file mode 100644 index 000000000..8142753f9 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py @@ -0,0 +1,36 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.positive_constants import PositiveConstants +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var_x_i = Variable("x", Integer.int32_t()) +var_x_u = Variable("x", Integer.uint32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.minus, [var_x_i, (Constant(-3, Integer.int32_t()))]), + [BinaryOperation(OperationType.plus, [var_x_i, Constant(3, Integer.int32_t())])], + ), + ( + BinaryOperation(OperationType.plus, [var_x_i, (Constant(-3, Integer.int32_t()))]), + [BinaryOperation(OperationType.minus, [var_x_i, Constant(3, Integer.int32_t())])], + ), + (BinaryOperation(OperationType.plus, [var_x_i, (Constant(3, Integer.int32_t()))]), []), + (BinaryOperation(OperationType.minus, [var_x_i, (Constant(3, Integer.int32_t()))]), []), + + ( + BinaryOperation(OperationType.minus, [var_x_u, (Constant(4294967293, Integer.uint32_t()))]), + [BinaryOperation(OperationType.plus, [var_x_u, Constant(3, Integer.uint32_t())])], + ), + ( + BinaryOperation(OperationType.plus, [var_x_u, (Constant(4294967293, Integer.uint32_t()))]), + [BinaryOperation(OperationType.minus, [var_x_u, Constant(3, Integer.uint32_t())])], + ), + (BinaryOperation(OperationType.plus, [var_x_u, (Constant(3, Integer.uint32_t()))]), []), + (BinaryOperation(OperationType.minus, [var_x_u, (Constant(3, Integer.uint32_t()))]), []), + ], +) +def test_fix_add_sub_sign(operation: Operation, result: list[Expression]): + assert PositiveConstants().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py new file mode 100644 index 000000000..f269d5ec9 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py @@ -0,0 +1,15 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_redundant_reference import SimplifyRedundantReference +from decompiler.structures.pseudo import Expression, Operation, OperationType, UnaryOperation, Variable + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [var := Variable("x")])]), [var]), + (UnaryOperation(OperationType.address, [Variable("x")]), []), + (UnaryOperation(OperationType.dereference, [Variable("x")]), []), + ], +) +def test_simplify_redundant_reference(operation: Operation, result: list[Expression]): + assert SimplifyRedundantReference().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py new file mode 100644 index 000000000..14faaa795 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py @@ -0,0 +1,29 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_arithmetic import SimplifyTrivialArithmetic +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, UnaryOperation, Variable + +var = Variable("x") +con_0 = Constant(0, Integer.int32_t()) +con_1 = Constant(1, Integer.int32_t()) +con_neg1 = Constant(-1, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.plus, [var, con_0]), [var]), + (BinaryOperation(OperationType.minus, [var, con_0]), [var]), + (BinaryOperation(OperationType.multiply, [var, con_0]), [con_0]), + (BinaryOperation(OperationType.multiply_us, [var, con_0]), [con_0]), + (BinaryOperation(OperationType.multiply, [var, con_1]), [var]), + (BinaryOperation(OperationType.multiply_us, [var, con_1]), [var]), + (BinaryOperation(OperationType.multiply, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.multiply_us, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.divide, [var, con_1]), [var]), + (BinaryOperation(OperationType.divide_us, [var, con_1]), [var]), + (BinaryOperation(OperationType.divide, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.divide_us, [var, con_neg1]), []), + ], +) +def test_simplify_trivial_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py new file mode 100644 index 000000000..e42771f66 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_bit_arithmetic import ( + SimplifyTrivialBitArithmetic, +) +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var = Variable("x", Integer.int32_t()) +con_0 = Constant(0, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.bitwise_or, [var, con_0]), [var]), + (BinaryOperation(OperationType.bitwise_or, [var, var]), [var]), + (BinaryOperation(OperationType.bitwise_and, [var, con_0]), [con_0]), + (BinaryOperation(OperationType.bitwise_and, [var, var]), [var]), + (BinaryOperation(OperationType.bitwise_xor, [var, con_0]), [var]), + (BinaryOperation(OperationType.bitwise_xor, [var, var]), [con_0]), + ], +) +def test_simplify_trivial_bit_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialBitArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py new file mode 100644 index 000000000..8d74ff8b1 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py @@ -0,0 +1,22 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_logic_arithmetic import ( + SimplifyTrivialLogicArithmetic, +) +from decompiler.structures.pseudo import BinaryOperation, Constant, CustomType, Expression, Operation, OperationType, Variable + +var = Variable("x", CustomType.bool()) +con_false = Constant(0, CustomType.bool()) +con_true = Constant(1, CustomType.bool()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.logical_or, [var, con_false]), [var]), + (BinaryOperation(OperationType.logical_or, [var, con_true]), [con_true]), + (BinaryOperation(OperationType.logical_and, [var, con_false]), [con_false]), + (BinaryOperation(OperationType.logical_and, [var, con_true]), [var]), + ], +) +def test_simplify_trivial_logic_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialLogicArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py new file mode 100644 index 000000000..a281caa33 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py @@ -0,0 +1,20 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_shift import SimplifyTrivialShift +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var = Variable("x") +con_0 = Constant(0, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.left_shift, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_shift, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_shift_us, [var, con_0]), [var]), + (BinaryOperation(OperationType.left_rotate, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_rotate, [var, con_0]), [var]), + ], +) +def test_simplify_trivial_shift(operation: Operation, result: list[Expression]): + assert SimplifyTrivialShift().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py new file mode 100644 index 000000000..d171aeb93 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py @@ -0,0 +1,20 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.sub_to_add import SubToAdd +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, UnaryOperation, Variable + +var_x = Variable("x", Integer.int32_t()) +var_y = Variable("y", Integer.int32_t()) +con_neg1 = Constant(-1, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.minus, [var_x, var_y]), + [BinaryOperation(OperationType.plus, [var_x, UnaryOperation(OperationType.negate, [var_y])])], + ), + ], +) +def test_sub_to_add(operation: Operation, result: list[Expression]): + assert SubToAdd().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py new file mode 100644 index 000000000..5846004e9 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py @@ -0,0 +1,15 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.term_order import TermOrder +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, Variable +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + +var = Variable("x") +con = Constant(42, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [(BinaryOperation(operation, [con, var]), [BinaryOperation(operation, [var, con])]) for operation in COMMUTATIVE_OPERATIONS], +) +def test_term_order(operation: Operation, result: list[Expression]): + assert TermOrder().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py new file mode 100644 index 000000000..24f88ed93 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py @@ -0,0 +1,138 @@ +from contextlib import nullcontext + +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold +from decompiler.structures.pseudo import Constant, Float, Integer, OperationType + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _c_u32(value: int) -> Constant: + return Constant(value, Integer.uint32_t()) + + +def _c_i16(value: int) -> Constant: + return Constant(value, Integer.int16_t()) + + +def _c_float(value: float) -> Constant: + return Constant(value, Float.float()) + + +@pytest.mark.parametrize( + ["operation"], + [(operation,) for operation in OperationType if operation not in FOLDABLE_OPERATIONS] +) +def test_constant_fold_invalid_operations(operation: OperationType): + with pytest.raises(ValueError): + constant_fold(operation, []) + + +@pytest.mark.parametrize( + ["operation", "constants", "result", "context"], + [ + (OperationType.plus, [_c_i32(3), _c_i32(4)], _c_i32(7), nullcontext()), + (OperationType.plus, [_c_i32(2147483647), _c_i32(1)], _c_i32(-2147483648), nullcontext()), + (OperationType.plus, [_c_u32(2147483658), _c_u32(2147483652)], _c_u32(14), nullcontext()), + (OperationType.plus, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.minus, [_c_i32(3), _c_i32(4)], _c_i32(-1), nullcontext()), + (OperationType.minus, [_c_i32(-2147483648), _c_i32(1)], _c_i32(2147483647), nullcontext()), + (OperationType.minus, [_c_u32(3), _c_u32(4)], _c_u32(4294967295), nullcontext()), + (OperationType.minus, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.multiply, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), + (OperationType.multiply, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), + (OperationType.multiply, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), + (OperationType.multiply, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.multiply_us, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), + (OperationType.multiply_us, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), + (OperationType.multiply_us, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), + (OperationType.multiply_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.divide, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), + (OperationType.divide, [_c_i32(-2147483648), _c_i32(2)], _c_i32(-1073741824), nullcontext()), + (OperationType.divide, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.divide_us, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), + (OperationType.divide_us, [_c_i32(-2147483648), _c_i32(2)], _c_i32(1073741824), nullcontext()), + (OperationType.divide_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.negate, [_c_i32(3)], _c_i32(-3), nullcontext()), + (OperationType.negate, [_c_i32(-2147483648)], _c_i32(-2147483648), nullcontext()), + (OperationType.negate, [], None, pytest.raises(ValueError)), + (OperationType.negate, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.left_shift, [_c_i32(3), _c_i32(4)], _c_i32(48), nullcontext()), + (OperationType.left_shift, [_c_i32(1073741824), _c_i32(1)], _c_i32(-2147483648), nullcontext()), + (OperationType.left_shift, [_c_u32(1073741824), _c_u32(1)], _c_u32(2147483648), nullcontext()), + (OperationType.left_shift, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.left_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.right_shift, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), + (OperationType.right_shift, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-1073741824), nullcontext()), + (OperationType.right_shift, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), + (OperationType.right_shift, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.right_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.right_shift_us, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), + (OperationType.right_shift_us, [_c_i32(-2147483648), _c_i32(1)], _c_i32(1073741824), nullcontext()), + (OperationType.right_shift_us, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), + (OperationType.right_shift_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.right_shift_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_or, [_c_i32(85), _c_i32(34)], _c_i32(119), nullcontext()), + (OperationType.bitwise_or, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-2147483647), nullcontext()), + (OperationType.bitwise_or, [_c_u32(2147483648), _c_u32(1)], _c_u32(2147483649), nullcontext()), + (OperationType.bitwise_or, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_and, [_c_i32(85), _c_i32(51)], _c_i32(17), nullcontext()), + (OperationType.bitwise_and, [_c_i32(-2147483647), _c_i32(3)], _c_i32(1), nullcontext()), + (OperationType.bitwise_and, [_c_u32(2147483649), _c_u32(3)], _c_u32(1), nullcontext()), + (OperationType.bitwise_and, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_xor, [_c_i32(85), _c_i32(51)], _c_i32(102), nullcontext()), + (OperationType.bitwise_xor, [_c_i32(-2147483647), _c_i32(-2147483646)], _c_i32(3), nullcontext()), + (OperationType.bitwise_xor, [_c_u32(2147483649), _c_u32(2147483650)], _c_u32(3), nullcontext()), + (OperationType.bitwise_xor, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_not, [_c_i32(6)], _c_i32(-7), nullcontext()), + (OperationType.bitwise_not, [_c_i32(-2147483648)], _c_i32(2147483647), nullcontext()), + (OperationType.bitwise_not, [_c_u32(2147483648)], _c_u32(2147483647), nullcontext()), + (OperationType.bitwise_not, [], None, pytest.raises(ValueError)), + (OperationType.bitwise_not, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + ] +) +def test_constant_fold(operation: OperationType, constants: list[Constant], result: Constant, context): + with context: + assert constant_fold(operation, constants) == result diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py new file mode 100644 index 000000000..a9b74833c --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py @@ -0,0 +1,109 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_constants import CollapseConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_nested_constants import CollapseNestedConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_arithmetic import SimplifyTrivialArithmetic +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.sub_to_add import SubToAdd +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.term_order import TermOrder +from decompiler.pipeline.controlflowanalysis.expression_simplification.stages import _ExpressionSimplificationBase +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Constant, + Expression, + Instruction, + Integer, + Operation, + OperationType, + Variable, +) + + +class _RedundantChanges(SimplificationRule): + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + return [(operation, operation)] + + +class _NoChanges(SimplificationRule): + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + return [] + + +def _add(left: Expression, right: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.plus, [left, right]) + + +def _sub(left: Expression, right: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.minus, [left, right]) + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _v_i32(name: str) -> Variable: + return Variable(name, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["rule_set", "instruction", "expected_result"], + [ + ( + [TermOrder()], + Assignment(_v_i32("a"), _add(_c_i32(1), _v_i32("b"))), + Assignment(_v_i32("a"), _add(_v_i32("b"), _c_i32(1))) + ), + ( + [CollapseConstants()], + Assignment(_v_i32("a"), _sub(_c_i32(10), _add(_c_i32(3), _c_i32(2)))), + Assignment(_v_i32("a"), _c_i32(5)) + ), + ( + [SubToAdd(), SimplifyTrivialArithmetic(), CollapseConstants(), CollapseNestedConstants()], + Assignment(_v_i32("a"), _sub(_add(_v_i32("a"), _c_i32(5)), _c_i32(5))), + Assignment(_v_i32("a"), _v_i32("a")) + ), + ] +) +def test_simplify_instructions_with_rule_set( + rule_set: list[SimplificationRule], + instruction: Instruction, + expected_result: Instruction +): + _ExpressionSimplificationBase._simplify_instructions_with_rule_set( + [instruction], + rule_set, + 100 + ) + assert instruction == expected_result + + +@pytest.mark.parametrize( + ["rule_set", "instruction", "max_iterations", "expect_exceed_max_iterations"], + [ + ( + [_RedundantChanges()], + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Constant(1), Variable("b")])), + 10, + True + ), + ( + [_NoChanges()], + Assignment(_v_i32("a"), _v_i32("b")), + 0, + False + ) + ] +) +def test_simplify_instructions_with_rule_set_max_iterations( + rule_set: list[SimplificationRule], + instruction: Instruction, + max_iterations: int, + expect_exceed_max_iterations: bool +): + iterations = _ExpressionSimplificationBase._simplify_instructions_with_rule_set( + [instruction], + rule_set, + max_iterations + ) + assert (iterations > max_iterations) == expect_exceed_max_iterations diff --git a/tests/pipeline/controlflowanalysis/test_expression_simplification.py b/tests/pipeline/controlflowanalysis/test_expression_simplification.py deleted file mode 100644 index 11a9a7042..000000000 --- a/tests/pipeline/controlflowanalysis/test_expression_simplification.py +++ /dev/null @@ -1,244 +0,0 @@ -from typing import Optional - -import pytest -from decompiler.pipeline.controlflowanalysis import ExpressionSimplification -from decompiler.structures.ast.ast_nodes import CodeNode -from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree -from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph -from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.structures.pseudo.expressions import Constant, ImportedFunctionSymbol, Variable -from decompiler.structures.pseudo.instructions import Assignment, Instruction, Return -from decompiler.structures.pseudo.operations import BinaryOperation, Call, ListOperation, OperationType, UnaryOperation -from decompiler.structures.pseudo.typing import Integer -from decompiler.task import DecompilerTask - - -def _task(ast: Optional[AbstractSyntaxTree] = None, cfg: Optional[ControlFlowGraph] = None) -> DecompilerTask: - cfg = ControlFlowGraph() if cfg is None else cfg - task = DecompilerTask("test_function", cfg, ast) - return task - - -x = Variable("x") -y = Variable("y") -const_0 = Constant(0, Integer(32, signed=True)) -const_m0 = Constant(-0, Integer(32, signed=True)) -const_1 = Constant(1, Integer(32, signed=True)) -const_m1 = Constant(-1, Integer(32, signed=True)) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.plus, [x, const_0])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.plus, [const_0, x])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.plus, [const_0, const_0])), Assignment(y, const_0)), - (Assignment(y, BinaryOperation(OperationType.plus, [x, const_m0])), Assignment(y, x)), - ( - Assignment(y, BinaryOperation(OperationType.plus, [x, BinaryOperation(OperationType.plus, [const_0, const_0])])), - Assignment(y, x), - ), - ( - Assignment(y, BinaryOperation(OperationType.plus, [const_0, BinaryOperation(OperationType.plus, [const_0, x])])), - Assignment(y, x), - ), - ], -) -def test_easy_simplification_with_zero_addition(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.multiply, [x, const_0])), Assignment(y, const_0)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_0, x])), Assignment(y, const_0)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_0, const_0])), Assignment(y, const_0)), - (Assignment(y, BinaryOperation(OperationType.multiply, [x, const_m0])), Assignment(y, const_0)), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [x, BinaryOperation(OperationType.multiply, [x, const_0])])), - Assignment(y, const_0), - ), - ], -) -def test_simplification_with_zero_multiplication(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.minus, (x, const_0))), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.minus, (const_0, x))), Assignment(y, UnaryOperation(OperationType.negate, [x]))), - (Assignment(y, BinaryOperation(OperationType.minus, [const_0, UnaryOperation(OperationType.negate, [x])])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.minus, [const_0, const_m1])), Assignment(y, const_1)), - (Assignment(y, BinaryOperation(OperationType.minus, (x, const_m0))), Assignment(y, x)), - ], -) -def test_simplification_with_zero_subtraction(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - ( - Assignment(y, BinaryOperation(OperationType.plus, [x, BinaryOperation(OperationType.multiply, [x, const_0])])), - Assignment(y, x), - ), - ( - Assignment(y, BinaryOperation(OperationType.minus, [y, BinaryOperation(OperationType.minus, [x, const_0])])), - Assignment(y, BinaryOperation(OperationType.minus, [y, x])), - ), - ], -) -def test_simplification_with_zero_mix(instruction, result): - cfg = ControlFlowGraph() - cfg.add_node(BasicBlock(1, [instruction])) - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict()), cfg) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.multiply, [x, const_1])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_1, x])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_1, const_1])), Assignment(y, const_1)), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [x, const_m1])), - Assignment(y, UnaryOperation(OperationType.negate, [x])), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [Constant(2, Integer(32, signed=False)), const_m1])), - Assignment(y, UnaryOperation(OperationType.negate, [Constant(2, Integer(32, signed=False))])), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [Constant(2, Integer(32, signed=True)), const_m1])), - Assignment(y, Constant(-2, Integer(32, signed=True))), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [Constant(-2, Integer(32, signed=True)), const_m1])), - Assignment(y, Constant(2, Integer(32, signed=True))), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [x, BinaryOperation(OperationType.multiply, [const_1, const_1])])), - Assignment(y, x), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [const_1, BinaryOperation(OperationType.multiply, [x, const_1])])), - Assignment(y, x), - ), - ], -) -def test_simplification_with_one_multiplication(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.divide, [x, const_1])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_1, const_1])), Assignment(y, const_1)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_m1, const_1])), Assignment(y, const_m1)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_1, const_m1])), Assignment(y, const_m1)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_m1, const_m1])), Assignment(y, const_1)), - (Assignment(y, BinaryOperation(OperationType.divide, [x, const_m1])), Assignment(y, UnaryOperation(OperationType.negate, [x]))), - ( - Assignment(y, BinaryOperation(OperationType.divide, [const_0, const_1])), - Assignment(y, const_0), - ), - ( - Assignment(y, BinaryOperation(OperationType.divide, [const_0, const_m1])), - Assignment(y, const_0), - ), - ( - Assignment(y, BinaryOperation(OperationType.divide, [Constant(2, Integer(32, signed=False)), const_m1])), - Assignment(y, UnaryOperation(OperationType.negate, [Constant(2, Integer(32, signed=False))])), - ), - ], -) -def test_simplification_with_one_division(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - ( - Assignment(UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("x")])]), Constant(0)), - Assignment(Variable("x"), Constant(0)), - ), - ( - Assignment( - ListOperation([]), - Call( - ImportedFunctionSymbol("foo", 0x42), - [ - BinaryOperation( - OperationType.minus, - [ - UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("x")])]), - Constant(2), - ], - ) - ], - ), - ), - Assignment( - ListOperation([]), - Call(ImportedFunctionSymbol("foo", 0x42), [BinaryOperation(OperationType.minus, [Variable("x"), Constant(2)])]), - ), - ), - ( - Return([UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("y")])])]), - Return([Variable("y")]), - ), - ], -) -def test_simplification_of_dereference_operations(instruction: Instruction, result: Instruction): - """Check if dereference operations with address-of operands are simplified correctly.""" - ExpressionSimplification().simplify(instruction) - assert instruction == result - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.divide, [x, const_1])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_1, const_1])), Assignment(y, const_1)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_1, const_1])), Assignment(y, const_1)), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [x, const_m1])), - Assignment(y, UnaryOperation(OperationType.negate, [x])), - ), - ( - Assignment(y, BinaryOperation(OperationType.plus, [x, BinaryOperation(OperationType.multiply, [x, const_0])])), - Assignment(y, x), - ), - (Assignment(y, BinaryOperation(OperationType.plus, [x, const_m0])), Assignment(y, x)), - ], -) -def test_for_cfg(instruction, result): - cfg = ControlFlowGraph() - cfg.add_node(BasicBlock(0, [instruction])) - task = _task(cfg=cfg) - ExpressionSimplification().run(task) - assert list(task.graph.instructions) == [result] diff --git a/tests/structures/visitors/test_substitute_visitor.py b/tests/structures/visitors/test_substitute_visitor.py new file mode 100644 index 000000000..89a874081 --- /dev/null +++ b/tests/structures/visitors/test_substitute_visitor.py @@ -0,0 +1,197 @@ +import pytest +from decompiler.structures.graphs.basicblock import BasicBlock +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Branch, + Call, + Constant, + DataflowObject, + Integer, + Phi, + Pointer, + RegisterPair, + Return, + UnaryOperation, + Variable, +) +from decompiler.structures.pseudo.operations import ArrayInfo, Condition, OperationType +from decompiler.structures.visitors.substitute_visitor import SubstituteVisitor + +_i32 = Integer.int32_t() +_p_i32 = Pointer(Integer.int32_t()) + +_a = Variable("a", Integer.int32_t(), 0) +_b = Variable("b", Integer.int32_t(), 1) +_c = Variable("c", Integer.int32_t(), 2) +_d = Variable("d", Integer.int32_t(), 3) + + +@pytest.mark.parametrize( + ["initial_obj", "expected_result", "visitor"], + [ + ( + o := Variable("v", _i32, 0), + r := Variable("x", _i32, 1), + SubstituteVisitor.identity(o, r) + ), + ( + o := Variable("v", _i32, 0), + r := Variable("x", _i32, 1), + SubstituteVisitor.equality(o, r) + ), + ( + o := Variable("v", _i32, 0), + o, + SubstituteVisitor.identity(Variable("v", _i32, 0), Variable("x", _i32, 1)) + ), + ( + o := Variable("v", _i32, 0), + r := Variable("x", _i32, 1), + SubstituteVisitor.equality(Variable("v", _i32, 0), r) + ), + ( + Assignment(a := Variable("a"), b := Variable("b")), + Assignment(a, c := Variable("c")), + SubstituteVisitor.identity(b, c) + ), + ( + Assignment(a := Variable("a"), b := Variable("b")), + Assignment(c := Variable("c"), b), + SubstituteVisitor.identity(a, c) + ), + ( + UnaryOperation(OperationType.dereference, [a := Variable("a")]), + UnaryOperation(OperationType.dereference, [b := Variable("b")]), + SubstituteVisitor.identity(a, b) + ), + ( + UnaryOperation( + OperationType.dereference, + [BinaryOperation(OperationType.plus, [a := Variable("a", _p_i32), Constant(4, _i32)])], + array_info=ArrayInfo(a, 1) + ), + UnaryOperation( + OperationType.dereference, + [BinaryOperation(OperationType.plus, [b := Variable("b", _p_i32), Constant(4, _i32)])], + array_info=ArrayInfo(b, 1) + ), + SubstituteVisitor.identity(a, b) + ), + ( + UnaryOperation( + OperationType.dereference, + [BinaryOperation( + OperationType.plus, + [ + a := Variable("a", _p_i32), + BinaryOperation(OperationType.multiply, [b := Variable("b", _i32), Constant(4, _i32)]) + ] + )], + array_info=ArrayInfo(a, b) + ), + UnaryOperation( + OperationType.dereference, + [BinaryOperation( + OperationType.plus, + [ + a := Variable("a", _p_i32), + BinaryOperation(OperationType.multiply, [c := Variable("c", _i32), Constant(4, _i32)]) + ] + )], + array_info=ArrayInfo(a, c) + ), + SubstituteVisitor.identity(b, c) + ), + ( + BinaryOperation(OperationType.multiply, [a := Variable("a"), b := Variable("b")]), + BinaryOperation(OperationType.multiply, [a, c := Variable("c")]), + SubstituteVisitor.identity(b, c) + ), + ( + RegisterPair(a := Variable("a"), b := Variable("b")), + RegisterPair(a, c := Variable("c")), + SubstituteVisitor.identity(b, c) + ), + ( + Call(f := Variable("f"), [a := Variable("a")]), + Call(f, [b := Variable("b")]), + SubstituteVisitor.identity(a, b) + ), + ( + Call(f := Variable("f"), [a := Variable("a")]), + Call(g := Variable("g"), [a]), + SubstituteVisitor.identity(f, g) + ), + ( + Phi( + a3 := Variable("a", _i32, 3), + [ + a2 := Variable("a", _i32, 2), + a1 := Variable("a", _i32, 1) + ], + { + BasicBlock(2): a2, + BasicBlock(1): a1, + } + ), + Phi( + a3, + [ + a2, + a0 := Variable("a", _i32, 0) + ], + { + BasicBlock(2): a2, + BasicBlock(1): a0, + } + ), + SubstituteVisitor.identity(a1, a0) + ), + ( + Phi( + a3 := Variable("a", _i32, 3), + [ + a2 := Variable("a", _i32, 2), + a1 := Variable("a", _i32, 1) + ], + { + BasicBlock(2): a2, + BasicBlock(1): a1, + } + ), + Phi( + a4 := Variable("a", _i32, 4), + [ + a2, + a1 + ], + { + BasicBlock(2): a2, + BasicBlock(1): a1, + } + ), + SubstituteVisitor.identity(a3, a4) + ), + ( + Branch(a := Condition(OperationType.equal, [])), + Branch(b := Condition(OperationType.not_equal, [])), + SubstituteVisitor.identity(a, b) + ), + ( + Return([a := Variable("a")]), + Return([b := Variable("b")]), + SubstituteVisitor.identity(a, b) + ), + ] +) +def test_substitute(initial_obj: DataflowObject, expected_result: DataflowObject, visitor: SubstituteVisitor): + result = initial_obj.accept(visitor) + if result is None: + result = initial_obj + + assert result == expected_result + + # if expected result is Phi also test for origin_block equality, as that is not covered by object equality + if isinstance(expected_result, Phi): + assert result.origin_block == expected_result.origin_block