Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: registrar key/list update race condition #22

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/HubPortal.sol
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ contract HubPortal is IHubPortal, Portal {
) external payable returns (bytes32 messageId_) {
uint128 index_ = _currentIndex();
bytes memory payload_ = PayloadEncoder.encodeIndex(index_, destinationChainId_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, _useMessageSequence(), payload_);

emit MTokenIndexSent(destinationChainId_, messageId_, index_);
}
Expand All @@ -69,8 +69,9 @@ contract HubPortal is IHubPortal, Portal {
bytes32 refundAddress_
) external payable returns (bytes32 messageId_) {
bytes32 value_ = IRegistrarLike(registrar).get(key_);
bytes memory payload_ = PayloadEncoder.encodeKey(key_, value_, destinationChainId_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_);
uint64 sequence_ = _useMessageSequence();
bytes memory payload_ = PayloadEncoder.encodeKey(key_, value_, sequence_, destinationChainId_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, sequence_, payload_);

emit RegistrarKeySent(destinationChainId_, messageId_, key_, value_);
}
Expand All @@ -83,8 +84,15 @@ contract HubPortal is IHubPortal, Portal {
bytes32 refundAddress_
) external payable returns (bytes32 messageId_) {
bool status_ = IRegistrarLike(registrar).listContains(listName_, account_);
bytes memory payload_ = PayloadEncoder.encodeListUpdate(listName_, account_, status_, destinationChainId_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_);
uint64 sequence_ = _useMessageSequence();
bytes memory payload_ = PayloadEncoder.encodeListUpdate(
listName_,
account_,
status_,
sequence_,
destinationChainId_
);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, sequence_, payload_);

emit RegistrarListStatusSent(destinationChainId_, messageId_, listName_, account_, status_);
}
Expand Down Expand Up @@ -137,6 +145,7 @@ contract HubPortal is IHubPortal, Portal {
function _sendMessage(
uint16 destinationChainId_,
bytes32 refundAddress_,
uint64 _sequence,
bytes memory payload_
) private returns (bytes32 messageId_) {
if (refundAddress_ == bytes32(0)) revert InvalidRefundAddress();
Expand All @@ -149,7 +158,7 @@ contract HubPortal is IHubPortal, Portal {
) = _prepareForTransfer(destinationChainId_, DEFAULT_TRANSCEIVER_INSTRUCTIONS);

TransceiverStructs.NttManagerMessage memory message_ = TransceiverStructs.NttManagerMessage(
bytes32(uint256(_useMessageSequence())),
bytes32(uint256(_sequence)),
msg.sender.toBytes32(),
payload_
);
Expand Down
30 changes: 25 additions & 5 deletions src/SpokePortal.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ contract SpokePortal is ISpokePortal, Portal {
/// @inheritdoc ISpokePortal
uint112 public outstandingPrincipal;

/// @inheritdoc ISpokePortal
uint64 public lastProcessedSequence;

/**
* @notice Constructs the contract.
* @param mToken_ The address of the M token to bridge.
Expand All @@ -48,7 +51,7 @@ contract SpokePortal is ISpokePortal, Portal {
}
}

/* ============ Internal Interactive Functions ============ */
/* ============ Internal/Private Interactive Functions ============ */

function _receiveCustomPayload(
bytes32 messageId_,
Expand Down Expand Up @@ -79,22 +82,31 @@ contract SpokePortal is ISpokePortal, Portal {

/// @notice Sets a Registrar key received from the Hub chain.
function _setRegistrarKey(bytes32 messageId_, bytes memory payload_) private {
(bytes32 key_, bytes32 value_, uint16 destinationChainId_) = payload_.decodeKey();
(bytes32 key_, bytes32 value_, uint64 sequence_, uint16 destinationChainId_) = payload_.decodeKey();

_verifyDestinationChain(destinationChainId_);

emit RegistrarKeyReceived(messageId_, key_, value_);
emit RegistrarKeyReceived(messageId_, key_, value_, sequence_);

_verifyMessageSequence(sequence_);

lastProcessedSequence = sequence_;

IRegistrarLike(registrar).setKey(key_, value_);
}

/// @notice Adds or removes an account from the Registrar List based on the message from the Hub chain.
function _updateRegistrarList(bytes32 messageId_, bytes memory payload_) private {
(bytes32 listName_, address account_, bool add_, uint16 destinationChainId_) = payload_.decodeListUpdate();
(bytes32 listName_, address account_, bool add_, uint64 sequence_, uint16 destinationChainId_) = payload_
.decodeListUpdate();

_verifyDestinationChain(destinationChainId_);

emit RegistrarListStatusReceived(messageId_, listName_, account_, add_);
emit RegistrarListStatusReceived(messageId_, listName_, account_, add_, sequence_);

_verifyMessageSequence(sequence_);

lastProcessedSequence = sequence_;

if (add_) {
IRegistrarLike(registrar).addToList(listName_, account_);
Expand Down Expand Up @@ -133,6 +145,14 @@ contract SpokePortal is ISpokePortal, Portal {
}
}

/// @dev Checks if the incoming message sequence is greater than the last processed one to prevent
/// Registrar data from being overwritten due to message reordering.
function _verifyMessageSequence(uint64 sequence_) private view {
uint64 lastProcessedSequence_ = lastProcessedSequence;
if (lastProcessedSequence_ != 0 && sequence_ < lastProcessedSequence_)
revert ObsoleteMessageSequence(sequence_, lastProcessedSequence_);
}

/// @dev Returns the current M token index used by the Spoke Portal.
function _currentIndex() internal view override returns (uint128) {
return ISpokeMTokenLike(mToken()).currentIndex();
Expand Down
18 changes: 15 additions & 3 deletions src/interfaces/ISpokePortal.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,42 @@ interface ISpokePortal is IPortal {
* @param messageId The unique identifier of the received message.
* @param key The Registrar key of some value.
* @param value The value.
* @param sequence The sequence of the message on the Hub.
*/
event RegistrarKeyReceived(bytes32 indexed messageId, bytes32 indexed key, bytes32 value);
event RegistrarKeyReceived(bytes32 indexed messageId, bytes32 indexed key, bytes32 value, uint64 sequence);

/**
* @notice Emitted when the Registrar list status is received from Mainnet.
* @param messageId The unique identifier of the received message.
* @param listName The name of the list.
* @param account The account.
* @param status Indicates if the account is added or removed from the list.
* @param sequence The sequence of the message on the Hub.
*/
event RegistrarListStatusReceived(
bytes32 indexed messageId,
bytes32 indexed listName,
address indexed account,
bool status
bool status,
uint64 sequence
);

/* ============ Custom Errors ============ */

/// @notice Emitted when processing Registrar Key and Registrar List Update messages
/// if an incoming message sequence is less than the last processed message sequence.
error ObsoleteMessageSequence(uint64 sequence, uint64 lastProcessedSequence);

/* ============ View/Pure Functions ============ */

/// @notice The maximum possible principal of the total bridged-in M tokens,
/// it will be used for calculations of excess of M yield in the Hub Portal.
/// it will be used for calculations of excess of M yield in the Hub Portal.
function outstandingPrincipal() external view returns (uint112);

/// @notice The excess of M yield in the Hub Portal contributed by the Spoke Portal,
/// total Hub Portal M yield excess equals to sum of all Spoke Portal M excesses.
function excess() external view returns (uint240);

/// @notice The message sequence of the last Set Registrar Key or Update List Status message received from the Hub.
function lastProcessedSequence() external view returns (uint64);
}
16 changes: 12 additions & 4 deletions src/libs/PayloadEncoder.sol
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,20 @@ library PayloadEncoder {
function encodeKey(
bytes32 key_,
bytes32 value_,
uint64 sequence_,
uint16 destinationChainId_
) internal pure returns (bytes memory encoded_) {
return abi.encodePacked(KEY_TRANSFER_PREFIX, key_, value_, destinationChainId_);
return abi.encodePacked(KEY_TRANSFER_PREFIX, key_, value_, sequence_, destinationChainId_);
}

function decodeKey(
bytes memory payload_
) internal pure returns (bytes32 key_, bytes32 value_, uint16 destinationChainId_) {
) internal pure returns (bytes32 key_, bytes32 value_, uint64 sequence_, uint16 destinationChainId_) {
uint256 offset_ = PAYLOAD_PREFIX_LENGTH;

(key_, offset_) = payload_.asBytes32Unchecked(offset_);
(value_, offset_) = payload_.asBytes32Unchecked(offset_);
(sequence_, offset_) = payload_.asUint64Unchecked(offset_);
(destinationChainId_, offset_) = payload_.asUint16Unchecked(offset_);

payload_.checkLength(offset_);
Expand All @@ -98,19 +100,25 @@ library PayloadEncoder {
bytes32 listName_,
address account_,
bool add_,
uint64 sequence_,
uint16 destinationChainId_
) internal pure returns (bytes memory encoded_) {
return abi.encodePacked(LIST_UPDATE_PREFIX, listName_, account_, add_, destinationChainId_);
return abi.encodePacked(LIST_UPDATE_PREFIX, listName_, account_, add_, sequence_, destinationChainId_);
}

function decodeListUpdate(
bytes memory payload_
) internal pure returns (bytes32 listName_, address account_, bool add_, uint16 destinationChainId_) {
)
internal
pure
returns (bytes32 listName_, address account_, bool add_, uint64 sequence_, uint16 destinationChainId_)
{
uint256 offset_ = PAYLOAD_PREFIX_LENGTH;

(listName_, offset_) = payload_.asBytes32Unchecked(offset_);
(account_, offset_) = payload_.asAddressUnchecked(offset_);
(add_, offset_) = payload_.asBoolUnchecked(offset_);
(sequence_, offset_) = payload_.asUint64Unchecked(offset_);
(destinationChainId_, offset_) = payload_.asUint16Unchecked(offset_);

payload_.checkLength(offset_);
Expand Down
6 changes: 4 additions & 2 deletions test/unit/HubPortal.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,14 @@ contract HubPortalTests is UnitTestBase {
bytes32 key_ = bytes32("key");
bytes32 value_ = bytes32("value");
bytes32 refundAddress_ = _alice.toBytes32();
uint64 sequence_ = _portal.nextMessageSequence();
uint256 fee_ = 1;

_registrar.set(key_, value_);
vm.deal(_alice, fee_);

(TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage(
PayloadEncoder.encodeKey(key_, value_, _REMOTE_CHAIN_ID),
PayloadEncoder.encodeKey(key_, value_, sequence_, _REMOTE_CHAIN_ID),
_LOCAL_CHAIN_ID
);

Expand Down Expand Up @@ -238,13 +239,14 @@ contract HubPortalTests is UnitTestBase {
bool status_ = true;
address account_ = _bob;
bytes32 refundAddress_ = _alice.toBytes32();
uint64 sequence_ = _portal.nextMessageSequence();
uint256 fee_ = 1;

vm.deal(_alice, fee_);
_registrar.setListContains(listName_, account_, status_);

(TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage(
PayloadEncoder.encodeListUpdate(listName_, account_, status_, _REMOTE_CHAIN_ID),
PayloadEncoder.encodeListUpdate(listName_, account_, status_, sequence_, _REMOTE_CHAIN_ID),
_LOCAL_CHAIN_ID
);

Expand Down
Loading
Loading