diff --git a/decompiler/pipeline/controlflowanalysis/__init__.py b/decompiler/pipeline/controlflowanalysis/__init__.py index ebbc4b673..cf0a9593b 100644 --- a/decompiler/pipeline/controlflowanalysis/__init__.py +++ b/decompiler/pipeline/controlflowanalysis/__init__.py @@ -1,4 +1,5 @@ -from .expression_simplification import ExpressionSimplification +from .expression_simplification.stages import ExpressionSimplificationAst, ExpressionSimplificationCfg from .instruction_length_handler import InstructionLengthHandler +from .loop_name_generator import LoopNameGenerator from .readability_based_refinement import ReadabilityBasedRefinement from .variable_name_generation import VariableNameGeneration diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification.py b/decompiler/pipeline/controlflowanalysis/expression_simplification.py deleted file mode 100644 index 5064e58a3..000000000 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Module implementing basic simplifications for expressions.""" -from typing import Optional - -from decompiler.pipeline.stage import PipelineStage -from decompiler.structures.ast.ast_nodes import CodeNode -from decompiler.structures.pseudo.expressions import Constant, Expression -from decompiler.structures.pseudo.instructions import Instruction -from decompiler.structures.pseudo.operations import BinaryOperation, Operation, OperationType, UnaryOperation -from decompiler.structures.pseudo.typing import Float, Integer -from decompiler.task import DecompilerTask - - -class ExpressionSimplification(PipelineStage): - """The ExpressionSimplification makes various simplifications to expressions on the AST, like a + 0 = a.""" - - name = "expression-simplification" - - def __init__(self): - self.HANDLERS = { - OperationType.plus: self._simplify_addition, - OperationType.minus: self._simplify_subtraction, - OperationType.multiply: self._simplify_multiplication, - OperationType.divide: self._simplify_division, - OperationType.divide_us: self._simplify_division, - OperationType.divide_float: self._simplify_division, - OperationType.dereference: self._simplify_dereference, - } - - def run(self, task: DecompilerTask): - """Run the task expression simplification on each instruction of the AST.""" - if task.syntax_tree is None: - for instruction in task.graph.instructions: - self.simplify(instruction) - else: - for node in task.syntax_tree.topological_order(): - if not isinstance(node, CodeNode): - continue - for instruction in node.instructions: - self.simplify(instruction) - - def simplify(self, instruction: Instruction): - """Simplify all subexpressions of the given instruction recursively.""" - todo = list(instruction) - while todo and (expression := todo.pop()): - if self.simplify_expression(expression, instruction): - todo = list(instruction) - else: - todo.extend(expression) - - def simplify_expression(self, expression: Expression, parent: Instruction) -> Optional[Expression]: - """Simplify the given instruction utilizing the registered OperationType handlers.""" - if isinstance(expression, Operation) and expression.operation in self.HANDLERS: - if simplified := self.HANDLERS[expression.operation](expression): - parent.substitute(expression, simplified) - return simplified - - def _simplify_addition(self, expression: BinaryOperation) -> Optional[Expression]: - """ - Simplifies the given addition in the given instruction. - - -> Simplifies a+0, 0+a, a-0 and -0 + a to a - """ - if any(self.is_zero_constant(zero := op) for op in expression.operands): - return self.get_other_operand(expression, zero).copy() - - def _simplify_subtraction(self, expression: BinaryOperation) -> Optional[Expression]: - """ - Simplifies the given subtraction in the given instruction. - - -> Simplifies a-0, a-(-0) to a - -> Simplifies 0-a, -0-a to -a - """ - if self.is_zero_constant(expression.operands[1]): - return expression.operands[0].copy() - if self.is_zero_constant(expression.operands[0]): - return self.negate_expression(expression.operands[1]) - - def _simplify_multiplication(self, expression: BinaryOperation) -> Optional[Expression]: - """ - Simplifies the given multiplication in the given instruction. - - -> Simplifies a*0, 0*a, a*(-0) and (-0) * a to 0 - -> Simplifies a*1, 1*a to a - -> Simplifies a*(-1), (-1)*a to -a - """ - if any(self.is_zero_constant(zero := op) for op in expression.operands): - return zero.copy() - if any(self.is_one_constant(one := op) for op in expression.operands): - return self.get_other_operand(expression, one).copy() - if any(self.is_minus_one_constant(minus_one := op) for op in expression.operands): - return self.negate_expression(self.get_other_operand(expression, minus_one)) - - def _simplify_division(self, expression: BinaryOperation) -> Optional[Expression]: - """ - Simplifies the given division in the given instruction. - - -> Simplifies a/1 to a and a/(-1) to -a - """ - if self.is_one_constant(expression.operands[1]): - return expression.operands[0].copy() - if self.is_minus_one_constant(expression.operands[1]): - return self.negate_expression(expression.operands[0]) - - def _simplify_dereference(self, expression: UnaryOperation) -> Optional[Expression]: - """ - Simplifies dereference expression with nested address-of expressions. - - -> Simplifies *(&(x)) to x - """ - if isinstance(expression.operand, UnaryOperation) and expression.operand.operation == OperationType.address: - return expression.operand.operand.copy() - - @staticmethod - def is_zero_constant(expression: Expression) -> bool: - """Checks whether the given expression is 0.""" - return isinstance(expression, Constant) and expression.value == 0 - - @staticmethod - def is_one_constant(expression: Expression) -> bool: - """Checks whether the given expression is 1.""" - return isinstance(expression, Constant) and expression.value == 1 - - @staticmethod - def is_minus_one_constant(expression: Expression) -> bool: - """Checks whether the given expression is -1.""" - return isinstance(expression, Constant) and expression.value == -1 - - @staticmethod - def negate_expression(expression: Expression) -> Expression: - """Negate the given expression and return it.""" - match expression: - case Constant(value=0): return expression - case UnaryOperation(operation=OperationType.negate): return expression.operand - case Constant(type=Integer(is_signed=True) | Float()): return Constant(-expression.value, expression.type) - case _: return UnaryOperation(OperationType.negate, [expression]) - - @staticmethod - def get_other_operand(binary_operation: BinaryOperation, expression: Expression) -> Expression: - """Returns the operand that is not equal to expression.""" - if binary_operation.operands[0] == expression: - return binary_operation.operands[1] - return binary_operation.operands[0] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py new file mode 100644 index 000000000..81c87498c --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py @@ -0,0 +1,145 @@ +import operator +from functools import partial +from typing import Callable, Optional + +from decompiler.structures.pseudo import Constant, Integer, OperationType + + +def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant: + """ + Fold operation with constants as operands. + + :param operation: The operation. + :param constants: All constant operands of the operation. + :return: A constant representing the result of the operation. + """ + + if operation not in _OPERATION_TO_FOLD_FUNCTION: + raise ValueError(f"Constant folding not implemented for operation '{operation}'.") + + return _OPERATION_TO_FOLD_FUNCTION[operation](constants) + + +def _constant_fold_arithmetic_binary( + constants: list[Constant], + fun: Callable[[int, int], int], + norm_sign: Optional[bool] = None +) -> Constant: + """ + Fold an arithmetic binary operation with constants as operands. + + :param constants: A list of exactly 2 constant operands. + :param fun: The binary function to perform on the constants. + :param norm_sign: Optional boolean flag to indicate if/how to normalize the input constants to 'fun': + - None (default): no normalization + - True: normalize inputs, interpreted as signed values + - False: normalize inputs, interpreted as unsigned values + :return: A constant representing the result of the operation. + """ + + if len(constants) != 2: + raise ValueError(f"Expected exactly 2 constants to fold, got {len(constants)}.") + if not all(constant.type == constants[0].type for constant in constants): + raise ValueError(f"Can not fold constants with different types: {(constant.type for constant in constants)}") + if not all(isinstance(constant.type, Integer) for constant in constants): + raise ValueError(f"All constants must be integers, got {list(constant.type for constant in constants)}.") + + left, right = constants + + left_value = left.value + right_value = right.value + if norm_sign is not None: + left_value = normalize_int(left_value, left.type.size, norm_sign) + right_value = normalize_int(right_value, right.type.size, norm_sign) + + return Constant( + normalize_int(fun(left_value, right_value), left.type.size, left.type.signed), + left.type + ) + + +def _constant_fold_arithmetic_unary(constants: list[Constant], fun: Callable[[int], int]) -> Constant: + """ + Fold an arithmetic unary operation with a constant operand. + + :param constants: A list containing a single constant operand. + :param fun: The unary function to perform on the constant. + :return: A constant representing the result of the operation. + """ + + if len(constants) != 1: + raise ValueError("Expected exactly 1 constant to fold") + if not isinstance(constants[0].type, Integer): + raise ValueError(f"Constant must be of type integer: {constants[0].type}") + + return Constant(normalize_int(fun(constants[0].value), constants[0].type.size, constants[0].type.signed), constants[0].type) + + +def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], int], signed: bool) -> Constant: + """ + Fold a shift operation with constants as operands. + + :param constants: A list of exactly 2 constant operands. + :param fun: The shift function to perform on the constants. + :param signed: Boolean flag indicating whether the shift is signed. + This is used to normalize the sign of the input constant to simulate unsigned shifts. + :return: A constant representing the result of the operation. + """ + + if len(constants) != 2: + raise ValueError("Expected exactly 2 constants to fold") + if not all(isinstance(constant.type, Integer) for constant in constants): + raise ValueError("All constants must be integers") + + left, right = constants + + shifted_value = fun( + normalize_int(left.value, left.type.size, left.type.signed and signed), + right.value + ) + return Constant( + normalize_int(shifted_value, left.type.size, left.type.signed), + left.type + ) + + +def normalize_int(v: int, size: int, signed: bool) -> int: + """ + Normalizes an integer value to a specific size and signedness. + + This function takes an integer value 'v' and normalizes it to fit within + the specified 'size' in bits by discarding overflowing bits. If 'signed' is + true, the value is treated as a signed integer, i.e. interpreted as a two's complement. + Therefore the return value will be negative iff 'signed' is true and the most-significant bit is set. + + :param v: The value to be normalized. + :param size: The desired bit size for the normalized integer. + :param signed: True if the integer should be treated as signed. + :return: The normalized integer value. + """ + value = v & ((1 << size) - 1) + if signed and value & (1 << (size - 1)): + return value - (1 << size) + else: + return value + + +_OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = { + OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub), + OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add), + OperationType.multiply: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=True), + OperationType.multiply_us: partial(_constant_fold_arithmetic_binary, fun=operator.mul, norm_sign=False), + OperationType.divide: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=True), + OperationType.divide_us: partial(_constant_fold_arithmetic_binary, fun=operator.floordiv, norm_sign=False), + OperationType.negate: partial(_constant_fold_arithmetic_unary, fun=operator.neg), + OperationType.left_shift: partial(_constant_fold_shift, fun=operator.lshift, signed=True), + OperationType.right_shift: partial(_constant_fold_shift, fun=operator.rshift, signed=True), + OperationType.right_shift_us: partial(_constant_fold_shift, fun=operator.rshift, signed=False), + OperationType.bitwise_or: partial(_constant_fold_arithmetic_binary, fun=operator.or_), + OperationType.bitwise_and: partial(_constant_fold_arithmetic_binary, fun=operator.and_), + OperationType.bitwise_xor: partial(_constant_fold_arithmetic_binary, fun=operator.xor), + OperationType.bitwise_not: partial(_constant_fold_arithmetic_unary, fun=operator.inv), +} + + +FOLDABLE_OPERATIONS = _OPERATION_TO_FOLD_FUNCTION.keys() diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py new file mode 100644 index 000000000..c978991fc --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_add_neg.py @@ -0,0 +1,30 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation + + +class CollapseAddNeg(SimplificationRule): + """ + Simplifies additions/subtraction with negated expression. + + - `e0 + -(e1) -> e0 - e1` + - `e0 - -(e1) -> e0 + e1` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in [OperationType.plus, OperationType.minus]: + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + right = operation.right + if not isinstance(right, UnaryOperation) or right.operation != OperationType.negate: + return [] + + return [( + operation, + BinaryOperation( + OperationType.minus if operation.operation == OperationType.plus else OperationType.plus, + [operation.left, right.operand], + operation.type + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py new file mode 100644 index 000000000..03db3c00d --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_constants.py @@ -0,0 +1,20 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Constant, Expression, Operation + + +class CollapseConstants(SimplificationRule): + """ + Fold operations with only constants as operands: + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if not all(isinstance(o, Constant) for o in operation.operands): + return [] + if operation.operation not in FOLDABLE_OPERATIONS: + return [] + + return [( + operation, + constant_fold(operation.operation, operation.operands) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py new file mode 100644 index 000000000..f3559daae --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py @@ -0,0 +1,81 @@ +from functools import reduce +from typing import Iterator + +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import constant_fold +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + + +class CollapseNestedConstants(SimplificationRule): + """ + This rule walks the dafaflow tree and collects and folds constants in commutative operations. + The first constant of the tree is replaced with the folded result and all remaining constants are replaced with the identity. + This stage exploits associativity and is the only stage doing so. Therefore, it cannot be replaced by a combination of `TermOrder` and `CollapseConstants`. + """ + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in COMMUTATIVE_OPERATIONS: + return [] + if not isinstance(operation, Operation): + raise TypeError(f"Expected Operation, got {type(operation)}") + + constants = list(_collect_constants(operation)) + if len(constants) <= 1: + return [] + + first, *rest = constants + + folded_constant = reduce( + lambda c0, c1: constant_fold(operation.operation, [c0, c1]), + rest, + first + ) + + identity_constant = _identity_constant(operation.operation, operation.type) + return [ + (first, folded_constant), + *((constant, identity_constant) for constant in rest) + ] + + +def _collect_constants(operation: Operation) -> Iterator[Constant]: + """ + Collects constants of potentially multiple nested commutative operations of the same type. + + This function traverses the subtree rooted at the provided operation and collects + all constants that belong to operations with the same operation type as the root operation. + The subtree includes only operations that have matching operation types. + """ + + operation_type = operation.operation + operand_type = operation.type + + context_stack: list[Operation] = [operation] + while context_stack: + current_operation = context_stack.pop() + + for i, operand in enumerate(current_operation.operands): + if operand.type != operand_type: + continue + + if isinstance(operand, Operation): + if operand.operation == operation_type: + context_stack.append(operand) + continue + elif isinstance(operand, Constant) and _identity_constant(operation_type, operand_type).value != operand.value: + yield operand + + +def _identity_constant(operation: OperationType, var_type: Type) -> Constant: + """ + Return a const containing the identity element for the specified operation and variable type. + """ + match operation: + case OperationType.plus | OperationType.bitwise_xor | OperationType.bitwise_or: + return Constant(0, var_type) + case OperationType.multiply | OperationType.multiply_us: + return Constant(1, var_type) + case OperationType.bitwise_and: + return constant_fold(OperationType.bitwise_not, [Constant(0, var_type)]) + case _: + raise NotImplementedError() diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py new file mode 100644 index 000000000..6b358dd81 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py @@ -0,0 +1,43 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import normalize_int +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType + + +class PositiveConstants(SimplificationRule): + """ + Changes add/sub so that the right operand constant is always positive. + For unsigned arithmetic, choose the operation with the lesser constant (e.g.: V - 4294967293 -> V + 3 for 32 bit ints). + + - `V - a -> E + (-a)` when signed(a) < 0 + - `V + a -> E - (-a)` when signed(a) < 0 + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in (OperationType.plus, OperationType.minus): + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + right = operation.right + if not isinstance(right, Constant): + return [] + + constant_type = right.type + if not isinstance(constant_type, Integer): + return [] + + signed_normalized_constant = normalize_int(right.value, constant_type.size, True) + if signed_normalized_constant >= 0: + return [] + + neg_constant = Constant( + normalize_int(-signed_normalized_constant, constant_type.size, constant_type.signed), + constant_type + ) + return [( + operation, + BinaryOperation( + OperationType.plus if operation.operation == OperationType.minus else OperationType.minus, + [operation.left, neg_constant] + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py new file mode 100644 index 000000000..a4f70334a --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/rule.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + +from decompiler.structures.pseudo import Expression, Operation + + +class SimplificationRule(ABC): + """ + This class defines the interface for simplification rules that can be applied to expressions. + """ + + @abstractmethod + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + """ + Apply the simplification rule to the given operation. + + :param operation: The operation to which the simplification rule should be applied. + :return: A list of tuples, each containing a pair of expressions representing the original + and simplified versions resulting from applying the simplification rule to the given operation. + """ + pass diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py new file mode 100644 index 000000000..9337be30b --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_redundant_reference.py @@ -0,0 +1,20 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import Expression, Operation, OperationType, UnaryOperation + + +class SimplifyRedundantReference(SimplificationRule): + """ + Removes redundant nesting of referencing, immediately followed by referencing. + + `*(&(e0)) -> e0` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case UnaryOperation( + operation=OperationType.dereference, + operand=UnaryOperation(operation=OperationType.address, operand=inner_operand) + ): + return [(operation, inner_operand)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py new file mode 100644 index 000000000..07ebac6c2 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_arithmetic.py @@ -0,0 +1,39 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType, UnaryOperation + + +class SimplifyTrivialArithmetic(SimplificationRule): + """ + Simplifies trivial arithmetic: + + - `e + 0 -> e` + - `e - 0 -> e` + - `e * 0 -> 0` + - `e u* 0 -> 0` + - `e * 1 -> e` + - `e u* 1 -> e` + - `e * -1 -> -e` + - `e u* -1 -> -e` + - `e / 1 -> e` + - `e u/ 1 -> e` + - `e / -1 -> -e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.plus | OperationType.minus, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation( + operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide | OperationType.divide_us, + right=Constant(value=1), + ): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.multiply | OperationType.multiply_us, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation( + operation=OperationType.multiply | OperationType.multiply_us | OperationType.divide, + right=Constant(value=-1) + ): + return [(operation, UnaryOperation(OperationType.negate, [operation.left]))] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py new file mode 100644 index 000000000..5a6a62c58 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_bit_arithmetic.py @@ -0,0 +1,28 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialBitArithmetic(SimplificationRule): + """ + Simplifies trivial bit arithmetic: + + - `e | 0 -> e` + - `e | e -> e` + - `e & 0 -> 0` + - `e & e -> e` + - `e ^ 0 -> e` + - `e ^ e -> 0` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_xor, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.bitwise_and, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation(operation=OperationType.bitwise_or | OperationType.bitwise_and, left=left, right=right) if left == right: + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.bitwise_xor, left=left, right=right) if left == right: + return [(operation, Constant(0, operation.type))] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py new file mode 100644 index 000000000..08994b1cc --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_logic_arithmetic.py @@ -0,0 +1,26 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialLogicArithmetic(SimplificationRule): + """ + Simplifies trivial logic arithmetic. + + - `e || false -> e` + - `e || true -> true` + - `e && false -> false` + - `e && true -> e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation(operation=OperationType.logical_or, right=Constant(value=0)): + return [(operation, operation.left)] + case BinaryOperation(operation=OperationType.logical_and, right=Constant(value=0)): + return [(operation, Constant(0, operation.type))] + case BinaryOperation(operation=OperationType.logical_or, right=Constant(value=value)) if value != 0: + return [(operation, Constant(1, operation.type))] + case BinaryOperation(operation=OperationType.logical_and, right=Constant(value=value)) if value != 0: + return [(operation, operation.left)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py new file mode 100644 index 000000000..fcf1fcfa9 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/simplify_trivial_shift.py @@ -0,0 +1,28 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation, OperationType + + +class SimplifyTrivialShift(SimplificationRule): + """ + Simplifies trivial shift/rotate arithmetic: + + - `e << 0 -> e` + - `e >> 0 -> e` + - `e u>> 0 -> e` + - `e lrot 0 -> e` + - `e rrot 0 -> e` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + match operation: + case BinaryOperation( + operation=OperationType.left_shift + | OperationType.right_shift + | OperationType.right_shift_us + | OperationType.left_rotate + | OperationType.right_rotate, + right=Constant(value=0), + ): + return [(operation, operation.left)] + case _: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py new file mode 100644 index 000000000..9943d365d --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/sub_to_add.py @@ -0,0 +1,27 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation + + +class SubToAdd(SimplificationRule): + """ + Replace subtractions with additions. + + `e0 - e1 -> e0 + (-e1)` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation != OperationType.minus: + return [] + if not isinstance(operation, BinaryOperation): + raise TypeError(f"Expected BinaryOperation, got {type(operation)}") + + neg_op = UnaryOperation(OperationType.negate, [operation.right]) + + return [( + operation, + BinaryOperation( + OperationType.plus, + [operation.left, neg_op], + operation.type + ) + )] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py new file mode 100644 index 000000000..638f4d08a --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/term_order.py @@ -0,0 +1,28 @@ +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Operation +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + + +class TermOrder(SimplificationRule): + """ + Swap constants of commutative operations to the right. + This stage is important because other stages expect constants to be on the right side. + Associativity is not exploited, i.e. nested operations of the same type are not considered. + + - `c + e -> e + c` + - `c * e -> e * c` + - `c & e -> e & c` + - `c | e -> e | c` + - `c ^ e -> e ^ c` + """ + + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + if operation.operation not in COMMUTATIVE_OPERATIONS: + return [] + if not isinstance(operation, BinaryOperation): + raise ValueError(f"Expected BinaryOperation, got {operation}") + + if isinstance(operation.left, Constant) and not isinstance(operation.right, Constant): + return [(operation, BinaryOperation(operation.operation, [operation.right, operation.left], operation.type, operation.tags))] + else: + return [] diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py new file mode 100644 index 000000000..cc05a59ba --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/stages.py @@ -0,0 +1,169 @@ +import logging +from abc import ABC, abstractmethod + +from decompiler.backend.cexpressiongenerator import CExpressionGenerator +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_add_neg import CollapseAddNeg +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_constants import CollapseConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_nested_constants import CollapseNestedConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.positive_constants import PositiveConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_redundant_reference import SimplifyRedundantReference +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_arithmetic import SimplifyTrivialArithmetic +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_bit_arithmetic import ( + SimplifyTrivialBitArithmetic, +) +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_logic_arithmetic import ( + SimplifyTrivialLogicArithmetic, +) +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_shift import SimplifyTrivialShift +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.sub_to_add import SubToAdd +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.term_order import TermOrder +from decompiler.pipeline.stage import PipelineStage +from decompiler.structures.ast.ast_nodes import CodeNode +from decompiler.structures.pseudo import Instruction, Operation +from decompiler.structures.visitors.substitute_visitor import SubstituteVisitor +from decompiler.task import DecompilerTask + + +class _ExpressionSimplificationBase(PipelineStage, ABC): + + def run(self, task: DecompilerTask): + max_iterations = task.options.getint("expression-simplification.max_iterations") + self._simplify_instructions(self._get_instructions(task), max_iterations) + + @abstractmethod + def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: + pass + + @classmethod + def _simplify_instructions(cls, instructions: list[Instruction], max_iterations: int): + rule_sets = [ + ("pre-rules", _pre_rules), + ("rules", _rules), + ("post-rules", _post_rules) + ] + for rule_name, rule_set in rule_sets: + # max_iterations is counted per rule_set + iteration_count = cls._simplify_instructions_with_rule_set(instructions, rule_set, max_iterations) + if iteration_count <= max_iterations: + logging.info(f"Expression simplification took {iteration_count} iterations for {rule_name}") + else: + logging.warning(f"Exceeded max iteration count for {rule_name}") + + @classmethod + def _simplify_instructions_with_rule_set( + cls, + instructions: list[Instruction], + rule_set: list[SimplificationRule], + max_iterations: int + ) -> int: + iteration_count = 0 + + changes = True + while changes: + changes = False + + for rule in rule_set: + for instruction in instructions: + additional_iterations = cls._simplify_instruction_with_rule(instruction, rule, max_iterations - iteration_count) + if additional_iterations > 0: + changes = True + + iteration_count += additional_iterations + if iteration_count > max_iterations: + return iteration_count + + return iteration_count + + @classmethod + def _simplify_instruction_with_rule( + cls, + instruction: Instruction, + rule: SimplificationRule, + max_iterations: int + ) -> int: + iteration_count = 0 + for expression in instruction.subexpressions(): + while True: + # NOTE: By breaking out of the inner endless loop, we just continue with the next subexpression. + if expression is None: + break + if not isinstance(expression, Operation): + break + + substitutions = rule.apply(expression) + if not substitutions: + break + + iteration_count += 1 + + if iteration_count > max_iterations: + logging.warning("Took to many iterations for rule set to finish") + return iteration_count + + for i, (replacee, replacement) in enumerate(substitutions): + expression_gen = CExpressionGenerator() + logging.debug( + f"[{rule.__class__.__name__}] {i}. Substituting: '{replacee.accept(expression_gen)}'" + f" with '{replacement.accept(expression_gen)}' in '{expression.accept(expression_gen)}'" + ) + instruction_repl = instruction.accept(SubstituteVisitor.identity(replacee, replacement)) + + # instruction.accept should never replace itself and consequently return none, because replacee is + # of type Expression and doesn't permit Instruction + assert instruction_repl is None + + # This is modifying the expression tree, while we are iterating over it. + # This works because we are iterating depth first and only + # modifying already visited nodes. + + # if expression got replaced, we need to update the reference + if replacee == expression: + expression = replacement + + return iteration_count + + +class ExpressionSimplificationCfg(_ExpressionSimplificationBase): + """ + Pipeline stage that simplifies cfg expressions by applying a set of simplification rules. + """ + + name = "expression-simplification-cfg" + + def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: + return list(task.graph.instructions) + + +class ExpressionSimplificationAst(_ExpressionSimplificationBase): + """ + Pipeline stage that simplifies ast expressions by applying a set of simplification rules. + """ + + name = "expression-simplification-ast" + + def _get_instructions(self, task: DecompilerTask) -> list[Instruction]: + instructions = [] + for node in task.syntax_tree.topological_order(): + if isinstance(node, CodeNode): + instructions.extend(node.instructions) + + return instructions + + +_pre_rules: list[SimplificationRule] = [] +_rules: list[SimplificationRule] = [ + TermOrder(), + SubToAdd(), + SimplifyRedundantReference(), + SimplifyTrivialArithmetic(), + SimplifyTrivialBitArithmetic(), + SimplifyTrivialLogicArithmetic(), + SimplifyTrivialShift(), + CollapseConstants(), + CollapseNestedConstants(), +] +_post_rules: list[SimplificationRule] = [ + CollapseAddNeg(), + PositiveConstants() +] diff --git a/decompiler/pipeline/controlflowanalysis/loop_name_generator.py b/decompiler/pipeline/controlflowanalysis/loop_name_generator.py new file mode 100644 index 000000000..169a111d3 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/loop_name_generator.py @@ -0,0 +1,123 @@ +from typing import List + +from decompiler.pipeline.controlflowanalysis.loop_utility_methods import ( + AstInstruction, + _find_continuation_instruction, + _get_variable_initialisation, + _requirement_without_reinitialization, + _single_defininition_reaches_node, +) +from decompiler.pipeline.stage import PipelineStage +from decompiler.structures.ast.ast_nodes import LoopNode +from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree +from decompiler.structures.pseudo import Assignment, Expression, Operation, Variable +from decompiler.task import DecompilerTask + + +class WhileLoopVariableRenamer: + """Iterate over While-Loop Nodes and rename their counter variables to counter, counter1, ...""" + + def __init__(self, ast: AbstractSyntaxTree): + self._ast = ast + self._variable_counter: int = 0 + + def rename(self): + """ + Iterate over While-Loop Nodes and rename their counter variables to counter, counter1, ... + + Only rename counter variables that suffice the following conditions: + -> any variable x is used in the loop condition + -> variable x is set inside the loop body + -> single definition of variable x reaches loop entry (x is initialized/used only once) + """ + + for loop_node in self._ast.get_while_loop_nodes_topological_order(): + if loop_node.is_endless_loop: + continue + for condition_var in loop_node.get_required_variables(self._ast.condition_map): + if not (variable_init := _get_variable_initialisation(self._ast, condition_var)): + continue + if not _find_continuation_instruction(self._ast, loop_node, condition_var, renaming=True): + continue + if not _single_defininition_reaches_node(self._ast, variable_init, loop_node): + continue + self._replace_variables(loop_node, variable_init) + break + + def _replace_variables(self, loop_node: LoopNode, variable_init: AstInstruction): + """ + Rename old variable usages to counter variable in: + - variable initialization + - condition/condition map + - loop body + Also add a copy instruction if the variable is used after the loop without reinitialization. + """ + new_variable = Variable(self._get_variable_name(), variable_init.instruction.destination.type) + self._ast.replace_variable_in_subtree(loop_node, variable_init.instruction.destination, new_variable) + if _requirement_without_reinitialization(self._ast, loop_node, variable_init.instruction.destination): + self._ast.add_instructions_after(loop_node, Assignment(variable_init.instruction.destination, new_variable)) + variable_init.node.replace_variable(variable_init.instruction.destination, new_variable) + + def _get_variable_name(self) -> str: + variable_name = f"counter{self._variable_counter if self._variable_counter > 0 else ''}" + self._variable_counter += 1 + return variable_name + + +class ForLoopVariableRenamer: + """Iterate over ForLoopNodes and rename their variables to i, j, ..., i1, j1, ...""" + + def __init__(self, ast: AbstractSyntaxTree, candidates: list[str]): + self._ast = ast + self._iteration: int = 0 + self._variable_counter: int = -1 + self._candidates: list[str] = candidates + + def rename(self): + """ + Iterate over ForLoopNodes and rename their variables to i, j, k, ... + We skip renaming for loops that are not initialized in its declaration. + """ + for loop_node in self._ast.get_for_loop_nodes_topological_order(): + if not isinstance(loop_node.declaration, Assignment): + continue + + old_variable: Variable = self._get_variable_from_assignment(loop_node.declaration.destination) + new_variable = Variable(self._get_variable_name(), old_variable.type, ssa_name=old_variable.ssa_name) + self._ast.replace_variable_in_subtree(loop_node, old_variable, new_variable) + + if _requirement_without_reinitialization(self._ast, loop_node, old_variable): + self._ast.add_instructions_after(loop_node, Assignment(old_variable, new_variable)) + + def _get_variable_name(self) -> str: + """Return variable names in the form of [i, j, ..., i1, j1, ...]""" + self._variable_counter += 1 + if self._variable_counter >= len(self._candidates): + self._variable_counter = 0 + self._iteration += 1 + return f"{self._candidates[self._variable_counter]}{self._iteration if self._iteration > 0 else ''}" + + def _get_variable_from_assignment(self, expr: Expression) -> Variable: + if isinstance(expr, Variable): + return expr + if isinstance(expr, Operation) and len(expr.operands) == 1: + return expr.operands[0] + raise ValueError("Did not expect a Constant/Unknown/Operation with more then 1 operand as a ForLoop declaration") + + +class LoopNameGenerator(PipelineStage): + """ + Stage which renames while/for-loops to custom names. + """ + + name = "loop-name-generator" + + def run(self, task: DecompilerTask): + rename_while_loops: bool = task.options.getboolean("loop-name-generator.rename_while_loop_variables", fallback=False) + for_loop_names: List[str] = task.options.getlist("loop-name-generator.for_loop_variable_names", fallback=[]) + + if rename_while_loops: + WhileLoopVariableRenamer(task._ast).rename() + + if for_loop_names: + ForLoopVariableRenamer(task._ast, for_loop_names).rename() diff --git a/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py b/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py new file mode 100644 index 000000000..6a2f7a4bc --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional + +from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, LoopNode, SeqNode, SwitchNode +from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree +from decompiler.structures.logic.logic_condition import LogicCondition +from decompiler.structures.pseudo import Assignment, Condition, Variable +from decompiler.structures.visitors.assignment_visitor import AssignmentVisitor + + +@dataclass +class AstInstruction: + instruction: Assignment + position: int + node: CodeNode + +def _is_single_instruction_loop_node(loop_node: LoopNode) -> bool: + """ + Check if the loop body contains only one instruction. + + :param loop_node: LoopNode with a body + :return: True if body contains only one instruction else False + """ + body: AbstractSyntaxTreeNode = loop_node.body + if isinstance(body, CodeNode): + return len(body.instructions) == 1 + if isinstance(body, LoopNode): + return _is_single_instruction_loop_node(body) + if isinstance(body, (SeqNode, SwitchNode)): + return False + return False + + +def _has_deep_requirement(condition_map: Dict[LogicCondition, Condition], node: AbstractSyntaxTreeNode, variable: Variable) -> bool: + """ + Check if a variable is required in a node or any of its children. + + :param condition_map: logic condition to condition mapping + :param node: start node + :param variable: requirement to search for + :return: True if a requirement was found, else False + """ + if node is None: + return False + + if variable in node.get_required_variables(condition_map): + return True + + if isinstance(node, (SeqNode, SwitchNode, CaseNode)): + return any([_has_deep_requirement(condition_map, child, variable) for child in node.children]) + elif isinstance(node, ConditionNode): + return any( + [ + _has_deep_requirement(condition_map, node.true_branch_child, variable), + _has_deep_requirement(condition_map, node.false_branch_child, variable), + ] + ) + elif isinstance(node, LoopNode): + return _has_deep_requirement(condition_map, node.body, variable) + + +def _get_last_definition_index_of(node: CodeNode, variable: Variable) -> int: + """ + Iterate over CodeNode returning the index of last assignment to variable. + + :param node: node in which to search for last definition of variable + :param variable: check if definition contains this variable + :return: index of last definition or -1 if not found + """ + candidate = -1 + for position, instr in enumerate(node.instructions): + if variable in instr.definitions: + candidate = position + return candidate + + +def _get_last_requirement_index_of(node: CodeNode, variable: Variable) -> int: + """ + Iterate over CodeNode returning the index of last instruction using variable. + + :param node: node in which to search for last requirement of variable + :param variable: check if requirements contains this variable + :return: index of last definition or -1 if not found + """ + candidate = -1 + for position, instr in enumerate(node.instructions): + if variable in instr.requirements: + candidate = position + return candidate + + +def _find_continuation_instruction( + ast: AbstractSyntaxTree, node: AbstractSyntaxTreeNode, variable: Variable, renaming: bool = False +) -> Optional[AstInstruction]: + """ + Find a valid continuation instruction for a given variable inside a node. A valid continuation instruction defines the variable without + having requirements in later instructions. + + If we only want to rename the continuation instruction (instead of converting a while to a for-loop) we can additionally look at + switch / case nodes. + + :param node: node in which to search for last definition + :param variable: search instruction defining variable + :param renaming: continuation assignment for renaming purposes only + :return: AstInstruction if a definition without later requirement was found, else None + """ + iter_types = (SeqNode, SwitchNode) if renaming else SeqNode + if isinstance(node, iter_types): + for child in node.children[::-1]: + if instruction := _find_continuation_instruction(ast, child, variable, renaming): + return instruction + elif _has_deep_requirement(ast.condition_map, child, variable): + return None + elif renaming and isinstance(node, CaseNode): + return _find_continuation_instruction(ast, node.child, variable, renaming) + elif isinstance(node, LoopNode): + return _find_continuation_instruction(ast, node.body, variable, renaming) + elif isinstance(node, CodeNode): + last_req_index = _get_last_requirement_index_of(node, variable) + last_def_index = _get_last_definition_index_of(node, variable) + if last_req_index <= last_def_index != -1: + return AstInstruction(node.instructions[last_def_index], last_def_index, node) + + +def _get_variable_initialisation(ast: AbstractSyntaxTree, variable: Variable) -> Optional[AstInstruction]: + """ + Iterates over CodeNodes returning the first definition of variable. + + :param ast: AbstractSyntaxTree to search in + :param variable: find initialization of this variable + """ + for code_node in ast.get_code_nodes_topological_order(): + for position, instruction in enumerate(code_node.instructions): + if variable in instruction.definitions: + return AstInstruction(instruction, position, code_node) + + +def _single_defininition_reaches_node(ast: AbstractSyntaxTree, variable_init: AstInstruction, target_node: AbstractSyntaxTreeNode) -> bool: + """ + Check if a variable initialisation is redefined or used before target node. + + If we did not find the target node on the way down we still can assume there was no redefinition or usage. + + :param ast: AbstractSyntaxTree to search in + :param variable_init: AstInstruction containing the first variable initialisation + :param target_node: Search for redefinition or usages until this node is reached + """ + for ast_node in ast.get_reachable_nodes_pre_order(variable_init.node): + if ast_node is target_node: + return True + + defined_vars = list(ast_node.get_defined_variables(ast.condition_map)) + required_vars = list(ast_node.get_required_variables(ast.condition_map)) + used_variables = defined_vars + required_vars + + if ast_node is variable_init.node: + if used_variables.count(variable_init.instruction.destination) > 1: + return False + elif variable_init.instruction.destination in used_variables: + return False + return True + + +def _initialization_reaches_loop_node(init_node: AbstractSyntaxTreeNode, usage_node: AbstractSyntaxTreeNode) -> bool: + """ + Check if init node always reaches the usage node + + This is not the case if: + - nodes are separated by a LoopNode + - init-nodes parent is not a sequence node or not on a path from root to usage-node (only initialized under certain conditions) + + :param init_node: node where initialization takes place + :param usage_node: node that is potentially inside a LoopNode + :return: True if init and usage node are separated by a LoopNode else False + """ + init_parent = init_node.parent + iter_parent = usage_node.parent + if not isinstance(init_parent, SeqNode): + return False + while iter_parent is not init_parent: + if isinstance(iter_parent, LoopNode): + return False + iter_parent = iter_parent.parent + if iter_parent is None: + return False + return True + + +def _requirement_without_reinitialization(ast: AbstractSyntaxTree, node: AbstractSyntaxTreeNode, variable: Variable) -> bool: + """ + Check if a variable is used without prior initialization starting at a given node. + Edge case: definition and requirement in same instruction + + :param ast: + :param node: + :param variable: + :return: True if has requirement that is not prior reinitialized else False + """ + + for ast_node in ast.get_reachable_nodes_pre_order(node): + assignment_visitor = AssignmentVisitor() + assignment_visitor.visit(ast_node) + for assignment in assignment_visitor.assignments: + if variable in assignment.definitions and variable not in assignment.requirements: + return False + elif variable in assignment.definitions and variable in assignment.requirements: + return True + elif variable in assignment.requirements: + return True \ No newline at end of file diff --git a/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py b/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py index 048d4c098..7a334fd44 100644 --- a/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py @@ -1,225 +1,24 @@ -"""Module implementing various readbility based refinements.""" +"""Module implementing various readability based refinements.""" from __future__ import annotations -from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Union -from decompiler.pipeline.stage import PipelineStage -from decompiler.structures.ast.ast_nodes import ( - AbstractSyntaxTreeNode, - CaseNode, - CodeNode, - ConditionNode, - DoWhileLoopNode, - ForLoopNode, - LoopNode, - SeqNode, - SwitchNode, - WhileLoopNode, +from decompiler.pipeline.controlflowanalysis.loop_utility_methods import ( + AstInstruction, + _find_continuation_instruction, + _get_variable_initialisation, + _initialization_reaches_loop_node, + _is_single_instruction_loop_node, + _single_defininition_reaches_node, ) +from decompiler.pipeline.stage import PipelineStage +from decompiler.structures.ast.ast_nodes import ConditionNode, DoWhileLoopNode, ForLoopNode, WhileLoopNode from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree -from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.structures.pseudo import Assignment, Condition, Expression, Operation, Variable -from decompiler.structures.visitors.assignment_visitor import AssignmentVisitor +from decompiler.structures.pseudo import Assignment from decompiler.task import DecompilerTask from decompiler.util.options import Options -def _is_single_instruction_loop_node(loop_node: LoopNode) -> bool: - """ - Check if the loop body contains only one instruction. - - :param loop_node: LoopNode with a body - :return: True if body contains only one instruction else False - """ - body: AbstractSyntaxTreeNode = loop_node.body - if isinstance(body, CodeNode): - return len(body.instructions) == 1 - if isinstance(body, LoopNode): - return _is_single_instruction_loop_node(body) - if isinstance(body, (SeqNode, SwitchNode)): - return False - return False - - -def _has_deep_requirement(condition_map: Dict[LogicCondition, Condition], node: AbstractSyntaxTreeNode, variable: Variable) -> bool: - """ - Check if a variable is required in a node or any of its children. - - :param condition_map: logic condition to condition mapping - :param node: start node - :param variable: requirement to search for - :return: True if a requirement was found, else False - """ - if node is None: - return False - - if variable in node.get_required_variables(condition_map): - return True - - if isinstance(node, (SeqNode, SwitchNode, CaseNode)): - return any([_has_deep_requirement(condition_map, child, variable) for child in node.children]) - elif isinstance(node, ConditionNode): - return any( - [ - _has_deep_requirement(condition_map, node.true_branch_child, variable), - _has_deep_requirement(condition_map, node.false_branch_child, variable), - ] - ) - elif isinstance(node, LoopNode): - return _has_deep_requirement(condition_map, node.body, variable) - - -def _get_last_definition_index_of(node: CodeNode, variable: Variable) -> int: - """ - Iterate over CodeNode returning the index of last assignment to variable. - - :param node: node in which to search for last definition of variable - :param variable: check if definition contains this variable - :return: index of last definition or -1 if not found - """ - candidate = -1 - for position, instr in enumerate(node.instructions): - if variable in instr.definitions: - candidate = position - return candidate - - -def _get_last_requirement_index_of(node: CodeNode, variable: Variable) -> int: - """ - Iterate over CodeNode returning the index of last instruction using variable. - - :param node: node in which to search for last requirement of variable - :param variable: check if requirements contains this variable - :return: index of last definition or -1 if not found - """ - candidate = -1 - for position, instr in enumerate(node.instructions): - if variable in instr.requirements: - candidate = position - return candidate - - -def _find_continuation_instruction( - ast: AbstractSyntaxTree, node: AbstractSyntaxTreeNode, variable: Variable, renaming: bool = False -) -> Optional[AstInstruction]: - """ - Find a valid continuation instruction for a given variable inside a node. A valid continuation instruction defines the variable without - having requirements in later instructions. - - If we only want to rename the continuation instruction (instead of converting a while to a for-loop) we can additionally look at - switch / case nodes. - - :param node: node in which to search for last definition - :param variable: search instruction defining variable - :param renaming: continuation assignment for renaming purposes only - :return: AstInstruction if a definition without later requirement was found, else None - """ - iter_types = (SeqNode, SwitchNode) if renaming else SeqNode - if isinstance(node, iter_types): - for child in node.children[::-1]: - if instruction := _find_continuation_instruction(ast, child, variable, renaming): - return instruction - elif _has_deep_requirement(ast.condition_map, child, variable): - return None - elif renaming and isinstance(node, CaseNode): - return _find_continuation_instruction(ast, node.child, variable, renaming) - elif isinstance(node, LoopNode): - return _find_continuation_instruction(ast, node.body, variable, renaming) - elif isinstance(node, CodeNode): - last_req_index = _get_last_requirement_index_of(node, variable) - last_def_index = _get_last_definition_index_of(node, variable) - if last_req_index <= last_def_index != -1: - return AstInstruction(node.instructions[last_def_index], last_def_index, node) - - -def _get_variable_initialisation(ast: AbstractSyntaxTree, variable: Variable) -> Optional[AstInstruction]: - """ - Iterates over CodeNodes returning the first definition of variable. - - :param ast: AbstractSyntaxTree to search in - :param variable: find initialization of this variable - """ - for code_node in ast.get_code_nodes_topological_order(): - for position, instruction in enumerate(code_node.instructions): - if variable in instruction.definitions: - return AstInstruction(instruction, position, code_node) - - -def _single_defininition_reaches_node(ast: AbstractSyntaxTree, variable_init: AstInstruction, target_node: AbstractSyntaxTreeNode) -> bool: - """ - Check if a variable initialisation is redefined or used before target node. - - If we did not find the target node on the way down we still can assume there was no redefinition or usage. - - :param ast: AbstractSyntaxTree to search in - :param variable_init: AstInstruction containing the first variable initialisation - :param target_node: Search for redefinition or usages until this node is reached - """ - for ast_node in ast.get_reachable_nodes_pre_order(variable_init.node): - if ast_node is target_node: - return True - - defined_vars = list(ast_node.get_defined_variables(ast.condition_map)) - required_vars = list(ast_node.get_required_variables(ast.condition_map)) - used_variables = defined_vars + required_vars - - if ast_node is variable_init.node: - if used_variables.count(variable_init.instruction.destination) > 1: - return False - elif variable_init.instruction.destination in used_variables: - return False - return True - - -def _initialization_reaches_loop_node(init_node: AbstractSyntaxTreeNode, usage_node: AbstractSyntaxTreeNode) -> bool: - """ - Check if init node always reaches the usage node - - This is not the case if: - - nodes are separated by a LoopNode - - init-nodes parent is not a sequence node or not on a path from root to usage-node (only initialized under certain conditions) - - :param init_node: node where initialization takes place - :param usage_node: node that is potentially inside a LoopNode - :return: True if init and usage node are separated by a LoopNode else False - """ - init_parent = init_node.parent - iter_parent = usage_node.parent - if not isinstance(init_parent, SeqNode): - return False - while iter_parent is not init_parent: - if isinstance(iter_parent, LoopNode): - return False - iter_parent = iter_parent.parent - if iter_parent is None: - return False - return True - - -def _requirement_without_reinitialization(ast: AbstractSyntaxTree, node: AbstractSyntaxTreeNode, variable: Variable) -> bool: - """ - Check if a variable is used without prior initialization starting at a given node. - Edge case: definition and requirement in same instruction - - :param ast: - :param node: - :param variable: - :return: True if has requirement that is not prior reinitialized else False - """ - - for ast_node in ast.get_reachable_nodes_pre_order(node): - assignment_visitor = AssignmentVisitor() - assignment_visitor.visit(ast_node) - for assignment in assignment_visitor.assignments: - if variable in assignment.definitions and variable not in assignment.requirements: - return False - elif variable in assignment.definitions and variable in assignment.requirements: - return True - elif variable in assignment.requirements: - return True - - def _get_potential_guarded_do_while_loops(ast: AbstractSyntaxTree) -> tuple(Union[DoWhileLoopNode, WhileLoopNode], ConditionNode): for loop_node in list(ast.get_loop_nodes_post_order()): if isinstance(loop_node, DoWhileLoopNode) and isinstance(loop_node.parent.parent, ConditionNode): @@ -242,13 +41,6 @@ def remove_guarded_do_while(ast: AbstractSyntaxTree): ast.substitute_loop_node(do_while_node, WhileLoopNode(do_while_node.condition, do_while_node.reaching_condition)) -@dataclass -class AstInstruction: - instruction: Assignment - position: int - node: CodeNode - - class WhileLoopReplacer: """Convert WhileLoopNodes to ForLoopNodes depending on the configuration. -> keep_empty_for_loops will keep empty for-loops in the code @@ -363,105 +155,13 @@ def _invalid_simple_for_loop_condition_type(self, logic_condition) -> bool: return False -class WhileLoopVariableRenamer: - """Iterate over While-Loop Nodes and rename their counter variables to counter, counter1, ...""" - - def __init__(self, ast: AbstractSyntaxTree): - self._ast = ast - self._variable_counter: int = 0 - - def rename(self): - """ - Iterate over While-Loop Nodes and rename their counter variables to counter, counter1, ... - - Only rename counter variables that suffice the following conditions: - -> any variable x is used in the loop condition - -> variable x is set inside the loop body - -> single definition of variable x reaches loop entry (x is initialized/used only once) - """ - - for loop_node in self._ast.get_while_loop_nodes_topological_order(): - if loop_node.is_endless_loop: - continue - for condition_var in loop_node.get_required_variables(self._ast.condition_map): - if not (variable_init := _get_variable_initialisation(self._ast, condition_var)): - continue - if not _find_continuation_instruction(self._ast, loop_node, condition_var, renaming=True): - continue - if not _single_defininition_reaches_node(self._ast, variable_init, loop_node): - continue - self._replace_variables(loop_node, variable_init) - break - - def _replace_variables(self, loop_node: LoopNode, variable_init: AstInstruction): - """ - Rename old variable usages to counter variable in: - - variable initialization - - condition/condition map - - loop body - Also add a copy instruction if the variable is used after the loop without reinitialization. - """ - new_variable = Variable(self._get_variable_name(), variable_init.instruction.destination.type) - self._ast.replace_variable_in_subtree(loop_node, variable_init.instruction.destination, new_variable) - if _requirement_without_reinitialization(self._ast, loop_node, variable_init.instruction.destination): - self._ast.add_instructions_after(loop_node, Assignment(variable_init.instruction.destination, new_variable)) - variable_init.node.replace_variable(variable_init.instruction.destination, new_variable) - - def _get_variable_name(self) -> str: - variable_name = f"counter{self._variable_counter if self._variable_counter > 0 else ''}" - self._variable_counter += 1 - return variable_name - - -class ForLoopVariableRenamer: - """Iterate over ForLoopNodes and rename their variables to i, j, ..., i1, j1, ...""" - - def __init__(self, ast: AbstractSyntaxTree, candidates: list[str]): - self._ast = ast - self._iteration: int = 0 - self._variable_counter: int = -1 - self._candidates: list[str] = candidates - - def rename(self): - """ - Iterate over ForLoopNodes and rename their variables to i, j, k, ... - We skip renaming for loops that are not initialized in its declaration. - """ - for loop_node in self._ast.get_for_loop_nodes_topological_order(): - if not isinstance(loop_node.declaration, Assignment): - continue - - old_variable: Variable = self._get_variable_from_assignment(loop_node.declaration.destination) - new_variable = Variable(self._get_variable_name(), old_variable.type, ssa_name=old_variable.ssa_name) - self._ast.replace_variable_in_subtree(loop_node, old_variable, new_variable) - loop_node.declaration.value.substitute(new_variable, old_variable) - - if _requirement_without_reinitialization(self._ast, loop_node, old_variable): - self._ast.add_instructions_after(loop_node, Assignment(old_variable, new_variable)) - - def _get_variable_name(self) -> str: - """Return variable names in the form of [i, j, ..., i1, j1, ...]""" - self._variable_counter += 1 - if self._variable_counter >= len(self._candidates): - self._variable_counter = 0 - self._iteration += 1 - return f"{self._candidates[self._variable_counter]}{self._iteration if self._iteration > 0 else ''}" - - def _get_variable_from_assignment(self, expr: Expression) -> Variable: - if isinstance(expr, Variable): - return expr - if isinstance(expr, Operation) and len(expr.operands) == 1: - return expr.operands[0] - raise ValueError("Did not expect a Constant/Unknown/Operation with more then 1 operand as a ForLoop declaration") - class ReadabilityBasedRefinement(PipelineStage): """ The ReadabilityBasedRefinement makes various transformations to improve readability based on the AST. Currently implemented transformations: - 1. while-loop to for-loop transformation - 2. for-loop variable renaming (e.g i, j, k, ...) - 3. while-loop variable renaming (e.g. counter, counter1, ...) + 1. remove guarded do while loops + 2. while-loop to for-loop transformation The AST is cleaned up before the first transformation and after every while- to for-loop transformation. """ @@ -472,12 +172,4 @@ def run(self, task: DecompilerTask): task.syntax_tree.clean_up() remove_guarded_do_while(task.syntax_tree) - WhileLoopReplacer(task.syntax_tree, task.options).run() - - variableNames = task.options.getlist("readability-based-refinement.for_loop_variable_names", fallback=[]) - if variableNames: - ForLoopVariableRenamer(task.syntax_tree, variableNames).rename() - - if task.options.getboolean("readability-based-refinement.rename_while_loop_variables"): - WhileLoopVariableRenamer(task.syntax_tree).rename() \ No newline at end of file diff --git a/decompiler/pipeline/default.py b/decompiler/pipeline/default.py index 85b0e6389..c33a6c9ff 100644 --- a/decompiler/pipeline/default.py +++ b/decompiler/pipeline/default.py @@ -1,8 +1,10 @@ """Module defining the available pipelines.""" from decompiler.pipeline.controlflowanalysis import ( - ExpressionSimplification, + ExpressionSimplificationAst, + ExpressionSimplificationCfg, InstructionLengthHandler, + LoopNameGenerator, ReadabilityBasedRefinement, VariableNameGeneration, ) @@ -36,10 +38,16 @@ IdentityElimination, CommonSubexpressionElimination, ArrayAccessDetection, - ExpressionSimplification, + ExpressionSimplificationCfg, DeadComponentPruner, GraphExpressionFolding, EdgePruner, ] -AST_STAGES = [ReadabilityBasedRefinement, ExpressionSimplification, InstructionLengthHandler, VariableNameGeneration] +AST_STAGES = [ + ReadabilityBasedRefinement, + ExpressionSimplificationAst, + InstructionLengthHandler, + VariableNameGeneration, + LoopNameGenerator +] diff --git a/decompiler/pipeline/preprocessing/missing_definitions.py b/decompiler/pipeline/preprocessing/missing_definitions.py index a322a8a05..d9c93a4a4 100644 --- a/decompiler/pipeline/preprocessing/missing_definitions.py +++ b/decompiler/pipeline/preprocessing/missing_definitions.py @@ -67,9 +67,19 @@ def get_smallest_label_copy(self, variable: Union[str, Variable]): return self._sorted_copies_of[variable][0] return min(self._copies_of_variable[variable], key=lambda var: var.ssa_label) + def _check_duplicated(self, var_name: str): + """ + Due to mixing of ssa_labels and memory versions, it can happen that we have duplicates in the copy pool. + E.g., [edx#0, edx#0, ...] + """ + ssa_labels = [var.ssa_label for var in self._sorted_copies_of[var_name]] + if any(i == j for i, j in zip(ssa_labels, ssa_labels[1:])): + raise ValueError(f"duplicate entries in copy pool for {var_name}") + def possible_missing_definitions_for(self, variable: Union[str, Variable]) -> List[Variable]: """Returns all variables whose definition may be missing because it is not the first in the order.""" var_name = self._get_variable_name(variable) + self._check_duplicated(var_name) return self._sorted_copies_of[var_name][1:] @staticmethod diff --git a/decompiler/structures/pseudo/operations.py b/decompiler/structures/pseudo/operations.py index 28d6ab9a6..c21002068 100644 --- a/decompiler/structures/pseudo/operations.py +++ b/decompiler/structures/pseudo/operations.py @@ -153,9 +153,14 @@ class OperationType(Enum): COMMUTATIVE_OPERATIONS = { OperationType.plus, OperationType.multiply, + OperationType.multiply_us, OperationType.bitwise_and, OperationType.bitwise_xor, OperationType.bitwise_or, + OperationType.logical_or, + OperationType.logical_and, + OperationType.equal, + OperationType.not_equal } NON_COMPOUNDABLE_OPERATIONS = { @@ -164,6 +169,10 @@ class OperationType(Enum): OperationType.left_rotate, OperationType.left_rotate_carry, OperationType.power, + OperationType.logical_or, + OperationType.logical_and, + OperationType.equal, + OperationType.not_equal } diff --git a/decompiler/structures/visitors/substitute_visitor.py b/decompiler/structures/visitors/substitute_visitor.py new file mode 100644 index 000000000..b4d7646e2 --- /dev/null +++ b/decompiler/structures/visitors/substitute_visitor.py @@ -0,0 +1,220 @@ +from typing import Callable, Optional, TypeVar, Union + +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Break, + Call, + Comment, + Condition, + Constant, + Continue, + DataflowObject, + Expression, + FunctionSymbol, + GenericBranch, + ImportedFunctionSymbol, + IntrinsicSymbol, + ListOperation, + MemPhi, + Operation, + Phi, + RegisterPair, + Return, + TernaryExpression, + UnaryOperation, + UnknownExpression, + Variable, +) +from decompiler.structures.pseudo.operations import ArrayInfo +from decompiler.structures.visitors.interfaces import DataflowObjectVisitorInterface + +T = TypeVar("T", bound=DataflowObject) + + +def _assert_type(obj: DataflowObject, t: type[T]) -> T: + if not isinstance(obj, t): + raise TypeError() + else: + return obj + + +class SubstituteVisitor(DataflowObjectVisitorInterface[Optional[DataflowObject]]): + """ + A visitor class for performing substitutions in a dataflow tree. + + This class allows you to create instances that can traverse a dataflow graph and perform substitutions + based on a provided mapping function. The mapping function is applied to each visited node in the graph, + and if the mapping function returns a non-None value, the node is replaced with the returned value. + + Note: + - Modifications to the dataflow tree happen in place. Only if the whole node that is being visited is replaced, + the visit method returns the replacement and not none. + - Even if a visit method returns a replacement, modifications could have happened to the original dataflow tree. + - Care should be taken when using this visitor, as substitution can leave the dataflow tree in an invalid state. + For example a dereference UnaryOperation could be updated without the changes being reflected in its ArrayInfo. + Same with changes to Phi and its origin_block + """ + + @classmethod + def identity(cls, replacee: DataflowObject, replacement: DataflowObject) -> "SubstituteVisitor": + """ + Create a SubstituteVisitor instance for identity-based substitution. + + This class method creates a SubstituteVisitor instance that replaces nodes equal to the 'replacee' + parameter with the 'replacement' parameter based on identity comparison (is). + + Note: + While SubstituteVisitor.equality() creates copies of the specified replacement, this one does not! + Be careful as to not introduce the same dataflow object twice into the dataflow tree. + + :param replacee: The object to be replaced based on identity. + :param replacement: The object to replace 'replacee' with. + :return: A SubstituteVisitor instance for identity-based substitution. + """ + + return SubstituteVisitor(lambda o: replacement if o is replacee else None) + + @classmethod + def equality(cls, replacee: DataflowObject, replacement: DataflowObject) -> "SubstituteVisitor": + """ + Create a SubstituteVisitor instance for equality-based substitution. + + This class method creates a SubstituteVisitor instance that replaces nodes equal to the 'replacee' + parameter with the 'replacement' parameter based on equality comparison (==). + + Note: + This visitor creates copies of the specified replacement when substituting. + + :param replacee: The object to be replaced based on equality. + :param replacement: The object to replace 'replacee' with. + :return: A SubstituteVisitor instance for equality-based substitution. + """ + + return SubstituteVisitor(lambda o: replacement.copy() if o == replacee else None) + + def __init__(self, mapper: Callable[[DataflowObject], Optional[DataflowObject]]) -> None: + """ + Initialize a SubstituteVisitor instance. + + :param mapper: A callable object that takes a DataflowObject as input and returns an Optional[DataflowObject]. + This function is used to determine replacements for visited nodes. + """ + + self._mapper = mapper + + def visit_unknown_expression(self, expr: UnknownExpression) -> Optional[DataflowObject]: + return self._mapper(expr) + + def visit_constant(self, expr: Constant) -> Optional[DataflowObject]: + return self._mapper(expr) + + def visit_variable(self, expr: Variable) -> Optional[DataflowObject]: + return self._mapper(expr) + + def visit_register_pair(self, expr: RegisterPair) -> Optional[DataflowObject]: + if (low_replacement := expr.low.accept(self)) is not None: + expr._low = _assert_type(low_replacement, Variable) + + if (high_replacement := expr.high.accept(self)) is not None: + expr._high = _assert_type(high_replacement, Variable) + + return self._mapper(expr) + + def _visit_operation(self, op: Operation) -> Optional[DataflowObject]: + """Base visit function used for all operation related visit functions""" + for index, operand in enumerate(op.operands): + if (repl := operand.accept(self)) is not None: + op.operands[index] = _assert_type(repl, Expression) + + return self._mapper(op) + + def visit_list_operation(self, op: ListOperation) -> Optional[DataflowObject]: + return self._visit_operation(op) + + def _substitute_array_info(self, array_info: ArrayInfo): + if (base_replacement := array_info.base.accept(self)) is not None: + array_info.base = _assert_type(base_replacement, Variable) + + # array_info.index can either be Variable or int. Only try substituting if not int + if isinstance(array_info.index, Variable): + if (index_replacement := array_info.index.accept(self)) is not None: + array_info.index = _assert_type(index_replacement, Variable) + + def visit_unary_operation(self, op: UnaryOperation) -> Optional[DataflowObject]: + if op.array_info is not None: + self._substitute_array_info(op.array_info) + + return self._visit_operation(op) + + def visit_binary_operation(self, op: BinaryOperation) -> Optional[DataflowObject]: + return self._visit_operation(op) + + def visit_call(self, op: Call) -> Optional[DataflowObject]: + if (function_replacement := op.function.accept(self)) is not None: + op._function = _assert_type( + function_replacement, + Union[FunctionSymbol, ImportedFunctionSymbol, IntrinsicSymbol, Variable] + ) + + return self._visit_operation(op) + + def visit_condition(self, op: Condition) -> Optional[DataflowObject]: + return self._visit_operation(op) + + def visit_ternary_expression(self, op: TernaryExpression) -> Optional[DataflowObject]: + return self._visit_operation(op) + + def visit_comment(self, instr: Comment) -> Optional[DataflowObject]: + return self._mapper(instr) + + def visit_assignment(self, instr: Assignment) -> Optional[DataflowObject]: + if (value_replacement := instr.value.accept(self)) is not None: + instr._value = _assert_type(value_replacement, Expression) + if (destination_replacement := instr.destination.accept(self)) is not None: + instr._destination = _assert_type(destination_replacement, Expression) + + return self._mapper(instr) + + def visit_generic_branch(self, instr: GenericBranch) -> Optional[DataflowObject]: + if (condition_replacement := instr.condition.accept(self)) is not None: + instr._condition = _assert_type(condition_replacement, Expression) + + return self._mapper(instr) + + def visit_return(self, instr: Return) -> Optional[DataflowObject]: + if (values_replacement := instr.values.accept(self)) is not None: + instr._values = _assert_type(values_replacement, ListOperation) + + return self._mapper(instr) + + def visit_break(self, instr: Break) -> Optional[DataflowObject]: + return self._mapper(instr) + + def visit_continue(self, instr: Continue) -> Optional[DataflowObject]: + return self._mapper(instr) + + def _visit_phi_base(self, instr: Phi, value_type: type[DataflowObject]): + if (repl := instr.value.accept(self)) is not None: + # Phi only accepts ListOperation with 'value_type' as valid values + for operand in _assert_type(repl, ListOperation).operands: + _assert_type(operand, value_type) + + instr._value = repl + + for node, expression in instr.origin_block.items(): + if (replacement := expression.accept(self)) is not None: + instr.origin_block[node] = _assert_type(replacement, Union[Variable, Constant]) + + if (destination_replacement := instr.destination.accept(self)) is not None: + instr._destination = _assert_type(destination_replacement, Variable) + + return self._mapper(instr) + + def visit_phi(self, instr: Phi) -> Optional[DataflowObject]: + return self._visit_phi_base(instr, Union[Variable, Constant]) + + def visit_mem_phi(self, instr: MemPhi) -> Optional[DataflowObject]: + """We do not want substitute capabilities for MemPhi, since we remove it while preprocessing.""" + # return self._visit_phi_base(instr, Union[Variable]) + pass diff --git a/decompiler/util/default.json b/decompiler/util/default.json index 1fe8aabfd..c4bca6cff 100644 --- a/decompiler/util/default.json +++ b/decompiler/util/default.json @@ -285,24 +285,6 @@ "is_hidden_from_cli": false, "argument_name": "--empty-for-loops" }, - { - "dest": "readability-based-refinement.for_loop_variable_names", - "default": [ - "i", - "j", - "k", - "l", - "m", - "n" - ], - "type": "array", - "elementType": "string", - "title": "Rename for-loop variables into desired names", - "description": "Rename for-loop variables to values from list", - "is_hidden_from_gui": false, - "is_hidden_from_cli": false, - "argument_name": "--for-loop-variable-names" - }, { "dest": "readability-based-refinement.max_condition_complexity_for_loop_recovery", "default": 100, @@ -347,16 +329,6 @@ "is_hidden_from_cli": false, "argument_name": "--for-loop-exclude-conditions" }, - { - "dest": "readability-based-refinement.rename_while_loop_variables", - "default": true, - "type": "boolean", - "title": "Rename while-loop variables", - "description": "Rename while-loop counter variables to counter, counter1, ...", - "is_hidden_from_gui": false, - "is_hidden_from_cli": false, - "argument_name": "--rename-while-loop-variables" - }, { "dest": "variable-name-generation.notation", "default": "default", @@ -412,6 +384,34 @@ "is_hidden_from_cli": false, "argument_name": "--variable-generation-counter-separator" }, + { + "dest": "loop-name-generator.rename_while_loop_variables", + "default": true, + "type": "boolean", + "title": "Rename while-loop variables", + "description": "Rename while-loop counter variables to counter, counter1, ...", + "is_hidden_from_gui": false, + "is_hidden_from_cli": false, + "argument_name": "--rename-while-loop-variables" + }, + { + "dest": "loop-name-generator.for_loop_variable_names", + "default": [ + "i", + "j", + "k", + "l", + "m", + "n" + ], + "type": "array", + "elementType": "string", + "title": "Rename for-loop variables", + "description": "Rename for-loop variables to values from given list", + "is_hidden_from_gui": false, + "is_hidden_from_cli": false, + "argument_name": "--for-loop-variable-names" + }, { "dest": "code-generator.max_complexity", "default": 100, @@ -581,6 +581,16 @@ "is_hidden_from_gui": true, "is_hidden_from_cli": false, "argument_name": "--loop_break_in_cases" + }, + { + "dest": "expression-simplification.max_iterations", + "default": 10000, + "type": "number", + "title": "The maximum number of iterations any rule set in the expression simplification stage is allowed to take", + "description": "Stop simplifying with a rule set after this number of iterations is exceeded, even if more possible simplifications are possible", + "is_hidden_from_gui": false, + "is_hidden_from_cli": false, + "argument_name": "--max_expression_simplification_iterations" } ] }, @@ -599,7 +609,7 @@ "dead-code-elimination", "expression-propagation-memory", "expression-propagation-function-call", - "expression-simplification", + "expression-simplification-cfg", "dead-code-elimination", "redundant-casts-elimination", "identity-elimination", @@ -619,9 +629,10 @@ "dest": "pipeline.ast_stages", "default": [ "readability-based-refinement", - "expression-simplification", + "expression-simplification-ast", "instruction-length-handler", - "variable-name-generation" + "variable-name-generation", + "loop-name-generator" ], "title": "AST pipeline stages", "type": "array", diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py new file mode 100644 index 000000000..d7300107f --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_add_neg.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_add_neg import CollapseAddNeg +from decompiler.structures.pseudo import BinaryOperation, Expression, Operation, OperationType, UnaryOperation, Variable + +var_x = Variable("x") +var_y = Variable("y") + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.plus, [var_x, UnaryOperation(OperationType.negate, [var_y])]), + [BinaryOperation(OperationType.minus, [var_x, var_y])], + ), + ( + BinaryOperation(OperationType.minus, [var_x, UnaryOperation(OperationType.negate, [var_y])]), + [BinaryOperation(OperationType.plus, [var_x, var_y])], + ), + ], +) +def test_collapse_add_neg(operation: Operation, result: list[Expression]): + assert CollapseAddNeg().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py new file mode 100644 index 000000000..1a4a5febf --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_constant.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_constants import CollapseConstants +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Float, Integer, Operation, OperationType, Variable + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _c_float(value: float) -> Constant: + return Constant(value, Float.float()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.plus, [_c_i32(3), _c_i32(4)]), [_c_i32(7)]), + (BinaryOperation(OperationType.plus, [_c_i32(3), Variable("x")]), []), + (BinaryOperation(OperationType.plus_float, [_c_float(3.0), _c_float(4.0)]), []), + ], +) +def test_collapse_constant(operation: Operation, result: list[Expression]): + assert CollapseConstants().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_nested_constants.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_nested_constants.py new file mode 100644 index 000000000..897c58472 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_collapse_nested_constants.py @@ -0,0 +1,149 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_nested_constants import CollapseNestedConstants +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable +from decompiler.structures.visitors.substitute_visitor import SubstituteVisitor + + +def _var_i32(name: str) -> Variable: + return Variable(name, Integer.int32_t()) + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _plus(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.plus, [e0, e1]) + + +def _mul(e: Expression, factor: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.multiply, [e, factor]) + + +def _mul_us(e: Expression, factor: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.multiply_us, [e, factor]) + + +def _bit_and(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_and, [e0, e1]) + + +def _bit_xor(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_xor, [e0, e1]) + + +def _bit_or(e0: Expression, e1: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.bitwise_or, [e0, e1]) + + +@pytest.mark.parametrize( + ["operation", "possible_results"], + [ + ( # plus + _plus(_plus(_c_i32(7), _c_i32(11)), _c_i32(42)), + { + _plus(_plus(_c_i32(0), _c_i32(0)), _c_i32(60)), + _plus(_plus(_c_i32(0), _c_i32(60)), _c_i32(0)), + _plus(_plus(_c_i32(60), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _plus(_plus(_var_i32("a"), _c_i32(2)), _plus(_var_i32("b"), _c_i32(3))), + { + _plus(_plus(_var_i32("a"), _c_i32(5)), _plus(_var_i32("b"), _c_i32(0))), + _plus(_plus(_var_i32("a"), _c_i32(0)), _plus(_var_i32("b"), _c_i32(5))), + }, + ), + ( # multiply + _mul(_mul(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _mul(_mul(_c_i32(1), _c_i32(1)), _c_i32(154)), + _mul(_mul(_c_i32(1), _c_i32(154)), _c_i32(1)), + _mul(_mul(_c_i32(154), _c_i32(1)), _c_i32(1)), + }, + ), + ( + _mul(_mul(_var_i32("a"), _c_i32(2)), _mul(_var_i32("b"), _c_i32(3))), + { + _mul(_mul(_var_i32("a"), _c_i32(6)), _mul(_var_i32("b"), _c_i32(1))), + _mul(_mul(_var_i32("a"), _c_i32(1)), _mul(_var_i32("b"), _c_i32(6))), + }, + ), + ( # multiply_us + _mul_us(_mul_us(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _mul_us(_mul_us(_c_i32(1), _c_i32(1)), _c_i32(154)), + _mul_us(_mul_us(_c_i32(1), _c_i32(154)), _c_i32(1)), + _mul_us(_mul_us(_c_i32(154), _c_i32(1)), _c_i32(1)), + }, + ), + ( + _mul_us(_mul_us(_var_i32("a"), _c_i32(2)), _mul_us(_var_i32("b"), _c_i32(3))), + { + _mul_us(_mul_us(_var_i32("a"), _c_i32(6)), _mul_us(_var_i32("b"), _c_i32(1))), + _mul_us(_mul_us(_var_i32("a"), _c_i32(1)), _mul_us(_var_i32("b"), _c_i32(6))), + }, + ), + ( # bitwise_and + _bit_and(_bit_and(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_and(_bit_and(_c_i32(-1), _c_i32(-1)), _c_i32(2)), + _bit_and(_bit_and(_c_i32(-1), _c_i32(2)), _c_i32(-1)), + _bit_and(_bit_and(_c_i32(2), _c_i32(-1)), _c_i32(-1)), + }, + ), + ( + _bit_and(_bit_and(_var_i32("a"), _c_i32(2)), _bit_and(_var_i32("b"), _c_i32(3))), + { + _bit_and(_bit_and(_var_i32("a"), _c_i32(2)), _bit_and(_var_i32("b"), _c_i32(-1))), + _bit_and(_bit_and(_var_i32("a"), _c_i32(-1)), _bit_and(_var_i32("b"), _c_i32(2))), + }, + ), + ( # bitwise_xor + _bit_xor(_bit_xor(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_xor(_bit_xor(_c_i32(0), _c_i32(0)), _c_i32(14)), + _bit_xor(_bit_xor(_c_i32(0), _c_i32(14)), _c_i32(0)), + _bit_xor(_bit_xor(_c_i32(14), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(2)), _bit_xor(_var_i32("b"), _c_i32(3))), + { + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(1)), _bit_xor(_var_i32("b"), _c_i32(0))), + _bit_xor(_bit_xor(_var_i32("a"), _c_i32(0)), _bit_xor(_var_i32("b"), _c_i32(1))), + }, + ), + ( # bitwise_or + _bit_or(_bit_or(_c_i32(7), _c_i32(11)), _c_i32(2)), + { + _bit_or(_bit_or(_c_i32(0), _c_i32(0)), _c_i32(15)), + _bit_or(_bit_or(_c_i32(0), _c_i32(15)), _c_i32(0)), + _bit_or(_bit_or(_c_i32(15), _c_i32(0)), _c_i32(0)), + }, + ), + ( + _bit_or(_bit_or(_var_i32("a"), _c_i32(2)), _bit_or(_var_i32("b"), _c_i32(3))), + { + _bit_or(_bit_or(_var_i32("a"), _c_i32(3)), _bit_or(_var_i32("b"), _c_i32(0))), + _bit_or(_bit_or(_var_i32("a"), _c_i32(0)), _bit_or(_var_i32("b"), _c_i32(3))), + }, + ), + ], +) +def test_collect_terms(operation: Operation, possible_results: set[Expression]): + collect_terms = CollapseNestedConstants() + + for i in range(100): + substitutions = collect_terms.apply(operation) + if not substitutions: + break + + for replacee, replacement in substitutions: + new_operation = operation.accept(SubstituteVisitor.identity(replacee, replacement)) + if new_operation is not None: + operation = new_operation + else: + raise RuntimeError("Max iterations exceeded") + + assert operation in possible_results diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py new file mode 100644 index 000000000..8142753f9 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_positive_constants.py @@ -0,0 +1,36 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.positive_constants import PositiveConstants +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var_x_i = Variable("x", Integer.int32_t()) +var_x_u = Variable("x", Integer.uint32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.minus, [var_x_i, (Constant(-3, Integer.int32_t()))]), + [BinaryOperation(OperationType.plus, [var_x_i, Constant(3, Integer.int32_t())])], + ), + ( + BinaryOperation(OperationType.plus, [var_x_i, (Constant(-3, Integer.int32_t()))]), + [BinaryOperation(OperationType.minus, [var_x_i, Constant(3, Integer.int32_t())])], + ), + (BinaryOperation(OperationType.plus, [var_x_i, (Constant(3, Integer.int32_t()))]), []), + (BinaryOperation(OperationType.minus, [var_x_i, (Constant(3, Integer.int32_t()))]), []), + + ( + BinaryOperation(OperationType.minus, [var_x_u, (Constant(4294967293, Integer.uint32_t()))]), + [BinaryOperation(OperationType.plus, [var_x_u, Constant(3, Integer.uint32_t())])], + ), + ( + BinaryOperation(OperationType.plus, [var_x_u, (Constant(4294967293, Integer.uint32_t()))]), + [BinaryOperation(OperationType.minus, [var_x_u, Constant(3, Integer.uint32_t())])], + ), + (BinaryOperation(OperationType.plus, [var_x_u, (Constant(3, Integer.uint32_t()))]), []), + (BinaryOperation(OperationType.minus, [var_x_u, (Constant(3, Integer.uint32_t()))]), []), + ], +) +def test_fix_add_sub_sign(operation: Operation, result: list[Expression]): + assert PositiveConstants().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py new file mode 100644 index 000000000..f269d5ec9 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_redundant_reference.py @@ -0,0 +1,15 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_redundant_reference import SimplifyRedundantReference +from decompiler.structures.pseudo import Expression, Operation, OperationType, UnaryOperation, Variable + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [var := Variable("x")])]), [var]), + (UnaryOperation(OperationType.address, [Variable("x")]), []), + (UnaryOperation(OperationType.dereference, [Variable("x")]), []), + ], +) +def test_simplify_redundant_reference(operation: Operation, result: list[Expression]): + assert SimplifyRedundantReference().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py new file mode 100644 index 000000000..14faaa795 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_arithmetic.py @@ -0,0 +1,29 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_arithmetic import SimplifyTrivialArithmetic +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, UnaryOperation, Variable + +var = Variable("x") +con_0 = Constant(0, Integer.int32_t()) +con_1 = Constant(1, Integer.int32_t()) +con_neg1 = Constant(-1, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.plus, [var, con_0]), [var]), + (BinaryOperation(OperationType.minus, [var, con_0]), [var]), + (BinaryOperation(OperationType.multiply, [var, con_0]), [con_0]), + (BinaryOperation(OperationType.multiply_us, [var, con_0]), [con_0]), + (BinaryOperation(OperationType.multiply, [var, con_1]), [var]), + (BinaryOperation(OperationType.multiply_us, [var, con_1]), [var]), + (BinaryOperation(OperationType.multiply, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.multiply_us, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.divide, [var, con_1]), [var]), + (BinaryOperation(OperationType.divide_us, [var, con_1]), [var]), + (BinaryOperation(OperationType.divide, [var, con_neg1]), [UnaryOperation(OperationType.negate, [var])]), + (BinaryOperation(OperationType.divide_us, [var, con_neg1]), []), + ], +) +def test_simplify_trivial_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py new file mode 100644 index 000000000..e42771f66 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_bit_arithmetic.py @@ -0,0 +1,23 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_bit_arithmetic import ( + SimplifyTrivialBitArithmetic, +) +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var = Variable("x", Integer.int32_t()) +con_0 = Constant(0, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.bitwise_or, [var, con_0]), [var]), + (BinaryOperation(OperationType.bitwise_or, [var, var]), [var]), + (BinaryOperation(OperationType.bitwise_and, [var, con_0]), [con_0]), + (BinaryOperation(OperationType.bitwise_and, [var, var]), [var]), + (BinaryOperation(OperationType.bitwise_xor, [var, con_0]), [var]), + (BinaryOperation(OperationType.bitwise_xor, [var, var]), [con_0]), + ], +) +def test_simplify_trivial_bit_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialBitArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py new file mode 100644 index 000000000..8d74ff8b1 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_logic_arithmetic.py @@ -0,0 +1,22 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_logic_arithmetic import ( + SimplifyTrivialLogicArithmetic, +) +from decompiler.structures.pseudo import BinaryOperation, Constant, CustomType, Expression, Operation, OperationType, Variable + +var = Variable("x", CustomType.bool()) +con_false = Constant(0, CustomType.bool()) +con_true = Constant(1, CustomType.bool()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.logical_or, [var, con_false]), [var]), + (BinaryOperation(OperationType.logical_or, [var, con_true]), [con_true]), + (BinaryOperation(OperationType.logical_and, [var, con_false]), [con_false]), + (BinaryOperation(OperationType.logical_and, [var, con_true]), [var]), + ], +) +def test_simplify_trivial_logic_arithmetic(operation: Operation, result: list[Expression]): + assert SimplifyTrivialLogicArithmetic().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py new file mode 100644 index 000000000..a281caa33 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_simplify_trivial_shift.py @@ -0,0 +1,20 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_shift import SimplifyTrivialShift +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, Variable + +var = Variable("x") +con_0 = Constant(0, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + (BinaryOperation(OperationType.left_shift, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_shift, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_shift_us, [var, con_0]), [var]), + (BinaryOperation(OperationType.left_rotate, [var, con_0]), [var]), + (BinaryOperation(OperationType.right_rotate, [var, con_0]), [var]), + ], +) +def test_simplify_trivial_shift(operation: Operation, result: list[Expression]): + assert SimplifyTrivialShift().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py new file mode 100644 index 000000000..d171aeb93 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_sub_to_add.py @@ -0,0 +1,20 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.sub_to_add import SubToAdd +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType, UnaryOperation, Variable + +var_x = Variable("x", Integer.int32_t()) +var_y = Variable("y", Integer.int32_t()) +con_neg1 = Constant(-1, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [ + ( + BinaryOperation(OperationType.minus, [var_x, var_y]), + [BinaryOperation(OperationType.plus, [var_x, UnaryOperation(OperationType.negate, [var_y])])], + ), + ], +) +def test_sub_to_add(operation: Operation, result: list[Expression]): + assert SubToAdd().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py new file mode 100644 index 000000000..5846004e9 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/rules/test_term_order.py @@ -0,0 +1,15 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.term_order import TermOrder +from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, Variable +from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS + +var = Variable("x") +con = Constant(42, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["operation", "result"], + [(BinaryOperation(operation, [con, var]), [BinaryOperation(operation, [var, con])]) for operation in COMMUTATIVE_OPERATIONS], +) +def test_term_order(operation: Operation, result: list[Expression]): + assert TermOrder().apply(operation) == [(operation, e) for e in result] diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py new file mode 100644 index 000000000..24f88ed93 --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_constant_folding.py @@ -0,0 +1,138 @@ +from contextlib import nullcontext + +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold +from decompiler.structures.pseudo import Constant, Float, Integer, OperationType + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _c_u32(value: int) -> Constant: + return Constant(value, Integer.uint32_t()) + + +def _c_i16(value: int) -> Constant: + return Constant(value, Integer.int16_t()) + + +def _c_float(value: float) -> Constant: + return Constant(value, Float.float()) + + +@pytest.mark.parametrize( + ["operation"], + [(operation,) for operation in OperationType if operation not in FOLDABLE_OPERATIONS] +) +def test_constant_fold_invalid_operations(operation: OperationType): + with pytest.raises(ValueError): + constant_fold(operation, []) + + +@pytest.mark.parametrize( + ["operation", "constants", "result", "context"], + [ + (OperationType.plus, [_c_i32(3), _c_i32(4)], _c_i32(7), nullcontext()), + (OperationType.plus, [_c_i32(2147483647), _c_i32(1)], _c_i32(-2147483648), nullcontext()), + (OperationType.plus, [_c_u32(2147483658), _c_u32(2147483652)], _c_u32(14), nullcontext()), + (OperationType.plus, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.plus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.minus, [_c_i32(3), _c_i32(4)], _c_i32(-1), nullcontext()), + (OperationType.minus, [_c_i32(-2147483648), _c_i32(1)], _c_i32(2147483647), nullcontext()), + (OperationType.minus, [_c_u32(3), _c_u32(4)], _c_u32(4294967295), nullcontext()), + (OperationType.minus, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.minus, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.multiply, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), + (OperationType.multiply, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), + (OperationType.multiply, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), + (OperationType.multiply, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.multiply, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.multiply_us, [_c_i32(3), _c_i32(4)], _c_i32(12), nullcontext()), + (OperationType.multiply_us, [_c_i32(-1073741824), _c_i32(2)], _c_i32(-2147483648), nullcontext()), + (OperationType.multiply_us, [_c_u32(3221225472), _c_u32(2)], _c_u32(2147483648), nullcontext()), + (OperationType.multiply_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.multiply_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.divide, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), + (OperationType.divide, [_c_i32(-2147483648), _c_i32(2)], _c_i32(-1073741824), nullcontext()), + (OperationType.divide, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.divide, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.divide_us, [_c_i32(12), _c_i32(4)], _c_i32(3), nullcontext()), + (OperationType.divide_us, [_c_i32(-2147483648), _c_i32(2)], _c_i32(1073741824), nullcontext()), + (OperationType.divide_us, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.divide_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.negate, [_c_i32(3)], _c_i32(-3), nullcontext()), + (OperationType.negate, [_c_i32(-2147483648)], _c_i32(-2147483648), nullcontext()), + (OperationType.negate, [], None, pytest.raises(ValueError)), + (OperationType.negate, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.left_shift, [_c_i32(3), _c_i32(4)], _c_i32(48), nullcontext()), + (OperationType.left_shift, [_c_i32(1073741824), _c_i32(1)], _c_i32(-2147483648), nullcontext()), + (OperationType.left_shift, [_c_u32(1073741824), _c_u32(1)], _c_u32(2147483648), nullcontext()), + (OperationType.left_shift, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.left_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.right_shift, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), + (OperationType.right_shift, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-1073741824), nullcontext()), + (OperationType.right_shift, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), + (OperationType.right_shift, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.right_shift, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.right_shift_us, [_c_i32(32), _c_i32(4)], _c_i32(2), nullcontext()), + (OperationType.right_shift_us, [_c_i32(-2147483648), _c_i32(1)], _c_i32(1073741824), nullcontext()), + (OperationType.right_shift_us, [_c_u32(2147483648), _c_u32(1)], _c_u32(1073741824), nullcontext()), + (OperationType.right_shift_us, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.right_shift_us, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_or, [_c_i32(85), _c_i32(34)], _c_i32(119), nullcontext()), + (OperationType.bitwise_or, [_c_i32(-2147483648), _c_i32(1)], _c_i32(-2147483647), nullcontext()), + (OperationType.bitwise_or, [_c_u32(2147483648), _c_u32(1)], _c_u32(2147483649), nullcontext()), + (OperationType.bitwise_or, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_or, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_and, [_c_i32(85), _c_i32(51)], _c_i32(17), nullcontext()), + (OperationType.bitwise_and, [_c_i32(-2147483647), _c_i32(3)], _c_i32(1), nullcontext()), + (OperationType.bitwise_and, [_c_u32(2147483649), _c_u32(3)], _c_u32(1), nullcontext()), + (OperationType.bitwise_and, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_and, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_xor, [_c_i32(85), _c_i32(51)], _c_i32(102), nullcontext()), + (OperationType.bitwise_xor, [_c_i32(-2147483647), _c_i32(-2147483646)], _c_i32(3), nullcontext()), + (OperationType.bitwise_xor, [_c_u32(2147483649), _c_u32(2147483650)], _c_u32(3), nullcontext()), + (OperationType.bitwise_xor, [_c_u32(3), _c_i32(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3), _c_i16(4)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3)], None, pytest.raises(ValueError)), + (OperationType.bitwise_xor, [_c_i32(3), _c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + + (OperationType.bitwise_not, [_c_i32(6)], _c_i32(-7), nullcontext()), + (OperationType.bitwise_not, [_c_i32(-2147483648)], _c_i32(2147483647), nullcontext()), + (OperationType.bitwise_not, [_c_u32(2147483648)], _c_u32(2147483647), nullcontext()), + (OperationType.bitwise_not, [], None, pytest.raises(ValueError)), + (OperationType.bitwise_not, [_c_i32(3), _c_i32(3)], None, pytest.raises(ValueError)), + ] +) +def test_constant_fold(operation: OperationType, constants: list[Constant], result: Constant, context): + with context: + assert constant_fold(operation, constants) == result diff --git a/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py b/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py new file mode 100644 index 000000000..a9b74833c --- /dev/null +++ b/tests/pipeline/controlflowanalysis/expression_simplification/test_stage.py @@ -0,0 +1,109 @@ +import pytest +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_constants import CollapseConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.collapse_nested_constants import CollapseNestedConstants +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.simplify_trivial_arithmetic import SimplifyTrivialArithmetic +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.sub_to_add import SubToAdd +from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.term_order import TermOrder +from decompiler.pipeline.controlflowanalysis.expression_simplification.stages import _ExpressionSimplificationBase +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Constant, + Expression, + Instruction, + Integer, + Operation, + OperationType, + Variable, +) + + +class _RedundantChanges(SimplificationRule): + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + return [(operation, operation)] + + +class _NoChanges(SimplificationRule): + def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: + return [] + + +def _add(left: Expression, right: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.plus, [left, right]) + + +def _sub(left: Expression, right: Expression) -> BinaryOperation: + return BinaryOperation(OperationType.minus, [left, right]) + + +def _c_i32(value: int) -> Constant: + return Constant(value, Integer.int32_t()) + + +def _v_i32(name: str) -> Variable: + return Variable(name, Integer.int32_t()) + + +@pytest.mark.parametrize( + ["rule_set", "instruction", "expected_result"], + [ + ( + [TermOrder()], + Assignment(_v_i32("a"), _add(_c_i32(1), _v_i32("b"))), + Assignment(_v_i32("a"), _add(_v_i32("b"), _c_i32(1))) + ), + ( + [CollapseConstants()], + Assignment(_v_i32("a"), _sub(_c_i32(10), _add(_c_i32(3), _c_i32(2)))), + Assignment(_v_i32("a"), _c_i32(5)) + ), + ( + [SubToAdd(), SimplifyTrivialArithmetic(), CollapseConstants(), CollapseNestedConstants()], + Assignment(_v_i32("a"), _sub(_add(_v_i32("a"), _c_i32(5)), _c_i32(5))), + Assignment(_v_i32("a"), _v_i32("a")) + ), + ] +) +def test_simplify_instructions_with_rule_set( + rule_set: list[SimplificationRule], + instruction: Instruction, + expected_result: Instruction +): + _ExpressionSimplificationBase._simplify_instructions_with_rule_set( + [instruction], + rule_set, + 100 + ) + assert instruction == expected_result + + +@pytest.mark.parametrize( + ["rule_set", "instruction", "max_iterations", "expect_exceed_max_iterations"], + [ + ( + [_RedundantChanges()], + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Constant(1), Variable("b")])), + 10, + True + ), + ( + [_NoChanges()], + Assignment(_v_i32("a"), _v_i32("b")), + 0, + False + ) + ] +) +def test_simplify_instructions_with_rule_set_max_iterations( + rule_set: list[SimplificationRule], + instruction: Instruction, + max_iterations: int, + expect_exceed_max_iterations: bool +): + iterations = _ExpressionSimplificationBase._simplify_instructions_with_rule_set( + [instruction], + rule_set, + max_iterations + ) + assert (iterations > max_iterations) == expect_exceed_max_iterations diff --git a/tests/pipeline/controlflowanalysis/test_expression_simplification.py b/tests/pipeline/controlflowanalysis/test_expression_simplification.py deleted file mode 100644 index 11a9a7042..000000000 --- a/tests/pipeline/controlflowanalysis/test_expression_simplification.py +++ /dev/null @@ -1,244 +0,0 @@ -from typing import Optional - -import pytest -from decompiler.pipeline.controlflowanalysis import ExpressionSimplification -from decompiler.structures.ast.ast_nodes import CodeNode -from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree -from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph -from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.structures.pseudo.expressions import Constant, ImportedFunctionSymbol, Variable -from decompiler.structures.pseudo.instructions import Assignment, Instruction, Return -from decompiler.structures.pseudo.operations import BinaryOperation, Call, ListOperation, OperationType, UnaryOperation -from decompiler.structures.pseudo.typing import Integer -from decompiler.task import DecompilerTask - - -def _task(ast: Optional[AbstractSyntaxTree] = None, cfg: Optional[ControlFlowGraph] = None) -> DecompilerTask: - cfg = ControlFlowGraph() if cfg is None else cfg - task = DecompilerTask("test_function", cfg, ast) - return task - - -x = Variable("x") -y = Variable("y") -const_0 = Constant(0, Integer(32, signed=True)) -const_m0 = Constant(-0, Integer(32, signed=True)) -const_1 = Constant(1, Integer(32, signed=True)) -const_m1 = Constant(-1, Integer(32, signed=True)) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.plus, [x, const_0])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.plus, [const_0, x])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.plus, [const_0, const_0])), Assignment(y, const_0)), - (Assignment(y, BinaryOperation(OperationType.plus, [x, const_m0])), Assignment(y, x)), - ( - Assignment(y, BinaryOperation(OperationType.plus, [x, BinaryOperation(OperationType.plus, [const_0, const_0])])), - Assignment(y, x), - ), - ( - Assignment(y, BinaryOperation(OperationType.plus, [const_0, BinaryOperation(OperationType.plus, [const_0, x])])), - Assignment(y, x), - ), - ], -) -def test_easy_simplification_with_zero_addition(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.multiply, [x, const_0])), Assignment(y, const_0)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_0, x])), Assignment(y, const_0)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_0, const_0])), Assignment(y, const_0)), - (Assignment(y, BinaryOperation(OperationType.multiply, [x, const_m0])), Assignment(y, const_0)), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [x, BinaryOperation(OperationType.multiply, [x, const_0])])), - Assignment(y, const_0), - ), - ], -) -def test_simplification_with_zero_multiplication(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.minus, (x, const_0))), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.minus, (const_0, x))), Assignment(y, UnaryOperation(OperationType.negate, [x]))), - (Assignment(y, BinaryOperation(OperationType.minus, [const_0, UnaryOperation(OperationType.negate, [x])])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.minus, [const_0, const_m1])), Assignment(y, const_1)), - (Assignment(y, BinaryOperation(OperationType.minus, (x, const_m0))), Assignment(y, x)), - ], -) -def test_simplification_with_zero_subtraction(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - ( - Assignment(y, BinaryOperation(OperationType.plus, [x, BinaryOperation(OperationType.multiply, [x, const_0])])), - Assignment(y, x), - ), - ( - Assignment(y, BinaryOperation(OperationType.minus, [y, BinaryOperation(OperationType.minus, [x, const_0])])), - Assignment(y, BinaryOperation(OperationType.minus, [y, x])), - ), - ], -) -def test_simplification_with_zero_mix(instruction, result): - cfg = ControlFlowGraph() - cfg.add_node(BasicBlock(1, [instruction])) - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict()), cfg) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.multiply, [x, const_1])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_1, x])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_1, const_1])), Assignment(y, const_1)), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [x, const_m1])), - Assignment(y, UnaryOperation(OperationType.negate, [x])), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [Constant(2, Integer(32, signed=False)), const_m1])), - Assignment(y, UnaryOperation(OperationType.negate, [Constant(2, Integer(32, signed=False))])), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [Constant(2, Integer(32, signed=True)), const_m1])), - Assignment(y, Constant(-2, Integer(32, signed=True))), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [Constant(-2, Integer(32, signed=True)), const_m1])), - Assignment(y, Constant(2, Integer(32, signed=True))), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [x, BinaryOperation(OperationType.multiply, [const_1, const_1])])), - Assignment(y, x), - ), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [const_1, BinaryOperation(OperationType.multiply, [x, const_1])])), - Assignment(y, x), - ), - ], -) -def test_simplification_with_one_multiplication(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.divide, [x, const_1])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_1, const_1])), Assignment(y, const_1)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_m1, const_1])), Assignment(y, const_m1)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_1, const_m1])), Assignment(y, const_m1)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_m1, const_m1])), Assignment(y, const_1)), - (Assignment(y, BinaryOperation(OperationType.divide, [x, const_m1])), Assignment(y, UnaryOperation(OperationType.negate, [x]))), - ( - Assignment(y, BinaryOperation(OperationType.divide, [const_0, const_1])), - Assignment(y, const_0), - ), - ( - Assignment(y, BinaryOperation(OperationType.divide, [const_0, const_m1])), - Assignment(y, const_0), - ), - ( - Assignment(y, BinaryOperation(OperationType.divide, [Constant(2, Integer(32, signed=False)), const_m1])), - Assignment(y, UnaryOperation(OperationType.negate, [Constant(2, Integer(32, signed=False))])), - ), - ], -) -def test_simplification_with_one_division(instruction, result): - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - task = _task(AbstractSyntaxTree(CodeNode([instruction], true_value.copy()), dict())) - ExpressionSimplification().run(task) - assert task.syntax_tree.root == CodeNode([result], true_value.copy()) - - -@pytest.mark.parametrize( - "instruction, result", - [ - ( - Assignment(UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("x")])]), Constant(0)), - Assignment(Variable("x"), Constant(0)), - ), - ( - Assignment( - ListOperation([]), - Call( - ImportedFunctionSymbol("foo", 0x42), - [ - BinaryOperation( - OperationType.minus, - [ - UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("x")])]), - Constant(2), - ], - ) - ], - ), - ), - Assignment( - ListOperation([]), - Call(ImportedFunctionSymbol("foo", 0x42), [BinaryOperation(OperationType.minus, [Variable("x"), Constant(2)])]), - ), - ), - ( - Return([UnaryOperation(OperationType.dereference, [UnaryOperation(OperationType.address, [Variable("y")])])]), - Return([Variable("y")]), - ), - ], -) -def test_simplification_of_dereference_operations(instruction: Instruction, result: Instruction): - """Check if dereference operations with address-of operands are simplified correctly.""" - ExpressionSimplification().simplify(instruction) - assert instruction == result - - -@pytest.mark.parametrize( - "instruction, result", - [ - (Assignment(y, BinaryOperation(OperationType.divide, [x, const_1])), Assignment(y, x)), - (Assignment(y, BinaryOperation(OperationType.divide, [const_1, const_1])), Assignment(y, const_1)), - (Assignment(y, BinaryOperation(OperationType.multiply, [const_1, const_1])), Assignment(y, const_1)), - ( - Assignment(y, BinaryOperation(OperationType.multiply, [x, const_m1])), - Assignment(y, UnaryOperation(OperationType.negate, [x])), - ), - ( - Assignment(y, BinaryOperation(OperationType.plus, [x, BinaryOperation(OperationType.multiply, [x, const_0])])), - Assignment(y, x), - ), - (Assignment(y, BinaryOperation(OperationType.plus, [x, const_m0])), Assignment(y, x)), - ], -) -def test_for_cfg(instruction, result): - cfg = ControlFlowGraph() - cfg.add_node(BasicBlock(0, [instruction])) - task = _task(cfg=cfg) - ExpressionSimplification().run(task) - assert list(task.graph.instructions) == [result] diff --git a/tests/pipeline/controlflowanalysis/test_loop_name_generator.py b/tests/pipeline/controlflowanalysis/test_loop_name_generator.py new file mode 100644 index 000000000..8a1a1aeec --- /dev/null +++ b/tests/pipeline/controlflowanalysis/test_loop_name_generator.py @@ -0,0 +1,1334 @@ +from typing import List + +import pytest +from decompiler.pipeline.controlflowanalysis.loop_name_generator import ForLoopVariableRenamer, LoopNameGenerator, WhileLoopVariableRenamer +from decompiler.pipeline.controlflowanalysis.loop_utility_methods import _initialization_reaches_loop_node +from decompiler.pipeline.controlflowanalysis.readability_based_refinement import ReadabilityBasedRefinement +from decompiler.structures.ast.ast_nodes import CaseNode, CodeNode, ConditionNode, ForLoopNode, SeqNode, SwitchNode, WhileLoopNode +from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree +from decompiler.structures.logic.logic_condition import LogicCondition +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Break, + Call, + Condition, + Constant, + ImportedFunctionSymbol, + ListOperation, + OperationType, + Variable, +) +from decompiler.structures.pseudo.operations import ArrayInfo, OperationType, UnaryOperation +from decompiler.task import DecompilerTask +from decompiler.util.options import Options + +# Test For/WhileLoop Renamer + +def logic_cond(name: str, context) -> LogicCondition: + return LogicCondition.initialize_symbol(name, context) + +@pytest.fixture +def ast_call_for_loop() -> AbstractSyntaxTree: + """ + a = 5; + while(b = foo; b <= 5; b++){ + a++; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, + ) + code_node = ast._add_code_node( + instructions=[ + Assignment(Variable("a"), Constant(5)), + ] + ) + loop_node = ast.factory.create_for_loop_node(Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), logic_cond("x1", context), Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("1")])), + ] + ) + ast._add_node(loop_node) + ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) + root._sorted_children = (code_node, loop_node) + return ast + + +def test_declaration_listop(ast_call_for_loop): + """Test renaming with ListOperation as Declaration""" + ForLoopVariableRenamer(ast_call_for_loop, ["i"]).rename() + for node in ast_call_for_loop: + if isinstance(node, ForLoopNode): + assert node.declaration.destination.operands[0].name == "i" + + +def test_for_loop_variable_generation(): + renamer = ForLoopVariableRenamer( + AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}), + ["i", "j", "k", "l", "m", "n"] + ) + assert [renamer._get_variable_name() for _ in range(14)] == [ + "i", + "j", + "k", + "l", + "m", + "n", + "i1", + "j1", + "k1", + "l1", + "m1", + "n1", + "i2", + "j2", + ] + + +def test_while_loop_variable_generation(): + renamer = WhileLoopVariableRenamer( + AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}) + ) + assert [renamer._get_variable_name() for _ in range(5)] == ["counter", "counter1", "counter2", "counter3", "counter4"] + +# Test Readabilitybasedrefinement + LoopNameGenerator together + + +def _generate_options(empty_loops: bool = False, hide_decl: bool = False, rename_for: bool = True, rename_while: bool = True, \ + max_condition: int = 100, max_modification: int = 100, force_for_loops: bool = False, blacklist : List[str] = []) -> Options: + options = Options() + options.set("readability-based-refinement.keep_empty_for_loops", empty_loops) + options.set("readability-based-refinement.hide_non_initializing_declaration", hide_decl) + options.set("readability-based-refinement.max_condition_complexity_for_loop_recovery", max_condition) + options.set("readability-based-refinement.max_modification_complexity_for_loop_recovery", max_modification) + options.set("readability-based-refinement.force_for_loops", force_for_loops) + options.set("readability-based-refinement.forbidden_condition_types_in_simple_for_loops", blacklist) + if rename_for: + names = ["i", "j", "k", "l", "m", "n"] + options.set("loop-name-generator.for_loop_variable_names", names) + options.set("loop-name-generator.rename_while_loop_variables", rename_while) + return options + + +@pytest.fixture +def ast_array_access_for_loop() -> AbstractSyntaxTree: + """ + for (var_0 = 0; var_0 < 10; var_0 = var_0 + 1) { + *(var_1 + var_0) = var_0; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("var_0"), Constant(10)])}, + ) + declaration = Assignment(Variable("var_0"), Constant(0)) + condition = logic_cond("x1", context) + modification = Assignment(Variable("var_0"), BinaryOperation(OperationType.plus, [Variable("var_0"), Constant(1)])) + for_loop = ast.factory.create_for_loop_node(declaration, condition, modification) + array_info = ArrayInfo(Variable("var_1"), Variable("var_0")) + array_access_unary_operation = UnaryOperation( + OperationType.dereference, [BinaryOperation(OperationType.plus, [Variable("var_1"), Variable("var_0")])], array_info=array_info + ) + for_loop_body = ast._add_code_node([Assignment(array_access_unary_operation, Variable("var_0"))]) + ast._add_node(for_loop) + ast._add_edges_from([(root, for_loop), (for_loop, for_loop_body)]) + return ast + + +@pytest.fixture +def ast_while_true() -> AbstractSyntaxTree: + """ + a = 0; + b = 0; + while(true){ + a = a + 1; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree(root := SeqNode(true_value), {}) + code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + loop_node = ast.add_endless_loop_with_body(loop_node_body) + ast._add_edges_from(((root, code_node), (root, loop_node))) + return ast + + +@pytest.fixture +def ast_single_instruction_while() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 10) { + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) + return ast + + +@pytest.fixture +def ast_replaceable_while() -> AbstractSyntaxTree: + """ + a = 0; + while (x < 10) { + printf("counter: %d", x); + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) + root._sorted_children = (init_code_node, while_loop) + return ast + + +@pytest.fixture +def ast_replaceable_while_usage() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 10) { + printf("counter: %d", a); + a = a + 1; + } + printf("final counter: %d", a); + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + exit_code_node = ast._add_code_node( + [Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")]))] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) + return ast + + +@pytest.fixture +def ast_replaceable_while_reinit_usage() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 10) { + printf("counter: %d", a); + a = a + 1; + } + a = 50; + printf("50 = %d", a); + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + exit_code_node = ast._add_code_node( + [ + Assignment(Variable("a"), Constant(50)), + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) + return ast + + +@pytest.fixture +def ast_replaceable_while_compound_usage() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 10) { + printf("counter: %d", a); + a = a + 1; + } + a = a + 50; + printf("50 = %d", a); + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + exit_code_node = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(50)])), + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) + return ast + + +@pytest.fixture +def ast_endless_while_init_outside() -> AbstractSyntaxTree: + """ + a = 0; + while (true) { + while (a < 5) { + printf("%d\n", a); + a = a + 1; + } + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) + ast._add_node(inner_while) + endless_loop = ast.add_endless_loop_with_body(inner_while) + + inner_while_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_edges_from([(root, init_code_node), (root, endless_loop), (endless_loop, inner_while), (inner_while, inner_while_body)]) + return ast + + +@pytest.fixture +def ast_nested_while() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 1) { + b = 0; + while (b < 1) { + b = b + 1; + } + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(5)]), + }, + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + outer_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) + outer_while_body = ast.factory.create_seq_node() + outer_while_init = ast._add_code_node([Assignment(Variable("b"), Constant(0))]) + outer_while_exit = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + + inner_while = ast.factory.create_while_loop_node(logic_cond("x2", context)) + inner_while_body = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + + ast._add_nodes_from((outer_while, outer_while_body, inner_while)) + ast._add_edges_from( + [ + (root, init_code_node), + (root, outer_while), + (outer_while, outer_while_body), + (outer_while_body, outer_while_init), + (outer_while_body, inner_while), + (outer_while_body, outer_while_exit), + (inner_while, inner_while_body), + ] + ) + return ast + + +@pytest.fixture +def ast_call_init() -> AbstractSyntaxTree: + """ + a = 5; + b = foo(); + while(b <= 5){ + a = a + b; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, + ) + code_node = ast._add_code_node( + instructions=[ + Assignment(Variable("a"), Constant(5)), + Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), + ] + ) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_node(loop_node) + ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) + root._sorted_children = (code_node, loop_node) + return ast + + +@pytest.fixture +def ast_redundant_init() -> AbstractSyntaxTree: + """ + b = 0; + a = 5; + b = 2; + + while(b <= 5){ + a = a + b; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} + ) + code_node = ast._add_code_node( + instructions=[ + Assignment(Variable("b"), Constant(0)), + Assignment(Variable("a"), Constant(5)), + Assignment(Variable("b"), Constant(2)), + ] + ) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_node(loop_node) + ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) + root._sorted_children = (code_node, loop_node) + return ast + + +@pytest.fixture +def ast_reinit_in_condition_true() -> AbstractSyntaxTree: + """ + int x = 1; + int i = 0; + + if (x == 1) { + i = 1; + } + + while (i < 10) { + x = x * 2; + i = i + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("a", context): Condition(OperationType.less, [Variable("i"), Constant(10)]), + logic_cond("b", context): Condition(OperationType.equal, [Variable("x"), Constant(1)]), + }, + ) + code_node = ast._add_code_node(instructions=[Assignment(Variable("x"), Constant(1)), Assignment(Variable("i"), Constant(0))]) + code_node_true = ast._add_code_node([Assignment(Variable("i"), Constant(1))]) + condition_node = ast._add_condition_node_with(logic_cond("b", context), code_node_true) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("a", context)) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("x"), BinaryOperation(OperationType.multiply, [Variable("x"), Constant(2)])), + Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])), + ] + ) + ast._add_nodes_from((condition_node, loop_node)) + ast._add_edges_from(((root, code_node), (root, condition_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) + root._sorted_children = (code_node, loop_node) + return ast + + +@pytest.fixture +def ast_usage_in_condition() -> AbstractSyntaxTree: + """ + int a = 1; + int b = 0; + + if (b == 1) { + a = 1; + } + + while (b < 10) { + a = a * 2; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), + logic_cond("x2", context): Condition(OperationType.equal, [Variable("b"), Constant(1)]), + }, + ) + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(1)), Assignment(Variable("b"), Constant(0))]) + code_node_true = ast._add_code_node([Assignment(Variable("a"), Constant(1))]) + condition_node = ast._add_condition_node_with(logic_cond("x2", context), code_node_true) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.multiply, [Variable("a"), Constant(2)])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_node(loop_node) + ast._add_edges_from(((root, init_code_node), (root, condition_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(init_code_node, loop_node_body) + root._sorted_children = (init_code_node, loop_node) + return ast + + +@pytest.fixture +def ast_sequenced_while_loops() -> AbstractSyntaxTree: + """ + a = 0; + b = 0; + + while (a < 5) { + printf("%d\n", a); + a++; + } + + while (b < 5) { + printf("%d\n", b); + b++; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(5)]), + }, + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) + + while_loop_1 = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_1_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + while_loop_2 = ast.factory.create_while_loop_node(logic_cond("x2", context)) + while_loop_2_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("b")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + + ast._add_nodes_from((while_loop_1, while_loop_2)) + ast._add_edges_from( + ( + (root, init_code_node), + (root, while_loop_1), + (root, while_loop_2), + (while_loop_1, while_loop_1_body), + (while_loop_2, while_loop_2_body), + ) + ) + return ast + + +@pytest.fixture +def ast_switch_as_loop_body() -> AbstractSyntaxTree: + """ + This while-loop should not be replaced with a for-loop because we don't know wich value 'a' has. + + Code of AST: + a = 5; + b = 0; + while(b <= 5){ + switch(a) { + case 0: + a = a + b: + break; + case 1: + b = b + 1; + break; + } + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("a", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, + ) + code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("a", context)) + root._sorted_children = (code_node, loop_node) + loop_body_switch = ast.factory.create_switch_node(Variable("a")) + loop_body_case_1 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) + code_node_case_1 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) + loop_body_case_2 = ast.factory.create_case_node(Variable("a"), Constant(1), break_case=True) + code_node_case_2 = ast._add_code_node( + [ + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_nodes_from((code_node, loop_node, loop_body_switch, loop_body_case_1, loop_body_case_2)) + ast._add_edges_from( + ( + (root, code_node), + (root, loop_node), + (loop_node, loop_body_switch), + (loop_body_switch, loop_body_case_1), + (loop_body_switch, loop_body_case_2), + (loop_body_case_1, code_node_case_1), + (loop_body_case_2, code_node_case_2), + ) + ) + ast._code_node_reachability_graph.add_reachability_from(((code_node, code_node_case_1), (code_node, code_node_case_2))) + return ast + + +@pytest.fixture +def ast_switch_as_loop_body_increment() -> AbstractSyntaxTree: + """ + This loop should be replaced with a for-loop because b has no usages after last definition, is in condition and is initialized + before loop without any usages in between. + + Code of AST: + a = 5; + b = 0; + while(b <= 5){ + switch(a) { + case 0: + a = a + b: + break; + case 1: + b = b + 1; + break; + } + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_seq = ast.factory.create_seq_node() + + switch_node = ast.factory.create_switch_node(Variable("a")) + case_1 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) + case_1_code = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) + case_2 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) + case_2_code = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + + increment_code = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + + ast._add_nodes_from((while_loop, while_loop_seq, switch_node, case_1, case_2)) + ast._add_edges_from( + [ + (root, init_code_node), + (root, while_loop), + (while_loop, while_loop_seq), + (while_loop_seq, switch_node), + (while_loop_seq, increment_code), + (switch_node, case_1), + (switch_node, case_2), + (case_1, case_1_code), + (case_2, case_2_code), + ] + ) + return ast + + +@pytest.fixture +def ast_init_in_switch() -> AbstractSyntaxTree: + """ + a = 5; + b = 0; + switch(a){ + case 0: + a = b; + } + while(b <= (5 + a)){ + a = a + b; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition( + OperationType.less_or_equal, + [Variable("b"), BinaryOperation(OperationType.plus, [Constant(5), Variable("a")])], + ) + }, + ) + init_code_node = ast._add_code_node(instructions=[Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) + switch_node = ast.factory.create_switch_node(Variable("a")) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) + case_node = ast.factory.create_case_node(Variable("a"), Constant(0)) + case_child = ast._add_code_node([Assignment(Variable("a"), Variable("b"))]) + loop_body = ast.factory.create_seq_node() + loop_body_child = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_nodes_from((switch_node, loop_node, loop_body, case_node)) + ast._add_edges_from( + ( + (root, init_code_node), + (root, switch_node), + (switch_node, case_node), + (case_node, case_child), + (root, loop_node), + (loop_node, loop_body), + (loop_body, loop_body_child), + ) + ) + ast._code_node_reachability_graph.add_reachability_from([(case_child, loop_body_child)]) + root._sorted_children = (init_code_node, switch_node, loop_node) + loop_body._sorted_children = (loop_body_child,) + switch_node._sorted_cases = (case_node,) + return ast + + +@pytest.fixture +def ast_while_in_else() -> AbstractSyntaxTree: + """ + while (true) { + if (b < 2) { + break; + } else { + a = 0; + while (a < 5) { + printf("%d\n", a); + a = a + 1; + } + } + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(2)]), + }, + ) + + inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) + ast._add_node(inner_while) + + true_branch_child = ast._add_code_node([Break()]) + inner_seq = ast.factory.create_seq_node() + ast._add_node(inner_seq) + condition_node = ast._add_condition_node_with(logic_cond("x2", context), true_branch_child, inner_seq) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + endless_loop = ast.add_endless_loop_with_body(condition_node) + + inner_while_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_edges_from( + [ + (root, endless_loop), + (endless_loop, condition_node), + (inner_seq, init_code_node), + (inner_seq, inner_while), + (inner_while, inner_while_body), + ] + ) + return ast + + +@pytest.fixture +def ast_continuation_is_not_first_var() -> AbstractSyntaxTree: + """ + a = 0; + b = 0; + while (a < b) { + printf("%d\n", a); + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Variable("b")])}, + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) + root._sorted_children = (init_code_node, while_loop) + return ast + + +@pytest.fixture +def ast_initialization_in_condition() -> AbstractSyntaxTree: + """ + if(b < 10 ){ + a = 5; + while (x < 10) { + printf("counter: %d", a); + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x0", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), + }, + ) + + true_branch = ast._add_code_node([Assignment(Variable("a"), Constant(5))]) + if_condition = ast._add_condition_node_with(logic_cond("x0", context), true_branch) + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, if_condition), (root, while_loop), (while_loop, while_loop_body)]) + root._sorted_children = (if_condition, while_loop) + return ast + + +@pytest.fixture +def ast_initialization_in_condition_sequence() -> AbstractSyntaxTree: + """ + if(b < 10 ){ + if(c < 10){ + b = 5; + } + a = 5; + while (x < 10) { + printf("counter: %d", a); + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x0", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), + logic_cond("x1", context): Condition(OperationType.less, [Variable("c"), Constant(10)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), + }, + ) + + true_branch_c = ast._add_code_node([Assignment(Variable("b"), Constant(5))]) + code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5))]) + if_condition_c = ast._add_condition_node_with(logic_cond("x1", context), true_branch_c) + ast._add_node(true_branch_b := ast.factory.create_seq_node()) + if_condition_b = ast._add_condition_node_with(logic_cond("x1", context), true_branch_b) + while_loop = ast.factory.create_while_loop_node(logic_cond("x2", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from( + [ + (root, if_condition_b), + (root, while_loop), + (while_loop, while_loop_body), + (true_branch_b, if_condition_c), + (true_branch_b, code_node), + ] + ) + true_branch_b._sorted_children = (if_condition_c, code_node) + root._sorted_children = (if_condition_b, while_loop) + return ast + + +class TestReadabilityBasedRefinementAndLoopNameGenerator: + """Test Readability functionality with all its substages.""" + + @staticmethod + def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): + task = DecompilerTask("func", cfg=None, ast=ast, options=options) + ReadabilityBasedRefinement().run(task) + LoopNameGenerator().run(task) + + + def test_no_replacement(self, ast_while_true): + self.run_rbr(ast_while_true) + assert all(not isinstance(node, ForLoopNode) for node in ast_while_true.topological_order()) + + def test_simple_replacement(self, ast_replaceable_while): + self.run_rbr(ast_replaceable_while) + + assert ast_replaceable_while.condition_map == { + logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("i"), Constant(10)]) + } + + loop_node = ast_replaceable_while.root + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_body = loop_node.body + assert isinstance(loop_body, CodeNode) + assert loop_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("i")])), + ] + + def test_with_usage(self, ast_replaceable_while_usage): + self.run_rbr(ast_replaceable_while_usage) + + for_loop = ast_replaceable_while_usage.root.children[0] + assert isinstance(for_loop, ForLoopNode) + assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) + + copy_instr_node = ast_replaceable_while_usage.root.children[1] + assert isinstance(copy_instr_node, CodeNode) + assert copy_instr_node.instructions == [Assignment(Variable("a"), Variable("i"))] + + def test_with_usage_redefinition(self, ast_replaceable_while_reinit_usage): + self.run_rbr(ast_replaceable_while_reinit_usage) + + for_loop = ast_replaceable_while_reinit_usage.root.children[0] + assert isinstance(for_loop, ForLoopNode) + assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) + assert for_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + exit_code_node = ast_replaceable_while_reinit_usage.root.children[1] + assert isinstance(exit_code_node, CodeNode) + assert exit_code_node.instructions == [ + Assignment(Variable("a"), Constant(50)), + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), + ] + + def test_with_usage_redefenition_2(self, ast_replaceable_while_compound_usage): + self.run_rbr(ast_replaceable_while_compound_usage) + + for_loop = ast_replaceable_while_compound_usage.root.children[0] + assert isinstance(for_loop, ForLoopNode) + assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) + assert for_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + copy_instr_node = ast_replaceable_while_compound_usage.root.children[1] + assert isinstance(copy_instr_node, CodeNode) + assert copy_instr_node.instructions == [Assignment(Variable("a"), Variable("i"))] + + def test_continuation_is_not_first_var(self, ast_continuation_is_not_first_var): + self.run_rbr(ast_continuation_is_not_first_var) + + init_code_node = ast_continuation_is_not_first_var.root.children[0] + assert isinstance(init_code_node, CodeNode) + assert init_code_node.instructions == [Assignment(Variable("a"), Constant(0))] + + loop_node = ast_continuation_is_not_first_var.root.children[1] + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_node_body = loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])) + ] + + def test_init_with_call(self, ast_call_init): + self.run_rbr(ast_call_init, _generate_options(rename_for=True)) + + code_node = ast_call_init.root.children[0] + assert isinstance(code_node, CodeNode) + assert code_node.instructions == [Assignment(Variable("a"), Constant(5))] + + for_loop_node = ast_call_init.root.children[1] + assert isinstance(for_loop_node, ForLoopNode) + assert for_loop_node.declaration == Assignment(Variable("i"), Call(ImportedFunctionSymbol("foo", 0), [])) + assert for_loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_node_body = for_loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("i")])) + ] + + assert for_loop_node.condition == logic_cond("x1", context := LogicCondition.generate_new_context()) + assert ast_call_init.condition_map == { + logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("i"), Constant(5)]) + } + + def test_double_init(self, ast_redundant_init): + self.run_rbr(ast_redundant_init) + + code_node = ast_redundant_init.root.children[0] + assert isinstance(code_node, CodeNode) + assert code_node.instructions == [ + Assignment(Variable("b"), Constant(0)), + Assignment(Variable("a"), Constant(5)), + Assignment(Variable("b"), Constant(2)), + ] + + for_loop_node = ast_redundant_init.root.children[1] + assert isinstance(for_loop_node, ForLoopNode) + assert for_loop_node.declaration == Variable("b") + assert for_loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) + + loop_node_body = for_loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), + ] + + assert for_loop_node.condition == logic_cond("x1", context := LogicCondition.generate_new_context()) + assert ast_redundant_init.condition_map == {logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} + + def test_double_init_condition_node(self, ast_reinit_in_condition_true): + self.run_rbr(ast_reinit_in_condition_true) + + def test_init_in_switch(self, ast_init_in_switch): + self.run_rbr(ast_init_in_switch) + + init_code_node = ast_init_in_switch.root.children[0] + assert isinstance(init_code_node, CodeNode) + assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))] + + loop_node = ast_init_in_switch.root.children[2] + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Variable("b") + assert loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) + + loop_node_body = loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])) + ] + + def test_usage_in_condition(self, ast_usage_in_condition): + self.run_rbr(ast_usage_in_condition) + + code_node = ast_usage_in_condition.root.children[0] + assert isinstance(code_node, CodeNode) + assert code_node.instructions == [Assignment(Variable("a"), Constant(1)), Assignment(Variable("b"), Constant(0))] + + condition_node = ast_usage_in_condition.root.children[1] + assert isinstance(condition_node, ConditionNode) + assert condition_node.condition == logic_cond("x2", context := LogicCondition.generate_new_context()) + + loop_node = ast_usage_in_condition.root.children[2] + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Variable("b") + assert loop_node.condition == logic_cond("x1", context) + assert loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) + + loop_body = loop_node.body + assert isinstance(loop_body, CodeNode) + assert loop_body.instructions == [Assignment(Variable("a"), BinaryOperation(OperationType.multiply, [Variable("a"), Constant(2)]))] + + def test_while_in_else(self, ast_while_in_else): + self.run_rbr(ast_while_in_else) + + endless_loop = ast_while_in_else.root + assert isinstance(endless_loop, WhileLoopNode) + + condition_node = endless_loop.body + assert isinstance(condition_node, ConditionNode) + + loop_node = condition_node.false_branch_child + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_node_body = loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("i")])) + ] + + def test_nested_while(self, ast_nested_while): + self.run_rbr(ast_nested_while, _generate_options(empty_loops=True)) + + outer_loop = ast_nested_while.root + assert isinstance(outer_loop, ForLoopNode) + assert outer_loop.declaration == Assignment(Variable("i"), Constant(0)) + assert ast_nested_while.condition_map[outer_loop.condition] == Condition(OperationType.less, [Variable("i"), Constant(5)]) + assert outer_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + inner_loop = outer_loop.children[0] + assert isinstance(inner_loop, ForLoopNode) + assert inner_loop.declaration == Assignment(Variable("j"), Constant(0)) + assert ast_nested_while.condition_map[inner_loop.condition] == Condition(OperationType.less, [Variable("j"), Constant(5)]) + assert inner_loop.modification == Assignment(Variable("j"), BinaryOperation(OperationType.plus, [Variable("j"), Constant(1)])) + + def test_nested_while_loop(self, ast_endless_while_init_outside): + self.run_rbr(ast_endless_while_init_outside) + + endless_loop = ast_endless_while_init_outside.root.children[1] + assert isinstance(endless_loop, WhileLoopNode) + + for_loop = endless_loop.body + assert isinstance(for_loop, ForLoopNode) + assert for_loop.declaration == Variable("a") + + def test_sequenced_loops(self, ast_sequenced_while_loops): + self.run_rbr(ast_sequenced_while_loops) + + loop_1 = ast_sequenced_while_loops.root.children[0] + assert isinstance(loop_1, ForLoopNode) + assert loop_1.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_1.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_1_body = loop_1.body + assert isinstance(loop_1_body, CodeNode) + assert loop_1_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("i")])), + ] + + loop_2 = ast_sequenced_while_loops.root.children[1] + assert isinstance(loop_2, ForLoopNode) + assert loop_2.declaration == Assignment(Variable("j"), Constant(0)) + assert loop_2.modification == Assignment(Variable("j"), BinaryOperation(OperationType.plus, [Variable("j"), Constant(1)])) + + loop_2_body = loop_2.body + assert isinstance(loop_2_body, CodeNode) + assert loop_2_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("j")])), + ] + + def test_switch_as_loop_body(self, ast_switch_as_loop_body): + self.run_rbr(ast_switch_as_loop_body) + + assert all(not isinstance(node, ForLoopNode) for node in ast_switch_as_loop_body.topological_order()) + + init_code_node = ast_switch_as_loop_body.root.children[0] + assert isinstance(init_code_node, CodeNode) + assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5)), Assignment(Variable("counter"), Constant(0))] + + while_node = ast_switch_as_loop_body.root.children[1] + assert isinstance(while_node, WhileLoopNode) + + switch_node = while_node.body + assert isinstance(switch_node, SwitchNode) + + case_1_body = switch_node.children[0].child + assert isinstance(case_1_body, CodeNode) + assert case_1_body.instructions == [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("counter")])) + ] + + case_2_body = switch_node.children[1].child + assert isinstance(case_2_body, CodeNode) + assert case_2_body.instructions == [ + Assignment(Variable("counter"), BinaryOperation(OperationType.plus, [Variable("counter"), Constant(1)])) + ] + + def test_switch_as_loop_with_increment(self, ast_switch_as_loop_body_increment): + self.run_rbr(ast_switch_as_loop_body_increment) + + init_code_node = ast_switch_as_loop_body_increment.root.children[0] + assert isinstance(init_code_node, CodeNode) + assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5))] + + loop_node = ast_switch_as_loop_body_increment.root.children[1] + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + switch_node = loop_node.body + assert isinstance(switch_node, SwitchNode) + + case_1 = switch_node.children[0] + assert isinstance(case_1, CaseNode) + + case_1_body = case_1.child + assert isinstance(case_1_body, CodeNode) + assert case_1_body.instructions == [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("i")]))] + + case_2 = switch_node.children[1] + assert isinstance(case_2, CaseNode) + + case_2_body = case_2.child + assert isinstance(case_2_body, CodeNode) + assert case_2_body.instructions == [Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)]))] + + assert ast_switch_as_loop_body_increment.condition_map == { + logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("i"), Constant(5)]) + } + + def test_rename_unary_operation_updates_array_info(self, ast_array_access_for_loop): + """Test if UnaryOperation.ArrayInfo gets updated on renaming""" + self.run_rbr(ast_array_access_for_loop, _generate_options(rename_for=True)) + + def find_unary_op(ast): + """look for UnaryOperation in AST""" + for node in ast.get_code_nodes_topological_order(): + for instr in node.instructions: + for unary_op in instr: + if isinstance(unary_op, UnaryOperation): + return unary_op + return None + + unary_operation = find_unary_op(ast_array_access_for_loop) + if not isinstance(unary_operation, UnaryOperation): + assert False, "Did not find UnaryOperation in AST (expect it for array access)" + assert unary_operation.array_info is not None + assert unary_operation.array_info.base in unary_operation.requirements + assert unary_operation.array_info.index in unary_operation.requirements + + def test_no_for_loop_renaming(self, ast_replaceable_while): + self.run_rbr(ast_replaceable_while, _generate_options(rename_for=False)) + + assert ast_replaceable_while.condition_map == { + logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("a"), Constant(10)]) + } + + loop_node = ast_replaceable_while.root + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("a"), Constant(0)) + assert loop_node.modification == Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])) + + loop_body = loop_node.body + assert isinstance(loop_body, CodeNode) + assert loop_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + ] + + def test_init_may_not_reach_loop_1(self, ast_initialization_in_condition): + assert ( + _initialization_reaches_loop_node( + ast_initialization_in_condition.root.children[0].true_branch_child, ast_initialization_in_condition.root.children[1] + ) + is False + ) + + self.run_rbr(ast_initialization_in_condition, _generate_options()) + assert any( + isinstance(for_loop := loop, ForLoopNode) for loop in ast_initialization_in_condition.get_for_loop_nodes_topological_order() + ) + assert for_loop.declaration == Variable("a") + + def test_init_may_not_reach_loop_2(self, ast_initialization_in_condition_sequence): + assert ( + _initialization_reaches_loop_node( + ast_initialization_in_condition_sequence.root.children[0].true_branch_child.children[1], + ast_initialization_in_condition_sequence.root.children[1], + ) + is False + ) + + self.run_rbr(ast_initialization_in_condition_sequence, _generate_options()) + assert any( + isinstance(for_loop := loop, ForLoopNode) + for loop in ast_initialization_in_condition_sequence.get_for_loop_nodes_topological_order() + ) + assert for_loop.declaration == Variable("a") + + @pytest.mark.parametrize("keep_empty_for_loops", [True, False]) + def test_keep_empty_for_loop(self, keep_empty_for_loops: bool, ast_single_instruction_while): + self.run_rbr(ast_single_instruction_while, _generate_options(keep_empty_for_loops)) + + if keep_empty_for_loops: + assert isinstance(ast_single_instruction_while.root, ForLoopNode) + else: + assert isinstance(ast_single_instruction_while.root.children[1], WhileLoopNode) diff --git a/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py b/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py index eecd157a3..15a602a1b 100644 --- a/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py +++ b/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py @@ -1,16 +1,13 @@ from typing import List import pytest -from decompiler.pipeline.controlflowanalysis.readability_based_refinement import ( - ForLoopVariableRenamer, - ReadabilityBasedRefinement, - WhileLoopReplacer, - WhileLoopVariableRenamer, +from decompiler.pipeline.controlflowanalysis.loop_utility_methods import ( _find_continuation_instruction, _has_deep_requirement, _initialization_reaches_loop_node, ) -from decompiler.structures.ast.ast_nodes import CaseNode, CodeNode, ConditionNode, ForLoopNode, SeqNode, SwitchNode, WhileLoopNode +from decompiler.pipeline.controlflowanalysis.readability_based_refinement import ReadabilityBasedRefinement, WhileLoopReplacer +from decompiler.structures.ast.ast_nodes import ConditionNode, ForLoopNode, SeqNode, WhileLoopNode from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree from decompiler.structures.logic.logic_condition import LogicCondition from decompiler.structures.pseudo import ( @@ -26,7 +23,7 @@ OperationType, Variable, ) -from decompiler.structures.pseudo.operations import ArrayInfo, OperationType, UnaryOperation +from decompiler.structures.pseudo.operations import OperationType from decompiler.task import DecompilerTask from decompiler.util.options import Options @@ -35,15 +32,11 @@ def logic_cond(name: str, context) -> LogicCondition: return LogicCondition.initialize_symbol(name, context) -def _generate_options(empty_loops: bool = False, hide_decl: bool = False, rename_for: bool = True, rename_while: bool = True, \ - max_condition: int = 100, max_modification: int = 100, force_for_loops: bool = False, blacklist : List[str] = []) -> Options: +def _generate_options(empty_loops: bool = False, hide_decl: bool = False, max_condition: int = 100, max_modification: int = 100, \ + force_for_loops: bool = False, blacklist : List[str] = []) -> Options: options = Options() options.set("readability-based-refinement.keep_empty_for_loops", empty_loops) options.set("readability-based-refinement.hide_non_initializing_declaration", hide_decl) - if rename_for: - names = ["i", "j", "k", "l", "m", "n"] - options.set("readability-based-refinement.for_loop_variable_names", names) - options.set("readability-based-refinement.rename_while_loop_variables", rename_while) options.set("readability-based-refinement.max_condition_complexity_for_loop_recovery", max_condition) options.set("readability-based-refinement.max_modification_complexity_for_loop_recovery", max_modification) options.set("readability-based-refinement.force_for_loops", force_for_loops) @@ -51,905 +44,6 @@ def _generate_options(empty_loops: bool = False, hide_decl: bool = False, rename return options -@pytest.fixture -def ast_array_access_for_loop() -> AbstractSyntaxTree: - """ - for (var_0 = 0; var_0 < 10; var_0 = var_0 + 1) { - *(var_1 + var_0) = var_0; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("var_0"), Constant(10)])}, - ) - declaration = Assignment(Variable("var_0"), Constant(0)) - condition = logic_cond("x1", context) - modification = Assignment(Variable("var_0"), BinaryOperation(OperationType.plus, [Variable("var_0"), Constant(1)])) - for_loop = ast.factory.create_for_loop_node(declaration, condition, modification) - array_info = ArrayInfo(Variable("var_1"), Variable("var_0")) - array_access_unary_operation = UnaryOperation( - OperationType.dereference, [BinaryOperation(OperationType.plus, [Variable("var_1"), Variable("var_0")])], array_info=array_info - ) - for_loop_body = ast._add_code_node([Assignment(array_access_unary_operation, Variable("var_0"))]) - ast._add_node(for_loop) - ast._add_edges_from([(root, for_loop), (for_loop, for_loop_body)]) - return ast - - -@pytest.fixture -def ast_while_true() -> AbstractSyntaxTree: - """ - a = 0; - b = 0; - while(true){ - a = a + 1; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree(root := SeqNode(true_value), {}) - code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - loop_node = ast.add_endless_loop_with_body(loop_node_body) - ast._add_edges_from(((root, code_node), (root, loop_node))) - return ast - - -@pytest.fixture -def ast_single_instruction_while() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 10) { - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) - return ast - - -@pytest.fixture -def ast_replaceable_while() -> AbstractSyntaxTree: - """ - a = 0; - while (x < 10) { - printf("counter: %d", x); - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) - root._sorted_children = (init_code_node, while_loop) - return ast - - -@pytest.fixture -def ast_replaceable_while_usage() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 10) { - printf("counter: %d", a); - a = a + 1; - } - printf("final counter: %d", a); - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - exit_code_node = ast._add_code_node( - [Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")]))] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) - return ast - - -@pytest.fixture -def ast_replaceable_while_reinit_usage() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 10) { - printf("counter: %d", a); - a = a + 1; - } - a = 50; - printf("50 = %d", a); - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - exit_code_node = ast._add_code_node( - [ - Assignment(Variable("a"), Constant(50)), - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) - return ast - - -@pytest.fixture -def ast_replaceable_while_compound_usage() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 10) { - printf("counter: %d", a); - a = a + 1; - } - a = a + 50; - printf("50 = %d", a); - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - exit_code_node = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(50)])), - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) - return ast - - -@pytest.fixture -def ast_endless_while_init_outside() -> AbstractSyntaxTree: - """ - a = 0; - while (true) { - while (a < 5) { - printf("%d\n", a); - a = a + 1; - } - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) - ast._add_node(inner_while) - endless_loop = ast.add_endless_loop_with_body(inner_while) - - inner_while_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_edges_from([(root, init_code_node), (root, endless_loop), (endless_loop, inner_while), (inner_while, inner_while_body)]) - return ast - - -@pytest.fixture -def ast_nested_while() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 1) { - b = 0; - while (b < 1) { - b = b + 1; - } - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), - logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(5)]), - }, - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - outer_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) - outer_while_body = ast.factory.create_seq_node() - outer_while_init = ast._add_code_node([Assignment(Variable("b"), Constant(0))]) - outer_while_exit = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) - - inner_while = ast.factory.create_while_loop_node(logic_cond("x2", context)) - inner_while_body = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) - - ast._add_nodes_from((outer_while, outer_while_body, inner_while)) - ast._add_edges_from( - [ - (root, init_code_node), - (root, outer_while), - (outer_while, outer_while_body), - (outer_while_body, outer_while_init), - (outer_while_body, inner_while), - (outer_while_body, outer_while_exit), - (inner_while, inner_while_body), - ] - ) - return ast - - -@pytest.fixture -def ast_call_init() -> AbstractSyntaxTree: - """ - a = 5; - b = foo(); - while(b <= 5){ - a = a + b; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, - ) - code_node = ast._add_code_node( - instructions=[ - Assignment(Variable("a"), Constant(5)), - Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), - ] - ) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_self_referential_init() -> AbstractSyntaxTree: - """ - a = 5; - b = foo(b); - while(b <= 5){ - a = a + b; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, - ) - code_node = ast._add_code_node( - instructions=[ - Assignment(Variable("a"), Constant(5)), - Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [Variable("b")])), - ] - ) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_call_for_loop() -> AbstractSyntaxTree: - """ - a = 5; - while(b = foo; b <= 5; b++){ - a++; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, - ) - code_node = ast._add_code_node( - instructions=[ - Assignment(Variable("a"), Constant(5)), - ] - ) - loop_node = ast.factory.create_for_loop_node(Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), logic_cond("x1", context), Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("1")])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_redundant_init() -> AbstractSyntaxTree: - """ - b = 0; - a = 5; - b = 2; - - while(b <= 5){ - a = a + b; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} - ) - code_node = ast._add_code_node( - instructions=[ - Assignment(Variable("b"), Constant(0)), - Assignment(Variable("a"), Constant(5)), - Assignment(Variable("b"), Constant(2)), - ] - ) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_reinit_in_condition_true() -> AbstractSyntaxTree: - """ - int x = 1; - int i = 0; - - if (x == 1) { - i = 1; - } - - while (i < 10) { - x = x * 2; - i = i + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("a", context): Condition(OperationType.less, [Variable("i"), Constant(10)]), - logic_cond("b", context): Condition(OperationType.equal, [Variable("x"), Constant(1)]), - }, - ) - code_node = ast._add_code_node(instructions=[Assignment(Variable("x"), Constant(1)), Assignment(Variable("i"), Constant(0))]) - code_node_true = ast._add_code_node([Assignment(Variable("i"), Constant(1))]) - condition_node = ast._add_condition_node_with(logic_cond("b", context), code_node_true) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("a", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("x"), BinaryOperation(OperationType.multiply, [Variable("x"), Constant(2)])), - Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])), - ] - ) - ast._add_nodes_from((condition_node, loop_node)) - ast._add_edges_from(((root, code_node), (root, condition_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_usage_in_condition() -> AbstractSyntaxTree: - """ - int a = 1; - int b = 0; - - if (b == 1) { - a = 1; - } - - while (b < 10) { - a = a * 2; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), - logic_cond("x2", context): Condition(OperationType.equal, [Variable("b"), Constant(1)]), - }, - ) - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(1)), Assignment(Variable("b"), Constant(0))]) - code_node_true = ast._add_code_node([Assignment(Variable("a"), Constant(1))]) - condition_node = ast._add_condition_node_with(logic_cond("x2", context), code_node_true) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.multiply, [Variable("a"), Constant(2)])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, init_code_node), (root, condition_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(init_code_node, loop_node_body) - root._sorted_children = (init_code_node, loop_node) - return ast - - -@pytest.fixture -def ast_sequenced_while_loops() -> AbstractSyntaxTree: - """ - a = 0; - b = 0; - - while (a < 5) { - printf("%d\n", a); - a++; - } - - while (b < 5) { - printf("%d\n", b); - b++; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), - logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(5)]), - }, - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) - - while_loop_1 = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_1_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - while_loop_2 = ast.factory.create_while_loop_node(logic_cond("x2", context)) - while_loop_2_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - - ast._add_nodes_from((while_loop_1, while_loop_2)) - ast._add_edges_from( - ( - (root, init_code_node), - (root, while_loop_1), - (root, while_loop_2), - (while_loop_1, while_loop_1_body), - (while_loop_2, while_loop_2_body), - ) - ) - return ast - - -@pytest.fixture -def ast_switch_as_loop_body() -> AbstractSyntaxTree: - """ - This while-loop should not be replaced with a for-loop because we don't know wich value 'a' has. - - Code of AST: - a = 5; - b = 0; - while(b <= 5){ - switch(a) { - case 0: - a = a + b: - break; - case 1: - b = b + 1; - break; - } - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("a", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, - ) - code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("a", context)) - root._sorted_children = (code_node, loop_node) - loop_body_switch = ast.factory.create_switch_node(Variable("a")) - loop_body_case_1 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) - code_node_case_1 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) - loop_body_case_2 = ast.factory.create_case_node(Variable("a"), Constant(1), break_case=True) - code_node_case_2 = ast._add_code_node( - [ - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_nodes_from((code_node, loop_node, loop_body_switch, loop_body_case_1, loop_body_case_2)) - ast._add_edges_from( - ( - (root, code_node), - (root, loop_node), - (loop_node, loop_body_switch), - (loop_body_switch, loop_body_case_1), - (loop_body_switch, loop_body_case_2), - (loop_body_case_1, code_node_case_1), - (loop_body_case_2, code_node_case_2), - ) - ) - ast._code_node_reachability_graph.add_reachability_from(((code_node, code_node_case_1), (code_node, code_node_case_2))) - return ast - - -@pytest.fixture -def ast_switch_as_loop_body_increment() -> AbstractSyntaxTree: - """ - This loop should be replaced with a for-loop because b has no usages after last definition, is in condition and is initialized - before loop without any usages in between. - - Code of AST: - a = 5; - b = 0; - while(b <= 5){ - switch(a) { - case 0: - a = a + b: - break; - case 1: - b = b + 1; - break; - } - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_seq = ast.factory.create_seq_node() - - switch_node = ast.factory.create_switch_node(Variable("a")) - case_1 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) - case_1_code = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) - case_2 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) - case_2_code = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) - - increment_code = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) - - ast._add_nodes_from((while_loop, while_loop_seq, switch_node, case_1, case_2)) - ast._add_edges_from( - [ - (root, init_code_node), - (root, while_loop), - (while_loop, while_loop_seq), - (while_loop_seq, switch_node), - (while_loop_seq, increment_code), - (switch_node, case_1), - (switch_node, case_2), - (case_1, case_1_code), - (case_2, case_2_code), - ] - ) - return ast - - -@pytest.fixture -def ast_init_in_switch() -> AbstractSyntaxTree: - """ - a = 5; - b = 0; - switch(a){ - case 0: - a = b; - } - while(b <= (5 + a)){ - a = a + b; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition( - OperationType.less_or_equal, - [Variable("b"), BinaryOperation(OperationType.plus, [Constant(5), Variable("a")])], - ) - }, - ) - init_code_node = ast._add_code_node(instructions=[Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) - switch_node = ast.factory.create_switch_node(Variable("a")) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - case_node = ast.factory.create_case_node(Variable("a"), Constant(0)) - case_child = ast._add_code_node([Assignment(Variable("a"), Variable("b"))]) - loop_body = ast.factory.create_seq_node() - loop_body_child = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_nodes_from((switch_node, loop_node, loop_body, case_node)) - ast._add_edges_from( - ( - (root, init_code_node), - (root, switch_node), - (switch_node, case_node), - (case_node, case_child), - (root, loop_node), - (loop_node, loop_body), - (loop_body, loop_body_child), - ) - ) - ast._code_node_reachability_graph.add_reachability_from([(case_child, loop_body_child)]) - root._sorted_children = (init_code_node, switch_node, loop_node) - loop_body._sorted_children = (loop_body_child,) - switch_node._sorted_cases = (case_node,) - return ast - - -@pytest.fixture -def ast_while_in_else() -> AbstractSyntaxTree: - """ - while (true) { - if (b < 2) { - break; - } else { - a = 0; - while (a < 5) { - printf("%d\n", a); - a = a + 1; - } - } - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)]), - logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(2)]), - }, - ) - - inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) - ast._add_node(inner_while) - - true_branch_child = ast._add_code_node([Break()]) - inner_seq = ast.factory.create_seq_node() - ast._add_node(inner_seq) - condition_node = ast._add_condition_node_with(logic_cond("x2", context), true_branch_child, inner_seq) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - endless_loop = ast.add_endless_loop_with_body(condition_node) - - inner_while_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_edges_from( - [ - (root, endless_loop), - (endless_loop, condition_node), - (inner_seq, init_code_node), - (inner_seq, inner_while), - (inner_while, inner_while_body), - ] - ) - return ast - - -@pytest.fixture -def ast_continuation_is_not_first_var() -> AbstractSyntaxTree: - """ - a = 0; - b = 0; - while (a < b) { - printf("%d\n", a); - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Variable("b")])}, - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) - root._sorted_children = (init_code_node, while_loop) - return ast - - -@pytest.fixture -def ast_initialization_in_condition() -> AbstractSyntaxTree: - """ - if(b < 10 ){ - a = 5; - while (x < 10) { - printf("counter: %d", a); - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x0", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), - logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), - }, - ) - - true_branch = ast._add_code_node([Assignment(Variable("a"), Constant(5))]) - if_condition = ast._add_condition_node_with(logic_cond("x0", context), true_branch) - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, if_condition), (root, while_loop), (while_loop, while_loop_body)]) - root._sorted_children = (if_condition, while_loop) - return ast - - -@pytest.fixture -def ast_initialization_in_condition_sequence() -> AbstractSyntaxTree: - """ - if(b < 10 ){ - if(c < 10){ - b = 5; - } - a = 5; - while (x < 10) { - printf("counter: %d", a); - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x0", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), - logic_cond("x1", context): Condition(OperationType.less, [Variable("c"), Constant(10)]), - logic_cond("x2", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), - }, - ) - - true_branch_c = ast._add_code_node([Assignment(Variable("b"), Constant(5))]) - code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5))]) - if_condition_c = ast._add_condition_node_with(logic_cond("x1", context), true_branch_c) - ast._add_node(true_branch_b := ast.factory.create_seq_node()) - if_condition_b = ast._add_condition_node_with(logic_cond("x1", context), true_branch_b) - while_loop = ast.factory.create_while_loop_node(logic_cond("x2", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from( - [ - (root, if_condition_b), - (root, while_loop), - (while_loop, while_loop_body), - (true_branch_b, if_condition_c), - (true_branch_b, code_node), - ] - ) - true_branch_b._sorted_children = (if_condition_c, code_node) - root._sorted_children = (if_condition_b, while_loop) - return ast - - @pytest.fixture def ast_innerWhile_simple_condition_complexity() -> AbstractSyntaxTree: """ @@ -1096,413 +190,61 @@ def ast_guarded_do_while_else() -> AbstractSyntaxTree: ast._add_node(do_while_loop) ast._add_edges_from([(root, init_code_node), (root, cond_node), (cond_node, false_branch), (false_branch, do_while_loop), (do_while_loop, do_while_loop_body)]) return ast - - -class TestReadabilityBasedRefinement: - """Test Readability functionality with all its substages.""" - - @staticmethod - def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): - ReadabilityBasedRefinement().run(DecompilerTask("func", cfg=None, ast=ast, options=options)) - - def test_no_replacement(self, ast_while_true): - self.run_rbr(ast_while_true) - assert all(not isinstance(node, ForLoopNode) for node in ast_while_true.topological_order()) - - def test_simple_replacement(self, ast_replaceable_while): - self.run_rbr(ast_replaceable_while) - - assert ast_replaceable_while.condition_map == { - logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("i"), Constant(10)]) - } - - loop_node = ast_replaceable_while.root - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_body = loop_node.body - assert isinstance(loop_body, CodeNode) - assert loop_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("i")])), - ] - - def test_with_usage(self, ast_replaceable_while_usage): - self.run_rbr(ast_replaceable_while_usage) - - for_loop = ast_replaceable_while_usage.root.children[0] - assert isinstance(for_loop, ForLoopNode) - assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) - - copy_instr_node = ast_replaceable_while_usage.root.children[1] - assert isinstance(copy_instr_node, CodeNode) - assert copy_instr_node.instructions == [Assignment(Variable("a"), Variable("i"))] - - def test_with_usage_redefinition(self, ast_replaceable_while_reinit_usage): - self.run_rbr(ast_replaceable_while_reinit_usage) - - for_loop = ast_replaceable_while_reinit_usage.root.children[0] - assert isinstance(for_loop, ForLoopNode) - assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) - assert for_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - exit_code_node = ast_replaceable_while_reinit_usage.root.children[1] - assert isinstance(exit_code_node, CodeNode) - assert exit_code_node.instructions == [ - Assignment(Variable("a"), Constant(50)), - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), - ] - - def test_with_usage_redefenition_2(self, ast_replaceable_while_compound_usage): - self.run_rbr(ast_replaceable_while_compound_usage) - - for_loop = ast_replaceable_while_compound_usage.root.children[0] - assert isinstance(for_loop, ForLoopNode) - assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) - assert for_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - copy_instr_node = ast_replaceable_while_compound_usage.root.children[1] - assert isinstance(copy_instr_node, CodeNode) - assert copy_instr_node.instructions == [Assignment(Variable("a"), Variable("i"))] - - def test_continuation_is_not_first_var(self, ast_continuation_is_not_first_var): - self.run_rbr(ast_continuation_is_not_first_var) - - init_code_node = ast_continuation_is_not_first_var.root.children[0] - assert isinstance(init_code_node, CodeNode) - assert init_code_node.instructions == [Assignment(Variable("a"), Constant(0))] - loop_node = ast_continuation_is_not_first_var.root.children[1] - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - loop_node_body = loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])) - ] - - def test_init_with_call(self, ast_call_init): - self.run_rbr(ast_call_init, _generate_options(rename_for=True)) - - code_node = ast_call_init.root.children[0] - assert isinstance(code_node, CodeNode) - assert code_node.instructions == [Assignment(Variable("a"), Constant(5))] - - for_loop_node = ast_call_init.root.children[1] - assert isinstance(for_loop_node, ForLoopNode) - assert for_loop_node.declaration == Assignment(Variable("i"), Call(ImportedFunctionSymbol("foo", 0), [])) - assert for_loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_node_body = for_loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("i")])) - ] - - assert for_loop_node.condition == logic_cond("x1", context := LogicCondition.generate_new_context()) - assert ast_call_init.condition_map == { - logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("i"), Constant(5)]) +@pytest.fixture +def ast_while_in_else() -> AbstractSyntaxTree: + """ + while (true) { + if (b < 2) { + break; + } else { + a = 0; + while (a < 5) { + printf("%d\n", a); + a = a + 1; + } } + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(2)]), + }, + ) - def test_double_init(self, ast_redundant_init): - self.run_rbr(ast_redundant_init) - - code_node = ast_redundant_init.root.children[0] - assert isinstance(code_node, CodeNode) - assert code_node.instructions == [ - Assignment(Variable("b"), Constant(0)), - Assignment(Variable("a"), Constant(5)), - Assignment(Variable("b"), Constant(2)), - ] - - for_loop_node = ast_redundant_init.root.children[1] - assert isinstance(for_loop_node, ForLoopNode) - assert for_loop_node.declaration == Variable("b") - assert for_loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) - - loop_node_body = for_loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - ] - - assert for_loop_node.condition == logic_cond("x1", context := LogicCondition.generate_new_context()) - assert ast_redundant_init.condition_map == {logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} - - def test_double_init_condition_node(self, ast_reinit_in_condition_true): - self.run_rbr(ast_reinit_in_condition_true) - - def test_init_in_switch(self, ast_init_in_switch): - self.run_rbr(ast_init_in_switch) - - init_code_node = ast_init_in_switch.root.children[0] - assert isinstance(init_code_node, CodeNode) - assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))] - - loop_node = ast_init_in_switch.root.children[2] - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Variable("b") - assert loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) - - loop_node_body = loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])) - ] - - def test_usage_in_condition(self, ast_usage_in_condition): - self.run_rbr(ast_usage_in_condition) - - code_node = ast_usage_in_condition.root.children[0] - assert isinstance(code_node, CodeNode) - assert code_node.instructions == [Assignment(Variable("a"), Constant(1)), Assignment(Variable("b"), Constant(0))] - - condition_node = ast_usage_in_condition.root.children[1] - assert isinstance(condition_node, ConditionNode) - assert condition_node.condition == logic_cond("x2", context := LogicCondition.generate_new_context()) - - loop_node = ast_usage_in_condition.root.children[2] - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Variable("b") - assert loop_node.condition == logic_cond("x1", context) - assert loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) - - loop_body = loop_node.body - assert isinstance(loop_body, CodeNode) - assert loop_body.instructions == [Assignment(Variable("a"), BinaryOperation(OperationType.multiply, [Variable("a"), Constant(2)]))] - - def test_while_in_else(self, ast_while_in_else): - self.run_rbr(ast_while_in_else) - - endless_loop = ast_while_in_else.root - assert isinstance(endless_loop, WhileLoopNode) - - condition_node = endless_loop.body - assert isinstance(condition_node, ConditionNode) - - loop_node = condition_node.false_branch_child - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_node_body = loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("i")])) - ] - - def test_nested_while(self, ast_nested_while): - self.run_rbr(ast_nested_while, _generate_options(empty_loops=True)) - - outer_loop = ast_nested_while.root - assert isinstance(outer_loop, ForLoopNode) - assert outer_loop.declaration == Assignment(Variable("i"), Constant(0)) - assert ast_nested_while.condition_map[outer_loop.condition] == Condition(OperationType.less, [Variable("i"), Constant(5)]) - assert outer_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - inner_loop = outer_loop.children[0] - assert isinstance(inner_loop, ForLoopNode) - assert inner_loop.declaration == Assignment(Variable("j"), Constant(0)) - assert ast_nested_while.condition_map[inner_loop.condition] == Condition(OperationType.less, [Variable("j"), Constant(5)]) - assert inner_loop.modification == Assignment(Variable("j"), BinaryOperation(OperationType.plus, [Variable("j"), Constant(1)])) - - def test_nested_while_loop(self, ast_endless_while_init_outside): - self.run_rbr(ast_endless_while_init_outside) - - endless_loop = ast_endless_while_init_outside.root.children[1] - assert isinstance(endless_loop, WhileLoopNode) - - for_loop = endless_loop.body - assert isinstance(for_loop, ForLoopNode) - assert for_loop.declaration == Variable("a") - - def test_sequenced_loops(self, ast_sequenced_while_loops): - self.run_rbr(ast_sequenced_while_loops) - - loop_1 = ast_sequenced_while_loops.root.children[0] - assert isinstance(loop_1, ForLoopNode) - assert loop_1.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_1.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_1_body = loop_1.body - assert isinstance(loop_1_body, CodeNode) - assert loop_1_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("i")])), - ] - - loop_2 = ast_sequenced_while_loops.root.children[1] - assert isinstance(loop_2, ForLoopNode) - assert loop_2.declaration == Assignment(Variable("j"), Constant(0)) - assert loop_2.modification == Assignment(Variable("j"), BinaryOperation(OperationType.plus, [Variable("j"), Constant(1)])) - - loop_2_body = loop_2.body - assert isinstance(loop_2_body, CodeNode) - assert loop_2_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("j")])), - ] - - def test_switch_as_loop_body(self, ast_switch_as_loop_body): - self.run_rbr(ast_switch_as_loop_body) - - assert all(not isinstance(node, ForLoopNode) for node in ast_switch_as_loop_body.topological_order()) - - init_code_node = ast_switch_as_loop_body.root.children[0] - assert isinstance(init_code_node, CodeNode) - assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5)), Assignment(Variable("counter"), Constant(0))] + inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) + ast._add_node(inner_while) - while_node = ast_switch_as_loop_body.root.children[1] - assert isinstance(while_node, WhileLoopNode) + true_branch_child = ast._add_code_node([Break()]) + inner_seq = ast.factory.create_seq_node() + ast._add_node(inner_seq) + condition_node = ast._add_condition_node_with(logic_cond("x2", context), true_branch_child, inner_seq) - switch_node = while_node.body - assert isinstance(switch_node, SwitchNode) + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - case_1_body = switch_node.children[0].child - assert isinstance(case_1_body, CodeNode) - assert case_1_body.instructions == [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("counter")])) - ] + endless_loop = ast.add_endless_loop_with_body(condition_node) - case_2_body = switch_node.children[1].child - assert isinstance(case_2_body, CodeNode) - assert case_2_body.instructions == [ - Assignment(Variable("counter"), BinaryOperation(OperationType.plus, [Variable("counter"), Constant(1)])) + inner_while_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), ] + ) - def test_switch_as_loop_with_increment(self, ast_switch_as_loop_body_increment): - self.run_rbr(ast_switch_as_loop_body_increment) - - init_code_node = ast_switch_as_loop_body_increment.root.children[0] - assert isinstance(init_code_node, CodeNode) - assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5))] - - loop_node = ast_switch_as_loop_body_increment.root.children[1] - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - switch_node = loop_node.body - assert isinstance(switch_node, SwitchNode) - - case_1 = switch_node.children[0] - assert isinstance(case_1, CaseNode) - - case_1_body = case_1.child - assert isinstance(case_1_body, CodeNode) - assert case_1_body.instructions == [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("i")]))] - - case_2 = switch_node.children[1] - assert isinstance(case_2, CaseNode) - - case_2_body = case_2.child - assert isinstance(case_2_body, CodeNode) - assert case_2_body.instructions == [Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)]))] - - assert ast_switch_as_loop_body_increment.condition_map == { - logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("i"), Constant(5)]) - } - - def test_rename_unary_operation_updates_array_info(self, ast_array_access_for_loop): - """Test if UnaryOperation.ArrayInfo gets updated on renaming""" - self.run_rbr(ast_array_access_for_loop, _generate_options(rename_for=True)) - - def find_unary_op(ast): - """look for UnaryOperation in AST""" - for node in ast.get_code_nodes_topological_order(): - for instr in node.instructions: - for unary_op in instr: - if isinstance(unary_op, UnaryOperation): - return unary_op - return None - - unary_operation = find_unary_op(ast_array_access_for_loop) - if not isinstance(unary_operation, UnaryOperation): - assert False, "Did not find UnaryOperation in AST (expect it for array access)" - assert unary_operation.array_info is not None - assert unary_operation.array_info.base in unary_operation.requirements - assert unary_operation.array_info.index in unary_operation.requirements - - def test_no_for_loop_renaming(self, ast_replaceable_while): - self.run_rbr(ast_replaceable_while, _generate_options(rename_for=False)) - - assert ast_replaceable_while.condition_map == { - logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("a"), Constant(10)]) - } - - loop_node = ast_replaceable_while.root - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("a"), Constant(0)) - assert loop_node.modification == Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])) - - loop_body = loop_node.body - assert isinstance(loop_body, CodeNode) - assert loop_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + ast._add_edges_from( + [ + (root, endless_loop), + (endless_loop, condition_node), + (inner_seq, init_code_node), + (inner_seq, inner_while), + (inner_while, inner_while_body), ] - - def test_init_may_not_reach_loop_1(self, ast_initialization_in_condition): - assert ( - _initialization_reaches_loop_node( - ast_initialization_in_condition.root.children[0].true_branch_child, ast_initialization_in_condition.root.children[1] - ) - is False - ) - - self.run_rbr(ast_initialization_in_condition, _generate_options()) - assert any( - isinstance(for_loop := loop, ForLoopNode) for loop in ast_initialization_in_condition.get_for_loop_nodes_topological_order() - ) - assert for_loop.declaration == Variable("a") - - def test_init_may_not_reach_loop_2(self, ast_initialization_in_condition_sequence): - assert ( - _initialization_reaches_loop_node( - ast_initialization_in_condition_sequence.root.children[0].true_branch_child.children[1], - ast_initialization_in_condition_sequence.root.children[1], - ) - is False - ) - - self.run_rbr(ast_initialization_in_condition_sequence, _generate_options()) - assert any( - isinstance(for_loop := loop, ForLoopNode) - for loop in ast_initialization_in_condition_sequence.get_for_loop_nodes_topological_order() - ) - assert for_loop.declaration == Variable("a") - - def test_guarded_do_while_if(self, ast_guarded_do_while_if): - self.run_rbr(ast_guarded_do_while_if, _generate_options()) - - for cond_node in ast_guarded_do_while_if.get_condition_nodes_post_order(): - assert False, "There should be no condition node" - - for loop_node in ast_guarded_do_while_if.get_loop_nodes_post_order(): - assert isinstance(loop_node, WhileLoopNode) - - def test_guarded_do_while_else(self, ast_guarded_do_while_else): - self.run_rbr(ast_guarded_do_while_else, _generate_options()) - - for cond_node in ast_guarded_do_while_else.get_condition_nodes_post_order(): - assert False, "There should be no condition node" - - for loop_node in ast_guarded_do_while_else.get_loop_nodes_post_order(): - assert isinstance(loop_node, WhileLoopNode) - - @pytest.mark.parametrize("keep_empty_for_loops", [True, False]) - def test_keep_empty_for_loop(self, keep_empty_for_loops: bool, ast_single_instruction_while): - self.run_rbr(ast_single_instruction_while, _generate_options(keep_empty_for_loops)) - - if keep_empty_for_loops: - assert isinstance(ast_single_instruction_while.root, ForLoopNode) - else: - assert isinstance(ast_single_instruction_while.root.children[1], WhileLoopNode) - - def test_rhs_of_for_loop_declaration_not_renamed(self, ast_self_referential_init: AbstractSyntaxTree): - self.run_rbr(ast_self_referential_init) - for_loops = list(ast_self_referential_init.get_for_loop_nodes_topological_order()) - assert len(for_loops) == 1 - assert for_loops[0].declaration == Assignment(Variable("i"), Call(ImportedFunctionSymbol("foo", 0), [Variable("b")])) + ) + return ast class TestForLoopRecovery: @@ -1551,6 +293,30 @@ def test_for_loop_recovery_blacklist(self, operation): assert isinstance(loop_node, ForLoopNode) +class TestGuardedDoWhile: + @staticmethod + def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): + ReadabilityBasedRefinement().run(DecompilerTask("func", cfg=None, ast=ast, options=options)) + + def test_guarded_do_while_if(self, ast_guarded_do_while_if): + self.run_rbr(ast_guarded_do_while_if, _generate_options()) + + for _ in ast_guarded_do_while_if.get_condition_nodes_post_order(): + assert False, "There should be no condition node" + + for loop_node in ast_guarded_do_while_if.get_loop_nodes_post_order(): + assert isinstance(loop_node, WhileLoopNode) + + def test_guarded_do_while_else(self, ast_guarded_do_while_else): + self.run_rbr(ast_guarded_do_while_else, _generate_options()) + + for _ in ast_guarded_do_while_else.get_condition_nodes_post_order(): + assert False, "There should be no condition node" + + for loop_node in ast_guarded_do_while_else.get_loop_nodes_post_order(): + assert isinstance(loop_node, WhileLoopNode) + + class TestReadabilityUtils: def test_find_continuation_instruction_1(self): """ @@ -2045,40 +811,6 @@ def test_separated_by_loop_node_4(self, ast_while_in_else): assert _initialization_reaches_loop_node(init_code_node, inner_while) is False - def test_for_loop_variable_generation(self): - renamer = ForLoopVariableRenamer( - AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}), - ["i", "j", "k", "l", "m", "n"] - ) - assert [renamer._get_variable_name() for _ in range(14)] == [ - "i", - "j", - "k", - "l", - "m", - "n", - "i1", - "j1", - "k1", - "l1", - "m1", - "n1", - "i2", - "j2", - ] - - def test_while_loop_variable_generation(self): - renamer = WhileLoopVariableRenamer( - AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}) - ) - assert [renamer._get_variable_name() for _ in range(5)] == ["counter", "counter1", "counter2", "counter3", "counter4"] - - def test_declaration_listop(self, ast_call_for_loop): - """Test renaming with ListOperation as Declaration""" - ForLoopVariableRenamer(ast_call_for_loop, ["i"]).rename() - for node in ast_call_for_loop: - if isinstance(node, ForLoopNode): - assert node.declaration.destination.operands[0].name == "i" def test_skip_for_loop_recovery_if_continue_in_while(self): """ diff --git a/tests/structures/visitors/test_substitute_visitor.py b/tests/structures/visitors/test_substitute_visitor.py new file mode 100644 index 000000000..89a874081 --- /dev/null +++ b/tests/structures/visitors/test_substitute_visitor.py @@ -0,0 +1,197 @@ +import pytest +from decompiler.structures.graphs.basicblock import BasicBlock +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Branch, + Call, + Constant, + DataflowObject, + Integer, + Phi, + Pointer, + RegisterPair, + Return, + UnaryOperation, + Variable, +) +from decompiler.structures.pseudo.operations import ArrayInfo, Condition, OperationType +from decompiler.structures.visitors.substitute_visitor import SubstituteVisitor + +_i32 = Integer.int32_t() +_p_i32 = Pointer(Integer.int32_t()) + +_a = Variable("a", Integer.int32_t(), 0) +_b = Variable("b", Integer.int32_t(), 1) +_c = Variable("c", Integer.int32_t(), 2) +_d = Variable("d", Integer.int32_t(), 3) + + +@pytest.mark.parametrize( + ["initial_obj", "expected_result", "visitor"], + [ + ( + o := Variable("v", _i32, 0), + r := Variable("x", _i32, 1), + SubstituteVisitor.identity(o, r) + ), + ( + o := Variable("v", _i32, 0), + r := Variable("x", _i32, 1), + SubstituteVisitor.equality(o, r) + ), + ( + o := Variable("v", _i32, 0), + o, + SubstituteVisitor.identity(Variable("v", _i32, 0), Variable("x", _i32, 1)) + ), + ( + o := Variable("v", _i32, 0), + r := Variable("x", _i32, 1), + SubstituteVisitor.equality(Variable("v", _i32, 0), r) + ), + ( + Assignment(a := Variable("a"), b := Variable("b")), + Assignment(a, c := Variable("c")), + SubstituteVisitor.identity(b, c) + ), + ( + Assignment(a := Variable("a"), b := Variable("b")), + Assignment(c := Variable("c"), b), + SubstituteVisitor.identity(a, c) + ), + ( + UnaryOperation(OperationType.dereference, [a := Variable("a")]), + UnaryOperation(OperationType.dereference, [b := Variable("b")]), + SubstituteVisitor.identity(a, b) + ), + ( + UnaryOperation( + OperationType.dereference, + [BinaryOperation(OperationType.plus, [a := Variable("a", _p_i32), Constant(4, _i32)])], + array_info=ArrayInfo(a, 1) + ), + UnaryOperation( + OperationType.dereference, + [BinaryOperation(OperationType.plus, [b := Variable("b", _p_i32), Constant(4, _i32)])], + array_info=ArrayInfo(b, 1) + ), + SubstituteVisitor.identity(a, b) + ), + ( + UnaryOperation( + OperationType.dereference, + [BinaryOperation( + OperationType.plus, + [ + a := Variable("a", _p_i32), + BinaryOperation(OperationType.multiply, [b := Variable("b", _i32), Constant(4, _i32)]) + ] + )], + array_info=ArrayInfo(a, b) + ), + UnaryOperation( + OperationType.dereference, + [BinaryOperation( + OperationType.plus, + [ + a := Variable("a", _p_i32), + BinaryOperation(OperationType.multiply, [c := Variable("c", _i32), Constant(4, _i32)]) + ] + )], + array_info=ArrayInfo(a, c) + ), + SubstituteVisitor.identity(b, c) + ), + ( + BinaryOperation(OperationType.multiply, [a := Variable("a"), b := Variable("b")]), + BinaryOperation(OperationType.multiply, [a, c := Variable("c")]), + SubstituteVisitor.identity(b, c) + ), + ( + RegisterPair(a := Variable("a"), b := Variable("b")), + RegisterPair(a, c := Variable("c")), + SubstituteVisitor.identity(b, c) + ), + ( + Call(f := Variable("f"), [a := Variable("a")]), + Call(f, [b := Variable("b")]), + SubstituteVisitor.identity(a, b) + ), + ( + Call(f := Variable("f"), [a := Variable("a")]), + Call(g := Variable("g"), [a]), + SubstituteVisitor.identity(f, g) + ), + ( + Phi( + a3 := Variable("a", _i32, 3), + [ + a2 := Variable("a", _i32, 2), + a1 := Variable("a", _i32, 1) + ], + { + BasicBlock(2): a2, + BasicBlock(1): a1, + } + ), + Phi( + a3, + [ + a2, + a0 := Variable("a", _i32, 0) + ], + { + BasicBlock(2): a2, + BasicBlock(1): a0, + } + ), + SubstituteVisitor.identity(a1, a0) + ), + ( + Phi( + a3 := Variable("a", _i32, 3), + [ + a2 := Variable("a", _i32, 2), + a1 := Variable("a", _i32, 1) + ], + { + BasicBlock(2): a2, + BasicBlock(1): a1, + } + ), + Phi( + a4 := Variable("a", _i32, 4), + [ + a2, + a1 + ], + { + BasicBlock(2): a2, + BasicBlock(1): a1, + } + ), + SubstituteVisitor.identity(a3, a4) + ), + ( + Branch(a := Condition(OperationType.equal, [])), + Branch(b := Condition(OperationType.not_equal, [])), + SubstituteVisitor.identity(a, b) + ), + ( + Return([a := Variable("a")]), + Return([b := Variable("b")]), + SubstituteVisitor.identity(a, b) + ), + ] +) +def test_substitute(initial_obj: DataflowObject, expected_result: DataflowObject, visitor: SubstituteVisitor): + result = initial_obj.accept(visitor) + if result is None: + result = initial_obj + + assert result == expected_result + + # if expected result is Phi also test for origin_block equality, as that is not covered by object equality + if isinstance(expected_result, Phi): + assert result.origin_block == expected_result.origin_block