diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..420a7b3 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,16 @@ +{ + "tabWidth": 2, + "useTabs": false, + "semi": false, + "singleQuote": true, + "trailingComma": "none", + "overrides": [ + { + "files": "*.sol", + "options": { + "printWidth": 120, + "singleQuote": false + } + } + ] +} diff --git a/test/StateReceiver.t.sol b/test/StateReceiver.t.sol index d8969f0..0cbbc57 100644 --- a/test/StateReceiver.t.sol +++ b/test/StateReceiver.t.sol @@ -3,6 +3,8 @@ pragma solidity 0.8.26; import "../lib/forge-std/src/Test.sol"; import "./helpers/IStateReceiver.sol"; +import {TestReenterer} from "test/helpers/TestReenterer.sol"; +import {TestRevertingReceiver} from "test/helpers/TestRevertingReceiver.sol"; contract StateReceiverTest is Test { address public constant SYSTEM_ADDRESS = 0xffffFFFfFFffffffffffffffFfFFFfffFFFfFFfE; @@ -11,12 +13,16 @@ contract StateReceiverTest is Test { IStateReceiver internal stateReceiver = IStateReceiver(0x0000000000000000000000000000000000001001); address internal rootSetter = makeAddr("rootSetter"); + TestReenterer internal reenterer = new TestReenterer(); + TestRevertingReceiver internal revertingReceiver = new TestRevertingReceiver(); + function setUp() public { address tmp = deployCode("out/StateReceiver.sol/StateReceiver.json", abi.encode(rootSetter)); vm.etch(address(stateReceiver), tmp.code); + vm.label(address(stateReceiver), "stateReceiver"); } - function test_deployment() public { + function test_deployment() public view { assertEq(stateReceiver.rootSetter(), rootSetter); } @@ -61,6 +67,7 @@ contract StateReceiverTest is Test { emit StateCommitted(stateId, false); assertFalse(stateReceiver.commitState(0, recordBytes)); assertEq(stateReceiver.lastStateId(), 1); + assertEq(stateReceiver.failedStateSyncs(stateId), abi.encode(receiver, stateData)); } function test_commitState_Success() public { @@ -83,6 +90,69 @@ contract StateReceiverTest is Test { assertEq(stateReceiver.lastStateId(), 1); } + function test_revertReplayFailedStateSync(uint256 stateId, bytes memory callData) public { + vm.assume(stateId > 0); + vm.store(address(stateReceiver), bytes32(0), bytes32(stateId - 1)); + assertTrue(revertingReceiver.shouldIRevert()); + bytes memory recordBytes = _encodeRecord(stateId, address(revertingReceiver), callData); + + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(stateId, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.failedStateSyncs(stateId), abi.encode(address(revertingReceiver), callData)); + + assertTrue(revertingReceiver.shouldIRevert()); + + vm.expectRevert("TestRevertingReceiver"); + stateReceiver.replayFailedStateSync(stateId); + } + + function test_ReplayFailedStateSync(uint256 stateId, bytes memory callData) public { + vm.assume(stateId > 0); + vm.store(address(stateReceiver), bytes32(0), bytes32(stateId - 1)); + assertTrue(revertingReceiver.shouldIRevert()); + bytes memory recordBytes = _encodeRecord(stateId, address(revertingReceiver), callData); + + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(stateId, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.failedStateSyncs(stateId), abi.encode(address(revertingReceiver), callData)); + + revertingReceiver.toggle(); + assertFalse(revertingReceiver.shouldIRevert()); + + vm.expectCall( + address(revertingReceiver), + 0, + abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, callData) + ); + stateReceiver.replayFailedStateSync(stateId); + + vm.expectRevert("!found"); + stateReceiver.replayFailedStateSync(stateId); + } + + function test_ReplayFailFromReenterer(uint256 stateId, bytes memory callData) public { + vm.assume(stateId > 0); + vm.store(address(stateReceiver), bytes32(0), bytes32(stateId - 1)); + bytes memory recordBytes = _encodeRecord(stateId, address(reenterer), callData); + + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(stateId, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.failedStateSyncs(stateId), abi.encode(address(reenterer), callData)); + + revertingReceiver.toggle(); + assertFalse(revertingReceiver.shouldIRevert()); + + vm.expectCall(address(reenterer), 0, abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, callData)); + vm.expectRevert("!found"); + stateReceiver.replayFailedStateSync(stateId); + } + function _encodeRecord( uint256 stateId, address receiver, diff --git a/test/helpers/IStateReceiver.sol b/test/helpers/IStateReceiver.sol index 2b0ac10..dbdae1b 100644 --- a/test/helpers/IStateReceiver.sol +++ b/test/helpers/IStateReceiver.sol @@ -2,10 +2,27 @@ pragma solidity >0.5.0; event StateCommitted(uint256 indexed stateId, bool success); +event StateSyncReplay(uint256 indexed stateId); interface IStateReceiver { function SYSTEM_ADDRESS() external view returns (address); function commitState(uint256 syncTime, bytes memory recordBytes) external returns (bool success); function lastStateId() external view returns (uint256); function rootSetter() external view returns (address); + function failedStateSyncsRoot() external view returns (bytes32); + function nullifier(bytes32) external view returns (bool); + function failedStateSyncs(uint256) external view returns (bytes memory); + function leafCount() external view returns (uint256); + function replayCount() external view returns (uint256); + function TREE_DEPTH() external view returns (uint256); + + function replayFailedStateSync(uint256 stateId) external; + function setRootAndLeafCount(bytes32 _root, uint256 _leafCount) external; + function replayHistoricFailedStateSync( + bytes32[16] calldata proof, + uint256 leafIndex, + uint256 stateId, + address receiver, + bytes calldata data + ) external; } diff --git a/test/helpers/TestReenterer.sol b/test/helpers/TestReenterer.sol new file mode 100644 index 0000000..f19a351 --- /dev/null +++ b/test/helpers/TestReenterer.sol @@ -0,0 +1,17 @@ +pragma solidity 0.8.26; + +contract TestReenterer { + uint256 public reenterCount; + + function onStateReceive(uint256 id, bytes calldata _data) external { + if (reenterCount++ == 0) { + (bool success, bytes memory ret) = msg.sender.call(abi.encodeWithSignature("replayFailedStateSync(uint256)", id)); + // bubble up revert for tests + if (!success) { + assembly { + revert(add(ret, 0x20), mload(ret)) + } + } + } + } +} diff --git a/test/helpers/TestRevertingReceiver.sol b/test/helpers/TestRevertingReceiver.sol new file mode 100644 index 0000000..89fea93 --- /dev/null +++ b/test/helpers/TestRevertingReceiver.sol @@ -0,0 +1,12 @@ +pragma solidity 0.8.26; + +contract TestRevertingReceiver { + bool public shouldIRevert = true; + function onStateReceive(uint256 _id, bytes calldata _data) external { + if (shouldIRevert) revert("TestRevertingReceiver"); + } + + function toggle() external { + shouldIRevert = !shouldIRevert; + } +} diff --git a/test/helpers/merkle.js b/test/helpers/merkle.js new file mode 100644 index 0000000..5da1281 --- /dev/null +++ b/test/helpers/merkle.js @@ -0,0 +1,106 @@ +const AbiCoder = require('web3-eth-abi') +const { keccak256 } = require('web3-utils') + +const abi = AbiCoder + +class SparseMerkleTree { + constructor(height) { + if (height <= 1) { + throw new Error('invalid height, must be greater than 1') + } + this.height = height + this.zeroHashes = this.generateZeroHashes(height) + const tree = [] + for (let i = 0; i <= height; i++) { + tree.push([]) + } + this.tree = tree + this.leafCount = 0 + this.dirty = false + } + + add(leaf) { + this.dirty = true + this.leafCount++ + this.tree[0].push(leaf) + } + + calcBranches() { + for (let i = 0; i < this.height; i++) { + const parent = this.tree[i + 1] + const child = this.tree[i] + for (let j = 0; j < child.length; j += 2) { + const leftNode = child[j] + const rightNode = + j + 1 < child.length ? child[j + 1] : this.zeroHashes[i] + parent[j / 2] = keccak256( + abi.encodeParameters(['bytes32', 'bytes32'], [leftNode, rightNode]) + ) + } + } + this.dirty = false + } + + getProofTreeByIndex(index) { + if (this.dirty) this.calcBranches() + const proof = [] + let currentIndex = index + for (let i = 0; i < this.height; i++) { + currentIndex = + currentIndex % 2 === 1 ? currentIndex - 1 : currentIndex + 1 + if (currentIndex < this.tree[i].length) + proof.push(this.tree[i][currentIndex]) + else proof.push(this.zeroHashes[i]) + currentIndex = Math.floor(currentIndex / 2) + } + + return proof + } + + getProofTreeByValue(value) { + const index = this.tree[0].indexOf(value) + if (index === -1) throw new Error('value not found') + return this.getProofTreeByIndex(index) + } + + getRoot() { + if (this.tree[0][0] === undefined) { + // No leafs in the tree, calculate root with all leafs to 0 + return keccak256( + abi.encodeParameters( + ['bytes32', 'bytes32'], + [this.zeroHashes[this.height - 1], this.zeroHashes[this.height - 1]] + ) + ) + } + if (this.dirty) this.calcBranches() + + return this.tree[this.height][0] + } + + generateZeroHashes(height) { + // keccak256(abi.encode(uint256(0), address(0), new bytes(0))); + const zeroHashes = [ + keccak256( + abi.encodeParameters( + ['uint256', 'address', 'bytes'], + [0, '0x' + '0'.repeat(40), '0x'] + ) + ) + ] + for (let i = 1; i < height; i++) { + zeroHashes.push( + keccak256( + abi.encodeParameters( + ['bytes32', 'bytes32'], + [zeroHashes[i - 1], zeroHashes[i - 1]] + ) + ) + ) + } + + return zeroHashes + } +} + +module.exports = { SparseMerkleTree }