diff --git a/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py b/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py index 91848bbbb..0184bf74f 100644 --- a/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py +++ b/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py @@ -36,7 +36,10 @@ class CfgInstruction: instruction: Instruction block: BasicBlock - index: int + + @property + def index(self): + return next(index for index, instruction in enumerate(self.block.instructions) if id(instruction) == id(self.instruction)) @dataclass() @@ -221,7 +224,7 @@ def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator: usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter) for basic_block in cfg: for index, instruction in enumerate(basic_block.instructions): - instruction_with_position = CfgInstruction(instruction, basic_block, index) + instruction_with_position = CfgInstruction(instruction, basic_block) for subexpression in _subexpression_dfs(instruction): usages[subexpression][instruction_with_position] += 1 return cls(usages, cfg.dominator_tree) @@ -236,7 +239,7 @@ def define(self, expression: Expression, variable: Variable): basic_block, index = self._find_location_for_insertion(expression) for usage in self._usages[expression]: usage.instruction.substitute(expression, variable.copy()) - self._insert_definition(CfgInstruction(Assignment(variable, expression), basic_block, index)) + self._insert_definition(Assignment(variable, expression), basic_block, index) def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]: """ @@ -265,18 +268,19 @@ def _is_invalid_dominator(self, basic_block: BasicBlock, expression: Expression) usages_in_the_same_block = [usage for usage in self._usages[expression] if usage.block == basic_block] return any([isinstance(usage.instruction, Phi) for usage in usages_in_the_same_block]) - def _insert_definition(self, definition: CfgInstruction): + def _insert_definition(self, instruction: Instruction, block: BasicBlock, index: int): """Insert a new intermediate definition for the given expression at the given location.""" - definition.block.instructions.insert(definition.index, definition.instruction) - for subexpression in _subexpression_dfs(definition.instruction): - self._usages[subexpression][definition] += 1 + block.instructions.insert(index, instruction) + cfg_instruction = CfgInstruction(instruction, block) + for subexpression in _subexpression_dfs(instruction): + self._usages[subexpression][cfg_instruction] += 1 @staticmethod def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstruction]) -> int: """Find the first index in the given basic block where a definition could be inserted.""" - usage = min((usage for usage in usages if usage.block == basic_block), default=None, key=lambda x: x.index) - if usage: - return basic_block.instructions.index(usage.instruction, usage.index) + first_usage_index = min((usage.index for usage in usages if usage.block == basic_block), default=None) + if first_usage_index is not None: + return first_usage_index if not basic_block.instructions: return 0 if isinstance(basic_block.instructions[-1], GenericBranch): diff --git a/tests/pipeline/dataflowanalysis/test_common_subexpression_elimination.py b/tests/pipeline/dataflowanalysis/test_common_subexpression_elimination.py index 009f70a8a..f5e3bc0bf 100644 --- a/tests/pipeline/dataflowanalysis/test_common_subexpression_elimination.py +++ b/tests/pipeline/dataflowanalysis/test_common_subexpression_elimination.py @@ -1283,3 +1283,32 @@ def test_common_subexpression_elimination_correct_place(): Assignment(Variable("d", ssa_label=4), BinaryOperation(OperationType.plus, [Variable("e", ssa_label=2), replacement0])), Assignment(Variable("f", ssa_label=1), BinaryOperation(OperationType.minus, [Variable("g", ssa_label=4), replacement1])), ] + + +def test_common_subexpression_elimination_correct_place2(): + """Check that the instruction is inserted at the correct position""" + expr0 = BinaryOperation(OperationType.multiply, [Variable("a", Integer.int32_t()), Constant(2, Integer.int32_t())]) + expr1 = BinaryOperation(OperationType.plus, [expr0.copy(), Constant(5, Integer.int32_t())]) + + cfg = ControlFlowGraph() + cfg.add_node( + node := BasicBlock( + 0, + instructions=[ + Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [expr0.copy(), expr1.copy()])), + Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [expr1.copy()])), + ], + ) + ) + _run_cse(cfg, _generate_options(threshold=2)) + + replacement0 = Variable("c0", Integer.int32_t(), ssa_label=0) + replacement1 = Variable("c1", Integer.int32_t(), ssa_label=0) + expr_new = BinaryOperation(OperationType.plus, [replacement1.copy(), Constant(5, Integer.int32_t())]) + + assert node.instructions == [ + Assignment(replacement1.copy(), expr0.copy()), + Assignment(replacement0.copy(), expr_new.copy()), + Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [replacement1.copy(), replacement0.copy()])), + Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [replacement0.copy()])), + ]