Skip to content

Commit

Permalink
Optimize common_subexpression_elimination
Browse files Browse the repository at this point in the history
  • Loading branch information
rihi committed Jan 24, 2024
1 parent 05b1bc3 commit a81a325
Showing 1 changed file with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module implementing common subexpression elimination."""
from __future__ import annotations

from collections import defaultdict, deque
from collections import Counter, defaultdict, deque
from dataclasses import dataclass
from itertools import chain
from logging import info, warning
Expand All @@ -17,7 +17,7 @@
from networkx import dfs_postorder_nodes


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class CfgInstruction:
"""
dataclass in charge of tracking the location of Instruction objects in the cfg
Expand Down Expand Up @@ -199,7 +199,7 @@ class DefinitionGenerator:

def __init__(
self,
expression_usages: DefaultDict[Expression, List[CfgInstruction]],
expression_usages: DefaultDict[Expression, Counter[CfgInstruction]],
dominator_tree: NetworkXGraph,
):
"""Generate a new instance based on data parsed from a cfg."""
Expand All @@ -209,16 +209,16 @@ def __init__(
@classmethod
def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator:
"""Initialize a DefinitionGenerator utilizing the data of the given cfg."""
usages: DefaultDict[Expression, List[CfgInstruction]] = defaultdict(list)
usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter)
for basic_block in cfg:
for index, instruction in enumerate(basic_block.instructions):
instruction_with_position = CfgInstruction(instruction, basic_block, index)
for subexpression in _subexpression_dfs(instruction):
usages[subexpression].append(instruction_with_position)
usages[subexpression][instruction_with_position] += 1
return cls(usages, cfg.dominator_tree)

@property
def usages(self) -> DefaultDict[Expression, List[CfgInstruction]]:
def usages(self) -> DefaultDict[Expression, Counter[CfgInstruction]]:
"""Return a mapping from expressions to a set of instructions using them."""
return self._usages

Expand All @@ -240,7 +240,7 @@ def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]:
candidate: BasicBlock = next(iter(usage_blocks))
while not self._is_common_dominator(candidate, usage_blocks) or self._is_invalid_dominator(candidate, expression):
candidate = self._dominator_tree.get_predecessors(candidate)[0]
return candidate, self._find_insertion_index(candidate, set(self._usages[expression]))
return candidate, self._find_insertion_index(candidate, self._usages[expression].keys()) # not a set...

def _is_common_dominator(self, candidate: BasicBlock, basic_blocks: Set[BasicBlock]) -> bool:
"""Check if the given candidate is the common dominator all of given basic blocks."""
Expand All @@ -260,7 +260,7 @@ def _insert_definition(self, definition: CfgInstruction):
"""Insert a new intermediate definition for the given expression at the given location."""
definition.block.instructions.insert(definition.index, definition.instruction)
for subexpression in _subexpression_dfs(definition.instruction):
self._usages[subexpression].append(definition)
self._usages[subexpression][definition] += 1

@staticmethod
def _find_insertion_index(basic_block: BasicBlock, usages: Set[CfgInstruction]) -> int:
Expand Down Expand Up @@ -315,7 +315,7 @@ def eliminate_common_subexpressions(self, definition_generator: DefinitionGenera
except StopIteration:
warning(f"[{self.name}] No dominating basic block could be found for {replacee}")

def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgInstruction]]) -> Iterator[Expression]:
def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstruction]]) -> Iterator[Expression]:
"""
Iterate all expressions, yielding the expressions which should be eliminated.
Expand All @@ -324,11 +324,12 @@ def _find_elimination_candidates(self, usages: DefaultDict[Expression, List[CfgI
expressions_by_complexity = sorted(usages.keys(), reverse=True, key=lambda x: x.complexity)
for expression in expressions_by_complexity:
if self._is_cse_candidate(expression, usages):
expression_usage = usages[expression]
for subexpression in _subexpression_dfs(expression):
usages[subexpression] = [x for x in usages[subexpression] if x not in usages[expression]]
usages[subexpression].subtract(expression_usage)
yield expression

def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, List[CfgInstruction]]):
def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstruction]]):
"""Checks that we can add a common subexpression for the given expression."""
return (
self._is_elimination_candidate(expression, usages[expression])
Expand All @@ -346,15 +347,16 @@ def _is_complex_string(self, expression: Expression) -> bool:
return isinstance(expression.value, str) and len(expression.value) >= self._min_string_length
return False

def _check_inter_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool:
def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool:
"""Check if the given expressions should be eliminated based on its global occurrences."""
referencing_instructions_count = len(set(instructions))
referencing_instructions_count = sum(1 for _, count in instructions.items() if count > 0)
return (expression.complexity >= 2 and referencing_instructions_count >= self._threshold) or (
self._is_complex_string(expression) and referencing_instructions_count >= self._string_threshold
)

def _check_intra_instruction(self, expression: Expression, instructions: List[CfgInstruction]) -> bool:
def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool:
"""Check if this expression should be eliminated based on the amount of unique instructions utilizing it."""
return (expression.complexity >= 2 and len(instructions) >= self._threshold) or (
self._is_complex_string(expression) and len(instructions) >= self._string_threshold
referencing_count = instructions.total()
return (expression.complexity >= 2 and referencing_count >= self._threshold) or (
self._is_complex_string(expression) and referencing_count >= self._string_threshold
)

0 comments on commit a81a325

Please sign in to comment.