Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ordering of inserted common subexpressions #403

Merged
merged 8 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class CfgInstruction:

instruction: Instruction
block: BasicBlock
index: int

@property
def index(self):
return next(index for index, instruction in enumerate(self.block.instructions) if id(instruction) == id(self.instruction))


@dataclass()
Expand Down Expand Up @@ -221,7 +224,7 @@ def from_cfg(cls, cfg: ControlFlowGraph) -> DefinitionGenerator:
usages: DefaultDict[Expression, Counter[CfgInstruction]] = defaultdict(Counter)
for basic_block in cfg:
for index, instruction in enumerate(basic_block.instructions):
instruction_with_position = CfgInstruction(instruction, basic_block, index)
instruction_with_position = CfgInstruction(instruction, basic_block)
for subexpression in _subexpression_dfs(instruction):
usages[subexpression][instruction_with_position] += 1
return cls(usages, cfg.dominator_tree)
Expand All @@ -236,7 +239,7 @@ def define(self, expression: Expression, variable: Variable):
basic_block, index = self._find_location_for_insertion(expression)
for usage in self._usages[expression]:
usage.instruction.substitute(expression, variable.copy())
self._insert_definition(CfgInstruction(Assignment(variable, expression), basic_block, index))
self._insert_definition(Assignment(variable, expression), basic_block, index)

def _find_location_for_insertion(self, expression) -> Tuple[BasicBlock, int]:
"""
Expand Down Expand Up @@ -265,18 +268,19 @@ def _is_invalid_dominator(self, basic_block: BasicBlock, expression: Expression)
usages_in_the_same_block = [usage for usage in self._usages[expression] if usage.block == basic_block]
return any([isinstance(usage.instruction, Phi) for usage in usages_in_the_same_block])

def _insert_definition(self, definition: CfgInstruction):
def _insert_definition(self, instruction: Instruction, block: BasicBlock, index: int):
"""Insert a new intermediate definition for the given expression at the given location."""
definition.block.instructions.insert(definition.index, definition.instruction)
for subexpression in _subexpression_dfs(definition.instruction):
self._usages[subexpression][definition] += 1
block.instructions.insert(index, instruction)
cfg_instruction = CfgInstruction(instruction, block)
for subexpression in _subexpression_dfs(instruction):
self._usages[subexpression][cfg_instruction] += 1

@staticmethod
def _find_insertion_index(basic_block: BasicBlock, usages: Iterable[CfgInstruction]) -> int:
"""Find the first index in the given basic block where a definition could be inserted."""
usage = min((usage for usage in usages if usage.block == basic_block), default=None, key=lambda x: x.index)
if usage:
return basic_block.instructions.index(usage.instruction, usage.index)
first_usage_index = min((usage.index for usage in usages if usage.block == basic_block), default=None)
if first_usage_index is not None:
return first_usage_index
if not basic_block.instructions:
return 0
if isinstance(basic_block.instructions[-1], GenericBranch):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1283,3 +1283,32 @@ def test_common_subexpression_elimination_correct_place():
Assignment(Variable("d", ssa_label=4), BinaryOperation(OperationType.plus, [Variable("e", ssa_label=2), replacement0])),
Assignment(Variable("f", ssa_label=1), BinaryOperation(OperationType.minus, [Variable("g", ssa_label=4), replacement1])),
]


def test_common_subexpression_elimination_correct_place2():
"""Check that the instruction is inserted at the correct position"""
expr0 = BinaryOperation(OperationType.multiply, [Variable("a", Integer.int32_t()), Constant(2, Integer.int32_t())])
expr1 = BinaryOperation(OperationType.plus, [expr0.copy(), Constant(5, Integer.int32_t())])

cfg = ControlFlowGraph()
cfg.add_node(
node := BasicBlock(
0,
instructions=[
Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [expr0.copy(), expr1.copy()])),
Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [expr1.copy()])),
],
)
)
_run_cse(cfg, _generate_options(threshold=2))

replacement0 = Variable("c0", Integer.int32_t(), ssa_label=0)
replacement1 = Variable("c1", Integer.int32_t(), ssa_label=0)
expr_new = BinaryOperation(OperationType.plus, [replacement1.copy(), Constant(5, Integer.int32_t())])

assert node.instructions == [
Assignment(replacement1.copy(), expr0.copy()),
Assignment(replacement0.copy(), expr_new.copy()),
Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [replacement1.copy(), replacement0.copy()])),
Assignment(ListOperation([]), Call(FunctionSymbol("func", 0), [replacement0.copy()])),
]
Loading