From 496af4ec1f6e7a65cc72d765b106b9f75045c00b Mon Sep 17 00:00:00 2001 From: adam Date: Thu, 15 Aug 2024 16:24:08 -0400 Subject: [PATCH] feat: add signature validation hooks --- src/account/UpgradeableModularAccount.sol | 21 ++++++- src/interfaces/IValidationHookModule.sol | 10 +--- src/modules/NativeTokenLimitModule.sol | 3 + .../permissionhooks/AllowlistModule.sol | 3 + test/account/PerHookData.t.sol | 47 ++++++++++++++++ test/account/UpgradeableModularAccount.t.sol | 2 +- test/mocks/modules/ComprehensiveModule.sol | 4 ++ .../modules/MockAccessControlHookModule.sol | 12 ++++ test/mocks/modules/ValidationModuleMocks.sol | 2 + test/utils/AccountTestBase.sol | 56 +++++++++++++++---- 10 files changed, 141 insertions(+), 19 deletions(-) diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 04c133e3..8c897879 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -295,8 +295,27 @@ contract UpgradeableModularAccount is function isValidSignature(bytes32 hash, bytes calldata signature) public view override returns (bytes4) { ModuleEntity sigValidation = ModuleEntity.wrap(bytes24(signature)); + signature = signature[24:]; - return _exec1271Validation(sigValidation, hash, signature[24:]); + ModuleEntity[] memory preSignatureValidationHooks = + getAccountStorage().validationData[sigValidation].preValidationHooks; + + for (uint256 i = 0; i < preSignatureValidationHooks.length; ++i) { + (address hookModule, uint32 hookEntityId) = preSignatureValidationHooks[i].unpack(); + + bytes memory currentSignatureSegment; + + (currentSignatureSegment, signature) = signature.advanceSegmentIfAtIndex(uint8(i)); + + // If this reverts, bubble up revert reason. + IValidationHookModule(hookModule).preSignatureValidationHook( + hookEntityId, msg.sender, hash, currentSignatureSegment + ); + } + + signature = signature.getFinalSegment(); + + return _exec1271Validation(sigValidation, hash, signature); } /// @notice Gets the entry point for this account diff --git a/src/interfaces/IValidationHookModule.sol b/src/interfaces/IValidationHookModule.sol index e7801266..54286eee 100644 --- a/src/interfaces/IValidationHookModule.sol +++ b/src/interfaces/IValidationHookModule.sol @@ -32,8 +32,6 @@ interface IValidationHookModule is IModule { bytes calldata authorization ) external; - // TODO: support this hook type within the account & in the manifest - /// @notice Run the pre signature validation hook specified by the `entityId`. /// @dev To indicate the call should revert, the function MUST revert. /// @param entityId An identifier that routes the call to different internal implementations, should there @@ -41,9 +39,7 @@ interface IValidationHookModule is IModule { /// @param sender The caller address. /// @param hash The hash of the message being signed. /// @param signature The signature of the message. - // function preSignatureValidationHook(uint32 entityId, address sender, bytes32 hash, bytes calldata - // signature) - // external - // view - // returns (bytes4); + function preSignatureValidationHook(uint32 entityId, address sender, bytes32 hash, bytes calldata signature) + external + view; } diff --git a/src/modules/NativeTokenLimitModule.sol b/src/modules/NativeTokenLimitModule.sol index 1007e8ae..967f2ab3 100644 --- a/src/modules/NativeTokenLimitModule.sol +++ b/src/modules/NativeTokenLimitModule.sol @@ -116,6 +116,9 @@ contract NativeTokenLimitModule is BaseModule, IExecutionHookModule, IValidation override {} // solhint-disable-line no-empty-blocks + // solhint-disable-next-line no-empty-blocks + function preSignatureValidationHook(uint32, address, bytes32, bytes calldata) external pure override {} + /// @inheritdoc IModule function moduleMetadata() external pure virtual override returns (ModuleMetadata memory) { ModuleMetadata memory metadata; diff --git a/src/modules/permissionhooks/AllowlistModule.sol b/src/modules/permissionhooks/AllowlistModule.sol index dd23a6c9..0302a4a6 100644 --- a/src/modules/permissionhooks/AllowlistModule.sol +++ b/src/modules/permissionhooks/AllowlistModule.sol @@ -85,6 +85,9 @@ contract AllowlistModule is IValidationHookModule, BaseModule { return; } + // solhint-disable-next-line no-empty-blocks + function preSignatureValidationHook(uint32, address, bytes32, bytes calldata) external pure override {} + function moduleMetadata() external pure override returns (ModuleMetadata memory) { ModuleMetadata memory metadata; metadata.name = "Allowlist Module"; diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol index 83941df9..505fd00f 100644 --- a/test/account/PerHookData.t.sol +++ b/test/account/PerHookData.t.sol @@ -355,6 +355,53 @@ contract PerHookDataTest is CustomValidationTestBase { ); } + function test_pass1271AccessControl() public { + string memory message = "Hello, world!"; + + bytes32 messageHash = keccak256(abi.encodePacked(message)); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, messageHash); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(message)}); + + bytes4 result = account1.isValidSignature( + messageHash, _encode1271Signature(_signerValidation, preValidationHookData, abi.encodePacked(r, s, v)) + ); + + assertEq(result, bytes4(0x1626ba7e)); + } + + function test_fail1271AccessControl_badSigData() public { + string memory message = "Hello, world!"; + + bytes32 messageHash = keccak256(abi.encodePacked(message)); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, messageHash); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({ + index: 0, + validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234)) + }); + + vm.expectRevert("Preimage not provided"); + account1.isValidSignature( + messageHash, _encode1271Signature(_signerValidation, preValidationHookData, abi.encodePacked(r, s, v)) + ); + } + + function test_fail1271AccessControl_noSigData() public { + string memory message = "Hello, world!"; + + bytes32 messageHash = keccak256(abi.encodePacked(message)); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, messageHash); + + vm.expectRevert("Preimage not provided"); + account1.isValidSignature(messageHash, _encode1271Signature(_signerValidation, abi.encodePacked(r, s, v))); + } + function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) { PackedUserOperation memory userOp = PackedUserOperation({ sender: address(account1), diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index e6b9d28a..7e9eb8b7 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -397,7 +397,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { // singleSignerValidationModule.ownerOf(address(account1)); - bytes memory signature = abi.encodePacked(_signerValidation, r, s, v); + bytes memory signature = _encode1271Signature(_signerValidation, abi.encodePacked(r, s, v)); bytes4 validationResult = IERC1271(address(account1)).isValidSignature(message, signature); diff --git a/test/mocks/modules/ComprehensiveModule.sol b/test/mocks/modules/ComprehensiveModule.sol index b16e5c0f..3fb98041 100644 --- a/test/mocks/modules/ComprehensiveModule.sol +++ b/test/mocks/modules/ComprehensiveModule.sol @@ -104,6 +104,10 @@ contract ComprehensiveModule is revert NotImplemented(); } + function preSignatureValidationHook(uint32, address, bytes32, bytes calldata) external pure override { + return; + } + function validateSignature(address, uint32 entityId, address, bytes32, bytes calldata) external pure diff --git a/test/mocks/modules/MockAccessControlHookModule.sol b/test/mocks/modules/MockAccessControlHookModule.sol index 4ef33181..91b2b331 100644 --- a/test/mocks/modules/MockAccessControlHookModule.sol +++ b/test/mocks/modules/MockAccessControlHookModule.sol @@ -72,5 +72,17 @@ contract MockAccessControlHookModule is IValidationHookModule, BaseModule { revert NotImplemented(); } + function preSignatureValidationHook(uint32, address, bytes32 hash, bytes calldata signature) + external + pure + override + { + // Simulates some signature checking by requiring a preimage of the hash. + + require(keccak256(signature) == hash, "Preimage not provided"); + + return; + } + function moduleMetadata() external pure override returns (ModuleMetadata memory) {} } diff --git a/test/mocks/modules/ValidationModuleMocks.sol b/test/mocks/modules/ValidationModuleMocks.sol index 4fae17c9..a143e03a 100644 --- a/test/mocks/modules/ValidationModuleMocks.sol +++ b/test/mocks/modules/ValidationModuleMocks.sol @@ -64,6 +64,8 @@ abstract contract MockBaseUserOpValidationModule is revert NotImplemented(); } + function preSignatureValidationHook(uint32, address, bytes32, bytes calldata) external pure override {} + function validateSignature(address, uint32, address, bytes32, bytes calldata) external pure diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 705690a8..05406b36 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -198,6 +198,21 @@ abstract contract AccountTestBase is OptimizedTest { return bytes32(uint256((g1 << 128) + uint128(g2))); } + // helper function to encode a 1271 signature, according to the per-hook and per-validation data format. + function _encode1271Signature( + ModuleEntity validationFunction, + PreValidationHookData[] memory preValidationHookData, + bytes memory validationData + ) internal pure returns (bytes memory) { + bytes memory sig = abi.encodePacked(validationFunction); + + sig = abi.encodePacked(sig, _packPreHookDatas(preValidationHookData)); + + sig = abi.encodePacked(sig, _packValidationResWithIndex(255, validationData)); + + return sig; + } + // helper function to encode a signature, according to the per-hook and per-validation data format. function _encodeSignature( ModuleEntity validationFunction, @@ -207,17 +222,8 @@ abstract contract AccountTestBase is OptimizedTest { ) internal pure returns (bytes memory) { bytes memory sig = abi.encodePacked(validationFunction, globalOrNot); - for (uint256 i = 0; i < preValidationHookData.length; ++i) { - sig = abi.encodePacked( - sig, - _packValidationResWithIndex( - preValidationHookData[i].index, preValidationHookData[i].validationData - ) - ); - } + sig = abi.encodePacked(sig, _packPreHookDatas(preValidationHookData)); - // Index of the actual validation data is the length of the preValidationHooksRetrieved - aka - // one-past-the-end sig = abi.encodePacked(sig, _packValidationResWithIndex(255, validationData)); return sig; @@ -233,6 +239,36 @@ abstract contract AccountTestBase is OptimizedTest { return _encodeSignature(validationFunction, globalOrNot, emptyPreValidationHookData, validationData); } + // overload for the case where there are no pre-validation hooks + function _encode1271Signature(ModuleEntity validationFunction, bytes memory validationData) + internal + pure + returns (bytes memory) + { + PreValidationHookData[] memory emptyPreValidationHookData = new PreValidationHookData[](0); + return _encode1271Signature(validationFunction, emptyPreValidationHookData, validationData); + } + + // helper function to pack pre-validation hook datas, according to the sparse calldata segment spec. + function _packPreHookDatas(PreValidationHookData[] memory preValidationHookData) + internal + pure + returns (bytes memory) + { + bytes memory res = ""; + + for (uint256 i = 0; i < preValidationHookData.length; ++i) { + res = abi.encodePacked( + res, + _packValidationResWithIndex( + preValidationHookData[i].index, preValidationHookData[i].validationData + ) + ); + } + + return res; + } + // helper function to pack validation data with an index, according to the sparse calldata segment spec. function _packValidationResWithIndex(uint8 index, bytes memory validationData) internal