From fa18db78f1994342ff0cc9057972e5818dbfb808 Mon Sep 17 00:00:00 2001 From: 0xIryna <43921510+0xIryna@users.noreply.github.com> Date: Wed, 20 Nov 2024 18:39:06 -0800 Subject: [PATCH] fix: registrar key/list update race condition (#22) --- src/HubPortal.sol | 21 +++-- src/SpokePortal.sol | 30 +++++-- src/interfaces/ISpokePortal.sol | 18 +++- src/libs/PayloadEncoder.sol | 16 +++- test/unit/HubPortal.t.sol | 6 +- test/unit/SpokePortal.t.sol | 132 ++++++++++++++++++++++++++-- test/unit/libs/PayloadEncoder.t.sol | 27 ++++-- 7 files changed, 216 insertions(+), 34 deletions(-) diff --git a/src/HubPortal.sol b/src/HubPortal.sol index b60f389..b871541 100644 --- a/src/HubPortal.sol +++ b/src/HubPortal.sol @@ -57,7 +57,7 @@ contract HubPortal is IHubPortal, Portal { ) external payable returns (bytes32 messageId_) { uint128 index_ = _currentIndex(); bytes memory payload_ = PayloadEncoder.encodeIndex(index_, destinationChainId_); - messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_); + messageId_ = _sendMessage(destinationChainId_, refundAddress_, _useMessageSequence(), payload_); emit MTokenIndexSent(destinationChainId_, messageId_, index_); } @@ -69,8 +69,9 @@ contract HubPortal is IHubPortal, Portal { bytes32 refundAddress_ ) external payable returns (bytes32 messageId_) { bytes32 value_ = IRegistrarLike(registrar).get(key_); - bytes memory payload_ = PayloadEncoder.encodeKey(key_, value_, destinationChainId_); - messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_); + uint64 sequence_ = _useMessageSequence(); + bytes memory payload_ = PayloadEncoder.encodeKey(key_, value_, sequence_, destinationChainId_); + messageId_ = _sendMessage(destinationChainId_, refundAddress_, sequence_, payload_); emit RegistrarKeySent(destinationChainId_, messageId_, key_, value_); } @@ -83,8 +84,15 @@ contract HubPortal is IHubPortal, Portal { bytes32 refundAddress_ ) external payable returns (bytes32 messageId_) { bool status_ = IRegistrarLike(registrar).listContains(listName_, account_); - bytes memory payload_ = PayloadEncoder.encodeListUpdate(listName_, account_, status_, destinationChainId_); - messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_); + uint64 sequence_ = _useMessageSequence(); + bytes memory payload_ = PayloadEncoder.encodeListUpdate( + listName_, + account_, + status_, + sequence_, + destinationChainId_ + ); + messageId_ = _sendMessage(destinationChainId_, refundAddress_, sequence_, payload_); emit RegistrarListStatusSent(destinationChainId_, messageId_, listName_, account_, status_); } @@ -137,6 +145,7 @@ contract HubPortal is IHubPortal, Portal { function _sendMessage( uint16 destinationChainId_, bytes32 refundAddress_, + uint64 _sequence, bytes memory payload_ ) private returns (bytes32 messageId_) { if (refundAddress_ == bytes32(0)) revert InvalidRefundAddress(); @@ -149,7 +158,7 @@ contract HubPortal is IHubPortal, Portal { ) = _prepareForTransfer(destinationChainId_, DEFAULT_TRANSCEIVER_INSTRUCTIONS); TransceiverStructs.NttManagerMessage memory message_ = TransceiverStructs.NttManagerMessage( - bytes32(uint256(_useMessageSequence())), + bytes32(uint256(_sequence)), msg.sender.toBytes32(), payload_ ); diff --git a/src/SpokePortal.sol b/src/SpokePortal.sol index c2187ba..c904a6c 100644 --- a/src/SpokePortal.sol +++ b/src/SpokePortal.sol @@ -24,6 +24,9 @@ contract SpokePortal is ISpokePortal, Portal { /// @inheritdoc ISpokePortal uint112 public outstandingPrincipal; + /// @inheritdoc ISpokePortal + uint64 public lastProcessedSequence; + /** * @notice Constructs the contract. * @param mToken_ The address of the M token to bridge. @@ -48,7 +51,7 @@ contract SpokePortal is ISpokePortal, Portal { } } - /* ============ Internal Interactive Functions ============ */ + /* ============ Internal/Private Interactive Functions ============ */ function _receiveCustomPayload( bytes32 messageId_, @@ -79,22 +82,31 @@ contract SpokePortal is ISpokePortal, Portal { /// @notice Sets a Registrar key received from the Hub chain. function _setRegistrarKey(bytes32 messageId_, bytes memory payload_) private { - (bytes32 key_, bytes32 value_, uint16 destinationChainId_) = payload_.decodeKey(); + (bytes32 key_, bytes32 value_, uint64 sequence_, uint16 destinationChainId_) = payload_.decodeKey(); _verifyDestinationChain(destinationChainId_); - emit RegistrarKeyReceived(messageId_, key_, value_); + emit RegistrarKeyReceived(messageId_, key_, value_, sequence_); + + _verifyMessageSequence(sequence_); + + lastProcessedSequence = sequence_; IRegistrarLike(registrar).setKey(key_, value_); } /// @notice Adds or removes an account from the Registrar List based on the message from the Hub chain. function _updateRegistrarList(bytes32 messageId_, bytes memory payload_) private { - (bytes32 listName_, address account_, bool add_, uint16 destinationChainId_) = payload_.decodeListUpdate(); + (bytes32 listName_, address account_, bool add_, uint64 sequence_, uint16 destinationChainId_) = payload_ + .decodeListUpdate(); _verifyDestinationChain(destinationChainId_); - emit RegistrarListStatusReceived(messageId_, listName_, account_, add_); + emit RegistrarListStatusReceived(messageId_, listName_, account_, add_, sequence_); + + _verifyMessageSequence(sequence_); + + lastProcessedSequence = sequence_; if (add_) { IRegistrarLike(registrar).addToList(listName_, account_); @@ -133,6 +145,14 @@ contract SpokePortal is ISpokePortal, Portal { } } + /// @dev Checks if the incoming message sequence is greater than the last processed one to prevent + /// Registrar data from being overwritten due to message reordering. + function _verifyMessageSequence(uint64 sequence_) private view { + uint64 lastProcessedSequence_ = lastProcessedSequence; + if (lastProcessedSequence_ != 0 && sequence_ < lastProcessedSequence_) + revert ObsoleteMessageSequence(sequence_, lastProcessedSequence_); + } + /// @dev Returns the current M token index used by the Spoke Portal. function _currentIndex() internal view override returns (uint128) { return ISpokeMTokenLike(mToken()).currentIndex(); diff --git a/src/interfaces/ISpokePortal.sol b/src/interfaces/ISpokePortal.sol index 577a250..706fd94 100644 --- a/src/interfaces/ISpokePortal.sol +++ b/src/interfaces/ISpokePortal.sol @@ -23,8 +23,9 @@ interface ISpokePortal is IPortal { * @param messageId The unique identifier of the received message. * @param key The Registrar key of some value. * @param value The value. + * @param sequence The sequence of the message on the Hub. */ - event RegistrarKeyReceived(bytes32 indexed messageId, bytes32 indexed key, bytes32 value); + event RegistrarKeyReceived(bytes32 indexed messageId, bytes32 indexed key, bytes32 value, uint64 sequence); /** * @notice Emitted when the Registrar list status is received from Mainnet. @@ -32,21 +33,32 @@ interface ISpokePortal is IPortal { * @param listName The name of the list. * @param account The account. * @param status Indicates if the account is added or removed from the list. + * @param sequence The sequence of the message on the Hub. */ event RegistrarListStatusReceived( bytes32 indexed messageId, bytes32 indexed listName, address indexed account, - bool status + bool status, + uint64 sequence ); + /* ============ Custom Errors ============ */ + + /// @notice Emitted when processing Registrar Key and Registrar List Update messages + /// if an incoming message sequence is less than the last processed message sequence. + error ObsoleteMessageSequence(uint64 sequence, uint64 lastProcessedSequence); + /* ============ View/Pure Functions ============ */ /// @notice The maximum possible principal of the total bridged-in M tokens, - /// it will be used for calculations of excess of M yield in the Hub Portal. + /// it will be used for calculations of excess of M yield in the Hub Portal. function outstandingPrincipal() external view returns (uint112); /// @notice The excess of M yield in the Hub Portal contributed by the Spoke Portal, /// total Hub Portal M yield excess equals to sum of all Spoke Portal M excesses. function excess() external view returns (uint240); + + /// @notice The message sequence of the last Set Registrar Key or Update List Status message received from the Hub. + function lastProcessedSequence() external view returns (uint64); } diff --git a/src/libs/PayloadEncoder.sol b/src/libs/PayloadEncoder.sol index ff5be88..5483d1e 100644 --- a/src/libs/PayloadEncoder.sol +++ b/src/libs/PayloadEncoder.sol @@ -77,18 +77,20 @@ library PayloadEncoder { function encodeKey( bytes32 key_, bytes32 value_, + uint64 sequence_, uint16 destinationChainId_ ) internal pure returns (bytes memory encoded_) { - return abi.encodePacked(KEY_TRANSFER_PREFIX, key_, value_, destinationChainId_); + return abi.encodePacked(KEY_TRANSFER_PREFIX, key_, value_, sequence_, destinationChainId_); } function decodeKey( bytes memory payload_ - ) internal pure returns (bytes32 key_, bytes32 value_, uint16 destinationChainId_) { + ) internal pure returns (bytes32 key_, bytes32 value_, uint64 sequence_, uint16 destinationChainId_) { uint256 offset_ = PAYLOAD_PREFIX_LENGTH; (key_, offset_) = payload_.asBytes32Unchecked(offset_); (value_, offset_) = payload_.asBytes32Unchecked(offset_); + (sequence_, offset_) = payload_.asUint64Unchecked(offset_); (destinationChainId_, offset_) = payload_.asUint16Unchecked(offset_); payload_.checkLength(offset_); @@ -98,19 +100,25 @@ library PayloadEncoder { bytes32 listName_, address account_, bool add_, + uint64 sequence_, uint16 destinationChainId_ ) internal pure returns (bytes memory encoded_) { - return abi.encodePacked(LIST_UPDATE_PREFIX, listName_, account_, add_, destinationChainId_); + return abi.encodePacked(LIST_UPDATE_PREFIX, listName_, account_, add_, sequence_, destinationChainId_); } function decodeListUpdate( bytes memory payload_ - ) internal pure returns (bytes32 listName_, address account_, bool add_, uint16 destinationChainId_) { + ) + internal + pure + returns (bytes32 listName_, address account_, bool add_, uint64 sequence_, uint16 destinationChainId_) + { uint256 offset_ = PAYLOAD_PREFIX_LENGTH; (listName_, offset_) = payload_.asBytes32Unchecked(offset_); (account_, offset_) = payload_.asAddressUnchecked(offset_); (add_, offset_) = payload_.asBoolUnchecked(offset_); + (sequence_, offset_) = payload_.asUint64Unchecked(offset_); (destinationChainId_, offset_) = payload_.asUint16Unchecked(offset_); payload_.checkLength(offset_); diff --git a/test/unit/HubPortal.t.sol b/test/unit/HubPortal.t.sol index bf3a34f..baf7e2d 100644 --- a/test/unit/HubPortal.t.sol +++ b/test/unit/HubPortal.t.sol @@ -199,13 +199,14 @@ contract HubPortalTests is UnitTestBase { bytes32 key_ = bytes32("key"); bytes32 value_ = bytes32("value"); bytes32 refundAddress_ = _alice.toBytes32(); + uint64 sequence_ = _portal.nextMessageSequence(); uint256 fee_ = 1; _registrar.set(key_, value_); vm.deal(_alice, fee_); (TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage( - PayloadEncoder.encodeKey(key_, value_, _REMOTE_CHAIN_ID), + PayloadEncoder.encodeKey(key_, value_, sequence_, _REMOTE_CHAIN_ID), _LOCAL_CHAIN_ID ); @@ -238,13 +239,14 @@ contract HubPortalTests is UnitTestBase { bool status_ = true; address account_ = _bob; bytes32 refundAddress_ = _alice.toBytes32(); + uint64 sequence_ = _portal.nextMessageSequence(); uint256 fee_ = 1; vm.deal(_alice, fee_); _registrar.setListContains(listName_, account_, status_); (TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage( - PayloadEncoder.encodeListUpdate(listName_, account_, status_, _REMOTE_CHAIN_ID), + PayloadEncoder.encodeListUpdate(listName_, account_, status_, sequence_, _REMOTE_CHAIN_ID), _LOCAL_CHAIN_ID ); diff --git a/test/unit/SpokePortal.t.sol b/test/unit/SpokePortal.t.sol index d67189c..b9ca272 100644 --- a/test/unit/SpokePortal.t.sol +++ b/test/unit/SpokePortal.t.sol @@ -122,60 +122,174 @@ contract SpokePortalTests is UnitTestBase { /* ============ _setRegistrarKey ============ */ - function test_setRegistrarKey() external { + function test_setRegistrarKey_sequenceZero() external { bytes32 key_ = bytes32("key"); bytes32 value_ = bytes32("value"); + uint64 sequence_ = 0; + + assertEq(_portal.lastProcessedSequence(), 0); (TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage( - PayloadEncoder.encodeKey(key_, value_, _LOCAL_CHAIN_ID), + PayloadEncoder.encodeKey(key_, value_, sequence_, _LOCAL_CHAIN_ID), _REMOTE_CHAIN_ID ); vm.expectEmit(); - emit ISpokePortal.RegistrarKeyReceived(messageId_, key_, value_); + emit ISpokePortal.RegistrarKeyReceived(messageId_, key_, value_, sequence_); vm.expectCall(address(_registrar), abi.encodeCall(_registrar.setKey, (key_, value_))); vm.prank(address(_transceiver)); _portal.attestationReceived(_REMOTE_CHAIN_ID, _PEER, message_); + + assertEq(_portal.lastProcessedSequence(), sequence_); + } + + function test_setRegistrarKey_sequenceHigher() external { + bytes32 key_ = bytes32("key"); + bytes32 value_ = bytes32("value"); + uint64 sequence_ = 1; + + assertEq(_portal.lastProcessedSequence(), 0); + + (TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage( + PayloadEncoder.encodeKey(key_, value_, sequence_, _LOCAL_CHAIN_ID), + _REMOTE_CHAIN_ID + ); + + vm.expectEmit(); + emit ISpokePortal.RegistrarKeyReceived(messageId_, key_, value_, sequence_); + + vm.expectCall(address(_registrar), abi.encodeCall(_registrar.setKey, (key_, value_))); + + vm.prank(address(_transceiver)); + _portal.attestationReceived(_REMOTE_CHAIN_ID, _PEER, message_); + + assertEq(_portal.lastProcessedSequence(), sequence_); + } + + function test_setRegistrarKey_sequenceLower() external { + bytes32 key_ = bytes32("key"); + bytes32 value_ = bytes32("value"); + uint64 sequence_ = 1; + + (TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage( + PayloadEncoder.encodeKey(key_, value_, sequence_, _LOCAL_CHAIN_ID), + _REMOTE_CHAIN_ID + ); + + vm.prank(address(_transceiver)); + _portal.attestationReceived(_REMOTE_CHAIN_ID, _PEER, message_); + + assertEq(_portal.lastProcessedSequence(), sequence_); + + // sequence < lastProcessedSequence + sequence_ = 0; + value_ = bytes32("old_value"); + + (message_, messageId_) = _createMessage( + PayloadEncoder.encodeKey(key_, value_, sequence_, _LOCAL_CHAIN_ID), + _REMOTE_CHAIN_ID + ); + + vm.expectRevert( + abi.encodeWithSelector( + ISpokePortal.ObsoleteMessageSequence.selector, + sequence_, + _portal.lastProcessedSequence() + ) + ); + + vm.prank(address(_transceiver)); + _portal.attestationReceived(_REMOTE_CHAIN_ID, _PEER, message_); } /* ============ setRegistrarListStatus ============ */ - function test_setRegistrarListStatus_addToList() external { + function test_setRegistrarListStatus_addToList_sequenceZero() external { bytes32 listName_ = bytes32("listName"); bool status_ = true; + uint64 sequence_ = 0; + + assertEq(_portal.lastProcessedSequence(), 0); (TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage( - PayloadEncoder.encodeListUpdate(listName_, _bob, status_, _LOCAL_CHAIN_ID), + PayloadEncoder.encodeListUpdate(listName_, _bob, status_, sequence_, _LOCAL_CHAIN_ID), _REMOTE_CHAIN_ID ); vm.expectEmit(); - emit ISpokePortal.RegistrarListStatusReceived(messageId_, listName_, _bob, status_); + emit ISpokePortal.RegistrarListStatusReceived(messageId_, listName_, _bob, status_, sequence_); vm.expectCall(address(_registrar), abi.encodeCall(_registrar.addToList, (listName_, _bob))); vm.prank(address(_transceiver)); _portal.attestationReceived(_REMOTE_CHAIN_ID, _PEER, message_); + + assertEq(_portal.lastProcessedSequence(), 0); } - function test_setRegistrarListStatus_removeFromList() external { + function test_setRegistrarListStatus_removeFromList_sequenceHigher() external { bytes32 listName_ = bytes32("listName"); bool status_ = false; + uint64 sequence_ = 1; + + assertEq(_portal.lastProcessedSequence(), 0); (TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage( - PayloadEncoder.encodeListUpdate(listName_, _bob, status_, _LOCAL_CHAIN_ID), + PayloadEncoder.encodeListUpdate(listName_, _bob, status_, sequence_, _LOCAL_CHAIN_ID), _REMOTE_CHAIN_ID ); vm.expectEmit(); - emit ISpokePortal.RegistrarListStatusReceived(messageId_, listName_, _bob, status_); + emit ISpokePortal.RegistrarListStatusReceived(messageId_, listName_, _bob, status_, sequence_); vm.expectCall(address(_registrar), abi.encodeCall(_registrar.removeFromList, (listName_, _bob))); vm.prank(address(_transceiver)); _portal.attestationReceived(_REMOTE_CHAIN_ID, _PEER, message_); + + assertEq(_portal.lastProcessedSequence(), sequence_); + } + + function test_setRegistrarListStatus_removeFromList_sequenceLower() external { + bytes32 listName_ = bytes32("listName"); + bool status_ = true; + uint64 sequence_ = 1; + + // sequence > lastProcessedSequence + assertEq(_portal.lastProcessedSequence(), 0); + + (TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage( + PayloadEncoder.encodeListUpdate(listName_, _bob, status_, sequence_, _LOCAL_CHAIN_ID), + _REMOTE_CHAIN_ID + ); + + vm.prank(address(_transceiver)); + _portal.attestationReceived(_REMOTE_CHAIN_ID, _PEER, message_); + + // lastProcessedSequence updated + assertEq(_portal.lastProcessedSequence(), 1); + + // sequence < lastProcessedSequence + status_ = false; + sequence_ = 0; + + (message_, messageId_) = _createMessage( + PayloadEncoder.encodeListUpdate(listName_, _bob, status_, sequence_, _LOCAL_CHAIN_ID), + _REMOTE_CHAIN_ID + ); + + vm.expectRevert( + abi.encodeWithSelector( + ISpokePortal.ObsoleteMessageSequence.selector, + sequence_, + _portal.lastProcessedSequence() + ) + ); + + vm.prank(address(_transceiver)); + _portal.attestationReceived(_REMOTE_CHAIN_ID, _PEER, message_); } /* ============ transfer ============ */ diff --git a/test/unit/libs/PayloadEncoder.t.sol b/test/unit/libs/PayloadEncoder.t.sol index dd2c986..6214aea 100644 --- a/test/unit/libs/PayloadEncoder.t.sol +++ b/test/unit/libs/PayloadEncoder.t.sol @@ -151,32 +151,40 @@ contract PayloadEncoderTest is Test { function test_encodeKey() external { bytes32 key_ = "key"; bytes32 value_ = "value"; + uint64 sequence_ = 2; bytes memory payload_ = abi.encodePacked( PayloadEncoder.KEY_TRANSFER_PREFIX, key_, value_, + sequence_, _DESTINATION_CHAIN_ID ); - assertEq(PayloadEncoder.encodeKey(key_, value_, _DESTINATION_CHAIN_ID), payload_); + assertEq(PayloadEncoder.encodeKey(key_, value_, sequence_, _DESTINATION_CHAIN_ID), payload_); } function test_decodeKey() external { bytes32 encodedKey_ = "key"; bytes32 encodedValue_ = "value"; + uint64 encodedSequence_ = 2; bytes memory payload_ = abi.encodePacked( PayloadEncoder.KEY_TRANSFER_PREFIX, encodedKey_, encodedValue_, + encodedSequence_, _DESTINATION_CHAIN_ID ); - (bytes32 decodedKey_, bytes32 decodedValue_, uint16 decodedDestinationChainId_) = PayloadEncoder.decodeKey( - payload_ - ); + ( + bytes32 decodedKey_, + bytes32 decodedValue_, + uint64 decodedSequence_, + uint16 decodedDestinationChainId_ + ) = PayloadEncoder.decodeKey(payload_); assertEq(decodedKey_, encodedKey_); assertEq(decodedValue_, encodedValue_); + assertEq(decodedSequence_, encodedSequence_); assertEq(decodedDestinationChainId_, _DESTINATION_CHAIN_ID); } @@ -184,26 +192,33 @@ contract PayloadEncoderTest is Test { bytes32 listName_ = "list"; address account_ = makeAddr("account"); bool add_ = true; + uint64 sequence_ = 3; bytes memory payload_ = abi.encodePacked( PayloadEncoder.LIST_UPDATE_PREFIX, listName_, account_, add_, + sequence_, _DESTINATION_CHAIN_ID ); - assertEq(PayloadEncoder.encodeListUpdate(listName_, account_, add_, _DESTINATION_CHAIN_ID), payload_); + assertEq( + PayloadEncoder.encodeListUpdate(listName_, account_, add_, sequence_, _DESTINATION_CHAIN_ID), + payload_ + ); } function test_decodeListUpdate() external { bytes32 encodedListName_ = "list"; address encodedAccount_ = makeAddr("account"); bool encodedStatus_ = true; + uint64 encodedSequence_ = 3; bytes memory payload_ = abi.encodePacked( PayloadEncoder.LIST_UPDATE_PREFIX, encodedListName_, encodedAccount_, encodedStatus_, + encodedSequence_, _DESTINATION_CHAIN_ID ); @@ -211,12 +226,14 @@ contract PayloadEncoderTest is Test { bytes32 decodedListName_, address decodedAccount_, bool decodedStatus_, + uint64 decodedSequence_, uint16 decodedDestinationChainId_ ) = PayloadEncoder.decodeListUpdate(payload_); assertEq(decodedListName_, encodedListName_); assertEq(decodedAccount_, encodedAccount_); assertEq(decodedStatus_, encodedStatus_); + assertEq(decodedSequence_, encodedSequence_); assertEq(decodedDestinationChainId_, _DESTINATION_CHAIN_ID); } }