diff --git a/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py b/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py index 60f64f332..bbe2d3fe3 100644 --- a/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py +++ b/decompiler/pipeline/dataflowanalysis/common_subexpression_elimination.py @@ -1,7 +1,7 @@ """Module implementing common subexpression elimination.""" from __future__ import annotations -from collections import defaultdict, deque +from collections import Counter, defaultdict, deque from dataclasses import dataclass from itertools import chain from logging import info, warning @@ -17,7 +17,7 @@ from networkx import dfs_postorder_nodes -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class CfgInstruction: """ dataclass in charge of tracking the location of Instruction objects in the cfg @@ -199,7 +199,7 @@ class DefinitionGenerator: def __init__( self, - expression_usages: DefaultDict[Expression, List[CfgInstruction]], + expression_usages: DefaultDict[Expression, Counter[CfgInstruction]], dominator_tree: NetworkXGraph, ): """Generate a new instance based on data parsed from a cfg.""" @@ -209,16 +209,16 @@ def __init__( @classmethod def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator: """Initialize a DefinitionGenerator utilizing the data of the given cfg.""" - usages: DefaultDict[Expression, List[CfgInstruction]] = defaultdict(list) + usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter) for basic_block in cfg: for index, instruction in enumerate(basic_block.instructions): instruction_with_position = CfgInstruction(instruction, basic_block, index) for subexpression in _subexpression_dfs(instruction): - usages[subexpression].append(instruction_with_position) + usages[subexpression][instruction_with_position] += 1 return cls(usages, cfg.dominator_tree) @property - def usages(self) -> DefaultDict[Expression, List[CfgInstruction]]: + def usages(self) -> DefaultDict[Expression, Counter[CfgInstruction]]: """Return a mapping from expressions to a set of instructions using them.""" return self._usages @@ -240,7 +240,7 @@ def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]: candidate: BasicBlock = next(iter(usage_blocks)) while not self._is_common_dominator(candidate, usage_blocks) or self._is_invalid_dominator(candidate, expression): candidate = self._dominator_tree.get_predecessors(candidate)[0] - return candidate, self._find_insertion_index(candidate, set(self._usages[expression])) + return candidate, self._find_insertion_index(candidate, self._usages[expression].keys()) # not a set... def _is_common_dominator(self, candidate: BasicBlock, basic_blocks: Set[BasicBlock]) -> bool: """Check if the given candidate is the common dominator all of given basic blocks.""" @@ -260,7 +260,7 @@ def _insert_definition(self, definition: CfgInstruction): """Insert a new intermediate definition for the given expression at the given location.""" definition.block.instructions.insert(definition.index, definition.instruction) for subexpression in _subexpression_dfs(definition.instruction): - self._usages[subexpression].append(definition) + self._usages[subexpression][definition] += 1 @staticmethod def _find_insertion_index(basic_block: BasicBlock, usages: Set[CfgInstruction]) -> int: @@ -315,7 +315,7 @@ def eliminate_common_subexpressions(self, definition_generator: DefinitionGenera except StopIteration: warning(f"[{self.name}] No dominating basic block could be found for {replacee}") - def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgInstruction]]) -> Iterator[Expression]: + def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstruction]]) -> Iterator[Expression]: """ Iterate all expressions, yielding the expressions which should be eliminated. @@ -324,11 +324,12 @@ def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgI expressions_by_complexity = sorted(usages.keys(), reverse=True, key=lambda x: x.complexity) for expression in expressions_by_complexity: if self._is_cse_candidate(expression, usages): + expression_usage = usages[expression] for subexpression in _subexpression_dfs(expression): - usages[subexpression] = [x for x in usages[subexpression] if x not in usages[expression]] + usages[subexpression].subtract(expression_usage) yield expression - def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, List[CfgInstruction]]): + def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstruction]]): """Checks that we can add a common subexpression for the given expression.""" return ( self._is_elimination_candidate(expression, usages[expression]) @@ -346,15 +347,16 @@ def _is_complex_string(self, expression: Expression) -> bool: return isinstance(expression.value, str) and len(expression.value) >= self._min_string_length return False - def _check_inter_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool: + def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool: """Check if the given expressions should be eliminated based on its global occurrences.""" - referencing_instructions_count = len(set(instructions)) + referencing_instructions_count = sum(1 for _, count in instructions.items() if count > 0) return (expression.complexity >= 2 and referencing_instructions_count >= self._threshold) or ( self._is_complex_string(expression) and referencing_instructions_count >= self._string_threshold ) - def _check_intra_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool: + def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool: """Check if this expression should be eliminated based on the amount of unique instructions utilizing it.""" - return (expression.complexity >= 2 and len(instructions) >= self._threshold) or ( - self._is_complex_string(expression) and len(instructions) >= self._string_threshold + referencing_count = instructions.total() + return (expression.complexity >= 2 and referencing_count >= self._threshold) or ( + self._is_complex_string(expression) and referencing_count >= self._string_threshold )