From 0e510c6d6a04faf40feacc57024044411c13dfa5 Mon Sep 17 00:00:00 2001 From: adam Date: Thu, 15 Aug 2024 16:24:08 -0400 Subject: [PATCH] feat: add signature validation hooks --- src/account/UpgradeableModularAccount.sol | 29 ++++++++-- src/interfaces/IModuleManager.sol | 2 - src/interfaces/IValidationHookModule.sol | 10 +--- src/modules/NativeTokenLimitModule.sol | 4 ++ .../permissionhooks/AllowlistModule.sol | 4 ++ test/account/PerHookData.t.sol | 47 ++++++++++++++++ test/account/UpgradeableModularAccount.t.sol | 8 ++- test/mocks/modules/ComprehensiveModule.sol | 4 ++ .../modules/MockAccessControlHookModule.sol | 12 ++++ test/mocks/modules/ValidationModuleMocks.sol | 4 ++ test/utils/AccountTestBase.sol | 56 +++++++++++++++---- 11 files changed, 154 insertions(+), 26 deletions(-) diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 506706c5..3cc5d716 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -234,8 +234,6 @@ contract UpgradeableModularAccount is } /// @notice Initializes the account with a validation function added to the global pool. - /// TODO: remove and merge with regular initialization, after we figure out a better install/uninstall workflow - /// with user install configs. /// @dev This function is only callable once. function initializeWithValidation( ValidationConfig validationConfig, @@ -298,14 +296,35 @@ contract UpgradeableModularAccount is AccountStorage storage _storage = getAccountStorage(); ModuleEntity sigValidation = ModuleEntity.wrap(bytes24(signature)); + signature = signature[24:]; - (address module, uint32 entityId) = sigValidation.unpack(); if (!_storage.validationData[sigValidation].isSignatureValidation) { - revert SignatureValidationInvalid(module, entityId); + (address _module, uint32 _entityId) = sigValidation.unpack(); + revert SignatureValidationInvalid(_module, _entityId); + } + + ModuleEntity[] memory preSignatureValidationHooks = + getAccountStorage().validationData[sigValidation].preValidationHooks; + + for (uint256 i = 0; i < preSignatureValidationHooks.length; ++i) { + (address hookModule, uint32 hookEntityId) = preSignatureValidationHooks[i].unpack(); + + bytes memory currentSignatureSegment; + + (currentSignatureSegment, signature) = signature.advanceSegmentIfAtIndex(uint8(i)); + + // If this reverts, bubble up revert reason. + IValidationHookModule(hookModule).preSignatureValidationHook( + hookEntityId, msg.sender, hash, currentSignatureSegment + ); } + signature = signature.getFinalSegment(); + + (address module, uint32 entityId) = sigValidation.unpack(); + if ( - IValidationModule(module).validateSignature(address(this), entityId, msg.sender, hash, signature[24:]) + IValidationModule(module).validateSignature(address(this), entityId, msg.sender, hash, signature) == _1271_MAGIC_VALUE ) { return _1271_MAGIC_VALUE; diff --git a/src/interfaces/IModuleManager.sol b/src/interfaces/IModuleManager.sol index 6a435bf2..a171eb5a 100644 --- a/src/interfaces/IModuleManager.sol +++ b/src/interfaces/IModuleManager.sol @@ -30,7 +30,6 @@ interface IModuleManager { /// path. /// Installs a validation function across a set of execution selectors, and optionally mark it as a global /// validation. - /// TODO: remove or update. /// @dev This does not validate anything against the manifest - the caller must ensure validity. /// @param validationConfig The validation function to install, along with configuration flags. /// @param selectors The selectors to install the validation function for. @@ -46,7 +45,6 @@ interface IModuleManager { ) external; /// @notice Uninstall a validation function from a set of execution selectors. - /// TODO: remove or update. /// @param validationFunction The validation function to uninstall. /// @param uninstallData Optional data to be decoded and used by the module to clear module data for the /// account. diff --git a/src/interfaces/IValidationHookModule.sol b/src/interfaces/IValidationHookModule.sol index e7801266..54286eee 100644 --- a/src/interfaces/IValidationHookModule.sol +++ b/src/interfaces/IValidationHookModule.sol @@ -32,8 +32,6 @@ interface IValidationHookModule is IModule { bytes calldata authorization ) external; - // TODO: support this hook type within the account & in the manifest - /// @notice Run the pre signature validation hook specified by the `entityId`. /// @dev To indicate the call should revert, the function MUST revert. /// @param entityId An identifier that routes the call to different internal implementations, should there @@ -41,9 +39,7 @@ interface IValidationHookModule is IModule { /// @param sender The caller address. /// @param hash The hash of the message being signed. /// @param signature The signature of the message. - // function preSignatureValidationHook(uint32 entityId, address sender, bytes32 hash, bytes calldata - // signature) - // external - // view - // returns (bytes4); + function preSignatureValidationHook(uint32 entityId, address sender, bytes32 hash, bytes calldata signature) + external + view; } diff --git a/src/modules/NativeTokenLimitModule.sol b/src/modules/NativeTokenLimitModule.sol index 1007e8ae..976735bb 100644 --- a/src/modules/NativeTokenLimitModule.sol +++ b/src/modules/NativeTokenLimitModule.sol @@ -116,6 +116,10 @@ contract NativeTokenLimitModule is BaseModule, IExecutionHookModule, IValidation override {} // solhint-disable-line no-empty-blocks + function preSignatureValidationHook(uint32, address, bytes32, bytes calldata) external pure override { + return; + } + /// @inheritdoc IModule function moduleMetadata() external pure virtual override returns (ModuleMetadata memory) { ModuleMetadata memory metadata; diff --git a/src/modules/permissionhooks/AllowlistModule.sol b/src/modules/permissionhooks/AllowlistModule.sol index dd23a6c9..cdb8a108 100644 --- a/src/modules/permissionhooks/AllowlistModule.sol +++ b/src/modules/permissionhooks/AllowlistModule.sol @@ -85,6 +85,10 @@ contract AllowlistModule is IValidationHookModule, BaseModule { return; } + function preSignatureValidationHook(uint32, address, bytes32, bytes calldata) external pure override { + return; + } + function moduleMetadata() external pure override returns (ModuleMetadata memory) { ModuleMetadata memory metadata; metadata.name = "Allowlist Module"; diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol index 9bce88ed..d180be2b 100644 --- a/test/account/PerHookData.t.sol +++ b/test/account/PerHookData.t.sol @@ -352,6 +352,53 @@ contract PerHookDataTest is CustomValidationTestBase { ); } + function test_pass1271AccessControl() public { + string memory message = "Hello, world!"; + + bytes32 messageHash = keccak256(abi.encodePacked(message)); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, messageHash); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(message)}); + + bytes4 result = account1.isValidSignature( + messageHash, _encode1271Signature(_signerValidation, preValidationHookData, abi.encodePacked(r, s, v)) + ); + + assertEq(result, bytes4(0x1626ba7e)); + } + + function test_fail1271AccessControl_badSigData() public { + string memory message = "Hello, world!"; + + bytes32 messageHash = keccak256(abi.encodePacked(message)); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, messageHash); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({ + index: 0, + validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234)) + }); + + vm.expectRevert("Preimage not provided"); + account1.isValidSignature( + messageHash, _encode1271Signature(_signerValidation, preValidationHookData, abi.encodePacked(r, s, v)) + ); + } + + function test_fail1271AccessControl_noSigData() public { + string memory message = "Hello, world!"; + + bytes32 messageHash = keccak256(abi.encodePacked(message)); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, messageHash); + + vm.expectRevert("Preimage not provided"); + account1.isValidSignature(messageHash, _encode1271Signature(_signerValidation, abi.encodePacked(r, s, v))); + } + function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) { PackedUserOperation memory userOp = PackedUserOperation({ sender: address(account1), diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index e4394f34..92ad67ba 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -12,6 +12,8 @@ import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/Messa import {ModuleManagerInternals} from "../../src/account/ModuleManagerInternals.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; + import {ExecutionDataView, IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; import {ExecutionManifest} from "../../src/interfaces/IExecutionModule.sol"; import {IModuleManager} from "../../src/interfaces/IModuleManager.sol"; @@ -381,8 +383,10 @@ contract UpgradeableModularAccountTest is AccountTestBase { // singleSignerValidationModule.ownerOf(address(account1)); - bytes memory signature = - abi.encodePacked(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID, r, s, v); + bytes memory signature = _encode1271Signature( + ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), + abi.encodePacked(r, s, v) + ); bytes4 validationResult = IERC1271(address(account1)).isValidSignature(message, signature); diff --git a/test/mocks/modules/ComprehensiveModule.sol b/test/mocks/modules/ComprehensiveModule.sol index b16e5c0f..3fb98041 100644 --- a/test/mocks/modules/ComprehensiveModule.sol +++ b/test/mocks/modules/ComprehensiveModule.sol @@ -104,6 +104,10 @@ contract ComprehensiveModule is revert NotImplemented(); } + function preSignatureValidationHook(uint32, address, bytes32, bytes calldata) external pure override { + return; + } + function validateSignature(address, uint32 entityId, address, bytes32, bytes calldata) external pure diff --git a/test/mocks/modules/MockAccessControlHookModule.sol b/test/mocks/modules/MockAccessControlHookModule.sol index 4ef33181..91b2b331 100644 --- a/test/mocks/modules/MockAccessControlHookModule.sol +++ b/test/mocks/modules/MockAccessControlHookModule.sol @@ -72,5 +72,17 @@ contract MockAccessControlHookModule is IValidationHookModule, BaseModule { revert NotImplemented(); } + function preSignatureValidationHook(uint32, address, bytes32 hash, bytes calldata signature) + external + pure + override + { + // Simulates some signature checking by requiring a preimage of the hash. + + require(keccak256(signature) == hash, "Preimage not provided"); + + return; + } + function moduleMetadata() external pure override returns (ModuleMetadata memory) {} } diff --git a/test/mocks/modules/ValidationModuleMocks.sol b/test/mocks/modules/ValidationModuleMocks.sol index 4fae17c9..25154a3d 100644 --- a/test/mocks/modules/ValidationModuleMocks.sol +++ b/test/mocks/modules/ValidationModuleMocks.sol @@ -64,6 +64,10 @@ abstract contract MockBaseUserOpValidationModule is revert NotImplemented(); } + function preSignatureValidationHook(uint32, address, bytes32, bytes calldata) external pure override { + return; + } + function validateSignature(address, uint32, address, bytes32, bytes calldata) external pure diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 8641f79b..81b1ade7 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -206,6 +206,21 @@ abstract contract AccountTestBase is OptimizedTest { return bytes32(uint256((g1 << 128) + uint128(g2))); } + // helper function to encode a 1271 signature, according to the per-hook and per-validation data format. + function _encode1271Signature( + ModuleEntity validationFunction, + PreValidationHookData[] memory preValidationHookData, + bytes memory validationData + ) internal pure returns (bytes memory) { + bytes memory sig = abi.encodePacked(validationFunction); + + sig = abi.encodePacked(sig, _packPreHookDatas(preValidationHookData)); + + sig = abi.encodePacked(sig, _packValidationResWithIndex(255, validationData)); + + return sig; + } + // helper function to encode a signature, according to the per-hook and per-validation data format. function _encodeSignature( ModuleEntity validationFunction, @@ -215,17 +230,8 @@ abstract contract AccountTestBase is OptimizedTest { ) internal pure returns (bytes memory) { bytes memory sig = abi.encodePacked(validationFunction, globalOrNot); - for (uint256 i = 0; i < preValidationHookData.length; ++i) { - sig = abi.encodePacked( - sig, - _packValidationResWithIndex( - preValidationHookData[i].index, preValidationHookData[i].validationData - ) - ); - } + sig = abi.encodePacked(sig, _packPreHookDatas(preValidationHookData)); - // Index of the actual validation data is the length of the preValidationHooksRetrieved - aka - // one-past-the-end sig = abi.encodePacked(sig, _packValidationResWithIndex(255, validationData)); return sig; @@ -241,6 +247,36 @@ abstract contract AccountTestBase is OptimizedTest { return _encodeSignature(validationFunction, globalOrNot, emptyPreValidationHookData, validationData); } + // overload for the case where there are no pre-validation hooks + function _encode1271Signature(ModuleEntity validationFunction, bytes memory validationData) + internal + pure + returns (bytes memory) + { + PreValidationHookData[] memory emptyPreValidationHookData = new PreValidationHookData[](0); + return _encode1271Signature(validationFunction, emptyPreValidationHookData, validationData); + } + + // helper function to pack pre-validation hook datas, according to the sparse calldata segment spec. + function _packPreHookDatas(PreValidationHookData[] memory preValidationHookData) + internal + pure + returns (bytes memory) + { + bytes memory res = ""; + + for (uint256 i = 0; i < preValidationHookData.length; ++i) { + res = abi.encodePacked( + res, + _packValidationResWithIndex( + preValidationHookData[i].index, preValidationHookData[i].validationData + ) + ); + } + + return res; + } + // helper function to pack validation data with an index, according to the sparse calldata segment spec. function _packValidationResWithIndex(uint8 index, bytes memory validationData) internal