diff --git a/contracts/evm/contracts/adapters/ClusterConnection.sol b/contracts/evm/contracts/adapters/ClusterConnection.sol index 2f333310a..b5e76c1cc 100644 --- a/contracts/evm/contracts/adapters/ClusterConnection.sol +++ b/contracts/evm/contracts/adapters/ClusterConnection.sol @@ -2,7 +2,7 @@ pragma solidity >=0.8.0; pragma abicoder v2; -import {console2 } from "forge-std/Test.sol"; +import {console2} from "forge-std/Test.sol"; import "openzeppelin-contracts-upgradeable/contracts/proxy/utils/Initializable.sol"; import "@xcall/utils/Types.sol"; @@ -10,13 +10,18 @@ import "@xcall/contracts/xcall/interfaces/IConnection.sol"; import "@iconfoundation/xcall-solidity-library/interfaces/ICallService.sol"; import "@iconfoundation/xcall-solidity-library/utils/RLPEncode.sol"; import "@iconfoundation/xcall-solidity-library/utils/RLPEncode.sol"; +import "@iconfoundation/xcall-solidity-library/utils/Strings.sol"; +import "@iconfoundation/xcall-solidity-library/utils/Integers.sol"; +/// @custom:oz-upgrades-from contracts/adapters/ClusterConnectionV1.sol:ClusterConnectionV1 contract ClusterConnection is Initializable, IConnection { - using RLPEncode for bytes; using RLPEncode for string; using RLPEncode for uint256; + using Strings for bytes; + using Integers for uint256; + mapping(string => uint256) private messageFees; mapping(string => uint256) private responseFees; mapping(string => mapping(uint256 => bool)) receipts; @@ -51,12 +56,18 @@ contract ClusterConnection is Initializable, IConnection { return validators; } - function updateValidators(bytes[] memory _validators, uint8 _threshold) external onlyAdmin { + function updateValidators( + bytes[] memory _validators, + uint8 _threshold + ) external onlyAdmin { delete validators; for (uint i = 0; i < _validators.length; i++) { address validators_address = publicKeyToAddress(_validators[i]); - if(!isValidator(validators_address) && validators_address != address(0)) { - validators.push(validators_address); + if ( + !isValidator(validators_address) && + validators_address != address(0) + ) { + validators.push(validators_address); } } require(validators.length >= _threshold, "Not enough validators"); @@ -143,40 +154,65 @@ contract ClusterConnection is Initializable, IConnection { bytes calldata _msg, bytes[] calldata _signedMessages ) public onlyRelayer { - require(_signedMessages.length >= validatorsThreshold, "Not enough signatures passed"); - bytes32 messageHash = getMessageHash(srcNetwork, _connSn, _msg); + require( + _signedMessages.length >= validatorsThreshold, + "Not enough signatures passed" + ); + + string memory dstNetwork = ICallService(xCall).getNetworkId(); + + bytes32 messageHash = getMessageHash( + srcNetwork, + _connSn, + _msg, + dstNetwork + ); uint signerCount = 0; - address[] memory collectedSigners = new address[](_signedMessages.length); - + address[] memory collectedSigners = new address[]( + _signedMessages.length + ); + for (uint i = 0; i < _signedMessages.length; i++) { address signer = recoverSigner(messageHash, _signedMessages[i]); require(signer != address(0), "Invalid signature"); - if (!isValidatorProcessed(collectedSigners, signer) && existsInValidators(signer)){ + if ( + !isValidatorProcessed(collectedSigners, signer) && + existsInValidators(signer) + ) { collectedSigners[signerCount] = signer; signerCount++; } } - require(signerCount >= validatorsThreshold,"Not enough valid signatures passed"); - recvMessage(srcNetwork,_connSn,_msg); + require( + signerCount >= validatorsThreshold, + "Not enough valid signatures passed" + ); + recvMessage(srcNetwork, _connSn, _msg); } function existsInValidators(address signer) internal view returns (bool) { - for (uint i = 0; i < validators.length; i++){ + for (uint i = 0; i < validators.length; i++) { if (validators[i] == signer) return true; } return false; } - function isValidatorProcessed(address[] memory processedSigners, address signer) public pure returns (bool) { + function isValidatorProcessed( + address[] memory processedSigners, + address signer + ) public pure returns (bool) { for (uint i = 0; i < processedSigners.length; i++) { if (processedSigners[i] == signer) { return true; } } return false; - } + } - function recoverSigner(bytes32 messageHash, bytes memory signature) public pure returns (address) { + function recoverSigner( + bytes32 messageHash, + bytes memory signature + ) public pure returns (address) { require(signature.length == 65, "Invalid signature length"); bytes32 r; bytes32 s; @@ -240,7 +276,7 @@ contract ClusterConnection is Initializable, IConnection { adminAddress = _address; } - /** + /** @notice Set the address of the relayer. @param _address The address of the relayer. */ @@ -268,7 +304,7 @@ contract ClusterConnection is Initializable, IConnection { @notice Set the required signature count for verification. @param _count The desired count. */ - function setRequiredValidatorCount(uint8 _count) external onlyAdmin() { + function setRequiredValidatorCount(uint8 _count) external onlyAdmin { validatorsThreshold = _count; } @@ -276,16 +312,25 @@ contract ClusterConnection is Initializable, IConnection { return validatorsThreshold; } - function getMessageHash(string memory srcNetwork, uint256 _connSn, bytes calldata _msg) internal pure returns (bytes32) { - bytes memory rlp = abi.encodePacked( - srcNetwork.encodeString(), - _connSn.encodeUint(), - _msg.encodeBytes() - ).encodeList(); - return keccak256(rlp); + function getMessageHash( + string memory srcNetwork, + uint256 _connSn, + bytes calldata _msg, + string memory dstNetwork + ) internal pure returns (bytes32) { + bytes memory encoded = abi + .encodePacked( + srcNetwork, + _connSn.toString(), + _msg, + dstNetwork + ); + return keccak256(encoded); } - function publicKeyToAddress(bytes memory publicKey) internal pure returns (address addr) { + function publicKeyToAddress( + bytes memory publicKey + ) internal pure returns (address addr) { require(publicKey.length == 65, "Invalid public key length"); bytes32 hash; @@ -298,5 +343,4 @@ contract ClusterConnection is Initializable, IConnection { addr = address(uint160(uint256(hash))); } - } diff --git a/contracts/evm/library/xcall/utils/Strings.sol b/contracts/evm/library/xcall/utils/Strings.sol index 0e86716ed..3ac7020b2 100644 --- a/contracts/evm/library/xcall/utils/Strings.sol +++ b/contracts/evm/library/xcall/utils/Strings.sol @@ -47,7 +47,7 @@ library Strings { converted[i * 2 + 1] = _base[uint8(buffer[i]) & 0xf]; } - return string(abi.encodePacked("0x", converted)); + return string(converted); } /** diff --git a/contracts/evm/test/adapters/ClusterConnection.t.sol b/contracts/evm/test/adapters/ClusterConnection.t.sol index 2612acaba..dccef0cce 100644 --- a/contracts/evm/test/adapters/ClusterConnection.t.sol +++ b/contracts/evm/test/adapters/ClusterConnection.t.sol @@ -1,22 +1,25 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.13; -import {Test, console2 } from "forge-std/Test.sol"; +import {Test, console2} from "forge-std/Test.sol"; import {LZEndpointMock} from "@lz-contracts/mocks/LZEndpointMock.sol"; import "@xcall/contracts/adapters/ClusterConnection.sol"; import "@xcall/contracts/xcall/CallService.sol"; import "@xcall/contracts/mocks/multi-protocol-dapp/MultiProtocolSampleDapp.sol"; import "@xcall/utils/Types.sol"; import "@iconfoundation/xcall-solidity-library/utils/RLPEncode.sol"; - +import "@iconfoundation/xcall-solidity-library/utils/Strings.sol"; +import "@iconfoundation/xcall-solidity-library/utils/Integers.sol"; contract ClusterConnectionTest is Test { - using RLPEncode for bytes; using RLPEncode for string; using RLPEncode for uint256; using RLPEncodeStruct for Types.CSMessage; using RLPEncodeStruct for Types.CSMessageRequestV2; + using Strings for bytes; + using Integers for uint256; + event CallExecuted(uint256 indexed _reqId, int _code, string _msg); event RollbackExecuted(uint256 indexed _sn); @@ -42,8 +45,8 @@ contract ClusterConnectionTest is Test { address public owner = address(uint160(uint256(keccak256("owner")))); address public admin = address(uint160(uint256(keccak256("admin")))); - address public user = address(uint160(uint256(keccak256("user")))); - + address public user = address(uint160(uint256(keccak256("user")))); + event CallMessage( string indexed _from, string indexed _to, @@ -84,7 +87,7 @@ contract ClusterConnectionTest is Test { dappTarget.initialize(address(xCallTarget)); adapterTarget = new ClusterConnection(); - adapterTarget.initialize(destination_relayer, address(xCallTarget)); + adapterTarget.initialize(destination_relayer, address(xCallTarget)); xCallTarget.setDefaultConnection(nidSource, address(adapterTarget)); } @@ -192,7 +195,7 @@ contract ClusterConnectionTest is Test { vm.stopPrank(); assert(source_relayer.balance == 10 ether); - } + } function testRecvMessageWithMultiSignatures() public { bytes memory data = bytes("test"); @@ -216,26 +219,51 @@ contract ClusterConnectionTest is Test { uint256 pk2 = 0x47e179ec197488593b187f80a00eb0da91f1b9d0b13f8733639f19c30a34926a; uint256 pk3 = 0x59c6995e998f97a5a0044966f0945389dc9e86dae88c7a8412f4603b6b78690d; uint256 pk4 = 0x2a871d0798f97d79848a013d4936a73bf4cc922c825d33c1cf7073dff6d409c6; - bytes32 hash = getMessageHash(nidSource, 1, RLPEncodeStruct.encodeCSMessage(message)); + bytes32 hash = getMessageHash( + nidSource, + 1, + RLPEncodeStruct.encodeCSMessage(message), + nidTarget + ); vm.startPrank(owner); bytes[] memory validators = new bytes[](4); - validators[0] = bytes(hex"048318535b54105d4a7aae60c08fc45f9687181b4fdfc625bd1a753fa7397fed753547f11ca8696646f2f3acb08e31016afac23e630c5d11f59f61fef57b0d2aa5"); - validators[1] = bytes(hex"04bf6ee64a8d2fdc551ec8bb9ef862ef6b4bcb1805cdc520c3aa5866c0575fd3b514c5562c3caae7aec5cd6f144b57135c75b6f6cea059c3d08d1f39a9c227219d"); - validators[2] = bytes(hex"04ba5734d8f7091719471e7f7ed6b9df170dc70cc661ca05e688601ad984f068b0d67351e5f06073092499336ab0839ef8a521afd334e53807205fa2f08eec74f4"); - validators[3] = bytes(hex"043255458e24278e31d5940f304b16300fdff3f6efd3e2a030b5818310ac67af45e28d057e6a332d07e0c5ab09d6947fd4eed1a646edbf224e2d2fec6f49f90abc"); + validators[0] = bytes( + hex"048318535b54105d4a7aae60c08fc45f9687181b4fdfc625bd1a753fa7397fed753547f11ca8696646f2f3acb08e31016afac23e630c5d11f59f61fef57b0d2aa5" + ); + validators[1] = bytes( + hex"04bf6ee64a8d2fdc551ec8bb9ef862ef6b4bcb1805cdc520c3aa5866c0575fd3b514c5562c3caae7aec5cd6f144b57135c75b6f6cea059c3d08d1f39a9c227219d" + ); + validators[2] = bytes( + hex"04ba5734d8f7091719471e7f7ed6b9df170dc70cc661ca05e688601ad984f068b0d67351e5f06073092499336ab0839ef8a521afd334e53807205fa2f08eec74f4" + ); + validators[3] = bytes( + hex"043255458e24278e31d5940f304b16300fdff3f6efd3e2a030b5818310ac67af45e28d057e6a332d07e0c5ab09d6947fd4eed1a646edbf224e2d2fec6f49f90abc" + ); adapterTarget.updateValidators(validators, 4); adapterTarget.listValidators(); vm.stopPrank(); vm.startPrank(destination_relayer); vm.expectEmit(); - emit CallMessage(iconDapp, ParseAddress.toString(address(dappSource)), 1, 1, data); - vm.expectCall(address(xCallTarget), abi.encodeCall(xCallTarget.handleMessage, (nidSource,RLPEncodeStruct.encodeCSMessage(message)))); - bytes[] memory signatures = new bytes[](4) ; - signatures[0] = signMessage(pk,hash); - signatures[1] = signMessage(pk2,hash); - signatures[2] = signMessage(pk3,hash); - signatures[3] = signMessage(pk4,hash); + emit CallMessage( + iconDapp, + ParseAddress.toString(address(dappSource)), + 1, + 1, + data + ); + vm.expectCall( + address(xCallTarget), + abi.encodeCall( + xCallTarget.handleMessage, + (nidSource, RLPEncodeStruct.encodeCSMessage(message)) + ) + ); + bytes[] memory signatures = new bytes[](4); + signatures[0] = signMessage(pk, hash); + signatures[1] = signMessage(pk2, hash); + signatures[2] = signMessage(pk3, hash); + signatures[3] = signMessage(pk4, hash); adapterTarget.recvMessageWithSignatures( nidSource, 1, @@ -245,20 +273,29 @@ contract ClusterConnectionTest is Test { vm.stopPrank(); } - function signMessage(uint256 pk,bytes32 hash) private pure returns (bytes memory){ + function signMessage( + uint256 pk, + bytes32 hash + ) private pure returns (bytes memory) { (uint8 v, bytes32 r, bytes32 s) = vm.sign(pk, hash); address signer = vm.addr(pk); - bytes memory signature = combineSignature(r,s,v); + bytes memory signature = combineSignature(r, s, v); - address recoverSigner=ecrecover(hash,v,r,s); + address recoverSigner = ecrecover(hash, v, r, s); return signature; } - - function combineSignature(bytes32 r, bytes32 s, uint8 v) private pure returns (bytes memory) { + + function combineSignature( + bytes32 r, + bytes32 s, + uint8 v + ) private pure returns (bytes memory) { return abi.encodePacked(r, s, v); } - function hexStringToUint256(string memory hexString) public pure returns (uint256) { + function hexStringToUint256( + string memory hexString + ) public pure returns (uint256) { bytes memory hexBytes = bytes(hexString); uint256 number = 0; @@ -283,17 +320,24 @@ contract ClusterConnectionTest is Test { vm.startPrank(owner); bytes[] memory validators = new bytes[](4); - validators[0] = bytes(hex"048318535b54105d4a7aae60c08fc45f9687181b4fdfc625bd1a753fa7397fed753547f11ca8696646f2f3acb08e31016afac23e630c5d11f59f61fef57b0d2aa5"); - validators[1] = bytes(hex"04bf6ee64a8d2fdc551ec8bb9ef862ef6b4bcb1805cdc520c3aa5866c0575fd3b514c5562c3caae7aec5cd6f144b57135c75b6f6cea059c3d08d1f39a9c227219d"); - validators[2] = bytes(hex"04ba5734d8f7091719471e7f7ed6b9df170dc70cc661ca05e688601ad984f068b0d67351e5f06073092499336ab0839ef8a521afd334e53807205fa2f08eec74f4"); - validators[3] = bytes(hex"043255458e24278e31d5940f304b16300fdff3f6efd3e2a030b5818310ac67af45e28d057e6a332d07e0c5ab09d6947fd4eed1a646edbf224e2d2fec6f49f90abc"); + validators[0] = bytes( + hex"048318535b54105d4a7aae60c08fc45f9687181b4fdfc625bd1a753fa7397fed753547f11ca8696646f2f3acb08e31016afac23e630c5d11f59f61fef57b0d2aa5" + ); + validators[1] = bytes( + hex"04bf6ee64a8d2fdc551ec8bb9ef862ef6b4bcb1805cdc520c3aa5866c0575fd3b514c5562c3caae7aec5cd6f144b57135c75b6f6cea059c3d08d1f39a9c227219d" + ); + validators[2] = bytes( + hex"04ba5734d8f7091719471e7f7ed6b9df170dc70cc661ca05e688601ad984f068b0d67351e5f06073092499336ab0839ef8a521afd334e53807205fa2f08eec74f4" + ); + validators[3] = bytes( + hex"043255458e24278e31d5940f304b16300fdff3f6efd3e2a030b5818310ac67af45e28d057e6a332d07e0c5ab09d6947fd4eed1a646edbf224e2d2fec6f49f90abc" + ); adapterTarget.updateValidators(validators, 4); console2.log(adapterTarget.listValidators()[0]); assertEq(4, adapterTarget.listValidators().length); vm.stopPrank(); } - function testRequiredCount() public { vm.startPrank(owner); adapterTarget.setRequiredValidatorCount(3); @@ -319,18 +363,24 @@ contract ClusterConnectionTest is Test { Types.CS_REQUEST, request.encodeCSMessageRequestV2() ); - uint256 pk = hexStringToUint256("ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"); + uint256 pk = hexStringToUint256( + "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80" + ); bytes32 hash = keccak256(RLPEncodeStruct.encodeCSMessage(message)); vm.startPrank(owner); bytes[] memory validators = new bytes[](2); - validators[0] = bytes(hex"048318535b54105d4a7aae60c08fc45f9687181b4fdfc625bd1a753fa7397fed753547f11ca8696646f2f3acb08e31016afac23e630c5d11f59f61fef57b0d2aa5"); - validators[1] = bytes(hex"04bf6ee64a8d2fdc551ec8bb9ef862ef6b4bcb1805cdc520c3aa5866c0575fd3b514c5562c3caae7aec5cd6f144b57135c75b6f6cea059c3d08d1f39a9c227219d"); + validators[0] = bytes( + hex"048318535b54105d4a7aae60c08fc45f9687181b4fdfc625bd1a753fa7397fed753547f11ca8696646f2f3acb08e31016afac23e630c5d11f59f61fef57b0d2aa5" + ); + validators[1] = bytes( + hex"04bf6ee64a8d2fdc551ec8bb9ef862ef6b4bcb1805cdc520c3aa5866c0575fd3b514c5562c3caae7aec5cd6f144b57135c75b6f6cea059c3d08d1f39a9c227219d" + ); adapterTarget.updateValidators(validators, 2); vm.stopPrank(); vm.startPrank(destination_relayer); - bytes[] memory signatures = new bytes[](2) ; - signatures[0] = signMessage(pk,hash); - signatures[1] = signMessage(pk,hash); + bytes[] memory signatures = new bytes[](2); + signatures[0] = signMessage(pk, hash); + signatures[1] = signMessage(pk, hash); vm.expectRevert("Not enough valid signatures passed"); adapterTarget.recvMessageWithSignatures( nidSource, @@ -341,12 +391,19 @@ contract ClusterConnectionTest is Test { vm.stopPrank(); } - function getMessageHash(string memory srcNetwork, uint256 _connSn, bytes memory _msg) internal pure returns (bytes32) { - bytes memory rlp = abi.encodePacked( - srcNetwork.encodeString(), - _connSn.encodeUint(), - _msg.encodeBytes() - ).encodeList(); - return keccak256(rlp); + function getMessageHash( + string memory srcNetwork, + uint256 _connSn, + bytes memory _msg, + string memory dstNetwork + ) internal pure returns (bytes32) { + bytes memory encoded = abi + .encodePacked( + srcNetwork, + _connSn.toString(), + _msg, + dstNetwork + ); + return keccak256(encoded); } -} +} \ No newline at end of file diff --git a/contracts/javascore/cluster-connection/src/main/java/xcall/adapter/cluster/ClusterConnection.java b/contracts/javascore/cluster-connection/src/main/java/xcall/adapter/cluster/ClusterConnection.java index 6546187fa..42bcbf882 100644 --- a/contracts/javascore/cluster-connection/src/main/java/xcall/adapter/cluster/ClusterConnection.java +++ b/contracts/javascore/cluster-connection/src/main/java/xcall/adapter/cluster/ClusterConnection.java @@ -31,23 +31,21 @@ import score.annotation.External; import score.annotation.Payable; -import java.util.Arrays; import java.util.List; - - public class ClusterConnection { protected final VarDB
xCall = Context.newVarDB("callService", Address.class); protected final VarDB
adminAddress = Context.newVarDB("admin", Address.class); protected final VarDB
relayerAddress = Context.newVarDB("relayer", Address.class); protected final VarDB validatorsThreshold = Context.newVarDB("reqValidatorCnt", BigInteger.class); private final VarDB connSn = Context.newVarDB("connSn", BigInteger.class); - private final ArrayDB validators = Context.newArrayDB("signers", String.class); + private final ArrayDB validators = Context.newArrayDB("signers", String.class); protected final DictDB messageFees = Context.newDictDB("messageFees", BigInteger.class); protected final DictDB responseFees = Context.newDictDB("responseFees", BigInteger.class); protected final BranchDB> receipts = Context.newBranchDB("receipts", Boolean.class); + public ClusterConnection(Address _relayer, Address _xCall) { if (xCall.get() == null) { xCall.set(_xCall); @@ -57,7 +55,7 @@ public ClusterConnection(Address _relayer, Address _xCall) { } } - /** + /** * Retrieves the validators. * * @return The validators . @@ -65,30 +63,31 @@ public ClusterConnection(Address _relayer, Address _xCall) { @External(readonly = true) public String[] listValidators() { String[] sgs = new String[validators.size()]; - for(int i = 0; i < validators.size(); i++) { + for (int i = 0; i < validators.size(); i++) { sgs[i] = validators.get(i); } return sgs; } -/** - * Adds a list of validators and sets the validation threshold. - * - * Clears existing validators and adds the provided addresses as validators. - * Ensures that the caller is an admin and that the number of validators - * meets or exceeds the specified threshold. - * - * @param _validators an array of compressed publickey bytes to be added as validators - * @param _threshold the minimum required number of validators - * @throws Exception if the number of validators is less than the threshold - */ + /** + * Adds a list of validators and sets the validation threshold. + * + * Clears existing validators and adds the provided addresses as validators. + * Ensures that the caller is an admin and that the number of validators + * meets or exceeds the specified threshold. + * + * @param _validators an array of compressed publickey bytes to be added as + * validators + * @param _threshold the minimum required number of validators + * @throws Exception if the number of validators is less than the threshold + */ @External public void updateValidators(byte[][] _validators, BigInteger _threshold) { OnlyAdmin(); clearValidators(); for (byte[] validator : _validators) { String hexValidator = bytesToHex(validator); - if(!isValidator(hexValidator)) { + if (!isValidator(hexValidator)) { validators.add(bytesToHex(validator)); } } @@ -103,20 +102,20 @@ public void updateValidators(byte[][] _validators, BigInteger _threshold) { * This is a private helper method called by addValidator. */ private void clearValidators() { - for(int i = 0; i < validators.size(); i++) { + for (int i = 0; i < validators.size(); i++) { validators.set(i, null); } } -/** - * Checks if the provided compressed pubkey bytes is a validator. - * - * @param validator the compressed publickey bytes to check for validation - * @return true if the compressed pubkey bytes is a validator, false otherwise - */ + /** + * Checks if the provided compressed pubkey bytes is a validator. + * + * @param validator the compressed publickey bytes to check for validation + * @return true if the compressed pubkey bytes is a validator, false otherwise + */ private boolean isValidator(String validator) { - for(int i = 0; i < validators.size(); i++) { - if(validator.equals(validators.get(i))) { + for (int i = 0; i < validators.size(); i++) { + if (validator.equals(validators.get(i))) { return true; } } @@ -140,8 +139,8 @@ public void ValidatorSetAdded(String _validators, BigInteger _threshold) { public void setRelayer(Address _relayer) { OnlyAdmin(); relayerAddress.set(_relayer); - } - + } + /** * Sets the admin address. * @@ -174,7 +173,7 @@ public void setRequiredValidatorCount(BigInteger _validatorCnt) { validatorsThreshold.set(_validatorCnt); } - /** + /** * Retrieves the required validator count. * * @return The required validator count. @@ -248,26 +247,33 @@ public void sendMessage(String to, String svc, BigInteger sn, byte[] msg) { * @param srcNetwork the source network id from which the message is received * @param _connSn the serial number of the connection message * @param msg serialized bytes of Service Message + * @param dstNetwork the destination network id * @param signatures array of signatures */ - @External - public void recvMessageWithSignatures(String srcNetwork, BigInteger _connSn, byte[] msg, - byte[][] signatures) { - OnlyRelayer(); - Context.require(signatures.length >= validatorsThreshold.get().intValue(), "Not enough signatures"); - byte[] messageHash = getMessageHash(srcNetwork, _connSn, msg); - List uniqueValidators = new ArrayList<>(); - for (byte[] signature : signatures) { - byte[] validator = getValidator(messageHash, signature); - String hexValidator = bytesToHex(validator); - Context.require(isValidator(hexValidator), "Invalid signature provided"); - if (!uniqueValidators.contains(hexValidator)) { - uniqueValidators.add(hexValidator); - } - } - Context.require(uniqueValidators.size() >= validatorsThreshold.get().intValue(), "Not enough valid signatures"); - recvMessage(srcNetwork, _connSn, msg); - } + @External + public void recvMessageWithSignatures( + String srcNetwork, + BigInteger _connSn, + byte[] msg, + byte[][] signatures) { + OnlyRelayer(); + Context.require(signatures.length >= validatorsThreshold.get().intValue(), "Not enough signatures"); + + String dstNetwork = Context.call(String.class, xCall.get(), "getNetworkId"); + + byte[] messageHash = getMessageHash(srcNetwork, _connSn, msg, dstNetwork); + List uniqueValidators = new ArrayList<>(); + for (byte[] signature : signatures) { + byte[] validator = getValidator(messageHash, signature); + String hexValidator = bytesToHex(validator); + Context.require(isValidator(hexValidator), "Invalid signature provided"); + if (!uniqueValidators.contains(hexValidator)) { + uniqueValidators.add(hexValidator); + } + } + Context.require(uniqueValidators.size() >= validatorsThreshold.get().intValue(), "Not enough valid signatures"); + recvMessage(srcNetwork, _connSn, msg); + } private void recvMessage(String srcNetwork, BigInteger _connSn, byte[] msg) { Context.require(!receipts.at(srcNetwork).getOrDefault(_connSn, false), "Duplicate Message"); @@ -278,16 +284,16 @@ private void recvMessage(String srcNetwork, BigInteger _connSn, byte[] msg) { private String bytesToHex(byte[] bytes) { StringBuilder hexString = new StringBuilder(); for (byte b : bytes) { - String hex = Integer.toHexString(0xff & b); // Mask with 0xff to handle negative values correctly + String hex = Integer.toHexString(0xff & b); // Mask with 0xff to handle negative values correctly if (hex.length() == 1) { - hexString.append('0'); // Add a leading zero if hex length is 1 + hexString.append('0'); // Add a leading zero if hex length is 1 } hexString.append(hex); } return hexString.toString(); } - private byte[] getValidator(byte[] msg, byte[] sig){ + private byte[] getValidator(byte[] msg, byte[] sig) { return Context.recoverKey("ecdsa-secp256k1", msg, sig, false); } @@ -349,16 +355,27 @@ private void OnlyAdmin() { * @param srcNetwork the source network id * @param _connSn the serial number of connection message * @param msg the message to hash + * @param dstNetwork the destination network id * @return the hash of the message */ - private byte[] getMessageHash(String srcNetwork, BigInteger _connSn, byte[] msg) { - ByteArrayObjectWriter writer = Context.newByteArrayObjectWriter("RLPn"); - writer.beginList(3); - writer.write(srcNetwork); - writer.write(_connSn); - writer.write(msg); - writer.end(); - return Context.hash("keccak-256", writer.toByteArray()); + private byte[] getMessageHash(String srcNetwork, BigInteger _connSn, byte[] msg, String dstNetwork) { + byte[] result = concatBytes(srcNetwork.getBytes(), String.valueOf(_connSn).getBytes(), msg, dstNetwork.getBytes()); + return Context.hash("keccak-256", result); + } + + private static byte[] concatBytes(byte[]... arrays) { + int totalLength = 0; + for (byte[] array : arrays) { + totalLength += array.length; + } + byte[] result = new byte[totalLength]; + int currentIndex = 0; + for (byte[] array : arrays) { + System.arraycopy(array, 0, result, currentIndex, array.length); + currentIndex += array.length; + } + return result; } + } \ No newline at end of file diff --git a/contracts/javascore/cluster-connection/src/test/java/xcall/adapter/cluster/ClusterConnectionTest.java b/contracts/javascore/cluster-connection/src/test/java/xcall/adapter/cluster/ClusterConnectionTest.java index c856f5a79..5d28e744d 100644 --- a/contracts/javascore/cluster-connection/src/test/java/xcall/adapter/cluster/ClusterConnectionTest.java +++ b/contracts/javascore/cluster-connection/src/test/java/xcall/adapter/cluster/ClusterConnectionTest.java @@ -5,11 +5,9 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; -import java.nio.charset.StandardCharsets; import java.security.*; import java.math.BigInteger; -import java.util.Arrays; import score.Context; @@ -17,8 +15,6 @@ import foundation.icon.icx.KeyWallet; - - import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.MockedStatic; @@ -30,12 +26,10 @@ import com.iconloop.score.test.TestBase; import score.UserRevertedException; -import score.Address; import score.ByteArrayObjectWriter; import foundation.icon.xcall.CallService; import foundation.icon.xcall.CallServiceScoreInterface; - import xcall.icon.test.MockContract; public class ClusterConnectionTest extends TestBase { @@ -65,7 +59,6 @@ public void setup() throws Exception { Security.addProvider(new BouncyCastleProvider()); callservice = new MockContract<>(CallServiceScoreInterface.class, CallService.class, sm, owner); - connection = sm.deploy(owner, ClusterConnection.class, source_relayer.getAddress(), callservice.getAddress()); connectionSpy = (ClusterConnection) spy(connection.getInstance()); @@ -112,47 +105,51 @@ public void testRevertMessage() { } @Test - public void testRevertMessage_unauthorized(){ - UserRevertedException e = assertThrows(UserRevertedException.class, ()->connection.invoke(user, "revertMessage", BigInteger.ONE)); - assertEquals("Reverted(0): "+"Only relayer can call this function", e.getMessage()); - + public void testRevertMessage_unauthorized() { + UserRevertedException e = assertThrows(UserRevertedException.class, + () -> connection.invoke(user, "revertMessage", BigInteger.ONE)); + assertEquals("Reverted(0): " + "Only relayer can call this function", e.getMessage()); + } @Test - public void testSetFeesUnauthorized(){ - UserRevertedException e = assertThrows(UserRevertedException.class,() -> connection.invoke(user, "setFee", "0xevm", - BigInteger.TEN, BigInteger.TEN)); - assertEquals("Reverted(0): "+"Only relayer can call this function", e.getMessage()); + public void testSetFeesUnauthorized() { + UserRevertedException e = assertThrows(UserRevertedException.class, + () -> connection.invoke(user, "setFee", "0xevm", + BigInteger.TEN, BigInteger.TEN)); + assertEquals("Reverted(0): " + "Only relayer can call this function", e.getMessage()); } @Test - public void testClaimFees(){ + public void testClaimFees() { setFee(); connection.invoke(source_relayer, "claimFees"); assertEquals(source_relayer.getBalance(), BigInteger.ZERO); - UserRevertedException e = assertThrows(UserRevertedException.class,() -> connection.invoke(callservice.account, "sendMessage", nidTarget, - "xcall", BigInteger.ONE, "null".getBytes())); + UserRevertedException e = assertThrows(UserRevertedException.class, + () -> connection.invoke(callservice.account, "sendMessage", nidTarget, + "xcall", BigInteger.ONE, "null".getBytes())); assertEquals(e.getMessage(), "Reverted(0): Insufficient balance"); try (MockedStatic contextMock = Mockito.mockStatic(Context.class, Mockito.CALLS_REAL_METHODS)) { contextMock.when(() -> Context.getValue()).thenReturn(BigInteger.valueOf(20)); - connection.invoke(callservice.account, "sendMessage", nidTarget,"xcall", BigInteger.ONE, "null".getBytes()); + connection.invoke(callservice.account, "sendMessage", nidTarget, "xcall", BigInteger.ONE, + "null".getBytes()); } - try (MockedStatic contextMock = Mockito.mockStatic(Context.class, Mockito.CALLS_REAL_METHODS)) { contextMock.when(() -> Context.getBalance(connection.getAddress())).thenReturn(BigInteger.valueOf(20)); - contextMock.when(() -> Context.transfer(source_relayer.getAddress(),BigInteger.valueOf(20))).then(invocationOnMock -> null); + contextMock.when(() -> Context.transfer(source_relayer.getAddress(), BigInteger.valueOf(20))) + .then(invocationOnMock -> null); connection.invoke(source_relayer, "claimFees"); } } @Test - public void testClaimFees_unauthorized(){ + public void testClaimFees_unauthorized() { setFee(); - UserRevertedException e = assertThrows(UserRevertedException.class,() -> connection.invoke(user, "claimFees")); - assertEquals(e.getMessage(), "Reverted(0): "+"Only relayer can call this function"); + UserRevertedException e = assertThrows(UserRevertedException.class, () -> connection.invoke(user, "claimFees")); + assertEquals(e.getMessage(), "Reverted(0): " + "Only relayer can call this function"); } public MockedStatic.Verification value() { @@ -160,25 +157,28 @@ public MockedStatic.Verification value() { } @Test - public void testRecvMessageWithSignatures() throws Exception{ + public void testRecvMessageWithSignatures() throws Exception { byte[] data = "test".getBytes(); - byte[] messageHash = getMessageHash(nidSource, BigInteger.ONE, data); + byte[] messageHash = getMessageHash(nidSource, BigInteger.ONE, data, nidTarget); byte[][] byteArray = new byte[1][]; KeyWallet wallet = KeyWallet.create(); byteArray[0] = wallet.sign(messageHash); byte[][] validators = new byte[][] { wallet.getPublicKey().toByteArray(), }; - connection.invoke(owner, "updateValidators", validators, BigInteger.ONE); - connection.invoke(source_relayer, "recvMessageWithSignatures", nidSource, BigInteger.ONE, data, byteArray); + + connection.invoke(owner, "updateValidators", validators, BigInteger.ONE); + + when(callservice.mock.getNetworkId()).thenReturn(nidTarget); + connection.invoke(source_relayer, "recvMessageWithSignatures", nidSource, BigInteger.ONE, data, + byteArray); verify(callservice.mock).handleMessage(eq(nidSource), eq("test".getBytes())); } - @Test - public void testRecvMessageWithMultiSignatures() throws Exception{ + public void testRecvMessageWithMultiSignatures() throws Exception { byte[] data = "test".getBytes(); - byte[] messageHash = getMessageHash(nidSource, BigInteger.ONE, data); + byte[] messageHash = getMessageHash(nidSource, BigInteger.ONE, data, nidTarget); byte[][] byteArray = new byte[2][]; KeyWallet wallet = KeyWallet.create(); KeyWallet wallet2 = KeyWallet.create(); @@ -188,15 +188,17 @@ public void testRecvMessageWithMultiSignatures() throws Exception{ wallet.getPublicKey().toByteArray(), wallet2.getPublicKey().toByteArray(), }; - connection.invoke(owner, "updateValidators", validators, BigInteger.TWO); - connection.invoke(source_relayer, "recvMessageWithSignatures", nidSource, BigInteger.ONE, data, byteArray); + connection.invoke(owner, "updateValidators", validators, BigInteger.TWO); + when(callservice.mock.getNetworkId()).thenReturn(nidTarget); + connection.invoke(source_relayer, "recvMessageWithSignatures", nidSource, BigInteger.ONE, data, + byteArray); verify(callservice.mock).handleMessage(eq(nidSource), eq("test".getBytes())); } @Test - public void testRecvMessageWithSignaturesNotEnoughSignatures() throws Exception{ + public void testRecvMessageWithSignaturesNotEnoughSignatures() throws Exception { byte[] data = "test".getBytes(); - byte[] messageHash = getMessageHash(nidSource, BigInteger.ONE, data); + byte[] messageHash = getMessageHash(nidSource, BigInteger.ONE, data, nidTarget); KeyWallet wallet = KeyWallet.create(); KeyWallet wallet2 = KeyWallet.create(); byte[][] byteArray = new byte[1][]; @@ -206,16 +208,18 @@ public void testRecvMessageWithSignaturesNotEnoughSignatures() throws Exception{ wallet2.getPublicKey().toByteArray(), }; connection.invoke(owner, "updateValidators", validators, BigInteger.TWO); + when(callservice.mock.getNetworkId()).thenReturn(nidTarget); UserRevertedException e = assertThrows(UserRevertedException.class, - ()->connection.invoke(source_relayer, "recvMessageWithSignatures", nidSource, BigInteger.ONE, data, byteArray)); + () -> connection.invoke(source_relayer, "recvMessageWithSignatures", nidSource, BigInteger.ONE, data, + byteArray)); assertEquals("Reverted(0): Not enough signatures", e.getMessage()); verifyNoInteractions(callservice.mock); } @Test - public void testRecvMessageWithSignaturesNotEnoughValidSignatures() throws Exception{ + public void testRecvMessageWithSignaturesNotEnoughValidSignatures() throws Exception { byte[] data = "test".getBytes(); - byte[] messageHash = getMessageHash(nidSource, BigInteger.ONE, data); + byte[] messageHash = getMessageHash(nidSource, BigInteger.ONE, data, nidTarget); KeyWallet wallet = KeyWallet.create(); KeyWallet wallet2 = KeyWallet.create(); byte[][] byteArray = new byte[2][]; @@ -226,24 +230,16 @@ public void testRecvMessageWithSignaturesNotEnoughValidSignatures() throws Excep wallet2.getPublicKey().toByteArray(), }; connection.invoke(owner, "updateValidators", validators, BigInteger.TWO); + + when(callservice.mock.getNetworkId()).thenReturn(nidTarget); UserRevertedException e = assertThrows(UserRevertedException.class, - ()->connection.invoke(source_relayer, "recvMessageWithSignatures", nidSource, BigInteger.ONE, data, byteArray)); + () -> connection.invoke(source_relayer, "recvMessageWithSignatures", nidSource, BigInteger.ONE, data, + byteArray)); assertEquals("Reverted(0): Not enough valid signatures", e.getMessage()); - verifyNoInteractions(callservice.mock); } - - public static byte[] getMessageHash(String srcNetwork, BigInteger _connSn, byte[] msg) { - ByteArrayObjectWriter writer = Context.newByteArrayObjectWriter("RLPn"); - writer.beginList(3); - writer.write(srcNetwork); - writer.write(_connSn); - writer.write(msg); - writer.end(); - return Context.hash("keccak-256", writer.toByteArray()); - } @Test - public void testAddSigners() throws Exception{ + public void testAddSigners() throws Exception { KeyWallet wallet = KeyWallet.create(); KeyWallet wallet2 = KeyWallet.create(); byte[][] validators = new byte[][] { @@ -251,8 +247,28 @@ public void testAddSigners() throws Exception{ wallet2.getPublicKey().toByteArray(), }; connection.invoke(owner, "updateValidators", validators, BigInteger.TWO); - String[] signers = connection.call(String[].class,"listValidators"); + String[] signers = connection.call(String[].class, "listValidators"); assertEquals(signers.length, 2); } + private byte[] getMessageHash(String srcNetwork, BigInteger _connSn, byte[] msg, String dstNetwork) { + byte[] result = concatBytes(srcNetwork.getBytes(), String.valueOf(_connSn).getBytes(), msg, dstNetwork.getBytes()); + return Context.hash("keccak-256", result); + } + + private static byte[] concatBytes(byte[]... arrays) { + int totalLength = 0; + for (byte[] array : arrays) { + totalLength += array.length; + } + byte[] result = new byte[totalLength]; + int currentIndex = 0; + for (byte[] array : arrays) { + System.arraycopy(array, 0, result, currentIndex, array.length); + currentIndex += array.length; + } + return result; + } + + } \ No newline at end of file diff --git a/contracts/javascore/xcall-lib/src/main/java/foundation/icon/xcall/CallService.java b/contracts/javascore/xcall-lib/src/main/java/foundation/icon/xcall/CallService.java index 338fb2b0d..2a99a07a6 100644 --- a/contracts/javascore/xcall-lib/src/main/java/foundation/icon/xcall/CallService.java +++ b/contracts/javascore/xcall-lib/src/main/java/foundation/icon/xcall/CallService.java @@ -29,119 +29,130 @@ @ScoreClient @ScoreInterface public interface CallService { - /** - * The name of CallService. - */ - String NAME = "xcallM"; - - /*======== At the source CALL_BSH ========*/ - /** - * Sends a call message to the contract on the destination chain. - * - * @param _to The BTP address of the callee on the destination chain - * @param _data The calldata specific to the target contract - * @param _rollback (Optional) The data for restoring the caller state when an error occurred - * @return The serial number of the request - */ - @Payable - @External - BigInteger sendCallMessage(String _to, - byte[] _data, - @Optional byte[] _rollback, - @Optional String[] _sources, - @Optional String[] _destinations); - - /** - * Handles incoming Messages. - * - * @param _from String ( Network id of source network ) - * @param _msg Bytes ( serialized bytes of CallMessage ) - */ - @External - void handleMessage(String _from, byte[] _msg); - - /** - * Handle the error on delivering the message. - * - * @param _sn Integer ( serial number of the original message ) - */ - - @External - void handleError(BigInteger _sn); - - /** - * Notifies that the requested call message has been sent. - * - * @param _from The chain-specific address of the caller - * @param _to The BTP address of the callee on the destination chain - * @param _sn The serial number of the request - */ - @EventLog(indexed=3) - void CallMessageSent(Address _from, String _to, BigInteger _sn); - - /** - * Notifies that a response message has arrived for the `_sn` if the request was a two-way message. - * - * @param _sn The serial number of the previous request - * @param _code The response code - * {@code (0: Success, -1: Unknown generic failure, >=1: User defined error code)} - */ - @EventLog(indexed=1) - void ResponseMessage(BigInteger _sn, int _code); - - /** - * Notifies the user that a rollback operation is required for the request '_sn'. - * - * @param _sn The serial number of the previous request - */ - @EventLog(indexed=1) - void RollbackMessage(BigInteger _sn); - - /** - * Rollbacks the caller state of the request '_sn'. - * - * @param _sn The serial number of the previous request - */ - @External - void executeRollback(BigInteger _sn); - - /** - * Notifies that the rollback has been executed. - * - * @param _sn The serial number for the rollback - */ - @EventLog(indexed=1) - void RollbackExecuted(BigInteger _sn); - - /*======== At the destination CALL_BSH ========*/ - /** - * Notifies the user that a new call message has arrived. - * - * @param _from The BTP address of the caller on the source chain - * @param _to A string representation of the callee address - * @param _sn The serial number of the request from the source - * @param _reqId The request id of the destination chain - * @param _data The calldata - */ - @EventLog(indexed=3) - void CallMessage(String _from, String _to, BigInteger _sn, BigInteger _reqId, byte[] _data); - - /** - * Executes the requested call message. - * - * @param _reqId The request id - */ - @External - void executeCall(BigInteger _reqId, byte[] _data); - - /** - * Notifies that the call message has been executed. - * - * @param _reqId The request id for the call message - * @param _code The execution result code - * {@code (0: Success, -1: Unknown generic failure, >=1: User defined error code)} - * @param _msg The result message if any - */ - @EventLog(indexed=1) - void CallExecuted(BigInteger _reqId, int _code, String _msg); + /** + * The name of CallService. + */ + String NAME = "xcallM"; + + /* ======== At the source CALL_BSH ======== */ + /** + * Sends a call message to the contract on the destination chain. + * + * @param _to The BTP address of the callee on the destination chain + * @param _data The calldata specific to the target contract + * @param _rollback (Optional) The data for restoring the caller state when an + * error occurred + * @return The serial number of the request + */ + @Payable + @External + BigInteger sendCallMessage(String _to, + byte[] _data, + @Optional byte[] _rollback, + @Optional String[] _sources, + @Optional String[] _destinations); + + /** + * Handles incoming Messages. + * + * @param _from String ( Network id of source network ) + * @param _msg Bytes ( serialized bytes of CallMessage ) + */ + @External + void handleMessage(String _from, byte[] _msg); + + /** + * Handle the error on delivering the message. + * + * @param _sn Integer ( serial number of the original message ) + */ + + @External + void handleError(BigInteger _sn); + + /** + * Notifies that the requested call message has been sent. + * + * @param _from The chain-specific address of the caller + * @param _to The BTP address of the callee on the destination chain + * @param _sn The serial number of the request + */ + @EventLog(indexed = 3) + void CallMessageSent(Address _from, String _to, BigInteger _sn); + + /** + * Notifies that a response message has arrived for the `_sn` if the request was + * a two-way message. + * + * @param _sn The serial number of the previous request + * @param _code The response code + * {@code (0: Success, -1: Unknown generic failure, >=1: User defined error code)} + */ + @EventLog(indexed = 1) + void ResponseMessage(BigInteger _sn, int _code); + + /** + * Notifies the user that a rollback operation is required for the request + * '_sn'. + * + * @param _sn The serial number of the previous request + */ + @EventLog(indexed = 1) + void RollbackMessage(BigInteger _sn); + + /** + * Rollbacks the caller state of the request '_sn'. + * + * @param _sn The serial number of the previous request + */ + @External + void executeRollback(BigInteger _sn); + + /** + * Notifies that the rollback has been executed. + * + * @param _sn The serial number for the rollback + */ + @EventLog(indexed = 1) + void RollbackExecuted(BigInteger _sn); + + /* ======== At the destination CALL_BSH ======== */ + /** + * Notifies the user that a new call message has arrived. + * + * @param _from The BTP address of the caller on the source chain + * @param _to A string representation of the callee address + * @param _sn The serial number of the request from the source + * @param _reqId The request id of the destination chain + * @param _data The calldata + */ + @EventLog(indexed = 3) + void CallMessage(String _from, String _to, BigInteger _sn, BigInteger _reqId, byte[] _data); + + /** + * Executes the requested call message. + * + * @param _reqId The request id + */ + @External + void executeCall(BigInteger _reqId, byte[] _data); + + /** + * Notifies that the call message has been executed. + * + * @param _reqId The request id for the call message + * @param _code The execution result code + * {@code (0: Success, -1: Unknown generic failure, >=1: User defined error code)} + * @param _msg The result message if any + */ + @EventLog(indexed = 1) + void CallExecuted(BigInteger _reqId, int _code, String _msg); + + /** + * Returns the network id of the chain + * + * @return nid network id of the chain + */ + @External(readonly = true) + public String getNetworkId(); }