diff --git a/decompiler/backend/cexpressiongenerator.py b/decompiler/backend/cexpressiongenerator.py index 09bcfac65..6ff59150a 100644 --- a/decompiler/backend/cexpressiongenerator.py +++ b/decompiler/backend/cexpressiongenerator.py @@ -1,5 +1,4 @@ import logging -from ctypes import c_byte, c_int, c_long, c_short, c_ubyte, c_uint, c_ulong, c_ushort from itertools import chain, repeat from decompiler.structures import pseudo as expressions @@ -8,6 +7,7 @@ from decompiler.structures.pseudo import operations as operations from decompiler.structures.pseudo.operations import MemberAccess from decompiler.structures.visitors.interfaces import DataflowObjectVisitorInterface +from decompiler.util.integer_util import normalize_int class CExpressionGenerator(DataflowObjectVisitorInterface): @@ -80,20 +80,6 @@ class CExpressionGenerator(DataflowObjectVisitorInterface): # OperationType.adc: "adc", } - SIGNED_FORMATS = { - 8: lambda x: c_byte(x).value, - 16: lambda x: c_short(x).value, - 32: lambda x: c_int(x).value, - 64: lambda x: c_long(x).value, - } - - UNSIGNED_FORMATS = { - 8: lambda x: c_ubyte(x).value, - 16: lambda x: c_ushort(x).value, - 32: lambda x: c_uint(x).value, - 64: lambda x: c_ulong(x).value, - } - """ Precedence used for correctly generating brackets. Higher precedence is more tightly binding. @@ -298,13 +284,7 @@ def _get_integer_literal_value(self, literal: expressions.Constant) -> int: Return the right integer value for the given type, assuming that the re-compilation host has the same sizes as the decompilation host. """ - if literal.type.is_signed: - if handler := self.SIGNED_FORMATS.get(literal.type.size, None): - return handler(literal.value) - elif literal.value < 0: - if handler := self.UNSIGNED_FORMATS.get(literal.type.size, None): - return handler(literal.value) - return literal.value + return normalize_int(literal.value, literal.type.size, literal.type.is_signed) @staticmethod def _interpret_integer_literal_type(value: int) -> Integer: diff --git a/decompiler/frontend/binaryninja/handlers/assignments.py b/decompiler/frontend/binaryninja/handlers/assignments.py index 85d81038c..f68a77f53 100644 --- a/decompiler/frontend/binaryninja/handlers/assignments.py +++ b/decompiler/frontend/binaryninja/handlers/assignments.py @@ -18,7 +18,7 @@ RegisterPair, UnaryOperation, ) -from decompiler.structures.pseudo.complextypes import Struct, Union +from decompiler.structures.pseudo.complextypes import Class, Struct, Union from decompiler.structures.pseudo.operations import MemberAccess @@ -67,9 +67,8 @@ def lift_set_field(self, assignment: mediumlevelil.MediumLevelILSetVarField, is_ """ # case 1 (struct), avoid set field of named integers: dest_type = self._lifter.lift(assignment.dest.type) - if isinstance(assignment.dest.type, binaryninja.NamedTypeReferenceType) and not ( - isinstance(dest_type, Pointer) and isinstance(dest_type.type, Integer) - ): + if isinstance(assignment.dest.type, binaryninja.NamedTypeReferenceType) and ( + isinstance(dest_type, Struct) or isinstance(dest_type, Class)): # otherwise get_member_by_offset not available struct_variable = self._lifter.lift(assignment.dest, is_aliased=True, parent=assignment) destination = MemberAccess( offset=assignment.offset, @@ -95,7 +94,7 @@ def lift_get_field(self, instruction: mediumlevelil.MediumLevelILVarField, is_al case 1: struct member read access e.g. (x = )book.title lift as (x = ) struct_member(book, title) case 2: accessing register portion e.g. (x = )eax.ah - lift as (x = ) eax & 0x0000ff00 + lift as (x = ) (uint8_t)(eax >> 8) (x = ) <- for the sake of example, only rhs expression is lifted here. """ source = self._lifter.lift(instruction.src, is_aliased=is_aliased, parent=instruction) @@ -103,10 +102,13 @@ def lift_get_field(self, instruction: mediumlevelil.MediumLevelILVarField, is_al return self._get_field_as_member_access(instruction, source, **kwargs) cast_type = source.type.resize(instruction.size * self.BYTE_SIZE) if instruction.offset: - return BinaryOperation( - OperationType.bitwise_and, - [source, Constant(self._get_all_ones_mask_for_type(instruction.size) << instruction.offset)], - vartype=cast_type, + return UnaryOperation( + OperationType.cast, + [BinaryOperation( + OperationType.right_shift_us, + [source, Constant(instruction.offset, Integer.int32_t())] + )], + cast_type ) return UnaryOperation(OperationType.cast, [source], vartype=cast_type, contraction=True) @@ -213,8 +215,14 @@ def lift_store_struct(self, instruction: mediumlevelil.MediumLevelILStoreStruct, """Lift a MLIL_STORE_STRUCT_SSA instruction to pseudo (e.g. object->field = x).""" vartype = self._lifter.lift(instruction.dest.expr_type) struct_variable = self._lifter.lift(instruction.dest, is_aliased=True, parent=instruction) + member = vartype.type.get_member_by_offset(instruction.offset) + if member is not None: + name = member.name + else: + name = f"__offset_{instruction.offset}" + name.replace("-", "minus_") struct_member_access = MemberAccess( - member_name=vartype.type.members.get(instruction.offset), + member_name=name, offset=instruction.offset, operands=[struct_variable], vartype=vartype, diff --git a/decompiler/frontend/binaryninja/handlers/types.py b/decompiler/frontend/binaryninja/handlers/types.py index 353a5922a..eeee02724 100644 --- a/decompiler/frontend/binaryninja/handlers/types.py +++ b/decompiler/frontend/binaryninja/handlers/types.py @@ -22,7 +22,7 @@ ) from decompiler.frontend.lifter import Handler from decompiler.structures.pseudo import CustomType, Float, FunctionTypeDef, Integer, Pointer, UnknownType, Variable -from decompiler.structures.pseudo.complextypes import ComplexTypeMember, ComplexTypeName, Enum, Struct +from decompiler.structures.pseudo.complextypes import Class, ComplexTypeMember, ComplexTypeName, Enum, Struct from decompiler.structures.pseudo.complextypes import Union as Union_ @@ -75,39 +75,60 @@ def lift_named_type_reference_type(self, custom: NamedTypeReferenceType, **kwarg def lift_enum(self, binja_enum: EnumerationType, name: str = None, **kwargs) -> Enum: """Lift enum type.""" - enum_name = name if name else self._get_data_type_name(binja_enum, keyword="enum") + type_id = hash(binja_enum) + enum_name = self._get_data_type_name(binja_enum, keyword="enum", provided_name=name) enum = Enum(binja_enum.width * self.BYTE_SIZE, enum_name, {}) for member in binja_enum.members: enum.add_member(self._lifter.lift(member)) - self._lifter.complex_types.add(enum) + self._lifter.complex_types.add(enum, type_id) return enum def lift_enum_member(self, enum_member: EnumerationMember, **kwargs) -> ComplexTypeMember: """Lift enum member type.""" return ComplexTypeMember(size=0, name=enum_member.name, offset=-1, type=Integer(32), value=int(enum_member.value)) - def lift_struct(self, struct: StructureType, name: str = None, **kwargs) -> Union[Struct, ComplexTypeName]: + def lift_struct(self, struct: StructureType, name: str = None, **kwargs) -> Union[Struct, Union_, Class, ComplexTypeName]: + type_id = hash(struct) + cached_type = self._lifter.complex_types.retrieve_by_id(type_id) + if cached_type is not None: + return cached_type + """Lift struct or union type.""" if struct.type == StructureVariant.StructStructureType: - type_name = name if name else self._get_data_type_name(struct, keyword="struct") - lifted_struct = Struct(struct.width * self.BYTE_SIZE, type_name, {}) + keyword, type, members = "struct", Struct, {} elif struct.type == StructureVariant.UnionStructureType: - type_name = name if name else self._get_data_type_name(struct, keyword="union") - lifted_struct = Union_(struct.width * self.BYTE_SIZE, type_name, []) + keyword, type, members = "union", Union_, [] + elif struct.type == StructureVariant.ClassStructureType: + keyword, type, members = "class", Class, {} else: raise RuntimeError(f"Unknown struct type {struct.type.name}") + + type_name = self._get_data_type_name(struct, keyword=keyword, provided_name=name) + lifted_struct = type(struct.width * self.BYTE_SIZE, type_name, members) + + self._lifter.complex_types.add(lifted_struct, type_id) for member in struct.members: lifted_struct.add_member(self.lift_struct_member(member, type_name)) - self._lifter.complex_types.add(lifted_struct) return lifted_struct @abstractmethod - def _get_data_type_name(self, complex_type: Union[StructureType, EnumerationType], keyword: str) -> str: - """Parse out the name of complex type.""" - string = complex_type.get_string() - if keyword in string: - return complex_type.get_string().split(keyword)[1] - return string + def _get_data_type_name(self, complex_type: Union[StructureType, EnumerationType], keyword: str, provided_name:str) -> str: + """Parse out the name of complex type. Empty and duplicate names are changed. + Calling this function has the side effect of incrementing a counter in the UniqueNameProvider.""" + if provided_name: + name = provided_name + else: + type_string = complex_type.get_string() + if keyword in type_string: + name = complex_type.get_string().split(keyword)[1] + else: + name = type_string + + if name.strip() == "": + name = f"__anonymous_{keyword}" + name = self._lifter.unique_name_provider.get_unique_name(name) + + return name def lift_struct_member(self, member: StructureMember, parent_struct_name: str = None) -> ComplexTypeMember: """Lift struct or union member.""" @@ -117,7 +138,7 @@ def lift_struct_member(self, member: StructureMember, parent_struct_name: str = else: # if member is an embedded struct/union, the name is already available member_type = self._lifter.lift(member.type, name=member.name) - return ComplexTypeMember(0, name=member.name, offset=member.offset, type=member_type) + return ComplexTypeMember(member_type.size, name=member.name, offset=member.offset, type=member_type) @abstractmethod def _get_member_pointer_on_the_parent_struct(self, member: StructureMember, parent_struct_name: str) -> ComplexTypeMember: diff --git a/decompiler/frontend/binaryninja/handlers/unary.py b/decompiler/frontend/binaryninja/handlers/unary.py index 180aecfd0..824f23f7c 100644 --- a/decompiler/frontend/binaryninja/handlers/unary.py +++ b/decompiler/frontend/binaryninja/handlers/unary.py @@ -99,7 +99,12 @@ def _lift_load_struct(self, instruction: mediumlevelil.MediumLevelILLoadStruct, struct_variable = self._lifter.lift(instruction.src) struct_ptr: Pointer = self._lifter.lift(instruction.src.expr_type) struct_member = struct_ptr.type.get_member_by_offset(instruction.offset) - return MemberAccess(vartype=struct_ptr, operands=[struct_variable], offset=struct_member.offset, member_name=struct_member.name) + if struct_member is not None: + name = struct_member.name + else: + name = f"__offset_{instruction.offset}" + name.replace("-", "minus_") + return MemberAccess(vartype=struct_ptr, operands=[struct_variable], offset=instruction.offset, member_name=name) def _lift_ftrunc(self, instruction: mediumlevelil.MediumLevelILFtrunc, **kwargs) -> UnaryOperation: """Lift a MLIL_FTRUNC operation.""" diff --git a/decompiler/frontend/binaryninja/lifter.py b/decompiler/frontend/binaryninja/lifter.py index e42761763..244df6ed4 100644 --- a/decompiler/frontend/binaryninja/lifter.py +++ b/decompiler/frontend/binaryninja/lifter.py @@ -6,7 +6,7 @@ from decompiler.frontend.lifter import ObserverLifter from decompiler.structures.pseudo import DataflowObject, Tag, UnknownExpression, UnknownType -from ...structures.pseudo.complextypes import ComplexTypeMap +from ...structures.pseudo.complextypes import ComplexTypeMap, UniqueNameProvider from .handlers import HANDLERS @@ -17,6 +17,7 @@ def __init__(self, no_bit_masks: bool = True, bv: BinaryView = None): self.no_bit_masks = no_bit_masks self.bv: BinaryView = bv self.complex_types: ComplexTypeMap = ComplexTypeMap() + self.unique_name_provider: UniqueNameProvider = UniqueNameProvider() for handler in HANDLERS: handler(self).register() diff --git a/decompiler/frontend/binaryninja/parser.py b/decompiler/frontend/binaryninja/parser.py index 43ccc684d..9c24b97dc 100644 --- a/decompiler/frontend/binaryninja/parser.py +++ b/decompiler/frontend/binaryninja/parser.py @@ -9,6 +9,7 @@ MediumLevelILBasicBlock, MediumLevelILConstPtr, MediumLevelILInstruction, + MediumLevelILJump, MediumLevelILJumpTo, MediumLevelILTailcallSsa, RegisterValueType, @@ -18,6 +19,7 @@ from decompiler.structures.graphs.cfg import BasicBlock, ControlFlowGraph, FalseCase, IndirectEdge, SwitchCase, TrueCase, UnconditionalEdge from decompiler.structures.pseudo import Constant, Instruction from decompiler.structures.pseudo.complextypes import ComplexTypeMap +from decompiler.structures.pseudo.instructions import Comment class BinaryninjaParser(Parser): @@ -135,6 +137,10 @@ def _get_lookup_table(self, block: MediumLevelILBasicBlock) -> Dict[int, List[Co lookup[target] += [Constant(value)] return lookup + def _has_undetermined_jump(self, basic_block: MediumLevelILBasicBlock) -> bool: + """Return True if basic-block is ending in a jump and has no outgoing edges""" + return bool(len(basic_block) and isinstance(basic_block[-1], MediumLevelILJump) and not basic_block.outgoing_edges) + def _lift_instructions(self, basic_block: MediumLevelILBasicBlock) -> Iterator[Instruction]: """Yield the lifted versions of all instructions in the given basic block.""" for instruction in basic_block: @@ -144,6 +150,8 @@ def _lift_instructions(self, basic_block: MediumLevelILBasicBlock) -> Iterator[I self._unlifted_instructions.append(instruction) continue yield lifted_instruction + if self._has_undetermined_jump(basic_block): + yield Comment("jump -> undetermined") def _report_lifter_errors(self): """Report instructions which could not be lifted and reset their counter.""" diff --git a/decompiler/pipeline/controlflowanalysis/__init__.py b/decompiler/pipeline/controlflowanalysis/__init__.py index a0cd40244..cf0a9593b 100644 --- a/decompiler/pipeline/controlflowanalysis/__init__.py +++ b/decompiler/pipeline/controlflowanalysis/__init__.py @@ -1,4 +1,5 @@ from .expression_simplification.stages import ExpressionSimplificationAst, ExpressionSimplificationCfg from .instruction_length_handler import InstructionLengthHandler +from .loop_name_generator import LoopNameGenerator from .readability_based_refinement import ReadabilityBasedRefinement from .variable_name_generation import VariableNameGeneration diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py index 81c87498c..f770ff6b1 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/constant_folding.py @@ -3,6 +3,7 @@ from typing import Callable, Optional from decompiler.structures.pseudo import Constant, Integer, OperationType +from decompiler.util.integer_util import normalize_int def constant_fold(operation: OperationType, constants: list[Constant]) -> Constant: @@ -103,27 +104,6 @@ def _constant_fold_shift(constants: list[Constant], fun: Callable[[int, int], in ) -def normalize_int(v: int, size: int, signed: bool) -> int: - """ - Normalizes an integer value to a specific size and signedness. - - This function takes an integer value 'v' and normalizes it to fit within - the specified 'size' in bits by discarding overflowing bits. If 'signed' is - true, the value is treated as a signed integer, i.e. interpreted as a two's complement. - Therefore the return value will be negative iff 'signed' is true and the most-significant bit is set. - - :param v: The value to be normalized. - :param size: The desired bit size for the normalized integer. - :param signed: True if the integer should be treated as signed. - :return: The normalized integer value. - """ - value = v & ((1 << size) - 1) - if signed and value & (1 << (size - 1)): - return value - (1 << size) - else: - return value - - _OPERATION_TO_FOLD_FUNCTION: dict[OperationType, Callable[[list[Constant]], Constant]] = { OperationType.minus: partial(_constant_fold_arithmetic_binary, fun=operator.sub), OperationType.plus: partial(_constant_fold_arithmetic_binary, fun=operator.add), diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py index f3559daae..2f9346a72 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/collapse_nested_constants.py @@ -1,11 +1,12 @@ from functools import reduce from typing import Iterator -from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import constant_fold +from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import FOLDABLE_OPERATIONS, constant_fold from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule from decompiler.structures.pseudo import Constant, Expression, Operation, OperationType, Type from decompiler.structures.pseudo.operations import COMMUTATIVE_OPERATIONS +_COLLAPSIBLE_OPERATIONS = COMMUTATIVE_OPERATIONS & FOLDABLE_OPERATIONS class CollapseNestedConstants(SimplificationRule): """ @@ -14,7 +15,7 @@ class CollapseNestedConstants(SimplificationRule): This stage exploits associativity and is the only stage doing so. Therefore, it cannot be replaced by a combination of `TermOrder` and `CollapseConstants`. """ def apply(self, operation: Operation) -> list[tuple[Expression, Expression]]: - if operation.operation not in COMMUTATIVE_OPERATIONS: + if operation.operation not in _COLLAPSIBLE_OPERATIONS: return [] if not isinstance(operation, Operation): raise TypeError(f"Expected Operation, got {type(operation)}") diff --git a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py index 6b358dd81..42da06986 100644 --- a/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py +++ b/decompiler/pipeline/controlflowanalysis/expression_simplification/rules/positive_constants.py @@ -1,6 +1,6 @@ -from decompiler.pipeline.controlflowanalysis.expression_simplification.constant_folding import normalize_int from decompiler.pipeline.controlflowanalysis.expression_simplification.rules.rule import SimplificationRule from decompiler.structures.pseudo import BinaryOperation, Constant, Expression, Integer, Operation, OperationType +from decompiler.util.integer_util import normalize_int class PositiveConstants(SimplificationRule): diff --git a/decompiler/pipeline/controlflowanalysis/loop_name_generator.py b/decompiler/pipeline/controlflowanalysis/loop_name_generator.py new file mode 100644 index 000000000..169a111d3 --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/loop_name_generator.py @@ -0,0 +1,123 @@ +from typing import List + +from decompiler.pipeline.controlflowanalysis.loop_utility_methods import ( + AstInstruction, + _find_continuation_instruction, + _get_variable_initialisation, + _requirement_without_reinitialization, + _single_defininition_reaches_node, +) +from decompiler.pipeline.stage import PipelineStage +from decompiler.structures.ast.ast_nodes import LoopNode +from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree +from decompiler.structures.pseudo import Assignment, Expression, Operation, Variable +from decompiler.task import DecompilerTask + + +class WhileLoopVariableRenamer: + """Iterate over While-Loop Nodes and rename their counter variables to counter, counter1, ...""" + + def __init__(self, ast: AbstractSyntaxTree): + self._ast = ast + self._variable_counter: int = 0 + + def rename(self): + """ + Iterate over While-Loop Nodes and rename their counter variables to counter, counter1, ... + + Only rename counter variables that suffice the following conditions: + -> any variable x is used in the loop condition + -> variable x is set inside the loop body + -> single definition of variable x reaches loop entry (x is initialized/used only once) + """ + + for loop_node in self._ast.get_while_loop_nodes_topological_order(): + if loop_node.is_endless_loop: + continue + for condition_var in loop_node.get_required_variables(self._ast.condition_map): + if not (variable_init := _get_variable_initialisation(self._ast, condition_var)): + continue + if not _find_continuation_instruction(self._ast, loop_node, condition_var, renaming=True): + continue + if not _single_defininition_reaches_node(self._ast, variable_init, loop_node): + continue + self._replace_variables(loop_node, variable_init) + break + + def _replace_variables(self, loop_node: LoopNode, variable_init: AstInstruction): + """ + Rename old variable usages to counter variable in: + - variable initialization + - condition/condition map + - loop body + Also add a copy instruction if the variable is used after the loop without reinitialization. + """ + new_variable = Variable(self._get_variable_name(), variable_init.instruction.destination.type) + self._ast.replace_variable_in_subtree(loop_node, variable_init.instruction.destination, new_variable) + if _requirement_without_reinitialization(self._ast, loop_node, variable_init.instruction.destination): + self._ast.add_instructions_after(loop_node, Assignment(variable_init.instruction.destination, new_variable)) + variable_init.node.replace_variable(variable_init.instruction.destination, new_variable) + + def _get_variable_name(self) -> str: + variable_name = f"counter{self._variable_counter if self._variable_counter > 0 else ''}" + self._variable_counter += 1 + return variable_name + + +class ForLoopVariableRenamer: + """Iterate over ForLoopNodes and rename their variables to i, j, ..., i1, j1, ...""" + + def __init__(self, ast: AbstractSyntaxTree, candidates: list[str]): + self._ast = ast + self._iteration: int = 0 + self._variable_counter: int = -1 + self._candidates: list[str] = candidates + + def rename(self): + """ + Iterate over ForLoopNodes and rename their variables to i, j, k, ... + We skip renaming for loops that are not initialized in its declaration. + """ + for loop_node in self._ast.get_for_loop_nodes_topological_order(): + if not isinstance(loop_node.declaration, Assignment): + continue + + old_variable: Variable = self._get_variable_from_assignment(loop_node.declaration.destination) + new_variable = Variable(self._get_variable_name(), old_variable.type, ssa_name=old_variable.ssa_name) + self._ast.replace_variable_in_subtree(loop_node, old_variable, new_variable) + + if _requirement_without_reinitialization(self._ast, loop_node, old_variable): + self._ast.add_instructions_after(loop_node, Assignment(old_variable, new_variable)) + + def _get_variable_name(self) -> str: + """Return variable names in the form of [i, j, ..., i1, j1, ...]""" + self._variable_counter += 1 + if self._variable_counter >= len(self._candidates): + self._variable_counter = 0 + self._iteration += 1 + return f"{self._candidates[self._variable_counter]}{self._iteration if self._iteration > 0 else ''}" + + def _get_variable_from_assignment(self, expr: Expression) -> Variable: + if isinstance(expr, Variable): + return expr + if isinstance(expr, Operation) and len(expr.operands) == 1: + return expr.operands[0] + raise ValueError("Did not expect a Constant/Unknown/Operation with more then 1 operand as a ForLoop declaration") + + +class LoopNameGenerator(PipelineStage): + """ + Stage which renames while/for-loops to custom names. + """ + + name = "loop-name-generator" + + def run(self, task: DecompilerTask): + rename_while_loops: bool = task.options.getboolean("loop-name-generator.rename_while_loop_variables", fallback=False) + for_loop_names: List[str] = task.options.getlist("loop-name-generator.for_loop_variable_names", fallback=[]) + + if rename_while_loops: + WhileLoopVariableRenamer(task._ast).rename() + + if for_loop_names: + ForLoopVariableRenamer(task._ast, for_loop_names).rename() diff --git a/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py b/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py new file mode 100644 index 000000000..6a2f7a4bc --- /dev/null +++ b/decompiler/pipeline/controlflowanalysis/loop_utility_methods.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional + +from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, LoopNode, SeqNode, SwitchNode +from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree +from decompiler.structures.logic.logic_condition import LogicCondition +from decompiler.structures.pseudo import Assignment, Condition, Variable +from decompiler.structures.visitors.assignment_visitor import AssignmentVisitor + + +@dataclass +class AstInstruction: + instruction: Assignment + position: int + node: CodeNode + +def _is_single_instruction_loop_node(loop_node: LoopNode) -> bool: + """ + Check if the loop body contains only one instruction. + + :param loop_node: LoopNode with a body + :return: True if body contains only one instruction else False + """ + body: AbstractSyntaxTreeNode = loop_node.body + if isinstance(body, CodeNode): + return len(body.instructions) == 1 + if isinstance(body, LoopNode): + return _is_single_instruction_loop_node(body) + if isinstance(body, (SeqNode, SwitchNode)): + return False + return False + + +def _has_deep_requirement(condition_map: Dict[LogicCondition, Condition], node: AbstractSyntaxTreeNode, variable: Variable) -> bool: + """ + Check if a variable is required in a node or any of its children. + + :param condition_map: logic condition to condition mapping + :param node: start node + :param variable: requirement to search for + :return: True if a requirement was found, else False + """ + if node is None: + return False + + if variable in node.get_required_variables(condition_map): + return True + + if isinstance(node, (SeqNode, SwitchNode, CaseNode)): + return any([_has_deep_requirement(condition_map, child, variable) for child in node.children]) + elif isinstance(node, ConditionNode): + return any( + [ + _has_deep_requirement(condition_map, node.true_branch_child, variable), + _has_deep_requirement(condition_map, node.false_branch_child, variable), + ] + ) + elif isinstance(node, LoopNode): + return _has_deep_requirement(condition_map, node.body, variable) + + +def _get_last_definition_index_of(node: CodeNode, variable: Variable) -> int: + """ + Iterate over CodeNode returning the index of last assignment to variable. + + :param node: node in which to search for last definition of variable + :param variable: check if definition contains this variable + :return: index of last definition or -1 if not found + """ + candidate = -1 + for position, instr in enumerate(node.instructions): + if variable in instr.definitions: + candidate = position + return candidate + + +def _get_last_requirement_index_of(node: CodeNode, variable: Variable) -> int: + """ + Iterate over CodeNode returning the index of last instruction using variable. + + :param node: node in which to search for last requirement of variable + :param variable: check if requirements contains this variable + :return: index of last definition or -1 if not found + """ + candidate = -1 + for position, instr in enumerate(node.instructions): + if variable in instr.requirements: + candidate = position + return candidate + + +def _find_continuation_instruction( + ast: AbstractSyntaxTree, node: AbstractSyntaxTreeNode, variable: Variable, renaming: bool = False +) -> Optional[AstInstruction]: + """ + Find a valid continuation instruction for a given variable inside a node. A valid continuation instruction defines the variable without + having requirements in later instructions. + + If we only want to rename the continuation instruction (instead of converting a while to a for-loop) we can additionally look at + switch / case nodes. + + :param node: node in which to search for last definition + :param variable: search instruction defining variable + :param renaming: continuation assignment for renaming purposes only + :return: AstInstruction if a definition without later requirement was found, else None + """ + iter_types = (SeqNode, SwitchNode) if renaming else SeqNode + if isinstance(node, iter_types): + for child in node.children[::-1]: + if instruction := _find_continuation_instruction(ast, child, variable, renaming): + return instruction + elif _has_deep_requirement(ast.condition_map, child, variable): + return None + elif renaming and isinstance(node, CaseNode): + return _find_continuation_instruction(ast, node.child, variable, renaming) + elif isinstance(node, LoopNode): + return _find_continuation_instruction(ast, node.body, variable, renaming) + elif isinstance(node, CodeNode): + last_req_index = _get_last_requirement_index_of(node, variable) + last_def_index = _get_last_definition_index_of(node, variable) + if last_req_index <= last_def_index != -1: + return AstInstruction(node.instructions[last_def_index], last_def_index, node) + + +def _get_variable_initialisation(ast: AbstractSyntaxTree, variable: Variable) -> Optional[AstInstruction]: + """ + Iterates over CodeNodes returning the first definition of variable. + + :param ast: AbstractSyntaxTree to search in + :param variable: find initialization of this variable + """ + for code_node in ast.get_code_nodes_topological_order(): + for position, instruction in enumerate(code_node.instructions): + if variable in instruction.definitions: + return AstInstruction(instruction, position, code_node) + + +def _single_defininition_reaches_node(ast: AbstractSyntaxTree, variable_init: AstInstruction, target_node: AbstractSyntaxTreeNode) -> bool: + """ + Check if a variable initialisation is redefined or used before target node. + + If we did not find the target node on the way down we still can assume there was no redefinition or usage. + + :param ast: AbstractSyntaxTree to search in + :param variable_init: AstInstruction containing the first variable initialisation + :param target_node: Search for redefinition or usages until this node is reached + """ + for ast_node in ast.get_reachable_nodes_pre_order(variable_init.node): + if ast_node is target_node: + return True + + defined_vars = list(ast_node.get_defined_variables(ast.condition_map)) + required_vars = list(ast_node.get_required_variables(ast.condition_map)) + used_variables = defined_vars + required_vars + + if ast_node is variable_init.node: + if used_variables.count(variable_init.instruction.destination) > 1: + return False + elif variable_init.instruction.destination in used_variables: + return False + return True + + +def _initialization_reaches_loop_node(init_node: AbstractSyntaxTreeNode, usage_node: AbstractSyntaxTreeNode) -> bool: + """ + Check if init node always reaches the usage node + + This is not the case if: + - nodes are separated by a LoopNode + - init-nodes parent is not a sequence node or not on a path from root to usage-node (only initialized under certain conditions) + + :param init_node: node where initialization takes place + :param usage_node: node that is potentially inside a LoopNode + :return: True if init and usage node are separated by a LoopNode else False + """ + init_parent = init_node.parent + iter_parent = usage_node.parent + if not isinstance(init_parent, SeqNode): + return False + while iter_parent is not init_parent: + if isinstance(iter_parent, LoopNode): + return False + iter_parent = iter_parent.parent + if iter_parent is None: + return False + return True + + +def _requirement_without_reinitialization(ast: AbstractSyntaxTree, node: AbstractSyntaxTreeNode, variable: Variable) -> bool: + """ + Check if a variable is used without prior initialization starting at a given node. + Edge case: definition and requirement in same instruction + + :param ast: + :param node: + :param variable: + :return: True if has requirement that is not prior reinitialized else False + """ + + for ast_node in ast.get_reachable_nodes_pre_order(node): + assignment_visitor = AssignmentVisitor() + assignment_visitor.visit(ast_node) + for assignment in assignment_visitor.assignments: + if variable in assignment.definitions and variable not in assignment.requirements: + return False + elif variable in assignment.definitions and variable in assignment.requirements: + return True + elif variable in assignment.requirements: + return True \ No newline at end of file diff --git a/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py b/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py index 048d4c098..c48621eb3 100644 --- a/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/readability_based_refinement.py @@ -1,225 +1,24 @@ -"""Module implementing various readbility based refinements.""" +"""Module implementing various readability based refinements.""" from __future__ import annotations -from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Union -from decompiler.pipeline.stage import PipelineStage -from decompiler.structures.ast.ast_nodes import ( - AbstractSyntaxTreeNode, - CaseNode, - CodeNode, - ConditionNode, - DoWhileLoopNode, - ForLoopNode, - LoopNode, - SeqNode, - SwitchNode, - WhileLoopNode, +from decompiler.pipeline.controlflowanalysis.loop_utility_methods import ( + AstInstruction, + _find_continuation_instruction, + _get_variable_initialisation, + _initialization_reaches_loop_node, + _is_single_instruction_loop_node, + _single_defininition_reaches_node, ) +from decompiler.pipeline.stage import PipelineStage +from decompiler.structures.ast.ast_nodes import ConditionNode, DoWhileLoopNode, ForLoopNode, WhileLoopNode from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree -from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.structures.pseudo import Assignment, Condition, Expression, Operation, Variable -from decompiler.structures.visitors.assignment_visitor import AssignmentVisitor +from decompiler.structures.pseudo import Assignment from decompiler.task import DecompilerTask from decompiler.util.options import Options -def _is_single_instruction_loop_node(loop_node: LoopNode) -> bool: - """ - Check if the loop body contains only one instruction. - - :param loop_node: LoopNode with a body - :return: True if body contains only one instruction else False - """ - body: AbstractSyntaxTreeNode = loop_node.body - if isinstance(body, CodeNode): - return len(body.instructions) == 1 - if isinstance(body, LoopNode): - return _is_single_instruction_loop_node(body) - if isinstance(body, (SeqNode, SwitchNode)): - return False - return False - - -def _has_deep_requirement(condition_map: Dict[LogicCondition, Condition], node: AbstractSyntaxTreeNode, variable: Variable) -> bool: - """ - Check if a variable is required in a node or any of its children. - - :param condition_map: logic condition to condition mapping - :param node: start node - :param variable: requirement to search for - :return: True if a requirement was found, else False - """ - if node is None: - return False - - if variable in node.get_required_variables(condition_map): - return True - - if isinstance(node, (SeqNode, SwitchNode, CaseNode)): - return any([_has_deep_requirement(condition_map, child, variable) for child in node.children]) - elif isinstance(node, ConditionNode): - return any( - [ - _has_deep_requirement(condition_map, node.true_branch_child, variable), - _has_deep_requirement(condition_map, node.false_branch_child, variable), - ] - ) - elif isinstance(node, LoopNode): - return _has_deep_requirement(condition_map, node.body, variable) - - -def _get_last_definition_index_of(node: CodeNode, variable: Variable) -> int: - """ - Iterate over CodeNode returning the index of last assignment to variable. - - :param node: node in which to search for last definition of variable - :param variable: check if definition contains this variable - :return: index of last definition or -1 if not found - """ - candidate = -1 - for position, instr in enumerate(node.instructions): - if variable in instr.definitions: - candidate = position - return candidate - - -def _get_last_requirement_index_of(node: CodeNode, variable: Variable) -> int: - """ - Iterate over CodeNode returning the index of last instruction using variable. - - :param node: node in which to search for last requirement of variable - :param variable: check if requirements contains this variable - :return: index of last definition or -1 if not found - """ - candidate = -1 - for position, instr in enumerate(node.instructions): - if variable in instr.requirements: - candidate = position - return candidate - - -def _find_continuation_instruction( - ast: AbstractSyntaxTree, node: AbstractSyntaxTreeNode, variable: Variable, renaming: bool = False -) -> Optional[AstInstruction]: - """ - Find a valid continuation instruction for a given variable inside a node. A valid continuation instruction defines the variable without - having requirements in later instructions. - - If we only want to rename the continuation instruction (instead of converting a while to a for-loop) we can additionally look at - switch / case nodes. - - :param node: node in which to search for last definition - :param variable: search instruction defining variable - :param renaming: continuation assignment for renaming purposes only - :return: AstInstruction if a definition without later requirement was found, else None - """ - iter_types = (SeqNode, SwitchNode) if renaming else SeqNode - if isinstance(node, iter_types): - for child in node.children[::-1]: - if instruction := _find_continuation_instruction(ast, child, variable, renaming): - return instruction - elif _has_deep_requirement(ast.condition_map, child, variable): - return None - elif renaming and isinstance(node, CaseNode): - return _find_continuation_instruction(ast, node.child, variable, renaming) - elif isinstance(node, LoopNode): - return _find_continuation_instruction(ast, node.body, variable, renaming) - elif isinstance(node, CodeNode): - last_req_index = _get_last_requirement_index_of(node, variable) - last_def_index = _get_last_definition_index_of(node, variable) - if last_req_index <= last_def_index != -1: - return AstInstruction(node.instructions[last_def_index], last_def_index, node) - - -def _get_variable_initialisation(ast: AbstractSyntaxTree, variable: Variable) -> Optional[AstInstruction]: - """ - Iterates over CodeNodes returning the first definition of variable. - - :param ast: AbstractSyntaxTree to search in - :param variable: find initialization of this variable - """ - for code_node in ast.get_code_nodes_topological_order(): - for position, instruction in enumerate(code_node.instructions): - if variable in instruction.definitions: - return AstInstruction(instruction, position, code_node) - - -def _single_defininition_reaches_node(ast: AbstractSyntaxTree, variable_init: AstInstruction, target_node: AbstractSyntaxTreeNode) -> bool: - """ - Check if a variable initialisation is redefined or used before target node. - - If we did not find the target node on the way down we still can assume there was no redefinition or usage. - - :param ast: AbstractSyntaxTree to search in - :param variable_init: AstInstruction containing the first variable initialisation - :param target_node: Search for redefinition or usages until this node is reached - """ - for ast_node in ast.get_reachable_nodes_pre_order(variable_init.node): - if ast_node is target_node: - return True - - defined_vars = list(ast_node.get_defined_variables(ast.condition_map)) - required_vars = list(ast_node.get_required_variables(ast.condition_map)) - used_variables = defined_vars + required_vars - - if ast_node is variable_init.node: - if used_variables.count(variable_init.instruction.destination) > 1: - return False - elif variable_init.instruction.destination in used_variables: - return False - return True - - -def _initialization_reaches_loop_node(init_node: AbstractSyntaxTreeNode, usage_node: AbstractSyntaxTreeNode) -> bool: - """ - Check if init node always reaches the usage node - - This is not the case if: - - nodes are separated by a LoopNode - - init-nodes parent is not a sequence node or not on a path from root to usage-node (only initialized under certain conditions) - - :param init_node: node where initialization takes place - :param usage_node: node that is potentially inside a LoopNode - :return: True if init and usage node are separated by a LoopNode else False - """ - init_parent = init_node.parent - iter_parent = usage_node.parent - if not isinstance(init_parent, SeqNode): - return False - while iter_parent is not init_parent: - if isinstance(iter_parent, LoopNode): - return False - iter_parent = iter_parent.parent - if iter_parent is None: - return False - return True - - -def _requirement_without_reinitialization(ast: AbstractSyntaxTree, node: AbstractSyntaxTreeNode, variable: Variable) -> bool: - """ - Check if a variable is used without prior initialization starting at a given node. - Edge case: definition and requirement in same instruction - - :param ast: - :param node: - :param variable: - :return: True if has requirement that is not prior reinitialized else False - """ - - for ast_node in ast.get_reachable_nodes_pre_order(node): - assignment_visitor = AssignmentVisitor() - assignment_visitor.visit(ast_node) - for assignment in assignment_visitor.assignments: - if variable in assignment.definitions and variable not in assignment.requirements: - return False - elif variable in assignment.definitions and variable in assignment.requirements: - return True - elif variable in assignment.requirements: - return True - - def _get_potential_guarded_do_while_loops(ast: AbstractSyntaxTree) -> tuple(Union[DoWhileLoopNode, WhileLoopNode], ConditionNode): for loop_node in list(ast.get_loop_nodes_post_order()): if isinstance(loop_node, DoWhileLoopNode) and isinstance(loop_node.parent.parent, ConditionNode): @@ -227,11 +26,11 @@ def _get_potential_guarded_do_while_loops(ast: AbstractSyntaxTree) -> tuple(Unio def remove_guarded_do_while(ast: AbstractSyntaxTree): - """ Removes a if statement which guards a do-while loop/while loop when: - -> there is nothing in between the if-node and the do-while-node/while-node - -> the if-node has only one branch (true branch) - -> the condition of the branch is the same as the condition of the do-while-node - Replacement is a WhileLoop, otherwise the control flow would not be correct + """Removes a if statement which guards a do-while loop/while loop when: + -> there is nothing in between the if-node and the do-while-node/while-node + -> the if-node has only one branch (true branch) + -> the condition of the branch is the same as the condition of the do-while-node + Replacement is a WhileLoop, otherwise the control flow would not be correct """ for do_while_node, condition_node in _get_potential_guarded_do_while_loops(ast): if condition_node.false_branch: @@ -242,49 +41,53 @@ def remove_guarded_do_while(ast: AbstractSyntaxTree): ast.substitute_loop_node(do_while_node, WhileLoopNode(do_while_node.condition, do_while_node.reaching_condition)) -@dataclass -class AstInstruction: - instruction: Assignment - position: int - node: CodeNode - - class WhileLoopReplacer: """Convert WhileLoopNodes to ForLoopNodes depending on the configuration. - -> keep_empty_for_loops will keep empty for-loops in the code - -> force_for_loops will transform every while-loop into a for-loop, worst case with empty declaration/modification statement - -> forbidden_condition_types_in_simple_for_loops will not transform trivial for-loop candidates (with only one condition) into for-loops - if the operator matches one of the forbidden operator list - -> max_condition_complexity_for_loop_recovery will transform for-loop candidates only into for-loops if the condition complexity is - less/equal then the threshold - -> max_modification_complexity_for_loop_recovery will transform for-loop candidates only into for-loops if the modification complexity is - less/equal then the threshold + -> keep_empty_for_loops will keep empty for-loops in the code + -> force_for_loops will transform every while-loop into a for-loop, worst case with empty declaration/modification statement + -> forbidden_condition_types_in_simple_for_loops will not transform trivial for-loop candidates (with only one condition) into for-loops + if the operator matches one of the forbidden operator list + -> max_condition_complexity_for_loop_recovery will transform for-loop candidates only into for-loops if the condition complexity is + less/equal then the threshold + -> max_modification_complexity_for_loop_recovery will transform for-loop candidates only into for-loops if the modification complexity is + less/equal then the threshold """ def __init__(self, ast: AbstractSyntaxTree, options: Options): self._ast = ast + self._restructure_for_loops = options.getboolean("readability-based-refinement.restructure_for_loops", fallback=True) self._keep_empty_for_loops = options.getboolean("readability-based-refinement.keep_empty_for_loops", fallback=False) self._hide_non_init_decl = options.getboolean("readability-based-refinement.hide_non_initializing_declaration", fallback=False) self._force_for_loops = options.getboolean("readability-based-refinement.force_for_loops", fallback=False) - self._forbidden_condition_types = options.getlist("readability-based-refinement.forbidden_condition_types_in_simple_for_loops", fallback=[]) - self._condition_max_complexity = options.getint("readability-based-refinement.max_condition_complexity_for_loop_recovery", fallback=100) - self._modification_max_complexity = options.getint("readability-based-refinement.max_modification_complexity_for_loop_recovery", fallback=100) + self._forbidden_condition_types = options.getlist( + "readability-based-refinement.forbidden_condition_types_in_simple_for_loops", fallback=[] + ) + self._condition_max_complexity = options.getint( + "readability-based-refinement.max_condition_complexity_for_loop_recovery", fallback=100 + ) + self._modification_max_complexity = options.getint( + "readability-based-refinement.max_modification_complexity_for_loop_recovery", fallback=100 + ) def run(self): """For each WhileLoop in AST check the following conditions: -> any variable in loop condition has a valid continuation instruction in loop body -> variable is initialized - -> loop condition complexity < condition complexity + -> loop condition complexity < condition complexity -> possible modification complexity < modification complexity -> if condition is only a symbol: check condition type for allowed one - - If 'force_for_loops' is enabled, the complexity options are ignored and every while loop after the - initial transformation will be forced into a for loop with an empty declaration/modification - """ + If 'force_for_loops' is enabled, the complexity options are ignored and every while loop after the + initial transformation will be forced into a for loop with an empty declaration/modification + """ + if not self._restructure_for_loops: + return for loop_node in list(self._ast.get_while_loop_nodes_topological_order()): - if loop_node.is_endless_loop or (not self._keep_empty_for_loops and _is_single_instruction_loop_node(loop_node)) \ - or self._invalid_simple_for_loop_condition_type(loop_node.condition): + if ( + loop_node.is_endless_loop + or (not self._keep_empty_for_loops and _is_single_instruction_loop_node(loop_node)) + or self._invalid_simple_for_loop_condition_type(loop_node.condition) + ): continue if any(node.does_end_with_continue for node in loop_node.body.get_descendant_code_nodes_interrupting_ancestor_loop()): @@ -308,11 +111,11 @@ def run(self): self._ast.substitute_loop_node( loop_node, ForLoopNode( - declaration=None, - condition=loop_node.condition, - modification=None, - reaching_condition=loop_node.reaching_condition, - ) + declaration=None, + condition=loop_node.condition, + modification=None, + reaching_condition=loop_node.reaching_condition, + ), ) def _replace_with_for_loop(self, loop_node: WhileLoopNode, continuation: AstInstruction, init: AstInstruction): @@ -347,9 +150,9 @@ def _replace_with_for_loop(self, loop_node: WhileLoopNode, continuation: AstInst ) continuation.node.instructions.remove(continuation.instruction) self._ast.clean_up() - + def _invalid_simple_for_loop_condition_type(self, logic_condition) -> bool: - """ Checks if a logic condition is only a symbol, if true checks condition type of symbol for forbidden ones""" + """Checks if a logic condition is only a symbol, if true checks condition type of symbol for forbidden ones""" if not logic_condition.is_symbol or not self._forbidden_condition_types: return False @@ -363,105 +166,13 @@ def _invalid_simple_for_loop_condition_type(self, logic_condition) -> bool: return False -class WhileLoopVariableRenamer: - """Iterate over While-Loop Nodes and rename their counter variables to counter, counter1, ...""" - - def __init__(self, ast: AbstractSyntaxTree): - self._ast = ast - self._variable_counter: int = 0 - - def rename(self): - """ - Iterate over While-Loop Nodes and rename their counter variables to counter, counter1, ... - - Only rename counter variables that suffice the following conditions: - -> any variable x is used in the loop condition - -> variable x is set inside the loop body - -> single definition of variable x reaches loop entry (x is initialized/used only once) - """ - - for loop_node in self._ast.get_while_loop_nodes_topological_order(): - if loop_node.is_endless_loop: - continue - for condition_var in loop_node.get_required_variables(self._ast.condition_map): - if not (variable_init := _get_variable_initialisation(self._ast, condition_var)): - continue - if not _find_continuation_instruction(self._ast, loop_node, condition_var, renaming=True): - continue - if not _single_defininition_reaches_node(self._ast, variable_init, loop_node): - continue - self._replace_variables(loop_node, variable_init) - break - - def _replace_variables(self, loop_node: LoopNode, variable_init: AstInstruction): - """ - Rename old variable usages to counter variable in: - - variable initialization - - condition/condition map - - loop body - Also add a copy instruction if the variable is used after the loop without reinitialization. - """ - new_variable = Variable(self._get_variable_name(), variable_init.instruction.destination.type) - self._ast.replace_variable_in_subtree(loop_node, variable_init.instruction.destination, new_variable) - if _requirement_without_reinitialization(self._ast, loop_node, variable_init.instruction.destination): - self._ast.add_instructions_after(loop_node, Assignment(variable_init.instruction.destination, new_variable)) - variable_init.node.replace_variable(variable_init.instruction.destination, new_variable) - - def _get_variable_name(self) -> str: - variable_name = f"counter{self._variable_counter if self._variable_counter > 0 else ''}" - self._variable_counter += 1 - return variable_name - - -class ForLoopVariableRenamer: - """Iterate over ForLoopNodes and rename their variables to i, j, ..., i1, j1, ...""" - - def __init__(self, ast: AbstractSyntaxTree, candidates: list[str]): - self._ast = ast - self._iteration: int = 0 - self._variable_counter: int = -1 - self._candidates: list[str] = candidates - - def rename(self): - """ - Iterate over ForLoopNodes and rename their variables to i, j, k, ... - We skip renaming for loops that are not initialized in its declaration. - """ - for loop_node in self._ast.get_for_loop_nodes_topological_order(): - if not isinstance(loop_node.declaration, Assignment): - continue - - old_variable: Variable = self._get_variable_from_assignment(loop_node.declaration.destination) - new_variable = Variable(self._get_variable_name(), old_variable.type, ssa_name=old_variable.ssa_name) - self._ast.replace_variable_in_subtree(loop_node, old_variable, new_variable) - loop_node.declaration.value.substitute(new_variable, old_variable) - - if _requirement_without_reinitialization(self._ast, loop_node, old_variable): - self._ast.add_instructions_after(loop_node, Assignment(old_variable, new_variable)) - - def _get_variable_name(self) -> str: - """Return variable names in the form of [i, j, ..., i1, j1, ...]""" - self._variable_counter += 1 - if self._variable_counter >= len(self._candidates): - self._variable_counter = 0 - self._iteration += 1 - return f"{self._candidates[self._variable_counter]}{self._iteration if self._iteration > 0 else ''}" - - def _get_variable_from_assignment(self, expr: Expression) -> Variable: - if isinstance(expr, Variable): - return expr - if isinstance(expr, Operation) and len(expr.operands) == 1: - return expr.operands[0] - raise ValueError("Did not expect a Constant/Unknown/Operation with more then 1 operand as a ForLoop declaration") - class ReadabilityBasedRefinement(PipelineStage): """ The ReadabilityBasedRefinement makes various transformations to improve readability based on the AST. Currently implemented transformations: - 1. while-loop to for-loop transformation - 2. for-loop variable renaming (e.g i, j, k, ...) - 3. while-loop variable renaming (e.g. counter, counter1, ...) + 1. remove guarded do while loops + 2. while-loop to for-loop transformation The AST is cleaned up before the first transformation and after every while- to for-loop transformation. """ @@ -472,12 +183,4 @@ def run(self, task: DecompilerTask): task.syntax_tree.clean_up() remove_guarded_do_while(task.syntax_tree) - WhileLoopReplacer(task.syntax_tree, task.options).run() - - variableNames = task.options.getlist("readability-based-refinement.for_loop_variable_names", fallback=[]) - if variableNames: - ForLoopVariableRenamer(task.syntax_tree, variableNames).rename() - - if task.options.getboolean("readability-based-refinement.rename_while_loop_variables"): - WhileLoopVariableRenamer(task.syntax_tree).rename() \ No newline at end of file diff --git a/decompiler/pipeline/default.py b/decompiler/pipeline/default.py index 7266e0596..c33a6c9ff 100644 --- a/decompiler/pipeline/default.py +++ b/decompiler/pipeline/default.py @@ -4,6 +4,7 @@ ExpressionSimplificationAst, ExpressionSimplificationCfg, InstructionLengthHandler, + LoopNameGenerator, ReadabilityBasedRefinement, VariableNameGeneration, ) @@ -47,5 +48,6 @@ ReadabilityBasedRefinement, ExpressionSimplificationAst, InstructionLengthHandler, - VariableNameGeneration + VariableNameGeneration, + LoopNameGenerator ] diff --git a/decompiler/pipeline/pipeline.py b/decompiler/pipeline/pipeline.py index 37fe1f7ed..829efa2ce 100644 --- a/decompiler/pipeline/pipeline.py +++ b/decompiler/pipeline/pipeline.py @@ -19,6 +19,9 @@ from decompiler.task import DecompilerTask from decompiler.util.decoration import DecoratedAST, DecoratedCFG +from ..structures.ast.ast_nodes import CodeNode +from ..structures.ast.syntaxtree import AbstractSyntaxTree +from ..structures.pseudo import Instruction from .default import AST_STAGES, CFG_STAGES from .stage import PipelineStage @@ -86,6 +89,7 @@ def run(self, task: DecompilerTask): print_ascii = output_format == "ascii" or output_format == "ascii_and_tabs" show_in_tabs = output_format == "tabs" or output_format == "ascii_and_tabs" debug_mode = task.options.getboolean("pipeline.debug", fallback=False) + validate_no_dataflow_dup = task.options.getboolean("pipeline.validate_no_dataflow_dup", fallback=False) self.validate() @@ -109,6 +113,12 @@ def run(self, task: DecompilerTask): raise e break + if validate_no_dataflow_dup: + if task.graph is not None: + self._assert_no_dataflow_duplicates(list(task.graph.instructions)) + if task.syntax_tree is not None: + self._assert_no_ast_duplicates(task.syntax_tree) + @staticmethod def _show_stage(task: DecompilerTask, stage_name: str, print_ascii: bool, show_in_tabs: bool): """Based on the task either an AST or a CFG is shown on the console (ASCII) and/or in BinaryNinja (FlowGraph) tabs.""" @@ -122,3 +132,23 @@ def _show_stage(task: DecompilerTask, stage_name: str, print_ascii: bool, show_i DecoratedCFG.print_ascii(task.graph, stage_name) if show_in_tabs: DecoratedCFG.show_flowgraph(task.graph, stage_name) + + @staticmethod + def _assert_no_ast_duplicates(ast: AbstractSyntaxTree): + instructions = [] + for node in ast.topological_order(): + if isinstance(node, CodeNode): + instructions.extend(node.instructions) + + DecompilerPipeline._assert_no_dataflow_duplicates(instructions) + + @staticmethod + def _assert_no_dataflow_duplicates(instructions: list[Instruction]): + encountered_ids: set[int] = set() + + for instruction in instructions: + for obj in instruction.subexpressions(): + if id(obj) in encountered_ids: + raise AssertionError(f"Found duplicated DataflowObject in cfg: {obj}") + + encountered_ids.add(id(obj)) diff --git a/decompiler/pipeline/preprocessing/missing_definitions.py b/decompiler/pipeline/preprocessing/missing_definitions.py index a322a8a05..d9c93a4a4 100644 --- a/decompiler/pipeline/preprocessing/missing_definitions.py +++ b/decompiler/pipeline/preprocessing/missing_definitions.py @@ -67,9 +67,19 @@ def get_smallest_label_copy(self, variable: Union[str, Variable]): return self._sorted_copies_of[variable][0] return min(self._copies_of_variable[variable], key=lambda var: var.ssa_label) + def _check_duplicated(self, var_name: str): + """ + Due to mixing of ssa_labels and memory versions, it can happen that we have duplicates in the copy pool. + E.g., [edx#0, edx#0, ...] + """ + ssa_labels = [var.ssa_label for var in self._sorted_copies_of[var_name]] + if any(i == j for i, j in zip(ssa_labels, ssa_labels[1:])): + raise ValueError(f"duplicate entries in copy pool for {var_name}") + def possible_missing_definitions_for(self, variable: Union[str, Variable]) -> List[Variable]: """Returns all variables whose definition may be missing because it is not the first in the order.""" var_name = self._get_variable_name(variable) + self._check_duplicated(var_name) return self._sorted_copies_of[var_name][1:] @staticmethod diff --git a/decompiler/structures/pseudo/complextypes.py b/decompiler/structures/pseudo/complextypes.py index b32528b4a..764143fa2 100644 --- a/decompiler/structures/pseudo/complextypes.py +++ b/decompiler/structures/pseudo/complextypes.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from decompiler.structures.pseudo.typing import Type @@ -28,6 +28,10 @@ def copy(self, **kwargs) -> Type: def declaration(self) -> str: raise NotImplementedError + @property + def complex_type_name(self): + return ComplexTypeName(0, self.name) + @dataclass(frozen=True, order=True) class ComplexTypeMember(ComplexType): @@ -54,16 +58,16 @@ def declaration(self) -> str: @dataclass(frozen=True, order=True) -class Struct(ComplexType): +class _BaseStruct(ComplexType): """Class representing a struct type.""" members: Dict[int, ComplexTypeMember] = field(compare=False) - type_specifier: ComplexTypeSpecifier = ComplexTypeSpecifier.STRUCT + type_specifier: ComplexTypeSpecifier def add_member(self, member: ComplexTypeMember): self.members[member.offset] = member - def get_member_by_offset(self, offset: int) -> ComplexTypeMember: + def get_member_by_offset(self, offset: int) -> Optional[ComplexTypeMember]: return self.members.get(offset) def declaration(self) -> str: @@ -71,6 +75,16 @@ def declaration(self) -> str: return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}" +@dataclass(frozen=True, order=True) +class Struct(_BaseStruct): + type_specifier: ComplexTypeSpecifier = ComplexTypeSpecifier.STRUCT + + +@dataclass(frozen=True, order=True) +class Class(_BaseStruct): + type_specifier: ComplexTypeSpecifier = ComplexTypeSpecifier.CLASS + + @dataclass(frozen=True, order=True) class Union(ComplexType): members: List[ComplexTypeMember] = field(compare=False) @@ -98,8 +112,9 @@ class Enum(ComplexType): def add_member(self, member: ComplexTypeMember): self.members[member.value] = member - def get_name_by_value(self, value: int) -> str: - return self.members.get(value).name + def get_name_by_value(self, value: int) -> Optional[str]: + member = self.members.get(value) + return member.name if member is not None else None def declaration(self) -> str: members = ",\n\t".join(f"{x.name} = {x.value}" for x in self.members.values()) @@ -117,19 +132,46 @@ def __str__(self) -> str: return self.name +class UniqueNameProvider: + """ The purpose of this class is to provide unique names for types, as duplicate names can potentially be encountered in the lifting stage (especially anonymous structs, etc.) + This class keeps track of all the names already used. If duplicates are found, they are renamed by appending suffixes with incrementing numbers. + E.g. `classname`, `classname__2`, `classname__3`, ... + """ + + def __init__(self): + self._name_to_count: Dict[str, int] = {} + + def get_unique_name(self, name: str) -> str: + """ This method returns the input name if it was unique so far. + Otherwise it returns the name with an added incrementing suffix. + In any case, the name occurence of the name is counted. + """ + if name not in self._name_to_count: + self._name_to_count[name] = 1 + return name + else: + self._name_to_count[name] += 1 + return f"{name}__{self._name_to_count[name]}" + + class ComplexTypeMap: """A class in charge of storing complex custom/user defined types by their string representation""" def __init__(self): self._name_to_type_map: Dict[ComplexTypeName, ComplexType] = {} + self._id_to_type_map: Dict[int, ComplexType] = {} - def retrieve_by_name(self, typename: ComplexTypeName) -> ComplexType: + def retrieve_by_name(self, typename: ComplexTypeName) -> Optional[ComplexType]: """Get complex type by name; used to avoid recursion.""" return self._name_to_type_map.get(typename, None) - def add(self, complex_type: ComplexType): + def retrieve_by_id(self, id: int) -> Optional[ComplexType]: + return self._id_to_type_map.get(id, None) + + def add(self, complex_type: ComplexType, type_id: int): """Add complex type to the mapping.""" - self._name_to_type_map[ComplexTypeName(0, complex_type.name)] = complex_type + self._id_to_type_map[type_id] = complex_type + self._name_to_type_map[complex_type.complex_type_name] = complex_type def pretty_print(self): for t in self._name_to_type_map.values(): diff --git a/decompiler/structures/pseudo/expressions.py b/decompiler/structures/pseudo/expressions.py index 41ef63f27..582af3002 100644 --- a/decompiler/structures/pseudo/expressions.py +++ b/decompiler/structures/pseudo/expressions.py @@ -193,7 +193,10 @@ def __str__(self) -> str: Constants of type Enum are represented as strings (corresponding enumerator identifiers). """ if isinstance(self._type, Enum): - return self._type.get_name_by_value(self.value) + name = self._type.get_name_by_value(self.value) + if name is not None: + return name + # otherwise, i.e. if value is not found in Enum class, fall through if self._type.is_boolean: return "true" if self.value else "false" if isinstance(self.value, str): diff --git a/decompiler/structures/pseudo/instructions.py b/decompiler/structures/pseudo/instructions.py index 6db63f018..1d89e76e0 100644 --- a/decompiler/structures/pseudo/instructions.py +++ b/decompiler/structures/pseudo/instructions.py @@ -142,7 +142,7 @@ def writes_memory(self) -> Optional[int]: """Return the memory version generated by this assignment, if any.""" if isinstance(self.value, Call): return self.value.writes_memory - if isinstance(self.destination, UnaryOperation) and self.destination.operation == OperationType.dereference: + if isinstance(self.destination, UnaryOperation) and self.destination.operation in {OperationType.member_access, OperationType.dereference}: return self.destination.writes_memory for variable in self.definitions: if variable.is_aliased: diff --git a/decompiler/task.py b/decompiler/task.py index 38f149a7b..d476e71ec 100644 --- a/decompiler/task.py +++ b/decompiler/task.py @@ -39,6 +39,7 @@ def __init__( self._failed = False self._failure_origin = None self._complex_types = complex_types if complex_types else ComplexTypeMap() + self._code = None @property def name(self) -> str: @@ -99,3 +100,13 @@ def failure_message(self) -> str: def complex_types(self) -> ComplexTypeMap: """Return complex types present in the function (structs, unions, enums, etc.).""" return self._complex_types + + @property + def code(self) -> str: + """Return C-Code representation for the Task.""" + return self._code + + @code.setter + def code(self, value): + """Setter function for C-Code representation of the Task""" + self._code = value \ No newline at end of file diff --git a/decompiler/util/bugfinder/bugfinder.py b/decompiler/util/bugfinder/bugfinder.py index f72f92984..22b7ec9d6 100644 --- a/decompiler/util/bugfinder/bugfinder.py +++ b/decompiler/util/bugfinder/bugfinder.py @@ -14,7 +14,18 @@ # Add project root to path (script located in dewolf/decompiler/util/bugfinder/) project_root = Path(__file__).resolve().parents[3] sys.path.append(str(project_root)) -from binaryninja import BinaryViewType, Function, core_version +from binaryninja import Function, core_version + +# use binaryninja.load for BN 3.5 up +version_numbers = core_version().split(".") +major, minor = int(version_numbers[0]), int(version_numbers[1]) +if major >= 3 and minor >= 5: + from binaryninja import load +else: + from binaryninja import BinaryViewType + + load = BinaryViewType.get_view_of_file + from decompile import Decompiler from decompiler.frontend import BinaryninjaFrontend from decompiler.logger import configure_logging @@ -124,7 +135,7 @@ def get_function_info(function: Function) -> dict: "function_size": function.highest_address - function.start, "function_arch": str(function.arch), "function_platform": str(function.platform), - "timestamp": datetime.now() + "timestamp": datetime.now(), } @staticmethod @@ -198,7 +209,7 @@ def iter_function_reports(self, sample) -> Iterator[dict]: def store_reports_from_sample(sample: Path, db_reports: DBConnector, max_size: int): """Store all reports from sample into database""" logging.info(f"processing {sample}") - if not (binary_view := BinaryViewType.get_view_of_file(sample)): + if not (binary_view := load(sample)): logging.warning(f"Could not get BinaryView '{sample}'") return try: diff --git a/decompiler/util/default.json b/decompiler/util/default.json index 5ae9ba0b5..a478f55bd 100644 --- a/decompiler/util/default.json +++ b/decompiler/util/default.json @@ -275,6 +275,16 @@ "is_hidden_from_cli": false, "argument_name": "--return-complexity-threshold" }, + { + "dest": "readability-based-refinement.restructure_for_loops", + "default": true, + "type": "boolean", + "title": "Enable for-loop recovery", + "description": "If enabled, certain while-loops will be transformed to for-loops. If set to false, no for-loops will be emitted at all.", + "is_hidden_from_gui": false, + "is_hidden_from_cli": false, + "argument_name": "--restructure-for-loops" + }, { "dest": "readability-based-refinement.keep_empty_for_loops", "default": false, @@ -285,24 +295,6 @@ "is_hidden_from_cli": false, "argument_name": "--empty-for-loops" }, - { - "dest": "readability-based-refinement.for_loop_variable_names", - "default": [ - "i", - "j", - "k", - "l", - "m", - "n" - ], - "type": "array", - "elementType": "string", - "title": "Rename for-loop variables into desired names", - "description": "Rename for-loop variables to values from list", - "is_hidden_from_gui": false, - "is_hidden_from_cli": false, - "argument_name": "--for-loop-variable-names" - }, { "dest": "readability-based-refinement.max_condition_complexity_for_loop_recovery", "default": 100, @@ -347,16 +339,6 @@ "is_hidden_from_cli": false, "argument_name": "--for-loop-exclude-conditions" }, - { - "dest": "readability-based-refinement.rename_while_loop_variables", - "default": true, - "type": "boolean", - "title": "Rename while-loop variables", - "description": "Rename while-loop counter variables to counter, counter1, ...", - "is_hidden_from_gui": false, - "is_hidden_from_cli": false, - "argument_name": "--rename-while-loop-variables" - }, { "dest": "variable-name-generation.notation", "default": "default", @@ -412,6 +394,34 @@ "is_hidden_from_cli": false, "argument_name": "--variable-generation-counter-separator" }, + { + "dest": "loop-name-generator.rename_while_loop_variables", + "default": true, + "type": "boolean", + "title": "Rename while-loop variables", + "description": "Rename while-loop counter variables to counter, counter1, ...", + "is_hidden_from_gui": false, + "is_hidden_from_cli": false, + "argument_name": "--rename-while-loop-variables" + }, + { + "dest": "loop-name-generator.for_loop_variable_names", + "default": [ + "i", + "j", + "k", + "l", + "m", + "n" + ], + "type": "array", + "elementType": "string", + "title": "Rename for-loop variables", + "description": "Rename for-loop variables to values from given list", + "is_hidden_from_gui": false, + "is_hidden_from_cli": false, + "argument_name": "--for-loop-variable-names" + }, { "dest": "code-generator.max_complexity", "default": 100, @@ -631,7 +641,8 @@ "readability-based-refinement", "expression-simplification-ast", "instruction-length-handler", - "variable-name-generation" + "variable-name-generation", + "loop-name-generator" ], "title": "AST pipeline stages", "type": "array", @@ -641,6 +652,16 @@ "is_hidden_from_cli": false, "argument_name": "--ast-stages" }, + { + "dest": "pipeline.validate_no_dataflow_dup", + "default": false, + "title": "Validate no DataflowObject duplication", + "type": "boolean", + "description": "Throw exception if duplicate DataflowObjects exist after any stage", + "is_hidden_from_gui": true, + "is_hidden_from_cli": false, + "argument_name": "--validate-no-dataflow-dup" + }, { "dest": "pipeline.debug", "default": false, diff --git a/decompiler/util/integer_util.py b/decompiler/util/integer_util.py new file mode 100644 index 000000000..1e96f62bf --- /dev/null +++ b/decompiler/util/integer_util.py @@ -0,0 +1,19 @@ +def normalize_int(v: int, size: int, signed: bool) -> int: + """ + Normalizes an integer value to a specific size and signedness. + + This function takes an integer value 'v' and normalizes it to fit within + the specified 'size' in bits by discarding overflowing bits. If 'signed' is + true, the value is treated as a signed integer, i.e. interpreted as a two's complement. + Therefore the return value will be negative iff 'signed' is true and the most-significant bit is set. + + :param v: The value to be normalized. + :param size: The desired bit size for the normalized integer. + :param signed: True if the integer should be treated as signed. + :return: The normalized integer value. + """ + value = v & ((1 << size) - 1) + if signed and value & (1 << (size - 1)): + return value - (1 << size) + else: + return value \ No newline at end of file diff --git a/tests/pipeline/controlflowanalysis/test_loop_name_generator.py b/tests/pipeline/controlflowanalysis/test_loop_name_generator.py new file mode 100644 index 000000000..8a1a1aeec --- /dev/null +++ b/tests/pipeline/controlflowanalysis/test_loop_name_generator.py @@ -0,0 +1,1334 @@ +from typing import List + +import pytest +from decompiler.pipeline.controlflowanalysis.loop_name_generator import ForLoopVariableRenamer, LoopNameGenerator, WhileLoopVariableRenamer +from decompiler.pipeline.controlflowanalysis.loop_utility_methods import _initialization_reaches_loop_node +from decompiler.pipeline.controlflowanalysis.readability_based_refinement import ReadabilityBasedRefinement +from decompiler.structures.ast.ast_nodes import CaseNode, CodeNode, ConditionNode, ForLoopNode, SeqNode, SwitchNode, WhileLoopNode +from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree +from decompiler.structures.logic.logic_condition import LogicCondition +from decompiler.structures.pseudo import ( + Assignment, + BinaryOperation, + Break, + Call, + Condition, + Constant, + ImportedFunctionSymbol, + ListOperation, + OperationType, + Variable, +) +from decompiler.structures.pseudo.operations import ArrayInfo, OperationType, UnaryOperation +from decompiler.task import DecompilerTask +from decompiler.util.options import Options + +# Test For/WhileLoop Renamer + +def logic_cond(name: str, context) -> LogicCondition: + return LogicCondition.initialize_symbol(name, context) + +@pytest.fixture +def ast_call_for_loop() -> AbstractSyntaxTree: + """ + a = 5; + while(b = foo; b <= 5; b++){ + a++; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, + ) + code_node = ast._add_code_node( + instructions=[ + Assignment(Variable("a"), Constant(5)), + ] + ) + loop_node = ast.factory.create_for_loop_node(Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), logic_cond("x1", context), Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("1")])), + ] + ) + ast._add_node(loop_node) + ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) + root._sorted_children = (code_node, loop_node) + return ast + + +def test_declaration_listop(ast_call_for_loop): + """Test renaming with ListOperation as Declaration""" + ForLoopVariableRenamer(ast_call_for_loop, ["i"]).rename() + for node in ast_call_for_loop: + if isinstance(node, ForLoopNode): + assert node.declaration.destination.operands[0].name == "i" + + +def test_for_loop_variable_generation(): + renamer = ForLoopVariableRenamer( + AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}), + ["i", "j", "k", "l", "m", "n"] + ) + assert [renamer._get_variable_name() for _ in range(14)] == [ + "i", + "j", + "k", + "l", + "m", + "n", + "i1", + "j1", + "k1", + "l1", + "m1", + "n1", + "i2", + "j2", + ] + + +def test_while_loop_variable_generation(): + renamer = WhileLoopVariableRenamer( + AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}) + ) + assert [renamer._get_variable_name() for _ in range(5)] == ["counter", "counter1", "counter2", "counter3", "counter4"] + +# Test Readabilitybasedrefinement + LoopNameGenerator together + + +def _generate_options(empty_loops: bool = False, hide_decl: bool = False, rename_for: bool = True, rename_while: bool = True, \ + max_condition: int = 100, max_modification: int = 100, force_for_loops: bool = False, blacklist : List[str] = []) -> Options: + options = Options() + options.set("readability-based-refinement.keep_empty_for_loops", empty_loops) + options.set("readability-based-refinement.hide_non_initializing_declaration", hide_decl) + options.set("readability-based-refinement.max_condition_complexity_for_loop_recovery", max_condition) + options.set("readability-based-refinement.max_modification_complexity_for_loop_recovery", max_modification) + options.set("readability-based-refinement.force_for_loops", force_for_loops) + options.set("readability-based-refinement.forbidden_condition_types_in_simple_for_loops", blacklist) + if rename_for: + names = ["i", "j", "k", "l", "m", "n"] + options.set("loop-name-generator.for_loop_variable_names", names) + options.set("loop-name-generator.rename_while_loop_variables", rename_while) + return options + + +@pytest.fixture +def ast_array_access_for_loop() -> AbstractSyntaxTree: + """ + for (var_0 = 0; var_0 < 10; var_0 = var_0 + 1) { + *(var_1 + var_0) = var_0; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("var_0"), Constant(10)])}, + ) + declaration = Assignment(Variable("var_0"), Constant(0)) + condition = logic_cond("x1", context) + modification = Assignment(Variable("var_0"), BinaryOperation(OperationType.plus, [Variable("var_0"), Constant(1)])) + for_loop = ast.factory.create_for_loop_node(declaration, condition, modification) + array_info = ArrayInfo(Variable("var_1"), Variable("var_0")) + array_access_unary_operation = UnaryOperation( + OperationType.dereference, [BinaryOperation(OperationType.plus, [Variable("var_1"), Variable("var_0")])], array_info=array_info + ) + for_loop_body = ast._add_code_node([Assignment(array_access_unary_operation, Variable("var_0"))]) + ast._add_node(for_loop) + ast._add_edges_from([(root, for_loop), (for_loop, for_loop_body)]) + return ast + + +@pytest.fixture +def ast_while_true() -> AbstractSyntaxTree: + """ + a = 0; + b = 0; + while(true){ + a = a + 1; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree(root := SeqNode(true_value), {}) + code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + loop_node = ast.add_endless_loop_with_body(loop_node_body) + ast._add_edges_from(((root, code_node), (root, loop_node))) + return ast + + +@pytest.fixture +def ast_single_instruction_while() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 10) { + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) + return ast + + +@pytest.fixture +def ast_replaceable_while() -> AbstractSyntaxTree: + """ + a = 0; + while (x < 10) { + printf("counter: %d", x); + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) + root._sorted_children = (init_code_node, while_loop) + return ast + + +@pytest.fixture +def ast_replaceable_while_usage() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 10) { + printf("counter: %d", a); + a = a + 1; + } + printf("final counter: %d", a); + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + exit_code_node = ast._add_code_node( + [Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")]))] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) + return ast + + +@pytest.fixture +def ast_replaceable_while_reinit_usage() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 10) { + printf("counter: %d", a); + a = a + 1; + } + a = 50; + printf("50 = %d", a); + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + exit_code_node = ast._add_code_node( + [ + Assignment(Variable("a"), Constant(50)), + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) + return ast + + +@pytest.fixture +def ast_replaceable_while_compound_usage() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 10) { + printf("counter: %d", a); + a = a + 1; + } + a = a + 50; + printf("50 = %d", a); + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + exit_code_node = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(50)])), + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) + return ast + + +@pytest.fixture +def ast_endless_while_init_outside() -> AbstractSyntaxTree: + """ + a = 0; + while (true) { + while (a < 5) { + printf("%d\n", a); + a = a + 1; + } + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) + ast._add_node(inner_while) + endless_loop = ast.add_endless_loop_with_body(inner_while) + + inner_while_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_edges_from([(root, init_code_node), (root, endless_loop), (endless_loop, inner_while), (inner_while, inner_while_body)]) + return ast + + +@pytest.fixture +def ast_nested_while() -> AbstractSyntaxTree: + """ + a = 0; + while (a < 1) { + b = 0; + while (b < 1) { + b = b + 1; + } + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(5)]), + }, + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + outer_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) + outer_while_body = ast.factory.create_seq_node() + outer_while_init = ast._add_code_node([Assignment(Variable("b"), Constant(0))]) + outer_while_exit = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + + inner_while = ast.factory.create_while_loop_node(logic_cond("x2", context)) + inner_while_body = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + + ast._add_nodes_from((outer_while, outer_while_body, inner_while)) + ast._add_edges_from( + [ + (root, init_code_node), + (root, outer_while), + (outer_while, outer_while_body), + (outer_while_body, outer_while_init), + (outer_while_body, inner_while), + (outer_while_body, outer_while_exit), + (inner_while, inner_while_body), + ] + ) + return ast + + +@pytest.fixture +def ast_call_init() -> AbstractSyntaxTree: + """ + a = 5; + b = foo(); + while(b <= 5){ + a = a + b; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, + ) + code_node = ast._add_code_node( + instructions=[ + Assignment(Variable("a"), Constant(5)), + Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), + ] + ) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_node(loop_node) + ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) + root._sorted_children = (code_node, loop_node) + return ast + + +@pytest.fixture +def ast_redundant_init() -> AbstractSyntaxTree: + """ + b = 0; + a = 5; + b = 2; + + while(b <= 5){ + a = a + b; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} + ) + code_node = ast._add_code_node( + instructions=[ + Assignment(Variable("b"), Constant(0)), + Assignment(Variable("a"), Constant(5)), + Assignment(Variable("b"), Constant(2)), + ] + ) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_node(loop_node) + ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) + root._sorted_children = (code_node, loop_node) + return ast + + +@pytest.fixture +def ast_reinit_in_condition_true() -> AbstractSyntaxTree: + """ + int x = 1; + int i = 0; + + if (x == 1) { + i = 1; + } + + while (i < 10) { + x = x * 2; + i = i + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("a", context): Condition(OperationType.less, [Variable("i"), Constant(10)]), + logic_cond("b", context): Condition(OperationType.equal, [Variable("x"), Constant(1)]), + }, + ) + code_node = ast._add_code_node(instructions=[Assignment(Variable("x"), Constant(1)), Assignment(Variable("i"), Constant(0))]) + code_node_true = ast._add_code_node([Assignment(Variable("i"), Constant(1))]) + condition_node = ast._add_condition_node_with(logic_cond("b", context), code_node_true) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("a", context)) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("x"), BinaryOperation(OperationType.multiply, [Variable("x"), Constant(2)])), + Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])), + ] + ) + ast._add_nodes_from((condition_node, loop_node)) + ast._add_edges_from(((root, code_node), (root, condition_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) + root._sorted_children = (code_node, loop_node) + return ast + + +@pytest.fixture +def ast_usage_in_condition() -> AbstractSyntaxTree: + """ + int a = 1; + int b = 0; + + if (b == 1) { + a = 1; + } + + while (b < 10) { + a = a * 2; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), + logic_cond("x2", context): Condition(OperationType.equal, [Variable("b"), Constant(1)]), + }, + ) + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(1)), Assignment(Variable("b"), Constant(0))]) + code_node_true = ast._add_code_node([Assignment(Variable("a"), Constant(1))]) + condition_node = ast._add_condition_node_with(logic_cond("x2", context), code_node_true) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) + loop_node_body = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.multiply, [Variable("a"), Constant(2)])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_node(loop_node) + ast._add_edges_from(((root, init_code_node), (root, condition_node), (root, loop_node), (loop_node, loop_node_body))) + ast._code_node_reachability_graph.add_reachability(init_code_node, loop_node_body) + root._sorted_children = (init_code_node, loop_node) + return ast + + +@pytest.fixture +def ast_sequenced_while_loops() -> AbstractSyntaxTree: + """ + a = 0; + b = 0; + + while (a < 5) { + printf("%d\n", a); + a++; + } + + while (b < 5) { + printf("%d\n", b); + b++; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(5)]), + }, + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) + + while_loop_1 = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_1_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + while_loop_2 = ast.factory.create_while_loop_node(logic_cond("x2", context)) + while_loop_2_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("b")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + + ast._add_nodes_from((while_loop_1, while_loop_2)) + ast._add_edges_from( + ( + (root, init_code_node), + (root, while_loop_1), + (root, while_loop_2), + (while_loop_1, while_loop_1_body), + (while_loop_2, while_loop_2_body), + ) + ) + return ast + + +@pytest.fixture +def ast_switch_as_loop_body() -> AbstractSyntaxTree: + """ + This while-loop should not be replaced with a for-loop because we don't know wich value 'a' has. + + Code of AST: + a = 5; + b = 0; + while(b <= 5){ + switch(a) { + case 0: + a = a + b: + break; + case 1: + b = b + 1; + break; + } + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("a", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, + ) + code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("a", context)) + root._sorted_children = (code_node, loop_node) + loop_body_switch = ast.factory.create_switch_node(Variable("a")) + loop_body_case_1 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) + code_node_case_1 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) + loop_body_case_2 = ast.factory.create_case_node(Variable("a"), Constant(1), break_case=True) + code_node_case_2 = ast._add_code_node( + [ + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_nodes_from((code_node, loop_node, loop_body_switch, loop_body_case_1, loop_body_case_2)) + ast._add_edges_from( + ( + (root, code_node), + (root, loop_node), + (loop_node, loop_body_switch), + (loop_body_switch, loop_body_case_1), + (loop_body_switch, loop_body_case_2), + (loop_body_case_1, code_node_case_1), + (loop_body_case_2, code_node_case_2), + ) + ) + ast._code_node_reachability_graph.add_reachability_from(((code_node, code_node_case_1), (code_node, code_node_case_2))) + return ast + + +@pytest.fixture +def ast_switch_as_loop_body_increment() -> AbstractSyntaxTree: + """ + This loop should be replaced with a for-loop because b has no usages after last definition, is in condition and is initialized + before loop without any usages in between. + + Code of AST: + a = 5; + b = 0; + while(b <= 5){ + switch(a) { + case 0: + a = a + b: + break; + case 1: + b = b + 1; + break; + } + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_seq = ast.factory.create_seq_node() + + switch_node = ast.factory.create_switch_node(Variable("a")) + case_1 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) + case_1_code = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) + case_2 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) + case_2_code = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + + increment_code = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + + ast._add_nodes_from((while_loop, while_loop_seq, switch_node, case_1, case_2)) + ast._add_edges_from( + [ + (root, init_code_node), + (root, while_loop), + (while_loop, while_loop_seq), + (while_loop_seq, switch_node), + (while_loop_seq, increment_code), + (switch_node, case_1), + (switch_node, case_2), + (case_1, case_1_code), + (case_2, case_2_code), + ] + ) + return ast + + +@pytest.fixture +def ast_init_in_switch() -> AbstractSyntaxTree: + """ + a = 5; + b = 0; + switch(a){ + case 0: + a = b; + } + while(b <= (5 + a)){ + a = a + b; + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition( + OperationType.less_or_equal, + [Variable("b"), BinaryOperation(OperationType.plus, [Constant(5), Variable("a")])], + ) + }, + ) + init_code_node = ast._add_code_node(instructions=[Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) + switch_node = ast.factory.create_switch_node(Variable("a")) + loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) + case_node = ast.factory.create_case_node(Variable("a"), Constant(0)) + case_child = ast._add_code_node([Assignment(Variable("a"), Variable("b"))]) + loop_body = ast.factory.create_seq_node() + loop_body_child = ast._add_code_node( + [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + ast._add_nodes_from((switch_node, loop_node, loop_body, case_node)) + ast._add_edges_from( + ( + (root, init_code_node), + (root, switch_node), + (switch_node, case_node), + (case_node, case_child), + (root, loop_node), + (loop_node, loop_body), + (loop_body, loop_body_child), + ) + ) + ast._code_node_reachability_graph.add_reachability_from([(case_child, loop_body_child)]) + root._sorted_children = (init_code_node, switch_node, loop_node) + loop_body._sorted_children = (loop_body_child,) + switch_node._sorted_cases = (case_node,) + return ast + + +@pytest.fixture +def ast_while_in_else() -> AbstractSyntaxTree: + """ + while (true) { + if (b < 2) { + break; + } else { + a = 0; + while (a < 5) { + printf("%d\n", a); + a = a + 1; + } + } + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(2)]), + }, + ) + + inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) + ast._add_node(inner_while) + + true_branch_child = ast._add_code_node([Break()]) + inner_seq = ast.factory.create_seq_node() + ast._add_node(inner_seq) + condition_node = ast._add_condition_node_with(logic_cond("x2", context), true_branch_child, inner_seq) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) + + endless_loop = ast.add_endless_loop_with_body(condition_node) + + inner_while_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_edges_from( + [ + (root, endless_loop), + (endless_loop, condition_node), + (inner_seq, init_code_node), + (inner_seq, inner_while), + (inner_while, inner_while_body), + ] + ) + return ast + + +@pytest.fixture +def ast_continuation_is_not_first_var() -> AbstractSyntaxTree: + """ + a = 0; + b = 0; + while (a < b) { + printf("%d\n", a); + b = b + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Variable("b")])}, + ) + + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) + + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) + root._sorted_children = (init_code_node, while_loop) + return ast + + +@pytest.fixture +def ast_initialization_in_condition() -> AbstractSyntaxTree: + """ + if(b < 10 ){ + a = 5; + while (x < 10) { + printf("counter: %d", a); + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x0", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), + }, + ) + + true_branch = ast._add_code_node([Assignment(Variable("a"), Constant(5))]) + if_condition = ast._add_condition_node_with(logic_cond("x0", context), true_branch) + while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from([(root, if_condition), (root, while_loop), (while_loop, while_loop_body)]) + root._sorted_children = (if_condition, while_loop) + return ast + + +@pytest.fixture +def ast_initialization_in_condition_sequence() -> AbstractSyntaxTree: + """ + if(b < 10 ){ + if(c < 10){ + b = 5; + } + a = 5; + while (x < 10) { + printf("counter: %d", a); + a = a + 1; + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x0", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), + logic_cond("x1", context): Condition(OperationType.less, [Variable("c"), Constant(10)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), + }, + ) + + true_branch_c = ast._add_code_node([Assignment(Variable("b"), Constant(5))]) + code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5))]) + if_condition_c = ast._add_condition_node_with(logic_cond("x1", context), true_branch_c) + ast._add_node(true_branch_b := ast.factory.create_seq_node()) + if_condition_b = ast._add_condition_node_with(logic_cond("x1", context), true_branch_b) + while_loop = ast.factory.create_while_loop_node(logic_cond("x2", context)) + while_loop_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), + ] + ) + + ast._add_node(while_loop) + ast._add_edges_from( + [ + (root, if_condition_b), + (root, while_loop), + (while_loop, while_loop_body), + (true_branch_b, if_condition_c), + (true_branch_b, code_node), + ] + ) + true_branch_b._sorted_children = (if_condition_c, code_node) + root._sorted_children = (if_condition_b, while_loop) + return ast + + +class TestReadabilityBasedRefinementAndLoopNameGenerator: + """Test Readability functionality with all its substages.""" + + @staticmethod + def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): + task = DecompilerTask("func", cfg=None, ast=ast, options=options) + ReadabilityBasedRefinement().run(task) + LoopNameGenerator().run(task) + + + def test_no_replacement(self, ast_while_true): + self.run_rbr(ast_while_true) + assert all(not isinstance(node, ForLoopNode) for node in ast_while_true.topological_order()) + + def test_simple_replacement(self, ast_replaceable_while): + self.run_rbr(ast_replaceable_while) + + assert ast_replaceable_while.condition_map == { + logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("i"), Constant(10)]) + } + + loop_node = ast_replaceable_while.root + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_body = loop_node.body + assert isinstance(loop_body, CodeNode) + assert loop_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("i")])), + ] + + def test_with_usage(self, ast_replaceable_while_usage): + self.run_rbr(ast_replaceable_while_usage) + + for_loop = ast_replaceable_while_usage.root.children[0] + assert isinstance(for_loop, ForLoopNode) + assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) + + copy_instr_node = ast_replaceable_while_usage.root.children[1] + assert isinstance(copy_instr_node, CodeNode) + assert copy_instr_node.instructions == [Assignment(Variable("a"), Variable("i"))] + + def test_with_usage_redefinition(self, ast_replaceable_while_reinit_usage): + self.run_rbr(ast_replaceable_while_reinit_usage) + + for_loop = ast_replaceable_while_reinit_usage.root.children[0] + assert isinstance(for_loop, ForLoopNode) + assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) + assert for_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + exit_code_node = ast_replaceable_while_reinit_usage.root.children[1] + assert isinstance(exit_code_node, CodeNode) + assert exit_code_node.instructions == [ + Assignment(Variable("a"), Constant(50)), + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), + ] + + def test_with_usage_redefenition_2(self, ast_replaceable_while_compound_usage): + self.run_rbr(ast_replaceable_while_compound_usage) + + for_loop = ast_replaceable_while_compound_usage.root.children[0] + assert isinstance(for_loop, ForLoopNode) + assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) + assert for_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + copy_instr_node = ast_replaceable_while_compound_usage.root.children[1] + assert isinstance(copy_instr_node, CodeNode) + assert copy_instr_node.instructions == [Assignment(Variable("a"), Variable("i"))] + + def test_continuation_is_not_first_var(self, ast_continuation_is_not_first_var): + self.run_rbr(ast_continuation_is_not_first_var) + + init_code_node = ast_continuation_is_not_first_var.root.children[0] + assert isinstance(init_code_node, CodeNode) + assert init_code_node.instructions == [Assignment(Variable("a"), Constant(0))] + + loop_node = ast_continuation_is_not_first_var.root.children[1] + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_node_body = loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])) + ] + + def test_init_with_call(self, ast_call_init): + self.run_rbr(ast_call_init, _generate_options(rename_for=True)) + + code_node = ast_call_init.root.children[0] + assert isinstance(code_node, CodeNode) + assert code_node.instructions == [Assignment(Variable("a"), Constant(5))] + + for_loop_node = ast_call_init.root.children[1] + assert isinstance(for_loop_node, ForLoopNode) + assert for_loop_node.declaration == Assignment(Variable("i"), Call(ImportedFunctionSymbol("foo", 0), [])) + assert for_loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_node_body = for_loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("i")])) + ] + + assert for_loop_node.condition == logic_cond("x1", context := LogicCondition.generate_new_context()) + assert ast_call_init.condition_map == { + logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("i"), Constant(5)]) + } + + def test_double_init(self, ast_redundant_init): + self.run_rbr(ast_redundant_init) + + code_node = ast_redundant_init.root.children[0] + assert isinstance(code_node, CodeNode) + assert code_node.instructions == [ + Assignment(Variable("b"), Constant(0)), + Assignment(Variable("a"), Constant(5)), + Assignment(Variable("b"), Constant(2)), + ] + + for_loop_node = ast_redundant_init.root.children[1] + assert isinstance(for_loop_node, ForLoopNode) + assert for_loop_node.declaration == Variable("b") + assert for_loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) + + loop_node_body = for_loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), + ] + + assert for_loop_node.condition == logic_cond("x1", context := LogicCondition.generate_new_context()) + assert ast_redundant_init.condition_map == {logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} + + def test_double_init_condition_node(self, ast_reinit_in_condition_true): + self.run_rbr(ast_reinit_in_condition_true) + + def test_init_in_switch(self, ast_init_in_switch): + self.run_rbr(ast_init_in_switch) + + init_code_node = ast_init_in_switch.root.children[0] + assert isinstance(init_code_node, CodeNode) + assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))] + + loop_node = ast_init_in_switch.root.children[2] + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Variable("b") + assert loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) + + loop_node_body = loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])) + ] + + def test_usage_in_condition(self, ast_usage_in_condition): + self.run_rbr(ast_usage_in_condition) + + code_node = ast_usage_in_condition.root.children[0] + assert isinstance(code_node, CodeNode) + assert code_node.instructions == [Assignment(Variable("a"), Constant(1)), Assignment(Variable("b"), Constant(0))] + + condition_node = ast_usage_in_condition.root.children[1] + assert isinstance(condition_node, ConditionNode) + assert condition_node.condition == logic_cond("x2", context := LogicCondition.generate_new_context()) + + loop_node = ast_usage_in_condition.root.children[2] + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Variable("b") + assert loop_node.condition == logic_cond("x1", context) + assert loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) + + loop_body = loop_node.body + assert isinstance(loop_body, CodeNode) + assert loop_body.instructions == [Assignment(Variable("a"), BinaryOperation(OperationType.multiply, [Variable("a"), Constant(2)]))] + + def test_while_in_else(self, ast_while_in_else): + self.run_rbr(ast_while_in_else) + + endless_loop = ast_while_in_else.root + assert isinstance(endless_loop, WhileLoopNode) + + condition_node = endless_loop.body + assert isinstance(condition_node, ConditionNode) + + loop_node = condition_node.false_branch_child + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_node_body = loop_node.body + assert isinstance(loop_node_body, CodeNode) + assert loop_node_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("i")])) + ] + + def test_nested_while(self, ast_nested_while): + self.run_rbr(ast_nested_while, _generate_options(empty_loops=True)) + + outer_loop = ast_nested_while.root + assert isinstance(outer_loop, ForLoopNode) + assert outer_loop.declaration == Assignment(Variable("i"), Constant(0)) + assert ast_nested_while.condition_map[outer_loop.condition] == Condition(OperationType.less, [Variable("i"), Constant(5)]) + assert outer_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + inner_loop = outer_loop.children[0] + assert isinstance(inner_loop, ForLoopNode) + assert inner_loop.declaration == Assignment(Variable("j"), Constant(0)) + assert ast_nested_while.condition_map[inner_loop.condition] == Condition(OperationType.less, [Variable("j"), Constant(5)]) + assert inner_loop.modification == Assignment(Variable("j"), BinaryOperation(OperationType.plus, [Variable("j"), Constant(1)])) + + def test_nested_while_loop(self, ast_endless_while_init_outside): + self.run_rbr(ast_endless_while_init_outside) + + endless_loop = ast_endless_while_init_outside.root.children[1] + assert isinstance(endless_loop, WhileLoopNode) + + for_loop = endless_loop.body + assert isinstance(for_loop, ForLoopNode) + assert for_loop.declaration == Variable("a") + + def test_sequenced_loops(self, ast_sequenced_while_loops): + self.run_rbr(ast_sequenced_while_loops) + + loop_1 = ast_sequenced_while_loops.root.children[0] + assert isinstance(loop_1, ForLoopNode) + assert loop_1.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_1.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + loop_1_body = loop_1.body + assert isinstance(loop_1_body, CodeNode) + assert loop_1_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("i")])), + ] + + loop_2 = ast_sequenced_while_loops.root.children[1] + assert isinstance(loop_2, ForLoopNode) + assert loop_2.declaration == Assignment(Variable("j"), Constant(0)) + assert loop_2.modification == Assignment(Variable("j"), BinaryOperation(OperationType.plus, [Variable("j"), Constant(1)])) + + loop_2_body = loop_2.body + assert isinstance(loop_2_body, CodeNode) + assert loop_2_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("j")])), + ] + + def test_switch_as_loop_body(self, ast_switch_as_loop_body): + self.run_rbr(ast_switch_as_loop_body) + + assert all(not isinstance(node, ForLoopNode) for node in ast_switch_as_loop_body.topological_order()) + + init_code_node = ast_switch_as_loop_body.root.children[0] + assert isinstance(init_code_node, CodeNode) + assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5)), Assignment(Variable("counter"), Constant(0))] + + while_node = ast_switch_as_loop_body.root.children[1] + assert isinstance(while_node, WhileLoopNode) + + switch_node = while_node.body + assert isinstance(switch_node, SwitchNode) + + case_1_body = switch_node.children[0].child + assert isinstance(case_1_body, CodeNode) + assert case_1_body.instructions == [ + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("counter")])) + ] + + case_2_body = switch_node.children[1].child + assert isinstance(case_2_body, CodeNode) + assert case_2_body.instructions == [ + Assignment(Variable("counter"), BinaryOperation(OperationType.plus, [Variable("counter"), Constant(1)])) + ] + + def test_switch_as_loop_with_increment(self, ast_switch_as_loop_body_increment): + self.run_rbr(ast_switch_as_loop_body_increment) + + init_code_node = ast_switch_as_loop_body_increment.root.children[0] + assert isinstance(init_code_node, CodeNode) + assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5))] + + loop_node = ast_switch_as_loop_body_increment.root.children[1] + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) + assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) + + switch_node = loop_node.body + assert isinstance(switch_node, SwitchNode) + + case_1 = switch_node.children[0] + assert isinstance(case_1, CaseNode) + + case_1_body = case_1.child + assert isinstance(case_1_body, CodeNode) + assert case_1_body.instructions == [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("i")]))] + + case_2 = switch_node.children[1] + assert isinstance(case_2, CaseNode) + + case_2_body = case_2.child + assert isinstance(case_2_body, CodeNode) + assert case_2_body.instructions == [Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)]))] + + assert ast_switch_as_loop_body_increment.condition_map == { + logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("i"), Constant(5)]) + } + + def test_rename_unary_operation_updates_array_info(self, ast_array_access_for_loop): + """Test if UnaryOperation.ArrayInfo gets updated on renaming""" + self.run_rbr(ast_array_access_for_loop, _generate_options(rename_for=True)) + + def find_unary_op(ast): + """look for UnaryOperation in AST""" + for node in ast.get_code_nodes_topological_order(): + for instr in node.instructions: + for unary_op in instr: + if isinstance(unary_op, UnaryOperation): + return unary_op + return None + + unary_operation = find_unary_op(ast_array_access_for_loop) + if not isinstance(unary_operation, UnaryOperation): + assert False, "Did not find UnaryOperation in AST (expect it for array access)" + assert unary_operation.array_info is not None + assert unary_operation.array_info.base in unary_operation.requirements + assert unary_operation.array_info.index in unary_operation.requirements + + def test_no_for_loop_renaming(self, ast_replaceable_while): + self.run_rbr(ast_replaceable_while, _generate_options(rename_for=False)) + + assert ast_replaceable_while.condition_map == { + logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("a"), Constant(10)]) + } + + loop_node = ast_replaceable_while.root + assert isinstance(loop_node, ForLoopNode) + assert loop_node.declaration == Assignment(Variable("a"), Constant(0)) + assert loop_node.modification == Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])) + + loop_body = loop_node.body + assert isinstance(loop_body, CodeNode) + assert loop_body.instructions == [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + ] + + def test_init_may_not_reach_loop_1(self, ast_initialization_in_condition): + assert ( + _initialization_reaches_loop_node( + ast_initialization_in_condition.root.children[0].true_branch_child, ast_initialization_in_condition.root.children[1] + ) + is False + ) + + self.run_rbr(ast_initialization_in_condition, _generate_options()) + assert any( + isinstance(for_loop := loop, ForLoopNode) for loop in ast_initialization_in_condition.get_for_loop_nodes_topological_order() + ) + assert for_loop.declaration == Variable("a") + + def test_init_may_not_reach_loop_2(self, ast_initialization_in_condition_sequence): + assert ( + _initialization_reaches_loop_node( + ast_initialization_in_condition_sequence.root.children[0].true_branch_child.children[1], + ast_initialization_in_condition_sequence.root.children[1], + ) + is False + ) + + self.run_rbr(ast_initialization_in_condition_sequence, _generate_options()) + assert any( + isinstance(for_loop := loop, ForLoopNode) + for loop in ast_initialization_in_condition_sequence.get_for_loop_nodes_topological_order() + ) + assert for_loop.declaration == Variable("a") + + @pytest.mark.parametrize("keep_empty_for_loops", [True, False]) + def test_keep_empty_for_loop(self, keep_empty_for_loops: bool, ast_single_instruction_while): + self.run_rbr(ast_single_instruction_while, _generate_options(keep_empty_for_loops)) + + if keep_empty_for_loops: + assert isinstance(ast_single_instruction_while.root, ForLoopNode) + else: + assert isinstance(ast_single_instruction_while.root.children[1], WhileLoopNode) diff --git a/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py b/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py index eecd157a3..7cd0c82f4 100644 --- a/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py +++ b/tests/pipeline/controlflowanalysis/test_readability_based_refinement.py @@ -1,16 +1,13 @@ from typing import List import pytest -from decompiler.pipeline.controlflowanalysis.readability_based_refinement import ( - ForLoopVariableRenamer, - ReadabilityBasedRefinement, - WhileLoopReplacer, - WhileLoopVariableRenamer, +from decompiler.pipeline.controlflowanalysis.loop_utility_methods import ( _find_continuation_instruction, _has_deep_requirement, _initialization_reaches_loop_node, ) -from decompiler.structures.ast.ast_nodes import CaseNode, CodeNode, ConditionNode, ForLoopNode, SeqNode, SwitchNode, WhileLoopNode +from decompiler.pipeline.controlflowanalysis.readability_based_refinement import ReadabilityBasedRefinement, WhileLoopReplacer +from decompiler.structures.ast.ast_nodes import ConditionNode, ForLoopNode, SeqNode, WhileLoopNode from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree from decompiler.structures.logic.logic_condition import LogicCondition from decompiler.structures.pseudo import ( @@ -26,7 +23,7 @@ OperationType, Variable, ) -from decompiler.structures.pseudo.operations import ArrayInfo, OperationType, UnaryOperation +from decompiler.structures.pseudo.operations import OperationType from decompiler.task import DecompilerTask from decompiler.util.options import Options @@ -35,15 +32,19 @@ def logic_cond(name: str, context) -> LogicCondition: return LogicCondition.initialize_symbol(name, context) -def _generate_options(empty_loops: bool = False, hide_decl: bool = False, rename_for: bool = True, rename_while: bool = True, \ - max_condition: int = 100, max_modification: int = 100, force_for_loops: bool = False, blacklist : List[str] = []) -> Options: +def _generate_options( + restructure: bool = True, + empty_loops: bool = False, + hide_decl: bool = False, + max_condition: int = 100, + max_modification: int = 100, + force_for_loops: bool = False, + blacklist: List[str] = [], +) -> Options: options = Options() + options.set("readability-based-refinement.restructure_for_loops", restructure) options.set("readability-based-refinement.keep_empty_for_loops", empty_loops) options.set("readability-based-refinement.hide_non_initializing_declaration", hide_decl) - if rename_for: - names = ["i", "j", "k", "l", "m", "n"] - options.set("readability-based-refinement.for_loop_variable_names", names) - options.set("readability-based-refinement.rename_while_loop_variables", rename_while) options.set("readability-based-refinement.max_condition_complexity_for_loop_recovery", max_condition) options.set("readability-based-refinement.max_modification_complexity_for_loop_recovery", max_modification) options.set("readability-based-refinement.force_for_loops", force_for_loops) @@ -51,905 +52,6 @@ def _generate_options(empty_loops: bool = False, hide_decl: bool = False, rename return options -@pytest.fixture -def ast_array_access_for_loop() -> AbstractSyntaxTree: - """ - for (var_0 = 0; var_0 < 10; var_0 = var_0 + 1) { - *(var_1 + var_0) = var_0; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("var_0"), Constant(10)])}, - ) - declaration = Assignment(Variable("var_0"), Constant(0)) - condition = logic_cond("x1", context) - modification = Assignment(Variable("var_0"), BinaryOperation(OperationType.plus, [Variable("var_0"), Constant(1)])) - for_loop = ast.factory.create_for_loop_node(declaration, condition, modification) - array_info = ArrayInfo(Variable("var_1"), Variable("var_0")) - array_access_unary_operation = UnaryOperation( - OperationType.dereference, [BinaryOperation(OperationType.plus, [Variable("var_1"), Variable("var_0")])], array_info=array_info - ) - for_loop_body = ast._add_code_node([Assignment(array_access_unary_operation, Variable("var_0"))]) - ast._add_node(for_loop) - ast._add_edges_from([(root, for_loop), (for_loop, for_loop_body)]) - return ast - - -@pytest.fixture -def ast_while_true() -> AbstractSyntaxTree: - """ - a = 0; - b = 0; - while(true){ - a = a + 1; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree(root := SeqNode(true_value), {}) - code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - loop_node = ast.add_endless_loop_with_body(loop_node_body) - ast._add_edges_from(((root, code_node), (root, loop_node))) - return ast - - -@pytest.fixture -def ast_single_instruction_while() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 10) { - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) - return ast - - -@pytest.fixture -def ast_replaceable_while() -> AbstractSyntaxTree: - """ - a = 0; - while (x < 10) { - printf("counter: %d", x); - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) - root._sorted_children = (init_code_node, while_loop) - return ast - - -@pytest.fixture -def ast_replaceable_while_usage() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 10) { - printf("counter: %d", a); - a = a + 1; - } - printf("final counter: %d", a); - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - exit_code_node = ast._add_code_node( - [Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")]))] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) - return ast - - -@pytest.fixture -def ast_replaceable_while_reinit_usage() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 10) { - printf("counter: %d", a); - a = a + 1; - } - a = 50; - printf("50 = %d", a); - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - exit_code_node = ast._add_code_node( - [ - Assignment(Variable("a"), Constant(50)), - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) - return ast - - -@pytest.fixture -def ast_replaceable_while_compound_usage() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 10) { - printf("counter: %d", a); - a = a + 1; - } - a = a + 50; - printf("50 = %d", a); - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - exit_code_node = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(50)])), - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (root, exit_code_node), (while_loop, while_loop_body)]) - return ast - - -@pytest.fixture -def ast_endless_while_init_outside() -> AbstractSyntaxTree: - """ - a = 0; - while (true) { - while (a < 5) { - printf("%d\n", a); - a = a + 1; - } - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) - ast._add_node(inner_while) - endless_loop = ast.add_endless_loop_with_body(inner_while) - - inner_while_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_edges_from([(root, init_code_node), (root, endless_loop), (endless_loop, inner_while), (inner_while, inner_while_body)]) - return ast - - -@pytest.fixture -def ast_nested_while() -> AbstractSyntaxTree: - """ - a = 0; - while (a < 1) { - b = 0; - while (b < 1) { - b = b + 1; - } - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), - logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(5)]), - }, - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - outer_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) - outer_while_body = ast.factory.create_seq_node() - outer_while_init = ast._add_code_node([Assignment(Variable("b"), Constant(0))]) - outer_while_exit = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) - - inner_while = ast.factory.create_while_loop_node(logic_cond("x2", context)) - inner_while_body = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) - - ast._add_nodes_from((outer_while, outer_while_body, inner_while)) - ast._add_edges_from( - [ - (root, init_code_node), - (root, outer_while), - (outer_while, outer_while_body), - (outer_while_body, outer_while_init), - (outer_while_body, inner_while), - (outer_while_body, outer_while_exit), - (inner_while, inner_while_body), - ] - ) - return ast - - -@pytest.fixture -def ast_call_init() -> AbstractSyntaxTree: - """ - a = 5; - b = foo(); - while(b <= 5){ - a = a + b; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, - ) - code_node = ast._add_code_node( - instructions=[ - Assignment(Variable("a"), Constant(5)), - Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), - ] - ) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_self_referential_init() -> AbstractSyntaxTree: - """ - a = 5; - b = foo(b); - while(b <= 5){ - a = a + b; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, - ) - code_node = ast._add_code_node( - instructions=[ - Assignment(Variable("a"), Constant(5)), - Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [Variable("b")])), - ] - ) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_call_for_loop() -> AbstractSyntaxTree: - """ - a = 5; - while(b = foo; b <= 5; b++){ - a++; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, - ) - code_node = ast._add_code_node( - instructions=[ - Assignment(Variable("a"), Constant(5)), - ] - ) - loop_node = ast.factory.create_for_loop_node(Assignment(ListOperation([Variable("b")]), Call(ImportedFunctionSymbol("foo", 0), [])), logic_cond("x1", context), Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("1")])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_redundant_init() -> AbstractSyntaxTree: - """ - b = 0; - a = 5; - b = 2; - - while(b <= 5){ - a = a + b; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} - ) - code_node = ast._add_code_node( - instructions=[ - Assignment(Variable("b"), Constant(0)), - Assignment(Variable("a"), Constant(5)), - Assignment(Variable("b"), Constant(2)), - ] - ) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, code_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_reinit_in_condition_true() -> AbstractSyntaxTree: - """ - int x = 1; - int i = 0; - - if (x == 1) { - i = 1; - } - - while (i < 10) { - x = x * 2; - i = i + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("a", context): Condition(OperationType.less, [Variable("i"), Constant(10)]), - logic_cond("b", context): Condition(OperationType.equal, [Variable("x"), Constant(1)]), - }, - ) - code_node = ast._add_code_node(instructions=[Assignment(Variable("x"), Constant(1)), Assignment(Variable("i"), Constant(0))]) - code_node_true = ast._add_code_node([Assignment(Variable("i"), Constant(1))]) - condition_node = ast._add_condition_node_with(logic_cond("b", context), code_node_true) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("a", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("x"), BinaryOperation(OperationType.multiply, [Variable("x"), Constant(2)])), - Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])), - ] - ) - ast._add_nodes_from((condition_node, loop_node)) - ast._add_edges_from(((root, code_node), (root, condition_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(code_node, loop_node_body) - root._sorted_children = (code_node, loop_node) - return ast - - -@pytest.fixture -def ast_usage_in_condition() -> AbstractSyntaxTree: - """ - int a = 1; - int b = 0; - - if (b == 1) { - a = 1; - } - - while (b < 10) { - a = a * 2; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), - logic_cond("x2", context): Condition(OperationType.equal, [Variable("b"), Constant(1)]), - }, - ) - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(1)), Assignment(Variable("b"), Constant(0))]) - code_node_true = ast._add_code_node([Assignment(Variable("a"), Constant(1))]) - condition_node = ast._add_condition_node_with(logic_cond("x2", context), code_node_true) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - loop_node_body = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.multiply, [Variable("a"), Constant(2)])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_node(loop_node) - ast._add_edges_from(((root, init_code_node), (root, condition_node), (root, loop_node), (loop_node, loop_node_body))) - ast._code_node_reachability_graph.add_reachability(init_code_node, loop_node_body) - root._sorted_children = (init_code_node, loop_node) - return ast - - -@pytest.fixture -def ast_sequenced_while_loops() -> AbstractSyntaxTree: - """ - a = 0; - b = 0; - - while (a < 5) { - printf("%d\n", a); - a++; - } - - while (b < 5) { - printf("%d\n", b); - b++; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), - logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(5)]), - }, - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) - - while_loop_1 = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_1_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - while_loop_2 = ast.factory.create_while_loop_node(logic_cond("x2", context)) - while_loop_2_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - - ast._add_nodes_from((while_loop_1, while_loop_2)) - ast._add_edges_from( - ( - (root, init_code_node), - (root, while_loop_1), - (root, while_loop_2), - (while_loop_1, while_loop_1_body), - (while_loop_2, while_loop_2_body), - ) - ) - return ast - - -@pytest.fixture -def ast_switch_as_loop_body() -> AbstractSyntaxTree: - """ - This while-loop should not be replaced with a for-loop because we don't know wich value 'a' has. - - Code of AST: - a = 5; - b = 0; - while(b <= 5){ - switch(a) { - case 0: - a = a + b: - break; - case 1: - b = b + 1; - break; - } - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("a", context): Condition(OperationType.less_or_equal, [Variable("b"), Constant(5)])}, - ) - code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("a", context)) - root._sorted_children = (code_node, loop_node) - loop_body_switch = ast.factory.create_switch_node(Variable("a")) - loop_body_case_1 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) - code_node_case_1 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) - loop_body_case_2 = ast.factory.create_case_node(Variable("a"), Constant(1), break_case=True) - code_node_case_2 = ast._add_code_node( - [ - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_nodes_from((code_node, loop_node, loop_body_switch, loop_body_case_1, loop_body_case_2)) - ast._add_edges_from( - ( - (root, code_node), - (root, loop_node), - (loop_node, loop_body_switch), - (loop_body_switch, loop_body_case_1), - (loop_body_switch, loop_body_case_2), - (loop_body_case_1, code_node_case_1), - (loop_body_case_2, code_node_case_2), - ) - ) - ast._code_node_reachability_graph.add_reachability_from(((code_node, code_node_case_1), (code_node, code_node_case_2))) - return ast - - -@pytest.fixture -def ast_switch_as_loop_body_increment() -> AbstractSyntaxTree: - """ - This loop should be replaced with a for-loop because b has no usages after last definition, is in condition and is initialized - before loop without any usages in between. - - Code of AST: - a = 5; - b = 0; - while(b <= 5){ - switch(a) { - case 0: - a = a + b: - break; - case 1: - b = b + 1; - break; - } - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_seq = ast.factory.create_seq_node() - - switch_node = ast.factory.create_switch_node(Variable("a")) - case_1 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) - case_1_code = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) - case_2 = ast.factory.create_case_node(Variable("a"), Constant(0), break_case=True) - case_2_code = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) - - increment_code = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) - - ast._add_nodes_from((while_loop, while_loop_seq, switch_node, case_1, case_2)) - ast._add_edges_from( - [ - (root, init_code_node), - (root, while_loop), - (while_loop, while_loop_seq), - (while_loop_seq, switch_node), - (while_loop_seq, increment_code), - (switch_node, case_1), - (switch_node, case_2), - (case_1, case_1_code), - (case_2, case_2_code), - ] - ) - return ast - - -@pytest.fixture -def ast_init_in_switch() -> AbstractSyntaxTree: - """ - a = 5; - b = 0; - switch(a){ - case 0: - a = b; - } - while(b <= (5 + a)){ - a = a + b; - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition( - OperationType.less_or_equal, - [Variable("b"), BinaryOperation(OperationType.plus, [Constant(5), Variable("a")])], - ) - }, - ) - init_code_node = ast._add_code_node(instructions=[Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))]) - switch_node = ast.factory.create_switch_node(Variable("a")) - loop_node = ast.factory.create_while_loop_node(condition=logic_cond("x1", context)) - case_node = ast.factory.create_case_node(Variable("a"), Constant(0)) - case_child = ast._add_code_node([Assignment(Variable("a"), Variable("b"))]) - loop_body = ast.factory.create_seq_node() - loop_body_child = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - ast._add_nodes_from((switch_node, loop_node, loop_body, case_node)) - ast._add_edges_from( - ( - (root, init_code_node), - (root, switch_node), - (switch_node, case_node), - (case_node, case_child), - (root, loop_node), - (loop_node, loop_body), - (loop_body, loop_body_child), - ) - ) - ast._code_node_reachability_graph.add_reachability_from([(case_child, loop_body_child)]) - root._sorted_children = (init_code_node, switch_node, loop_node) - loop_body._sorted_children = (loop_body_child,) - switch_node._sorted_cases = (case_node,) - return ast - - -@pytest.fixture -def ast_while_in_else() -> AbstractSyntaxTree: - """ - while (true) { - if (b < 2) { - break; - } else { - a = 0; - while (a < 5) { - printf("%d\n", a); - a = a + 1; - } - } - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)]), - logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(2)]), - }, - ) - - inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) - ast._add_node(inner_while) - - true_branch_child = ast._add_code_node([Break()]) - inner_seq = ast.factory.create_seq_node() - ast._add_node(inner_seq) - condition_node = ast._add_condition_node_with(logic_cond("x2", context), true_branch_child, inner_seq) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - - endless_loop = ast.add_endless_loop_with_body(condition_node) - - inner_while_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_edges_from( - [ - (root, endless_loop), - (endless_loop, condition_node), - (inner_seq, init_code_node), - (inner_seq, inner_while), - (inner_while, inner_while_body), - ] - ) - return ast - - -@pytest.fixture -def ast_continuation_is_not_first_var() -> AbstractSyntaxTree: - """ - a = 0; - b = 0; - while (a < b) { - printf("%d\n", a); - b = b + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Variable("b")])}, - ) - - init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0)), Assignment(Variable("b"), Constant(0))]) - - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, init_code_node), (root, while_loop), (while_loop, while_loop_body)]) - root._sorted_children = (init_code_node, while_loop) - return ast - - -@pytest.fixture -def ast_initialization_in_condition() -> AbstractSyntaxTree: - """ - if(b < 10 ){ - a = 5; - while (x < 10) { - printf("counter: %d", a); - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x0", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), - logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), - }, - ) - - true_branch = ast._add_code_node([Assignment(Variable("a"), Constant(5))]) - if_condition = ast._add_condition_node_with(logic_cond("x0", context), true_branch) - while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from([(root, if_condition), (root, while_loop), (while_loop, while_loop_body)]) - root._sorted_children = (if_condition, while_loop) - return ast - - -@pytest.fixture -def ast_initialization_in_condition_sequence() -> AbstractSyntaxTree: - """ - if(b < 10 ){ - if(c < 10){ - b = 5; - } - a = 5; - while (x < 10) { - printf("counter: %d", a); - a = a + 1; - } - """ - true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) - ast = AbstractSyntaxTree( - root := SeqNode(true_value), - condition_map={ - logic_cond("x0", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), - logic_cond("x1", context): Condition(OperationType.less, [Variable("c"), Constant(10)]), - logic_cond("x2", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), - }, - ) - - true_branch_c = ast._add_code_node([Assignment(Variable("b"), Constant(5))]) - code_node = ast._add_code_node([Assignment(Variable("a"), Constant(5))]) - if_condition_c = ast._add_condition_node_with(logic_cond("x1", context), true_branch_c) - ast._add_node(true_branch_b := ast.factory.create_seq_node()) - if_condition_b = ast._add_condition_node_with(logic_cond("x1", context), true_branch_b) - while_loop = ast.factory.create_while_loop_node(logic_cond("x2", context)) - while_loop_body = ast._add_code_node( - [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), - ] - ) - - ast._add_node(while_loop) - ast._add_edges_from( - [ - (root, if_condition_b), - (root, while_loop), - (while_loop, while_loop_body), - (true_branch_b, if_condition_c), - (true_branch_b, code_node), - ] - ) - true_branch_b._sorted_children = (if_condition_c, code_node) - root._sorted_children = (if_condition_b, while_loop) - return ast - - @pytest.fixture def ast_innerWhile_simple_condition_complexity() -> AbstractSyntaxTree: """ @@ -981,14 +83,19 @@ def ast_innerWhile_simple_condition_complexity() -> AbstractSyntaxTree: outer_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) outer_while_body = ast.factory.create_seq_node() - outer_while_init = ast._add_code_node([Assignment(Variable("b"), Constant(0)), Assignment(Variable("c"), Constant(0)) - , Assignment(Variable("d"), Constant(0))]) + outer_while_init = ast._add_code_node( + [Assignment(Variable("b"), Constant(0)), Assignment(Variable("c"), Constant(0)), Assignment(Variable("d"), Constant(0))] + ) outer_while_exit = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) inner_while = ast.factory.create_while_loop_node(logic_cond("x2", context) & logic_cond("x3", context) & logic_cond("x4", context)) - inner_while_body = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), - Assignment(Variable("c"), BinaryOperation(OperationType.plus, [Variable("c"), Constant(1)])), - Assignment(Variable("d"), BinaryOperation(OperationType.plus, [Variable("d"), Constant(1)]))]) + inner_while_body = ast._add_code_node( + [ + Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])), + Assignment(Variable("c"), BinaryOperation(OperationType.plus, [Variable("c"), Constant(1)])), + Assignment(Variable("d"), BinaryOperation(OperationType.plus, [Variable("d"), Constant(1)])), + ] + ) ast._add_nodes_from((outer_while, outer_while_body, inner_while)) ast._add_edges_from( @@ -1005,7 +112,7 @@ def ast_innerWhile_simple_condition_complexity() -> AbstractSyntaxTree: return ast -def generate_ast_with_modification_complexity(complexity : int) -> AbstractSyntaxTree: +def generate_ast_with_modification_complexity(complexity: int) -> AbstractSyntaxTree: """ a = 0; while (a < 10) { @@ -1027,7 +134,7 @@ def generate_ast_with_modification_complexity(complexity : int) -> AbstractSynta return ast -def generate_ast_with_condition_type(op : OperationType) -> AbstractSyntaxTree: +def generate_ast_with_condition_type(op: OperationType) -> AbstractSyntaxTree: """ a = 0; while (a 10) { @@ -1067,9 +174,17 @@ def ast_guarded_do_while_if() -> AbstractSyntaxTree: ast._add_node(cond_node) ast._add_node(true_branch) ast._add_node(do_while_loop) - ast._add_edges_from([(root, init_code_node), (root, cond_node), (cond_node, true_branch), (true_branch, do_while_loop), (do_while_loop, do_while_loop_body)]) + ast._add_edges_from( + [ + (root, init_code_node), + (root, cond_node), + (cond_node, true_branch), + (true_branch, do_while_loop), + (do_while_loop, do_while_loop_body), + ] + ) return ast - + @pytest.fixture def ast_guarded_do_while_else() -> AbstractSyntaxTree: @@ -1094,419 +209,76 @@ def ast_guarded_do_while_else() -> AbstractSyntaxTree: ast._add_node(cond_node) ast._add_node(false_branch) ast._add_node(do_while_loop) - ast._add_edges_from([(root, init_code_node), (root, cond_node), (cond_node, false_branch), (false_branch, do_while_loop), (do_while_loop, do_while_loop_body)]) - return ast - - -class TestReadabilityBasedRefinement: - """Test Readability functionality with all its substages.""" - - @staticmethod - def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): - ReadabilityBasedRefinement().run(DecompilerTask("func", cfg=None, ast=ast, options=options)) - - def test_no_replacement(self, ast_while_true): - self.run_rbr(ast_while_true) - assert all(not isinstance(node, ForLoopNode) for node in ast_while_true.topological_order()) - - def test_simple_replacement(self, ast_replaceable_while): - self.run_rbr(ast_replaceable_while) - - assert ast_replaceable_while.condition_map == { - logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("i"), Constant(10)]) - } - - loop_node = ast_replaceable_while.root - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_body = loop_node.body - assert isinstance(loop_body, CodeNode) - assert loop_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("i")])), - ] - - def test_with_usage(self, ast_replaceable_while_usage): - self.run_rbr(ast_replaceable_while_usage) - - for_loop = ast_replaceable_while_usage.root.children[0] - assert isinstance(for_loop, ForLoopNode) - assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) - - copy_instr_node = ast_replaceable_while_usage.root.children[1] - assert isinstance(copy_instr_node, CodeNode) - assert copy_instr_node.instructions == [Assignment(Variable("a"), Variable("i"))] - - def test_with_usage_redefinition(self, ast_replaceable_while_reinit_usage): - self.run_rbr(ast_replaceable_while_reinit_usage) - - for_loop = ast_replaceable_while_reinit_usage.root.children[0] - assert isinstance(for_loop, ForLoopNode) - assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) - assert for_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - exit_code_node = ast_replaceable_while_reinit_usage.root.children[1] - assert isinstance(exit_code_node, CodeNode) - assert exit_code_node.instructions == [ - Assignment(Variable("a"), Constant(50)), - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("final counter: %d"), Variable("a")])), - ] - - def test_with_usage_redefenition_2(self, ast_replaceable_while_compound_usage): - self.run_rbr(ast_replaceable_while_compound_usage) - - for_loop = ast_replaceable_while_compound_usage.root.children[0] - assert isinstance(for_loop, ForLoopNode) - assert for_loop.declaration == Assignment(Variable("i"), Constant(0)) - assert for_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - copy_instr_node = ast_replaceable_while_compound_usage.root.children[1] - assert isinstance(copy_instr_node, CodeNode) - assert copy_instr_node.instructions == [Assignment(Variable("a"), Variable("i"))] - - def test_continuation_is_not_first_var(self, ast_continuation_is_not_first_var): - self.run_rbr(ast_continuation_is_not_first_var) - - init_code_node = ast_continuation_is_not_first_var.root.children[0] - assert isinstance(init_code_node, CodeNode) - assert init_code_node.instructions == [Assignment(Variable("a"), Constant(0))] - - loop_node = ast_continuation_is_not_first_var.root.children[1] - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_node_body = loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])) + ast._add_edges_from( + [ + (root, init_code_node), + (root, cond_node), + (cond_node, false_branch), + (false_branch, do_while_loop), + (do_while_loop, do_while_loop_body), ] + ) + return ast - def test_init_with_call(self, ast_call_init): - self.run_rbr(ast_call_init, _generate_options(rename_for=True)) - - code_node = ast_call_init.root.children[0] - assert isinstance(code_node, CodeNode) - assert code_node.instructions == [Assignment(Variable("a"), Constant(5))] - - for_loop_node = ast_call_init.root.children[1] - assert isinstance(for_loop_node, ForLoopNode) - assert for_loop_node.declaration == Assignment(Variable("i"), Call(ImportedFunctionSymbol("foo", 0), [])) - assert for_loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_node_body = for_loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("i")])) - ] - assert for_loop_node.condition == logic_cond("x1", context := LogicCondition.generate_new_context()) - assert ast_call_init.condition_map == { - logic_cond("x1", context): Condition(OperationType.less_or_equal, [Variable("i"), Constant(5)]) +@pytest.fixture +def ast_while_in_else() -> AbstractSyntaxTree: + """ + while (true) { + if (b < 2) { + break; + } else { + a = 0; + while (a < 5) { + printf("%d\n", a); + a = a + 1; + } } + } + """ + true_value = LogicCondition.initialize_true(context := LogicCondition.generate_new_context()) + ast = AbstractSyntaxTree( + root := SeqNode(true_value), + condition_map={ + logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(2)]), + logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(2)]), + }, + ) - def test_double_init(self, ast_redundant_init): - self.run_rbr(ast_redundant_init) - - code_node = ast_redundant_init.root.children[0] - assert isinstance(code_node, CodeNode) - assert code_node.instructions == [ - Assignment(Variable("b"), Constant(0)), - Assignment(Variable("a"), Constant(5)), - Assignment(Variable("b"), Constant(2)), - ] - - for_loop_node = ast_redundant_init.root.children[1] - assert isinstance(for_loop_node, ForLoopNode) - assert for_loop_node.declaration == Variable("b") - assert for_loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) - - loop_node_body = for_loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])), - ] - - assert for_loop_node.condition == logic_cond("x1", context := LogicCondition.generate_new_context()) - assert ast_redundant_init.condition_map == {logic_cond("x1", context): Condition(OperationType.less, [Variable("b"), Constant(5)])} - - def test_double_init_condition_node(self, ast_reinit_in_condition_true): - self.run_rbr(ast_reinit_in_condition_true) - - def test_init_in_switch(self, ast_init_in_switch): - self.run_rbr(ast_init_in_switch) - - init_code_node = ast_init_in_switch.root.children[0] - assert isinstance(init_code_node, CodeNode) - assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5)), Assignment(Variable("b"), Constant(0))] - - loop_node = ast_init_in_switch.root.children[2] - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Variable("b") - assert loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) - - loop_node_body = loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")])) - ] - - def test_usage_in_condition(self, ast_usage_in_condition): - self.run_rbr(ast_usage_in_condition) - - code_node = ast_usage_in_condition.root.children[0] - assert isinstance(code_node, CodeNode) - assert code_node.instructions == [Assignment(Variable("a"), Constant(1)), Assignment(Variable("b"), Constant(0))] - - condition_node = ast_usage_in_condition.root.children[1] - assert isinstance(condition_node, ConditionNode) - assert condition_node.condition == logic_cond("x2", context := LogicCondition.generate_new_context()) - - loop_node = ast_usage_in_condition.root.children[2] - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Variable("b") - assert loop_node.condition == logic_cond("x1", context) - assert loop_node.modification == Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)])) - - loop_body = loop_node.body - assert isinstance(loop_body, CodeNode) - assert loop_body.instructions == [Assignment(Variable("a"), BinaryOperation(OperationType.multiply, [Variable("a"), Constant(2)]))] - - def test_while_in_else(self, ast_while_in_else): - self.run_rbr(ast_while_in_else) - - endless_loop = ast_while_in_else.root - assert isinstance(endless_loop, WhileLoopNode) - - condition_node = endless_loop.body - assert isinstance(condition_node, ConditionNode) - - loop_node = condition_node.false_branch_child - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_node_body = loop_node.body - assert isinstance(loop_node_body, CodeNode) - assert loop_node_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("i")])) - ] - - def test_nested_while(self, ast_nested_while): - self.run_rbr(ast_nested_while, _generate_options(empty_loops=True)) - - outer_loop = ast_nested_while.root - assert isinstance(outer_loop, ForLoopNode) - assert outer_loop.declaration == Assignment(Variable("i"), Constant(0)) - assert ast_nested_while.condition_map[outer_loop.condition] == Condition(OperationType.less, [Variable("i"), Constant(5)]) - assert outer_loop.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - inner_loop = outer_loop.children[0] - assert isinstance(inner_loop, ForLoopNode) - assert inner_loop.declaration == Assignment(Variable("j"), Constant(0)) - assert ast_nested_while.condition_map[inner_loop.condition] == Condition(OperationType.less, [Variable("j"), Constant(5)]) - assert inner_loop.modification == Assignment(Variable("j"), BinaryOperation(OperationType.plus, [Variable("j"), Constant(1)])) - - def test_nested_while_loop(self, ast_endless_while_init_outside): - self.run_rbr(ast_endless_while_init_outside) - - endless_loop = ast_endless_while_init_outside.root.children[1] - assert isinstance(endless_loop, WhileLoopNode) - - for_loop = endless_loop.body - assert isinstance(for_loop, ForLoopNode) - assert for_loop.declaration == Variable("a") - - def test_sequenced_loops(self, ast_sequenced_while_loops): - self.run_rbr(ast_sequenced_while_loops) - - loop_1 = ast_sequenced_while_loops.root.children[0] - assert isinstance(loop_1, ForLoopNode) - assert loop_1.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_1.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - loop_1_body = loop_1.body - assert isinstance(loop_1_body, CodeNode) - assert loop_1_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("i")])), - ] - - loop_2 = ast_sequenced_while_loops.root.children[1] - assert isinstance(loop_2, ForLoopNode) - assert loop_2.declaration == Assignment(Variable("j"), Constant(0)) - assert loop_2.modification == Assignment(Variable("j"), BinaryOperation(OperationType.plus, [Variable("j"), Constant(1)])) - - loop_2_body = loop_2.body - assert isinstance(loop_2_body, CodeNode) - assert loop_2_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("j")])), - ] - - def test_switch_as_loop_body(self, ast_switch_as_loop_body): - self.run_rbr(ast_switch_as_loop_body) - - assert all(not isinstance(node, ForLoopNode) for node in ast_switch_as_loop_body.topological_order()) - - init_code_node = ast_switch_as_loop_body.root.children[0] - assert isinstance(init_code_node, CodeNode) - assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5)), Assignment(Variable("counter"), Constant(0))] + inner_while = ast.factory.create_while_loop_node(logic_cond("x1", context)) + ast._add_node(inner_while) - while_node = ast_switch_as_loop_body.root.children[1] - assert isinstance(while_node, WhileLoopNode) + true_branch_child = ast._add_code_node([Break()]) + inner_seq = ast.factory.create_seq_node() + ast._add_node(inner_seq) + condition_node = ast._add_condition_node_with(logic_cond("x2", context), true_branch_child, inner_seq) - switch_node = while_node.body - assert isinstance(switch_node, SwitchNode) + init_code_node = ast._add_code_node([Assignment(Variable("a"), Constant(0))]) - case_1_body = switch_node.children[0].child - assert isinstance(case_1_body, CodeNode) - assert case_1_body.instructions == [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("counter")])) - ] + endless_loop = ast.add_endless_loop_with_body(condition_node) - case_2_body = switch_node.children[1].child - assert isinstance(case_2_body, CodeNode) - assert case_2_body.instructions == [ - Assignment(Variable("counter"), BinaryOperation(OperationType.plus, [Variable("counter"), Constant(1)])) + inner_while_body = ast._add_code_node( + [ + Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("%d\n"), Variable("a")])), + Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])), ] + ) - def test_switch_as_loop_with_increment(self, ast_switch_as_loop_body_increment): - self.run_rbr(ast_switch_as_loop_body_increment) - - init_code_node = ast_switch_as_loop_body_increment.root.children[0] - assert isinstance(init_code_node, CodeNode) - assert init_code_node.instructions == [Assignment(Variable("a"), Constant(5))] - - loop_node = ast_switch_as_loop_body_increment.root.children[1] - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("i"), Constant(0)) - assert loop_node.modification == Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)])) - - switch_node = loop_node.body - assert isinstance(switch_node, SwitchNode) - - case_1 = switch_node.children[0] - assert isinstance(case_1, CaseNode) - - case_1_body = case_1.child - assert isinstance(case_1_body, CodeNode) - assert case_1_body.instructions == [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("i")]))] - - case_2 = switch_node.children[1] - assert isinstance(case_2, CaseNode) - - case_2_body = case_2.child - assert isinstance(case_2_body, CodeNode) - assert case_2_body.instructions == [Assignment(Variable("i"), BinaryOperation(OperationType.plus, [Variable("i"), Constant(1)]))] - - assert ast_switch_as_loop_body_increment.condition_map == { - logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("i"), Constant(5)]) - } - - def test_rename_unary_operation_updates_array_info(self, ast_array_access_for_loop): - """Test if UnaryOperation.ArrayInfo gets updated on renaming""" - self.run_rbr(ast_array_access_for_loop, _generate_options(rename_for=True)) - - def find_unary_op(ast): - """look for UnaryOperation in AST""" - for node in ast.get_code_nodes_topological_order(): - for instr in node.instructions: - for unary_op in instr: - if isinstance(unary_op, UnaryOperation): - return unary_op - return None - - unary_operation = find_unary_op(ast_array_access_for_loop) - if not isinstance(unary_operation, UnaryOperation): - assert False, "Did not find UnaryOperation in AST (expect it for array access)" - assert unary_operation.array_info is not None - assert unary_operation.array_info.base in unary_operation.requirements - assert unary_operation.array_info.index in unary_operation.requirements - - def test_no_for_loop_renaming(self, ast_replaceable_while): - self.run_rbr(ast_replaceable_while, _generate_options(rename_for=False)) - - assert ast_replaceable_while.condition_map == { - logic_cond("x1", LogicCondition.generate_new_context()): Condition(OperationType.less, [Variable("a"), Constant(10)]) - } - - loop_node = ast_replaceable_while.root - assert isinstance(loop_node, ForLoopNode) - assert loop_node.declaration == Assignment(Variable("a"), Constant(0)) - assert loop_node.modification == Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)])) - - loop_body = loop_node.body - assert isinstance(loop_body, CodeNode) - assert loop_body.instructions == [ - Assignment(ListOperation([]), Call(ImportedFunctionSymbol("printf", 0), [Constant("counter: %d\n"), Variable("a")])), + ast._add_edges_from( + [ + (root, endless_loop), + (endless_loop, condition_node), + (inner_seq, init_code_node), + (inner_seq, inner_while), + (inner_while, inner_while_body), ] - - def test_init_may_not_reach_loop_1(self, ast_initialization_in_condition): - assert ( - _initialization_reaches_loop_node( - ast_initialization_in_condition.root.children[0].true_branch_child, ast_initialization_in_condition.root.children[1] - ) - is False - ) - - self.run_rbr(ast_initialization_in_condition, _generate_options()) - assert any( - isinstance(for_loop := loop, ForLoopNode) for loop in ast_initialization_in_condition.get_for_loop_nodes_topological_order() - ) - assert for_loop.declaration == Variable("a") - - def test_init_may_not_reach_loop_2(self, ast_initialization_in_condition_sequence): - assert ( - _initialization_reaches_loop_node( - ast_initialization_in_condition_sequence.root.children[0].true_branch_child.children[1], - ast_initialization_in_condition_sequence.root.children[1], - ) - is False - ) - - self.run_rbr(ast_initialization_in_condition_sequence, _generate_options()) - assert any( - isinstance(for_loop := loop, ForLoopNode) - for loop in ast_initialization_in_condition_sequence.get_for_loop_nodes_topological_order() - ) - assert for_loop.declaration == Variable("a") - - def test_guarded_do_while_if(self, ast_guarded_do_while_if): - self.run_rbr(ast_guarded_do_while_if, _generate_options()) - - for cond_node in ast_guarded_do_while_if.get_condition_nodes_post_order(): - assert False, "There should be no condition node" - - for loop_node in ast_guarded_do_while_if.get_loop_nodes_post_order(): - assert isinstance(loop_node, WhileLoopNode) - - def test_guarded_do_while_else(self, ast_guarded_do_while_else): - self.run_rbr(ast_guarded_do_while_else, _generate_options()) - - for cond_node in ast_guarded_do_while_else.get_condition_nodes_post_order(): - assert False, "There should be no condition node" - - for loop_node in ast_guarded_do_while_else.get_loop_nodes_post_order(): - assert isinstance(loop_node, WhileLoopNode) - - @pytest.mark.parametrize("keep_empty_for_loops", [True, False]) - def test_keep_empty_for_loop(self, keep_empty_for_loops: bool, ast_single_instruction_while): - self.run_rbr(ast_single_instruction_while, _generate_options(keep_empty_for_loops)) - - if keep_empty_for_loops: - assert isinstance(ast_single_instruction_while.root, ForLoopNode) - else: - assert isinstance(ast_single_instruction_while.root.children[1], WhileLoopNode) - - def test_rhs_of_for_loop_declaration_not_renamed(self, ast_self_referential_init: AbstractSyntaxTree): - self.run_rbr(ast_self_referential_init) - for_loops = list(ast_self_referential_init.get_for_loop_nodes_topological_order()) - assert len(for_loops) == 1 - assert for_loops[0].declaration == Assignment(Variable("i"), Call(ImportedFunctionSymbol("foo", 0), [Variable("b")])) + ) + return ast class TestForLoopRecovery: - """ Test options for for-loop recovery """ + """Test options for for-loop recovery""" + @staticmethod def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): ReadabilityBasedRefinement().run(DecompilerTask("func", cfg=None, ast=ast, options=options)) @@ -1520,7 +292,6 @@ def test_max_condition_complexity(self, ast_innerWhile_simple_condition_complexi else: assert isinstance(loop_node, WhileLoopNode) - @pytest.mark.parametrize("modification_nesting", [1, 2]) def test_max_modification_complexity(self, modification_nesting): ast = generate_ast_with_modification_complexity(modification_nesting) @@ -1534,11 +305,10 @@ def test_max_modification_complexity(self, modification_nesting): assert isinstance(loop_node, WhileLoopNode) for condition_variable in loop_node.get_required_variables(ast.condition_map): instruction = _find_continuation_instruction(ast, loop_node, condition_variable) - assert instruction is not None + assert instruction is not None assert instruction.instruction.complexity > max_modi_complexity - - @pytest.mark.parametrize("operation", [OperationType.equal, OperationType.not_equal ,OperationType.less_or_equal, OperationType.less]) + @pytest.mark.parametrize("operation", [OperationType.equal, OperationType.not_equal, OperationType.less_or_equal, OperationType.less]) def test_for_loop_recovery_blacklist(self, operation): ast = generate_ast_with_condition_type(operation) forbidden_conditon_types = ["not_equal", "equal"] @@ -1550,6 +320,39 @@ def test_for_loop_recovery_blacklist(self, operation): else: assert isinstance(loop_node, ForLoopNode) + @pytest.mark.parametrize("restructure", [True, False]) + def test_restructure_for_loop_option(self, restructure, ast_while_in_else): + self.run_rbr(ast_while_in_else, _generate_options(restructure=restructure)) + for_loop = list(ast_while_in_else.get_for_loop_nodes_topological_order()) + if restructure: + assert len(for_loop) == 1 + else: + assert len(for_loop) == 0 + + +class TestGuardedDoWhile: + @staticmethod + def run_rbr(ast: AbstractSyntaxTree, options: Options = _generate_options()): + ReadabilityBasedRefinement().run(DecompilerTask("func", cfg=None, ast=ast, options=options)) + + def test_guarded_do_while_if(self, ast_guarded_do_while_if): + self.run_rbr(ast_guarded_do_while_if, _generate_options()) + + for _ in ast_guarded_do_while_if.get_condition_nodes_post_order(): + assert False, "There should be no condition node" + + for loop_node in ast_guarded_do_while_if.get_loop_nodes_post_order(): + assert isinstance(loop_node, WhileLoopNode) + + def test_guarded_do_while_else(self, ast_guarded_do_while_else): + self.run_rbr(ast_guarded_do_while_else, _generate_options()) + + for _ in ast_guarded_do_while_else.get_condition_nodes_post_order(): + assert False, "There should be no condition node" + + for loop_node in ast_guarded_do_while_else.get_loop_nodes_post_order(): + assert isinstance(loop_node, WhileLoopNode) + class TestReadabilityUtils: def test_find_continuation_instruction_1(self): @@ -2045,41 +848,6 @@ def test_separated_by_loop_node_4(self, ast_while_in_else): assert _initialization_reaches_loop_node(init_code_node, inner_while) is False - def test_for_loop_variable_generation(self): - renamer = ForLoopVariableRenamer( - AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}), - ["i", "j", "k", "l", "m", "n"] - ) - assert [renamer._get_variable_name() for _ in range(14)] == [ - "i", - "j", - "k", - "l", - "m", - "n", - "i1", - "j1", - "k1", - "l1", - "m1", - "n1", - "i2", - "j2", - ] - - def test_while_loop_variable_generation(self): - renamer = WhileLoopVariableRenamer( - AbstractSyntaxTree(SeqNode(LogicCondition.initialize_true(LogicCondition.generate_new_context())), {}) - ) - assert [renamer._get_variable_name() for _ in range(5)] == ["counter", "counter1", "counter2", "counter3", "counter4"] - - def test_declaration_listop(self, ast_call_for_loop): - """Test renaming with ListOperation as Declaration""" - ForLoopVariableRenamer(ast_call_for_loop, ["i"]).rename() - for node in ast_call_for_loop: - if isinstance(node, ForLoopNode): - assert node.declaration.destination.operands[0].name == "i" - def test_skip_for_loop_recovery_if_continue_in_while(self): """ a = 0 @@ -2096,15 +864,12 @@ def test_skip_for_loop_recovery_if_continue_in_while(self): root := SeqNode(true_value), condition_map={ logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(10)]), - logic_cond("x2", context): Condition(OperationType.equal, [Variable("a"), Constant(2)]) - } + logic_cond("x2", context): Condition(OperationType.equal, [Variable("a"), Constant(2)]), + }, ) true_branch = ast._add_code_node( - [ - Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(2)])), - Continue() - ] + [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(2)])), Continue()] ) if_condition = ast._add_condition_node_with(logic_cond("x2", context), true_branch) @@ -2112,7 +877,9 @@ def test_skip_for_loop_recovery_if_continue_in_while(self): while_loop = ast.factory.create_while_loop_node(logic_cond("x1", context)) while_loop_body = ast.factory.create_seq_node() - while_loop_iteration = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + while_loop_iteration = ast._add_code_node( + [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))] + ) ast._add_node(while_loop) ast._add_node(while_loop_body) @@ -2122,7 +889,7 @@ def test_skip_for_loop_recovery_if_continue_in_while(self): (root, while_loop), (while_loop, while_loop_body), (while_loop_body, if_condition), - (while_loop_body, while_loop_iteration) + (while_loop_body, while_loop_iteration), ] ) @@ -2149,28 +916,31 @@ def test_skip_for_loop_recovery_if_continue_in_nested_while(self): condition_map={ logic_cond("x1", context): Condition(OperationType.less, [Variable("a"), Constant(5)]), logic_cond("x2", context): Condition(OperationType.less, [Variable("b"), Constant(10)]), - logic_cond("x3", context): Condition(OperationType.less, [Variable("b"), Constant(0)]) - } + logic_cond("x3", context): Condition(OperationType.less, [Variable("b"), Constant(0)]), + }, ) true_branch = ast._add_code_node( - [ - Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(2)])), - Continue() - ] + [Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(2)])), Continue()] ) if_condition = ast._add_condition_node_with(logic_cond("x3", context), true_branch) while_loop_outer = ast.factory.create_while_loop_node(logic_cond("x1", context)) while_loop_body_outer = ast.factory.create_seq_node() - while_loop_iteration_outer_1 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))]) - while_loop_iteration_outer_2 = ast._add_code_node([Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))]) + while_loop_iteration_outer_1 = ast._add_code_node( + [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Variable("b")]))] + ) + while_loop_iteration_outer_2 = ast._add_code_node( + [Assignment(Variable("a"), BinaryOperation(OperationType.plus, [Variable("a"), Constant(1)]))] + ) ast._add_node(while_loop_outer) ast._add_node(while_loop_body_outer) while_loop_inner = ast.factory.create_while_loop_node(logic_cond("x2", context)) while_loop_body_inner = ast.factory.create_seq_node() - while_loop_iteration_inner = ast._add_code_node([Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))]) + while_loop_iteration_inner = ast._add_code_node( + [Assignment(Variable("b"), BinaryOperation(OperationType.plus, [Variable("b"), Constant(1)]))] + ) ast._add_node(while_loop_inner) ast._add_node(while_loop_body_inner) @@ -2183,10 +953,10 @@ def test_skip_for_loop_recovery_if_continue_in_nested_while(self): (while_loop_body_outer, while_loop_iteration_outer_2), (while_loop_inner, while_loop_body_inner), (while_loop_body_inner, if_condition), - (while_loop_body_inner, while_loop_iteration_inner) + (while_loop_body_inner, while_loop_iteration_inner), ] ) WhileLoopReplacer(ast, _generate_options()).run() loop_nodes = list(ast.get_loop_nodes_post_order()) - assert not isinstance(loop_nodes[0], ForLoopNode) and isinstance(loop_nodes[1], ForLoopNode) \ No newline at end of file + assert not isinstance(loop_nodes[0], ForLoopNode) and isinstance(loop_nodes[1], ForLoopNode) diff --git a/tests/structures/pseudo/test_complextypes.py b/tests/structures/pseudo/test_complextypes.py index 3bad97d60..c5a7d5c13 100644 --- a/tests/structures/pseudo/test_complextypes.py +++ b/tests/structures/pseudo/test_complextypes.py @@ -190,7 +190,7 @@ def blue(): class TestComplexTypeMap: def test_declarations(self, complex_types: ComplexTypeMap, book: Struct, color: Enum, record_id: Union): assert complex_types.declarations() == f"{book.declaration()};\n{color.declaration()};\n{record_id.declaration()};" - complex_types.add(book) + complex_types.add(book, 0) assert complex_types.declarations() == f"{book.declaration()};\n{color.declaration()};\n{record_id.declaration()};" def test_retrieve_by_name(self, complex_types: ComplexTypeMap, book: Struct, color: Enum, record_id: Union): @@ -201,7 +201,7 @@ def test_retrieve_by_name(self, complex_types: ComplexTypeMap, book: Struct, col @pytest.fixture def complex_types(self, book: Struct, color: Enum, record_id: Union): complex_types = ComplexTypeMap() - complex_types.add(book) - complex_types.add(color) - complex_types.add(record_id) + complex_types.add(book, 0) + complex_types.add(color, 1) + complex_types.add(record_id, 2) return complex_types