Skip to content

Commit

Permalink
fix: registrar key/list update race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
0xIryna committed Nov 20, 2024
1 parent b8b570f commit b3ea906
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 37 deletions.
21 changes: 15 additions & 6 deletions src/HubPortal.sol
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ contract HubPortal is IHubPortal, Portal {
) external payable returns (bytes32 messageId_) {
uint128 index_ = _currentIndex();
bytes memory payload_ = PayloadEncoder.encodeIndex(index_, destinationChainId_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, _useMessageSequence(), payload_);

emit MTokenIndexSent(destinationChainId_, messageId_, index_);
}
Expand All @@ -69,8 +69,9 @@ contract HubPortal is IHubPortal, Portal {
bytes32 refundAddress_
) external payable returns (bytes32 messageId_) {
bytes32 value_ = IRegistrarLike(registrar).get(key_);
bytes memory payload_ = PayloadEncoder.encodeKey(key_, value_, destinationChainId_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_);
uint64 sequence_ = _useMessageSequence();
bytes memory payload_ = PayloadEncoder.encodeKey(key_, value_, sequence_, destinationChainId_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, sequence_, payload_);

emit RegistrarKeySent(destinationChainId_, messageId_, key_, value_);
}
Expand All @@ -83,8 +84,15 @@ contract HubPortal is IHubPortal, Portal {
bytes32 refundAddress_
) external payable returns (bytes32 messageId_) {
bool status_ = IRegistrarLike(registrar).listContains(listName_, account_);
bytes memory payload_ = PayloadEncoder.encodeListUpdate(listName_, account_, status_, destinationChainId_);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, payload_);
uint64 sequence_ = _useMessageSequence();
bytes memory payload_ = PayloadEncoder.encodeListUpdate(
listName_,
account_,
status_,
sequence_,
destinationChainId_
);
messageId_ = _sendMessage(destinationChainId_, refundAddress_, sequence_, payload_);

emit RegistrarListStatusSent(destinationChainId_, messageId_, listName_, account_, status_);
}
Expand Down Expand Up @@ -137,6 +145,7 @@ contract HubPortal is IHubPortal, Portal {
function _sendMessage(
uint16 destinationChainId_,
bytes32 refundAddress_,
uint64 _sequence,
bytes memory payload_
) private returns (bytes32 messageId_) {
if (refundAddress_ == bytes32(0)) revert InvalidRefundAddress();
Expand All @@ -149,7 +158,7 @@ contract HubPortal is IHubPortal, Portal {
) = _prepareForTransfer(destinationChainId_, DEFAULT_TRANSCEIVER_INSTRUCTIONS);

TransceiverStructs.NttManagerMessage memory message_ = TransceiverStructs.NttManagerMessage(
bytes32(uint256(_useMessageSequence())),
bytes32(uint256(_sequence)),
msg.sender.toBytes32(),
payload_
);
Expand Down
39 changes: 30 additions & 9 deletions src/SpokePortal.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ contract SpokePortal is ISpokePortal, Portal {
/// @inheritdoc ISpokePortal
uint112 public outstandingPrincipal;

/// @dev The message sequence of the latest Set Registrar Key message received from the Hub
uint64 public lastSetKeySequence;

/// @dev The message sequence of the latest Update List Status message received from the Hub
uint64 public lastUpdateListSequence;

/**
* @notice Constructs the contract.
* @param mToken_ The address of the M token to bridge.
Expand Down Expand Up @@ -79,27 +85,42 @@ contract SpokePortal is ISpokePortal, Portal {

/// @notice Sets a Registrar key received from the Hub chain.
function _setRegistrarKey(bytes32 messageId_, bytes memory payload_) private {
(bytes32 key_, bytes32 value_, uint16 destinationChainId_) = payload_.decodeKey();
(bytes32 key_, bytes32 value_, uint64 sequence_, uint16 destinationChainId_) = payload_.decodeKey();

_verifyDestinationChain(destinationChainId_);

emit RegistrarKeyReceived(messageId_, key_, value_);
emit RegistrarKeyReceived(messageId_, key_, value_, sequence_);

uint64 lastSetKeySequence_ = lastSetKeySequence;

IRegistrarLike(registrar).setKey(key_, value_);
// Update the key only if the incoming message has the higher sequence or is the fist message
// to prevent the race condition
if (lastSetKeySequence_ == 0 || sequence_ > lastSetKeySequence_) {
IRegistrarLike(registrar).setKey(key_, value_);
lastSetKeySequence = sequence_;
}
}

/// @notice Adds or removes an account from the Registrar List based on the message from the Hub chain.
function _updateRegistrarList(bytes32 messageId_, bytes memory payload_) private {
(bytes32 listName_, address account_, bool add_, uint16 destinationChainId_) = payload_.decodeListUpdate();
(bytes32 listName_, address account_, bool add_, uint64 sequence_, uint16 destinationChainId_) = payload_
.decodeListUpdate();

_verifyDestinationChain(destinationChainId_);

emit RegistrarListStatusReceived(messageId_, listName_, account_, add_);
emit RegistrarListStatusReceived(messageId_, listName_, account_, add_, sequence_);

if (add_) {
IRegistrarLike(registrar).addToList(listName_, account_);
} else {
IRegistrarLike(registrar).removeFromList(listName_, account_);
uint64 lastUpdateListSequence_ = lastUpdateListSequence;
// Update the status only if the incoming message has the higher sequence or is the fist message
// to prevent the race condition
if (lastUpdateListSequence_ == 0 || sequence_ > lastUpdateListSequence_) {
if (add_) {
IRegistrarLike(registrar).addToList(listName_, account_);
} else {
IRegistrarLike(registrar).removeFromList(listName_, account_);
}

lastUpdateListSequence = sequence_;
}
}

Expand Down
7 changes: 5 additions & 2 deletions src/interfaces/ISpokePortal.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,24 @@ interface ISpokePortal is IPortal {
* @param messageId The unique identifier of the received message.
* @param key The Registrar key of some value.
* @param value The value.
* @param sequence The sequence of the message on the Hub.
*/
event RegistrarKeyReceived(bytes32 indexed messageId, bytes32 indexed key, bytes32 value);
event RegistrarKeyReceived(bytes32 indexed messageId, bytes32 indexed key, bytes32 value, uint64 sequence);

/**
* @notice Emitted when the Registrar list status is received from Mainnet.
* @param messageId The unique identifier of the received message.
* @param listName The name of the list.
* @param account The account.
* @param status Indicates if the account is added or removed from the list.
* @param sequence The sequence of the message on the Hub.
*/
event RegistrarListStatusReceived(
bytes32 indexed messageId,
bytes32 indexed listName,
address indexed account,
bool status
bool status,
uint64 sequence
);

/* ============ View/Pure Functions ============ */
Expand Down
16 changes: 12 additions & 4 deletions src/libs/PayloadEncoder.sol
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,20 @@ library PayloadEncoder {
function encodeKey(
bytes32 key_,
bytes32 value_,
uint64 sequence_,
uint16 destinationChainId_
) internal pure returns (bytes memory encoded_) {
return abi.encodePacked(KEY_TRANSFER_PREFIX, key_, value_, destinationChainId_);
return abi.encodePacked(KEY_TRANSFER_PREFIX, key_, value_, sequence_, destinationChainId_);
}

function decodeKey(
bytes memory payload_
) internal pure returns (bytes32 key_, bytes32 value_, uint16 destinationChainId_) {
) internal pure returns (bytes32 key_, bytes32 value_, uint64 sequence_, uint16 destinationChainId_) {
uint256 offset_ = PAYLOAD_PREFIX_LENGTH;

(key_, offset_) = payload_.asBytes32Unchecked(offset_);
(value_, offset_) = payload_.asBytes32Unchecked(offset_);
(sequence_, offset_) = payload_.asUint64Unchecked(offset_);
(destinationChainId_, offset_) = payload_.asUint16Unchecked(offset_);

payload_.checkLength(offset_);
Expand All @@ -98,19 +100,25 @@ library PayloadEncoder {
bytes32 listName_,
address account_,
bool add_,
uint64 sequence_,
uint16 destinationChainId_
) internal pure returns (bytes memory encoded_) {
return abi.encodePacked(LIST_UPDATE_PREFIX, listName_, account_, add_, destinationChainId_);
return abi.encodePacked(LIST_UPDATE_PREFIX, listName_, account_, add_, sequence_, destinationChainId_);
}

function decodeListUpdate(
bytes memory payload_
) internal pure returns (bytes32 listName_, address account_, bool add_, uint16 destinationChainId_) {
)
internal
pure
returns (bytes32 listName_, address account_, bool add_, uint64 sequence_, uint16 destinationChainId_)
{
uint256 offset_ = PAYLOAD_PREFIX_LENGTH;

(listName_, offset_) = payload_.asBytes32Unchecked(offset_);
(account_, offset_) = payload_.asAddressUnchecked(offset_);
(add_, offset_) = payload_.asBoolUnchecked(offset_);
(sequence_, offset_) = payload_.asUint64Unchecked(offset_);
(destinationChainId_, offset_) = payload_.asUint16Unchecked(offset_);

payload_.checkLength(offset_);
Expand Down
6 changes: 4 additions & 2 deletions test/unit/HubPortal.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,14 @@ contract HubPortalTests is UnitTestBase {
bytes32 key_ = bytes32("key");
bytes32 value_ = bytes32("value");
bytes32 refundAddress_ = _alice.toBytes32();
uint64 sequence_ = _portal.nextMessageSequence();
uint256 fee_ = 1;

_registrar.set(key_, value_);
vm.deal(_alice, fee_);

(TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage(
PayloadEncoder.encodeKey(key_, value_, _REMOTE_CHAIN_ID),
PayloadEncoder.encodeKey(key_, value_, sequence_, _REMOTE_CHAIN_ID),
_LOCAL_CHAIN_ID
);

Expand Down Expand Up @@ -238,13 +239,14 @@ contract HubPortalTests is UnitTestBase {
bool status_ = true;
address account_ = _bob;
bytes32 refundAddress_ = _alice.toBytes32();
uint64 sequence_ = _portal.nextMessageSequence();
uint256 fee_ = 1;

vm.deal(_alice, fee_);
_registrar.setListContains(listName_, account_, status_);

(TransceiverStructs.NttManagerMessage memory message_, bytes32 messageId_) = _createMessage(
PayloadEncoder.encodeListUpdate(listName_, account_, status_, _REMOTE_CHAIN_ID),
PayloadEncoder.encodeListUpdate(listName_, account_, status_, sequence_, _REMOTE_CHAIN_ID),
_LOCAL_CHAIN_ID
);

Expand Down
Loading

0 comments on commit b3ea906

Please sign in to comment.