diff --git a/tests/unit/compiler/venom/test_common_subexpression_elimination.py b/tests/unit/compiler/venom/test_common_subexpression_elimination.py new file mode 100644 index 0000000000..85f3fd7f07 --- /dev/null +++ b/tests/unit/compiler/venom/test_common_subexpression_elimination.py @@ -0,0 +1,215 @@ +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.basicblock import IRBasicBlock, IRLabel +from vyper.venom.context import IRContext +from vyper.venom.passes.common_subexpression_elimination import CSE +from vyper.venom.passes.store_expansion import StoreExpansionPass + + +def test_common_subexpression_elimination(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + op = bb.append_instruction("store", 10) + sum_1 = bb.append_instruction("add", op, 10) + bb.append_instruction("mul", sum_1, 10) + sum_2 = bb.append_instruction("add", op, 10) + bb.append_instruction("mul", sum_2, 10) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + CSE(ac, fn).run_pass() + + assert sum(1 for inst in bb.instructions if inst.opcode == "add") == 1, "wrong number of adds" + assert sum(1 for inst in bb.instructions if inst.opcode == "mul") == 1, "wrong number of muls" + + +def test_common_subexpression_elimination_commutative(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + op = bb.append_instruction("store", 10) + sum_1 = bb.append_instruction("add", 10, op) + bb.append_instruction("mul", sum_1, 10) + sum_2 = bb.append_instruction("add", op, 10) + bb.append_instruction("mul", sum_2, 10) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + CSE(ac, fn).run_pass() + + assert sum(1 for inst in bb.instructions if inst.opcode == "add") == 1, "wrong number of adds" + assert sum(1 for inst in bb.instructions if inst.opcode == "mul") == 1, "wrong number of muls" + + +def test_common_subexpression_elimination_effects_1(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + mload_1 = bb.append_instruction("mload", 0) + op = bb.append_instruction("store", 10) + bb.append_instruction("mstore", op, 0) + mload_2 = bb.append_instruction("mload", 0) + bb.append_instruction("add", mload_1, 10) + bb.append_instruction("add", mload_2, 10) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + CSE(ac, fn).run_pass() + + assert sum(1 for inst in bb.instructions if inst.opcode == "add") == 2, "wrong number of adds" + + +def test_common_subexpression_elimination_effects_2(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + mload_1 = bb.append_instruction("mload", 0) + bb.append_instruction("add", mload_1, 10) + op = bb.append_instruction("store", 10) + bb.append_instruction("mstore", op, 0) + mload_2 = bb.append_instruction("mload", 0) + bb.append_instruction("add", mload_1, 10) + bb.append_instruction("add", mload_2, 10) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + CSE(ac, fn).run_pass() + + assert sum(1 for inst in bb.instructions if inst.opcode == "add") == 2, "wrong number of adds" + + +def test_common_subexpression_elimination_logs(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + num2 = bb.append_instruction("store", 10) + num1 = bb.append_instruction("store", 20) + num3 = bb.append_instruction("store", 20) + bb.append_instruction("log", num1) + bb.append_instruction("log", num2) + bb.append_instruction("log", num1) + bb.append_instruction("log", num3) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + CSE(ac, fn).run_pass() + + assert sum(1 for inst in bb.instructions if inst.opcode == "log") == 4, "wrong number of log" + + +def test_common_subexpression_elimination_effects_3(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + addr1 = bb.append_instruction("store", 10) + addr2 = bb.append_instruction("store", 10) + bb.append_instruction("mstore", 0, addr1) + bb.append_instruction("mstore", 2, addr2) + bb.append_instruction("mstore", 0, addr1) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + CSE(ac, fn).run_pass() + + assert ( + sum(1 for inst in bb.instructions if inst.opcode == "mstore") == 3 + ), "wrong number of mstores" + + +def test_common_subexpression_elimination_effect_mstore(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + op = bb.append_instruction("store", 10) + bb.append_instruction("mstore", op, 0) + mload_1 = bb.append_instruction("mload", 0) + op = bb.append_instruction("store", 10) + bb.append_instruction("mstore", op, 0) + mload_2 = bb.append_instruction("mload", 0) + bb.append_instruction("add", mload_1, mload_2) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + CSE(ac, fn).run_pass() + + assert ( + sum(1 for inst in bb.instructions if inst.opcode == "mstore") == 1 + ), "wrong number of mstores" + assert ( + sum(1 for inst in bb.instructions if inst.opcode == "mload") == 1 + ), "wrong number of mloads" + + +def test_common_subexpression_elimination_effect_mstore_with_msize(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + op = bb.append_instruction("store", 10) + bb.append_instruction("mstore", op, 0) + mload_1 = bb.append_instruction("mload", 0) + op = bb.append_instruction("store", 10) + bb.append_instruction("mstore", op, 0) + mload_2 = bb.append_instruction("mload", 0) + msize_read = bb.append_instruction("msize") + bb.append_instruction("add", mload_1, msize_read) + bb.append_instruction("add", mload_2, msize_read) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + StoreExpansionPass(ac, fn).run_pass() + CSE(ac, fn).run_pass() + + assert ( + sum(1 for inst in bb.instructions if inst.opcode == "mstore") == 2 + ), "wrong number of mstores" + assert ( + sum(1 for inst in bb.instructions if inst.opcode == "mload") == 2 + ), "wrong number of mloads" + + +def test_common_subexpression_elimination_different_branches(): + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() + addr = bb.append_instruction("store", 10) + rand_cond = bb.append_instruction("mload", addr) + + br1 = IRBasicBlock(IRLabel("br1"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("br2"), fn) + fn.append_basic_block(br2) + join_bb = IRBasicBlock(IRLabel("join_bb"), fn) + fn.append_basic_block(join_bb) + + bb.append_instruction("jnz", rand_cond, br1.label, br2.label) + + def do_same(bb: IRBasicBlock, rand: int): + a = bb.append_instruction("store", 10) + b = bb.append_instruction("store", 20) + c = bb.append_instruction("add", a, b) + bb.append_instruction("mul", c, rand) + + do_same(br1, 1) + br1.append_instruction("jmp", join_bb.label) + do_same(br2, 2) + br2.append_instruction("jmp", join_bb.label) + do_same(join_bb, 3) + join_bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + StoreExpansionPass(ac, fn).run_pass() + CSE(ac, fn).run_pass() + + assert sum(1 for inst in br1.instructions if inst.opcode == "add") == 1, "wrong number of adds" + assert sum(1 for inst in br2.instructions if inst.opcode == "add") == 1, "wrong number of adds" + assert ( + sum(1 for inst in join_bb.instructions if inst.opcode == "add") == 1 + ), "wrong number of adds" diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 593a9556a9..e27d4f32f6 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -10,6 +10,7 @@ from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom from vyper.venom.passes import ( + CSE, SCCP, AlgebraicOptimizationPass, BranchOptimizationPass, @@ -66,9 +67,12 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: # MakeSSA again. MakeSSA(ac, fn).run_pass() BranchOptimizationPass(ac, fn).run_pass() + RemoveUnusedVariablesPass(ac, fn).run_pass() + CSE(ac, fn).run_pass() StoreExpansionPass(ac, fn).run_pass() + RemoveUnusedVariablesPass(ac, fn).run_pass() DFTPass(ac, fn).run_pass() diff --git a/vyper/venom/analysis/available_expression.py b/vyper/venom/analysis/available_expression.py new file mode 100644 index 0000000000..6d85ea7faa --- /dev/null +++ b/vyper/venom/analysis/available_expression.py @@ -0,0 +1,267 @@ +# REVIEW: rename this to cse_analysis or common_subexpression_analysis + +from dataclasses import dataclass +from functools import cached_property + +from vyper.utils import OrderedSet +from vyper.venom.analysis.analysis import IRAnalysesCache, IRAnalysis +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.analysis.equivalent_vars import VarEquivalenceAnalysis +from vyper.venom.basicblock import ( + BB_TERMINATORS, + IRBasicBlock, + IRInstruction, + IROperand, + IRVariable, +) +from vyper.venom.context import IRFunction +from vyper.venom.effects import EMPTY, Effects + + +@dataclass +class _Expression: + inst: IRInstruction + opcode: str + # the child is either expression of operand since + # there are possibilities for cycles + operands: list["IROperand | _Expression"] + ignore_msize: bool + + # equality for lattices only based on original instruction + def __eq__(self, other) -> bool: + if not isinstance(other, _Expression): + return False + + return self.inst == other.inst + + def __hash__(self) -> int: + return hash(self.inst) + + # Full equality for expressions based on opcode and operands + def same(self, other, eq_vars: VarEquivalenceAnalysis) -> bool: + return same(self, other, eq_vars) + + def __repr__(self) -> str: + if self.opcode == "store": + assert len(self.operands) == 1, "wrong store" + return repr(self.operands[0]) + res = self.opcode + " [ " + for op in self.operands: + res += repr(op) + " " + res += "]" + return res + + @cached_property + def get_depth(self) -> int: + max_depth = 0 + for op in self.operands: + if isinstance(op, _Expression): + d = op.get_depth + if d > max_depth: + max_depth = d + return max_depth + 1 + + @cached_property + def get_reads_deep(self) -> Effects: + tmp_reads = self.inst.get_read_effects() + for op in self.operands: + if isinstance(op, _Expression): + tmp_reads = tmp_reads | op.get_reads + if self.ignore_msize: + tmp_reads &= ~Effects.MSIZE + return tmp_reads + + @cached_property + def get_reads(self) -> Effects: + tmp_reads = self.inst.get_read_effects() + if self.ignore_msize: + tmp_reads &= ~Effects.MSIZE + return tmp_reads + + @cached_property + def get_writes_deep(self) -> Effects: + tmp_reads = self.inst.get_write_effects() + for op in self.operands: + if isinstance(op, _Expression): + tmp_reads = tmp_reads | op.get_writes + if self.ignore_msize: + tmp_reads &= ~Effects.MSIZE + return tmp_reads + + @cached_property + def get_writes(self) -> Effects: + tmp_reads = self.inst.get_write_effects() + if self.ignore_msize: + tmp_reads &= ~Effects.MSIZE + return tmp_reads + + @property + def is_commutative(self) -> bool: + return self.inst.is_commutative + + +def same( + a: IROperand | _Expression, b: IROperand | _Expression, eq_vars: VarEquivalenceAnalysis +) -> bool: + if isinstance(a, IROperand) and isinstance(b, IROperand): + return a.value == b.value + if not isinstance(a, _Expression) or not isinstance(b, _Expression): + return False + + if a.inst == b.inst: + return True + + if a.opcode != b.opcode: + return False + + # Early return special case for commutative instructions + if a.is_commutative: + if same(a.operands[0], b.operands[1], eq_vars) and same( + a.operands[1], b.operands[0], eq_vars + ): + return True + + # General case + for self_op, other_op in zip(a.operands, b.operands): + if ( + self_op is not other_op + and not eq_vars.equivalent(self_op, other_op) + and self_op != other_op + ): + return False + + return True + + +class CSEAnalysis(IRAnalysis): + inst_to_expr: dict[IRInstruction, _Expression] + dfg: DFGAnalysis + inst_to_available: dict[IRInstruction, OrderedSet[_Expression]] + bb_outs: dict[IRBasicBlock, OrderedSet[_Expression]] + eq_vars: VarEquivalenceAnalysis + + ignore_msize: bool + + def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): + super().__init__(analyses_cache, function) + self.analyses_cache.request_analysis(CFGAnalysis) + dfg = self.analyses_cache.request_analysis(DFGAnalysis) + assert isinstance(dfg, DFGAnalysis) + self.dfg = dfg + self.eq_vars = self.analyses_cache.request_analysis(VarEquivalenceAnalysis) # type: ignore + + self.inst_to_expr = dict() + self.inst_to_available = dict() + self.bb_outs = dict() + + self.ignore_msize = not self._contains_msize() + + def analyze(self): + worklist: OrderedSet = OrderedSet() + worklist.add(self.function.entry) + while len(worklist) > 0: + bb: IRBasicBlock = worklist.pop() + changed = self._handle_bb(bb) + + if changed: + for out in bb.cfg_out: + worklist.add(out) + + # msize effect should be only necessery + # to be handled when there is a possibility + # of msize read otherwise it should not make difference + # for this analysis + def _contains_msize(self) -> bool: + for bb in self.function.get_basic_blocks(): + for inst in bb.instructions: + if inst.opcode == "msize": + return True + return False + + def _handle_bb(self, bb: IRBasicBlock) -> bool: + available_expr: OrderedSet[_Expression] = OrderedSet() + if len(bb.cfg_in) > 0: + available_expr = OrderedSet.intersection( + *(self.bb_outs.get(in_bb, OrderedSet()) for in_bb in bb.cfg_in) + ) + + # bb_lat = self.lattice.data[bb] + change = False + for inst in bb.instructions: + # if inst.opcode in UNINTERESTING_OPCODES or inst.opcode in BB_TERMINATORS: + if inst.opcode in BB_TERMINATORS: + continue + + # REVIEW: why replace inst_to_available if they are not equal? + if inst not in self.inst_to_available or available_expr != self.inst_to_available[inst]: + self.inst_to_available[inst] = available_expr.copy() + inst_expr = self.get_expression(inst, available_expr) + write_effects = inst_expr.get_writes + for expr in available_expr.copy(): + read_effects = expr.get_reads + if read_effects & write_effects != EMPTY: + available_expr.remove(expr) + continue + write_effects_expr = expr.get_writes + if write_effects_expr & write_effects != EMPTY: + available_expr.remove(expr) + + if inst_expr.get_writes_deep & inst_expr.get_reads_deep == EMPTY: + available_expr.add(inst_expr) + + if bb not in self.bb_outs or available_expr != self.bb_outs[bb]: + self.bb_outs[bb] = available_expr.copy() + # change is only necessery when the output of the + # basic block is changed (otherwise it wont affect rest) + change |= True + + return change + + def _get_operand( + self, op: IROperand, available_exprs: OrderedSet[_Expression] + ) -> IROperand | _Expression: + if isinstance(op, IRVariable): + inst = self.dfg.get_producing_instruction(op) + assert inst is not None + # this can both create better solutions and is necessery + # for correct effect handle, otherwise you could go over + # effect bounderies + # the phi condition is here because it is only way to + # create call loop + if inst.is_volatile or inst.opcode == "phi": + return op + if inst.opcode == "store": + return self._get_operand(inst.operands[0], available_exprs) + if inst in self.inst_to_expr: + return self.inst_to_expr[inst] + return self.get_expression(inst, available_exprs) + return op + + def _get_operands( + self, inst: IRInstruction, available_exprs: OrderedSet[_Expression] + ) -> list[IROperand | _Expression]: + return [self._get_operand(op, available_exprs) for op in inst.operands] + + def get_expression( + self, inst: IRInstruction, available_exprs: OrderedSet[_Expression] | None = None + ) -> _Expression: + available_exprs = available_exprs or self.inst_to_available.get(inst, OrderedSet()) + assert available_exprs is not None + operands: list[IROperand | _Expression] = self._get_operands(inst, available_exprs) + expr = _Expression(inst, inst.opcode, operands, self.ignore_msize) + + if inst in self.inst_to_expr and self.inst_to_expr[inst] in available_exprs: + return self.inst_to_expr[inst] + + # REVIEW: performance issue - loop over available_exprs. + for e in available_exprs: + if expr.same(e, self.eq_vars): + self.inst_to_expr[inst] = e + return e + + self.inst_to_expr[inst] = expr + return expr + + def get_available(self, inst: IRInstruction) -> OrderedSet[_Expression]: + return self.inst_to_available.get(inst, OrderedSet()) diff --git a/vyper/venom/effects.py b/vyper/venom/effects.py index 97cffe2cb2..f19f8af219 100644 --- a/vyper/venom/effects.py +++ b/vyper/venom/effects.py @@ -43,7 +43,7 @@ def __iter__(self): "create2": ALL ^ (MEMORY | IMMUTABLES), "invoke": ALL, # could be smarter, look up the effects of the invoked function "log": LOG, - "dloadbytes": MEMORY, + "dloadbytes": MEMORY | IMMUTABLES, "returndatacopy": MEMORY, "calldatacopy": MEMORY, "codecopy": MEMORY, diff --git a/vyper/venom/passes/__init__.py b/vyper/venom/passes/__init__.py index fcd2aa1f22..da461c8341 100644 --- a/vyper/venom/passes/__init__.py +++ b/vyper/venom/passes/__init__.py @@ -1,5 +1,6 @@ from .algebraic_optimization import AlgebraicOptimizationPass from .branch_optimization import BranchOptimizationPass +from .common_subexpression_elimination import CSE from .dft import DFTPass from .float_allocas import FloatAllocas from .make_ssa import MakeSSA diff --git a/vyper/venom/passes/common_subexpression_elimination.py b/vyper/venom/passes/common_subexpression_elimination.py new file mode 100644 index 0000000000..a6384d273e --- /dev/null +++ b/vyper/venom/passes/common_subexpression_elimination.py @@ -0,0 +1,117 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis.available_expression import CSEAnalysis +from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable +from vyper.venom.passes.base_pass import IRPass + +# instruction that are not usefull to be +# substituted +UNINTERESTING_OPCODES = frozenset( + [ + "store", + "param", + "offset", + "phi", + "nop", + "calldatasize", + "returndatasize", + "gas", + "gaslimit", + "gasprice", + "gaslimit", + "address", + "origin", + "codesize", + "caller", + "callvalue", + "coinbase", + "timestamp", + "number", + "prevrandao", + "chainid", + "basefee", + "blobbasefee", + "pc", + "msize", + ] +) +# intruction that cannot be substituted (without further analysis) +NONIDEMPOTENT_INSTRUCTIONS = frozenset(["log", "call", "staticcall", "delegatecall", "invoke"]) + + +class CSE(IRPass): + expression_analysis: CSEAnalysis + + def run_pass(self): + available_expression_analysis = self.analyses_cache.request_analysis(CSEAnalysis) + assert isinstance(available_expression_analysis, CSEAnalysis) + self.expression_analysis = available_expression_analysis + + while True: + replace_dict = self._find_replaceble() + if len(replace_dict) == 0: + return + self._replace(replace_dict) + self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + # should be ok to be reevaluted + # self.available_expression_analysis.analyze(min_depth, max_depth) + self.expression_analysis = self.analyses_cache.force_analysis( + CSEAnalysis + ) # type: ignore + + # return instruction and to which instruction it could + # replaced by + def _find_replaceble(self) -> dict[IRInstruction, IRInstruction]: + res: dict[IRInstruction, IRInstruction] = dict() + + for bb in self.function.get_basic_blocks(): + for inst in bb.instructions: + # skip instruction that for sure + # wont be substituted + if ( + inst.opcode in UNINTERESTING_OPCODES + or inst.opcode in NONIDEMPOTENT_INSTRUCTIONS + ): + continue + inst_expr = self.expression_analysis.get_expression(inst) + avail = self.expression_analysis.get_available(inst) + # heuristic to not replace small expressions + # basic block bounderies (it can create better codesize) + if inst_expr in avail and ( + inst_expr.get_depth > 1 or inst.parent == inst_expr.inst.parent + ): + res[inst] = inst_expr.inst + + return res + + def _replace(self, replace_dict: dict[IRInstruction, IRInstruction]): + for orig, to in replace_dict.items(): + while to in replace_dict.keys(): + to = replace_dict[to] + self._replace_inst(orig, to) + + def _replace_inst(self, orig_inst: IRInstruction, to_inst: IRInstruction): + visited: OrderedSet[IRBasicBlock] = OrderedSet() + if orig_inst.output is not None: + assert isinstance(orig_inst.output, IRVariable), f"not var {orig_inst}" + assert isinstance(to_inst.output, IRVariable), f"not var {to_inst}" + self._replace_inst_r(orig_inst.parent, orig_inst.output, to_inst.output, visited) + orig_inst.parent.remove_instruction(orig_inst) + + def _replace_inst_r( + self, bb: IRBasicBlock, orig: IRVariable, to: IRVariable, visited: OrderedSet[IRBasicBlock] + ): + if bb in visited: + return + visited.add(bb) + + for inst in bb.instructions: + for i in range(len(inst.operands)): + op = inst.operands[i] + if op == orig: + inst.operands[i] = to + + for out in bb.cfg_out: + self._replace_inst_r(out, orig, to, visited)