Skip to content

Commit

Permalink
small calldatacopy optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
HodanPlodky committed Nov 27, 2024
1 parent 1e43010 commit b87dfe6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
19 changes: 19 additions & 0 deletions tests/unit/compiler/venom/test_memmerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,22 @@ def test_memzeroing_3():
assert bb.instructions[1].operands[0].value == 256 + 2 * 32
assert bb.instructions[1].operands[2].value == 64
assert len(bb.instructions) == 3


def test_memzeroing_small_calldatacopy():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

calldatasize = bb.append_instruction("calldatasize")
bb.append_instruction("calldatacopy", 32, calldatasize, 64)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)
MemMergePass(ac, fn).run_pass()
RemoveUnusedVariablesPass(ac, fn).run_pass()

assert bb.instructions[0].opcode == "mstore"
assert bb.instructions[0].operands[0].value == 0
assert bb.instructions[0].operands[1].value == 64
36 changes: 19 additions & 17 deletions vyper/venom/passes/memmerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,26 @@ def _barrier():
# optimize memzeroing operations
def _optimize_memzero(self, bb: IRBasicBlock, intervals: list[_Interval]):
for interval in intervals:
if interval.length <= 32:
# TODO: if interval.length == 32, then we actually want to
# turn calldatacopy into (mstore 0)
inst = interval.insts[-1]
if interval.length == 32 and inst.opcode == "calldatacopy":
dst = inst.operands[2]
inst.opcode = "mstore"
inst.operands = [IRLiteral(0), dst]
elif interval.length <= 32:
continue
inst = interval.insts[0]

index = bb.instructions.index(inst)
calldatasize = bb.parent.get_next_variable()
bb.insert_instruction(IRInstruction("calldatasize", [], output=calldatasize), index)

inst.output = None
inst.opcode = "calldatacopy"
inst.operands = [
IRLiteral(interval.length),
calldatasize,
IRLiteral(interval.dst_start),
]
for inst in interval.insts[1:]:
else:
index = bb.instructions.index(inst)
calldatasize = bb.parent.get_next_variable()
bb.insert_instruction(IRInstruction("calldatasize", [], output=calldatasize), index)

inst.output = None
inst.opcode = "calldatacopy"
inst.operands = [
IRLiteral(interval.length),
calldatasize,
IRLiteral(interval.dst_start),
]
for inst in interval.insts[0:-1]:
bb.remove_instruction(inst)

intervals.clear()
Expand Down

0 comments on commit b87dfe6

Please sign in to comment.