diff --git a/lib/forge-std b/lib/forge-std index 978ac6f..58d3051 160000 --- a/lib/forge-std +++ b/lib/forge-std @@ -1 +1 @@ -Subproject commit 978ac6fadb62f5f0b723c996f64be52eddba6801 +Subproject commit 58d30519826c313ce47345abedfdc07679e944d1 diff --git a/test/StateReceiver.t.sol b/test/StateReceiver.t.sol index 0cbbc57..f2fb639 100644 --- a/test/StateReceiver.t.sol +++ b/test/StateReceiver.t.sol @@ -123,6 +123,8 @@ contract StateReceiverTest is Test { revertingReceiver.toggle(); assertFalse(revertingReceiver.shouldIRevert()); + vm.expectEmit(); + emit StateSyncReplay(stateId); vm.expectCall( address(revertingReceiver), 0, @@ -153,6 +155,88 @@ contract StateReceiverTest is Test { stateReceiver.replayFailedStateSync(stateId); } + function test_rootSetter(address random) public { + vm.prank(random); + if (random != rootSetter) vm.expectRevert("!rootSetter"); + stateReceiver.setRootAndLeafCount(bytes32(uint(0x1337)), 0); + + vm.prank(rootSetter); + if (random == rootSetter) vm.expectRevert("!zero"); + stateReceiver.setRootAndLeafCount(bytes32(uint(0x1337)), 0); + } + + function test_shouldNotReplayZeroLeaf(bytes32 root, bytes32[16] memory proof) public { + vm.prank(rootSetter); + stateReceiver.setRootAndLeafCount(root, 1); + + vm.expectRevert(bytes("used")); + stateReceiver.replayHistoricFailedStateSync(proof, 0, 0, address(0), new bytes(0)); + } + + function test_shouldNotReplayInvalidProof(bytes32 root, bytes32[16] memory proof, bytes memory stateData) public { + vm.prank(rootSetter); + stateReceiver.setRootAndLeafCount(root, 1); + + vm.expectRevert("!proof"); + stateReceiver.replayHistoricFailedStateSync( + proof, + vm.randomUint(0, 2 ** 16), + vm.randomUint(), + address(uint160(vm.randomUint())), + stateData + ); + } + + function test_FailedStateSyncs(bytes[] memory stateDatas) external { + vm.assume(stateDatas.length > 1 && stateDatas.length < 10); + + address receiver = address(revertingReceiver); + + for (uint256 i = 0; i < stateDatas.length; ++i) { + bytes memory recordBytes = _encodeRecord(i + 1, receiver, stateDatas[i]); + + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(i + 1, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); + } + + uint256 leafCount = stateDatas.length; + bytes32 root; + bytes[] memory proofs = new bytes[](leafCount); + (root, proofs) = _getRootAndProofs(receiver, abi.encode(stateDatas)); + + vm.prank(rootSetter); + stateReceiver.setRootAndLeafCount(root, leafCount); + + revertingReceiver.toggle(); + + for (uint256 i = 0; i < stateDatas.length; ++i) { + vm.expectEmit(); + emit StateSyncReplay(i + 1); + stateReceiver.replayHistoricFailedStateSync( + abi.decode(proofs[i], (bytes32[16])), + i, + i + 1, + receiver, + stateDatas[i] + ); + } + } + + function _getRootAndProofs( + address receiver, + bytes memory stateDatasEncoded + ) internal returns (bytes32 root, bytes[] memory proofs) { + string[] memory inputs = new string[](4); + inputs[0] = "node"; + inputs[1] = "test/helpers/merkle.js"; + inputs[2] = vm.toString(receiver); + inputs[3] = vm.toString(stateDatasEncoded); + + (root, proofs) = abi.decode(vm.ffi(inputs), (bytes32, bytes[])); + } + function _encodeRecord( uint256 stateId, address receiver, diff --git a/test/helpers/merkle.js b/test/helpers/merkle.js index 5da1281..6107e35 100644 --- a/test/helpers/merkle.js +++ b/test/helpers/merkle.js @@ -103,4 +103,30 @@ class SparseMerkleTree { } } -module.exports = { SparseMerkleTree } +function getLeaf(stateID, receiverAddress, stateData) { + return keccak256( + abi.encodeParameters( + ['uint256', 'address', 'bytes'], + [stateID, receiverAddress, stateData] + ) + ) +} + +const [receiver, stateDatasEncoded] = process.argv.slice(2) + +const stateDatas = abi.decodeParameter('bytes[]', stateDatasEncoded) + +const tree = new SparseMerkleTree(16) + +for (let i = 0; i < stateDatas.length; i++) { + tree.add(getLeaf(i + 1, receiver, stateDatas[i])) +} +const root = tree.getRoot() +const proofs = stateDatas.map((_, i) => tree.getProofTreeByIndex(i)) + +console.log( + abi.encodeParameters( + ['bytes32', 'bytes[]'], + [root, proofs.map((proof) => abi.encodeParameter('bytes32[16]', proof))] + ) +)