Skip to content

Commit

Permalink
Fix ordering of inserted common subexpressions (#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
rihi authored Apr 4, 2024
1 parent 4d9cd14 commit 87fe3e3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class CfgInstruction:

instruction: Instruction
block: BasicBlock
index: int

@property
def index(self):
return next(index for index, instruction in enumerate(self.block.instructions) if id(instruction) == id(self.instruction))


@dataclass()
Expand Down Expand Up @@ -221,7 +224,7 @@ def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator:
usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter)
for basic_block in cfg:
for index, instruction in enumerate(basic_block.instructions):
instruction_with_position = CfgInstruction(instruction, basic_block, index)
instruction_with_position = CfgInstruction(instruction, basic_block)
for subexpression in _subexpression_dfs(instruction):
usages[subexpression][instruction_with_position] += 1
return cls(usages, cfg.dominator_tree)
Expand All @@ -236,7 +239,7 @@ def define(self, expression: Expression, variable: Variable):
basic_block, index = self._find_location_for_insertion(expression)
for usage in self._usages[expression]:
usage.instruction.substitute(expression, variable.copy())
self._insert_definition(CfgInstruction(Assignment(variable, expression), basic_block, index))
self._insert_definition(Assignment(variable, expression), basic_block, index)

def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]:
"""
Expand Down Expand Up @@ -265,18 +268,19 @@ def _is_invalid_dominator(self, basic_block: BasicBlock, expression: Expression)
usages_in_the_same_block = [usage for usage in self._usages[expression] if usage.block == basic_block]
return any([isinstance(usage.instruction, Phi) for usage in usages_in_the_same_block])

def _insert_definition(self, definition: CfgInstruction):
def _insert_definition(self, instruction: Instruction, block: BasicBlock, index: int):
"""Insert a new intermediate definition for the given expression at the given location."""
definition.block.instructions.insert(definition.index, definition.instruction)
for subexpression in _subexpression_dfs(definition.instruction):
self._usages[subexpression][definition] += 1
block.instructions.insert(index, instruction)
cfg_instruction = CfgInstruction(instruction, block)
for subexpression in _subexpression_dfs(instruction):
self._usages[subexpression][cfg_instruction] += 1

@staticmethod
def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstruction]) -> int:
"""Find the first index in the given basic block where a definition could be inserted."""
usage = min((usage for usage in usages if usage.block == basic_block), default=None, key=lambda x: x.index)
if usage:
return basic_block.instructions.index(usage.instruction, usage.index)
first_usage_index = min((usage.index for usage in usages if usage.block == basic_block), default=None)
if first_usage_index is not None:
return first_usage_index
if not basic_block.instructions:
return 0
if isinstance(basic_block.instructions[-1], GenericBranch):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1283,3 +1283,32 @@ def test_common_subexpression_elimination_correct_place():
Assignment(Variable("d", ssa_label=4), BinaryOperation(OperationType.plus, [Variable("e", ssa_label=2), replacement0])),
Assignment(Variable("f", ssa_label=1), BinaryOperation(OperationType.minus, [Variable("g", ssa_label=4), replacement1])),
]


def test_common_subexpression_elimination_correct_place2():
"""Check that the instruction is inserted at the correct position"""
expr0 = BinaryOperation(OperationType.multiply, [Variable("a", Integer.int32_t()), Constant(2, Integer.int32_t())])
expr1 = BinaryOperation(OperationType.plus, [expr0.copy(), Constant(5, Integer.int32_t())])

cfg = ControlFlowGraph()
cfg.add_node(
node := BasicBlock(
0,
instructions=[
Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [expr0.copy(), expr1.copy()])),
Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [expr1.copy()])),
],
)
)
_run_cse(cfg, _generate_options(threshold=2))

replacement0 = Variable("c0", Integer.int32_t(), ssa_label=0)
replacement1 = Variable("c1", Integer.int32_t(), ssa_label=0)
expr_new = BinaryOperation(OperationType.plus, [replacement1.copy(), Constant(5, Integer.int32_t())])

assert node.instructions == [
Assignment(replacement1.copy(), expr0.copy()),
Assignment(replacement0.copy(), expr_new.copy()),
Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [replacement1.copy(), replacement0.copy()])),
Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [replacement0.copy()])),
]

0 comments on commit 87fe3e3

Please sign in to comment.