Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ordering of inserted common subexpressions #403

Merged
merged 8 commits into from
Apr 4, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import dataclasses
from collections import Counter, defaultdict, deque
from dataclasses import dataclass
from itertools import chain
Expand All @@ -18,26 +19,21 @@
from networkx import dfs_postorder_nodes


@dataclass(frozen=True, eq=False)
class CfgInstruction:
@dataclass(frozen=True)
class CfgInstructionLocation:
"""
dataclass in charge of tracking the location of Instruction objects in the cfg

-> The considered instruction, where block is the basic block where it is contained and index the position in the basic block.

Note: Two instances with the same data will not be equal (because of eq=False).
This way, eq and hash are way more performant, because at the time of writing this, eq and hash are very
expensive on big instructions.

eq=True would probably be nicer to use, but we don't actually create instances with the same data
multiple times. (Rationale: initially just one instance is created per (block, index) pair.
All further instances with the same (block, index) will have a less complex instruction than before)
"""

instruction: Instruction
block: BasicBlock
index: int

@property
def instruction(self):
return self.block.instructions[self.index]


@dataclass()
class DefinedVariable:
Expand Down Expand Up @@ -208,7 +204,7 @@ class DefinitionGenerator:

def __init__(
self,
expression_usages: DefaultDict[Expression, Counter[CfgInstruction]],
expression_usages: DefaultDict[Expression, Counter[CfgInstructionLocation]],
dominator_tree: NetworkXGraph,
):
"""Generate a new instance based on data parsed from a cfg."""
Expand All @@ -218,16 +214,16 @@ def __init__(
@classmethod
def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator:
"""Initialize a DefinitionGenerator utilizing the data of the given cfg."""
usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter)
usages: DefaultDict[Expression, Counter[CfgInstructionLocation]] = defaultdict(Counter)
for basic_block in cfg:
for index, instruction in enumerate(basic_block.instructions):
instruction_with_position = CfgInstruction(instruction, basic_block, index)
instruction_with_position = CfgInstructionLocation(basic_block, index)
for subexpression in _subexpression_dfs(instruction):
usages[subexpression][instruction_with_position] += 1
return cls(usages, cfg.dominator_tree)

@property
def usages(self) -> DefaultDict[Expression, Counter[CfgInstruction]]:
def usages(self) -> DefaultDict[Expression, Counter[CfgInstructionLocation]]:
"""Return a mapping from expressions to a set of instructions using them."""
return self._usages

Expand All @@ -236,7 +232,7 @@ def define(self, expression: Expression, variable: Variable):
basic_block, index = self._find_location_for_insertion(expression)
for usage in self._usages[expression]:
usage.instruction.substitute(expression, variable.copy())
self._insert_definition(CfgInstruction(Assignment(variable, expression), basic_block, index))
self._insert_definition(Assignment(variable, expression), basic_block, index)

def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]:
"""
Expand Down Expand Up @@ -265,18 +261,31 @@ def _is_invalid_dominator(self, basic_block: BasicBlock, expression: Expression)
usages_in_the_same_block = [usage for usage in self._usages[expression] if usage.block == basic_block]
return any([isinstance(usage.instruction, Phi) for usage in usages_in_the_same_block])

def _insert_definition(self, definition: CfgInstruction):
def _insert_definition(self, instruction: Instruction, block: BasicBlock, index: int):
"""Insert a new intermediate definition for the given expression at the given location."""
definition.block.instructions.insert(definition.index, definition.instruction)
for subexpression in _subexpression_dfs(definition.instruction):
self._usages[subexpression][definition] += 1
block.instructions.insert(index, instruction)

def update_location(location: CfgInstructionLocation) -> CfgInstructionLocation:
if location.block == block and location.index < index:
return dataclasses.replace(location, index=location.index + 1)
else:
return location
rihi marked this conversation as resolved.
Show resolved Hide resolved

# update positions of expression usages
for occurrences in self._usages.values():
for location in list(occurrences):
if location.block == block and location.index >= index:
occurrences[dataclasses.replace(location, index=location.index + 1)] = occurrences.pop(location)
rihi marked this conversation as resolved.
Show resolved Hide resolved

for subexpression in _subexpression_dfs(instruction):
self._usages[subexpression][CfgInstructionLocation(block, index)] += 1
rihi marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstruction]) -> int:
def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstructionLocation]) -> int:
"""Find the first index in the given basic block where a definition could be inserted."""
usage = min((usage for usage in usages if usage.block == basic_block), default=None, key=lambda x: x.index)
if usage:
return basic_block.instructions.index(usage.instruction, usage.index)
return usage.index
if not basic_block.instructions:
return 0
if isinstance(basic_block.instructions[-1], GenericBranch):
Expand Down Expand Up @@ -324,7 +333,7 @@ def eliminate_common_subexpressions(self, definition_generator: DefinitionGenera
except StopIteration:
warning(f"[{self.name}] No dominating basic block could be found for {replacee}")

def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstruction]]) -> Iterator[Expression]:
def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[CfgInstructionLocation]]) -> Iterator[Expression]:
"""
Iterate all expressions, yielding the expressions which should be eliminated.

Expand All @@ -338,7 +347,7 @@ def _find_elimination_candidates(self, usages: DefaultDict[Expression, Counter[C
usages[subexpression].subtract(expression_usage)
yield expression

def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstruction]]):
def _is_cse_candidate(self, expression: Expression, usages: DefaultDict[Expression, Counter[CfgInstructionLocation]]):
"""Checks that we can add a common subexpression for the given expression."""
return (
self._is_elimination_candidate(expression, usages[expression])
Expand All @@ -356,14 +365,14 @@ def _is_complex_string(self, expression: Expression) -> bool:
return isinstance(expression.value, str) and len(expression.value) >= self._min_string_length
return False

def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool:
def _check_inter_instruction(self, expression: Expression, instructions: Counter[CfgInstructionLocation]) -> bool:
"""Check if the given expressions should be eliminated based on its global occurrences."""
referencing_instructions_count = sum(1 for _, count in instructions.items() if count > 0)
return (expression.complexity >= 2 and referencing_instructions_count >= self._threshold) or (
self._is_complex_string(expression) and referencing_instructions_count >= self._string_threshold
)

def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstruction]) -> bool:
def _check_intra_instruction(self, expression: Expression, instructions: Counter[CfgInstructionLocation]) -> bool:
"""Check if this expression should be eliminated based on the amount of unique instructions utilizing it."""
referencing_count = instructions.total()
return (expression.complexity >= 2 and referencing_count >= self._threshold) or (
Expand Down
Loading