diff --git a/contracts/Arbitrator.sol b/contracts/Arbitrator.sol index 410866e..f955daa 100644 --- a/contracts/Arbitrator.sol +++ b/contracts/Arbitrator.sol @@ -28,8 +28,8 @@ contract Arbitrator is IArbitrator, OwnableUpgradeable, UUPSUpgradeable, Reentra mapping(IL1Gateway => DoubleEndedQueueUpgradeable.Bytes32Deque) public secondaryChainMessageHashQueues; /// @notice List of permitted relayers mapping(address relayerAddress => bool isRelayer) public relayers; - /// @dev The forward params are used to forward a l2 message from source chain to target chains - bytes private forwardParams; + /// @dev A transient storage value for forwarding message from source chain to target chains + bytes32 private finalizeMessageHash; /** * @dev This empty reserved space is put in place to allow future versions to add new * variables without shifting down storage in the inheritance chain. @@ -143,13 +143,13 @@ contract Arbitrator is IArbitrator, OwnableUpgradeable, UUPSUpgradeable, Reentra function enqueueMessage(uint256 _value, bytes calldata _callData) external payable { require(msg.value == _value, "Invalid msg value"); // store message hash for forwarding - bytes32 finalizeMessageHash = keccak256(abi.encode(_value, _callData)); + bytes32 _finalizeMessageHash = keccak256(abi.encode(_value, _callData)); IL1Gateway gateway = IL1Gateway(msg.sender); if (gateway == primaryChainGateway) { - primaryChainMessageHashQueue.pushBack(finalizeMessageHash); + primaryChainMessageHashQueue.pushBack(_finalizeMessageHash); } else { require(secondaryChainGateways[gateway], "Not secondary chain gateway"); - secondaryChainMessageHashQueues[gateway].pushBack(finalizeMessageHash); + secondaryChainMessageHashQueues[gateway].pushBack(_finalizeMessageHash); } emit MessageReceived(_value, _callData); } @@ -157,48 +157,12 @@ contract Arbitrator is IArbitrator, OwnableUpgradeable, UUPSUpgradeable, Reentra /// @dev This function is called within the `claimMessageCallback` of L1 gateway function receiveMessage(uint256 _value, bytes calldata _callData) external payable { require(msg.value == _value, "Invalid msg value"); - IL1Gateway sourceGateway = IL1Gateway(msg.sender); - // `forwardParams` is set in `claimMessage` - bytes memory _forwardParams; + // temporary store message hash for forwarding + IL1Gateway gateway = IL1Gateway(msg.sender); + require(gateway == primaryChainGateway || secondaryChainGateways[gateway], "Invalid gateway"); + bytes32 _finalizeMessageHash = keccak256(abi.encode(msg.sender, _value, _callData)); assembly { - _forwardParams := tload(forwardParams.slot) - } - // Ensure the caller is L1 gateway - if (sourceGateway == primaryChainGateway) { - // Unpack destination chain and final callData - bytes[] memory gatewayDataList = abi.decode(_callData, (bytes[])); - bytes[] memory gatewayForwardParamsList = abi.decode(_forwardParams, (bytes[])); - uint256 gatewayLength = gatewayDataList.length; - require(gatewayLength == gatewayForwardParamsList.length, "Invalid forward params length"); - unchecked { - for (uint256 i = 0; i < gatewayLength; ++i) { - bytes memory gatewayData = gatewayDataList[i]; - bytes memory gatewayForwardParams = gatewayForwardParamsList[i]; - (IL1Gateway targetGateway, uint256 targetCallValue, bytes memory targetCallData) = abi.decode( - gatewayData, - (IL1Gateway, uint256, bytes) - ); - require(secondaryChainGateways[targetGateway], "Invalid secondary chain gateway"); - (uint256 sendMsgFee, bytes memory adapterParams) = abi.decode( - gatewayForwardParams, - (uint256, bytes) - ); - // Forward fee to send message - targetGateway.sendMessage{value: sendMsgFee + targetCallValue}( - targetCallValue, - targetCallData, - adapterParams - ); - emit MessageForwarded(targetGateway, targetCallValue, targetCallData); - } - } - } else { - require(secondaryChainGateways[sourceGateway], "Not secondary chain gateway"); - (uint256 sendMsgFee, bytes memory adapterParams) = abi.decode(_forwardParams, (uint256, bytes)); - // Forward fee to send message - IL1Gateway targetGateway = primaryChainGateway; - targetGateway.sendMessage{value: sendMsgFee + _value}(_value, _callData, adapterParams); - emit MessageForwarded(targetGateway, _value, _callData); + tstore(finalizeMessageHash.slot, _finalizeMessageHash) } } @@ -208,9 +172,9 @@ contract Arbitrator is IArbitrator, OwnableUpgradeable, UUPSUpgradeable, Reentra bytes calldata _callData, bytes calldata _adapterParams ) external payable nonReentrant onlyRelayer { - bytes32 finalizeMessageHash = keccak256(abi.encode(_value, _callData)); + bytes32 _finalizeMessageHash = keccak256(abi.encode(_value, _callData)); if (_gateway == primaryChainGateway) { - require(finalizeMessageHash == primaryChainMessageHashQueue.popFront(), "Invalid finalize message hash"); + require(_finalizeMessageHash == primaryChainMessageHashQueue.popFront(), "Invalid finalize message hash"); // Unpack destination chain and final callData (IL1Gateway secondaryChainGateway, bytes memory finalCallData) = abi.decode(_callData, (IL1Gateway, bytes)); require(secondaryChainGateways[secondaryChainGateway], "Invalid secondary chain gateway"); @@ -219,7 +183,7 @@ contract Arbitrator is IArbitrator, OwnableUpgradeable, UUPSUpgradeable, Reentra } else { require(secondaryChainGateways[_gateway], "Not secondary chain gateway"); require( - finalizeMessageHash == secondaryChainMessageHashQueues[_gateway].popFront(), + _finalizeMessageHash == secondaryChainMessageHashQueues[_gateway].popFront(), "Invalid finalize message hash" ); // Forward fee to send message @@ -231,16 +195,72 @@ contract Arbitrator is IArbitrator, OwnableUpgradeable, UUPSUpgradeable, Reentra function claimMessage( address _sourceChainCanonicalMessageService, bytes calldata _sourceChainClaimCallData, - bytes memory _forwardParams + IL1Gateway _sourceChainL1Gateway, + uint256 _receiveValue, + bytes calldata _receiveCallData, + bytes calldata _forwardParams ) external payable nonReentrant onlyRelayer { - // The `forwardParams` will be cleared after tx executed - assembly { - tstore(forwardParams.slot, _forwardParams) - } // Call the claim interface of source chain message service // And it will inner call the `claimMessageCallback` interface of source chain L1Gateway // In the `claimMessageCallback` of L1Gateway, it will inner call `receiveMessage` of Arbitrator // No use of return value Address.functionCall(_sourceChainCanonicalMessageService, _sourceChainClaimCallData); + + // Load the transient `finalizeMessageHash` + bytes32 _finalizeMessageHash; + assembly { + _finalizeMessageHash := tload(finalizeMessageHash.slot) + } + require( + _finalizeMessageHash == keccak256(abi.encode(_sourceChainL1Gateway, _receiveValue, _receiveCallData)), + "Incorrect finalize data" + ); + + // The msg value should be equal to the combined cost of all messages delivered from l1 to l2 + // The excess fees will be refunded to the relayer by rollup canonical message service + if (_sourceChainL1Gateway == primaryChainGateway) { + // Unpack destination chain and final callData + bytes[] memory gatewayDataList = abi.decode(_receiveCallData, (bytes[])); + bytes[] memory gatewayForwardParamsList = abi.decode(_forwardParams, (bytes[])); + uint256 gatewayLength = gatewayDataList.length; + require(gatewayLength == gatewayForwardParamsList.length, "Invalid forward params length"); + uint256 totalCallValue; + uint256 totalSendMsgFee; + unchecked { + for (uint256 i = 0; i < gatewayLength; ++i) { + bytes memory gatewayData = gatewayDataList[i]; + bytes memory gatewayForwardParams = gatewayForwardParamsList[i]; + (IL1Gateway targetGateway, uint256 targetCallValue, bytes memory targetCallData) = abi.decode( + gatewayData, + (IL1Gateway, uint256, bytes) + ); + require(secondaryChainGateways[targetGateway], "Invalid secondary chain gateway"); + totalCallValue += targetCallValue; + (uint256 sendMsgFee, bytes memory adapterParams) = abi.decode( + gatewayForwardParams, + (uint256, bytes) + ); + totalSendMsgFee += sendMsgFee; + // Forward fee to send message + targetGateway.sendMessage{value: sendMsgFee + targetCallValue}( + targetCallValue, + targetCallData, + adapterParams + ); + emit MessageForwarded(targetGateway, targetCallValue, targetCallData); + } + } + require(totalCallValue == _receiveValue, "Invalid call value"); + require(totalSendMsgFee == msg.value, "Invalid send msg fee"); + } else { + IL1Gateway targetGateway = primaryChainGateway; + // Forward fee to send message + targetGateway.sendMessage{value: msg.value + _receiveValue}( + _receiveValue, + _receiveCallData, + _forwardParams + ); + emit MessageForwarded(targetGateway, _receiveValue, _receiveCallData); + } } } diff --git a/contracts/dev-contracts/DummyArbitrator.sol b/contracts/dev-contracts/DummyArbitrator.sol index cf8cba3..1d5dad8 100644 --- a/contracts/dev-contracts/DummyArbitrator.sol +++ b/contracts/dev-contracts/DummyArbitrator.sol @@ -39,7 +39,14 @@ contract DummyArbitrator is IArbitrator, OwnableUpgradeable, UUPSUpgradeable, Re _gateway.sendMessage{value: msg.value + _value}(_value, _callData, _adapterParams); } - function claimMessage(address, bytes calldata, bytes memory) external payable { + function claimMessage( + address, + bytes calldata, + IL1Gateway, + uint256, + bytes calldata, + bytes calldata + ) external payable { // do nothing } } diff --git a/contracts/interfaces/IArbitrator.sol b/contracts/interfaces/IArbitrator.sol index a113283..5d13d7c 100644 --- a/contracts/interfaces/IArbitrator.sol +++ b/contracts/interfaces/IArbitrator.sol @@ -31,10 +31,16 @@ interface IArbitrator { /// @notice Claim a message of source chain and deliver it to the target chain /// @param _sourceChainCanonicalMessageService The message service to claim message /// @param _sourceChainClaimCallData The call data that need to claim message from source chain + /// @param _sourceChainL1Gateway The msg.sender passed in the `receiveMessage` interface + /// @param _receiveValue The value passed in the `receiveMessage` interface + /// @param _receiveCallData The call data passed in the `receiveMessage` interface /// @param _forwardParams Some params need to call canonical message service of target chain function claimMessage( address _sourceChainCanonicalMessageService, bytes calldata _sourceChainClaimCallData, - bytes memory _forwardParams + IL1Gateway _sourceChainL1Gateway, + uint256 _receiveValue, + bytes calldata _receiveCallData, + bytes calldata _forwardParams ) external payable; }