diff --git a/src/account/ReferenceModularAccount.sol b/src/account/ReferenceModularAccount.sol index 9be42767..78aa23e9 100644 --- a/src/account/ReferenceModularAccount.sol +++ b/src/account/ReferenceModularAccount.sol @@ -52,6 +52,12 @@ contract ReferenceModularAccount is ModuleEntity postExecHook; } + enum ValidationCheckingType { + GLOBAL, + SELECTOR, + EITHER + } + IEntryPoint private immutable _ENTRY_POINT; // As per the EIP-165 spec, no interface should ever match 0xffffffff @@ -187,7 +193,11 @@ contract ReferenceModularAccount is // Check if the runtime validation function is allowed to be called bool isGlobalValidation = uint8(authorization[24]) == 1; - _checkIfValidationAppliesCallData(data, runtimeValidationFunction, isGlobalValidation); + _checkIfValidationAppliesCallData( + data, + runtimeValidationFunction, + isGlobalValidation ? ValidationCheckingType.GLOBAL : ValidationCheckingType.SELECTOR + ); _doRuntimeValidation(runtimeValidationFunction, data, authorization[25:]); @@ -343,7 +353,11 @@ contract ReferenceModularAccount is ModuleEntity userOpValidationFunction = ModuleEntity.wrap(bytes24(userOp.signature[:24])); bool isGlobalValidation = uint8(userOp.signature[24]) == 1; - _checkIfValidationAppliesCallData(userOp.callData, userOpValidationFunction, isGlobalValidation); + _checkIfValidationAppliesCallData( + userOp.callData, + userOpValidationFunction, + isGlobalValidation ? ValidationCheckingType.GLOBAL : ValidationCheckingType.SELECTOR + ); // Check if there are permission hooks associated with the validator, and revert if the call isn't to // `executeUserOp` @@ -549,7 +563,7 @@ contract ReferenceModularAccount is ModuleEntity directCallValidationKey = ModuleEntityLib.pack(msg.sender, DIRECT_CALL_VALIDATION_ENTITYID); - _checkIfValidationAppliesCallData(msg.data, directCallValidationKey, false); + _checkIfValidationAppliesCallData(msg.data, directCallValidationKey, ValidationCheckingType.EITHER); // Direct call is allowed, run associated permission & validation hooks @@ -645,7 +659,7 @@ contract ReferenceModularAccount is function _checkIfValidationAppliesCallData( bytes calldata callData, ModuleEntity validationFunction, - bool isGlobal + ValidationCheckingType checkingType ) internal view { bytes4 outerSelector = bytes4(callData[:4]); if (outerSelector == this.executeUserOp.selector) { @@ -655,7 +669,7 @@ contract ReferenceModularAccount is outerSelector = bytes4(callData[:4]); } - _checkIfValidationAppliesSelector(outerSelector, validationFunction, isGlobal); + _checkIfValidationAppliesSelector(outerSelector, validationFunction, checkingType); if (outerSelector == IModularAccount.execute.selector) { (address target,,) = abi.decode(callData[4:], (address, uint256, bytes)); @@ -689,26 +703,50 @@ contract ReferenceModularAccount is revert SelfCallRecursionDepthExceeded(); } - _checkIfValidationAppliesSelector(nestedSelector, validationFunction, isGlobal); + _checkIfValidationAppliesSelector(nestedSelector, validationFunction, checkingType); } } } } - function _checkIfValidationAppliesSelector(bytes4 selector, ModuleEntity validationFunction, bool isGlobal) - internal - view - { + function _checkIfValidationAppliesSelector( + bytes4 selector, + ModuleEntity validationFunction, + ValidationCheckingType checkingType + ) internal view { // Check that the provided validation function is applicable to the selector - if (isGlobal) { - if (!_globalValidationAllowed(selector) || !_isValidationGlobal(validationFunction)) { + + if (checkingType == ValidationCheckingType.GLOBAL) { + if (!_globalValidationApplies(selector, validationFunction)) { + revert ValidationFunctionMissing(selector); + } + } else if (checkingType == ValidationCheckingType.SELECTOR) { + if (!_selectorValidationApplies(selector, validationFunction)) { revert ValidationFunctionMissing(selector); } } else { - // Not global validation, but per-selector - if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) { + if ( + !_globalValidationApplies(selector, validationFunction) + && !_selectorValidationApplies(selector, validationFunction) + ) { revert ValidationFunctionMissing(selector); } } } + + function _globalValidationApplies(bytes4 selector, ModuleEntity validationFunction) + internal + view + returns (bool) + { + return _globalValidationAllowed(selector) && _isValidationGlobal(validationFunction); + } + + function _selectorValidationApplies(bytes4 selector, ModuleEntity validationFunction) + internal + view + returns (bool) + { + return getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector)); + } } diff --git a/test/account/DirectCallsFromModule.t.sol b/test/account/DirectCallsFromModule.t.sol index 98544c11..c85a9c31 100644 --- a/test/account/DirectCallsFromModule.t.sol +++ b/test/account/DirectCallsFromModule.t.sol @@ -21,6 +21,15 @@ contract DirectCallsFromModuleTest is AccountTestBase { event ValidationUninstalled(address indexed module, uint32 indexed entityId, bool onUninstallSucceeded); + modifier randomizedValidationType(bool selectorValidation) { + if (selectorValidation) { + _installValidationSelector(); + } else { + _installValidationGlobal(); + } + _; + } + function setUp() public { _module = new DirectCallModule(); assertFalse(_module.preHookRan()); @@ -38,9 +47,10 @@ contract DirectCallsFromModuleTest is AccountTestBase { account1.execute(address(0), 0, ""); } - function test_Fail_DirectCallModuleUninstalled() external { - _installValidation(); - + function testFuzz_Fail_DirectCallModuleUninstalled(bool validationType) + external + randomizedValidationType(validationType) + { _uninstallValidation(); vm.prank(address(_module)); @@ -49,7 +59,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { } function test_Fail_DirectCallModuleCallOtherSelector() external { - _installValidation(); + _installValidationSelector(); Call[] memory calls = new Call[](0); @@ -62,9 +72,10 @@ contract DirectCallsFromModuleTest is AccountTestBase { /* Positives */ /* -------------------------------------------------------------------------- */ - function test_Pass_DirectCallFromModulePrank() external { - _installValidation(); - + function testFuzz_Pass_DirectCallFromModulePrank(bool validationType) + external + randomizedValidationType(validationType) + { vm.prank(address(_module)); account1.execute(address(0), 0, ""); @@ -72,9 +83,10 @@ contract DirectCallsFromModuleTest is AccountTestBase { assertTrue(_module.postHookRan()); } - function test_Pass_DirectCallFromModuleCallback() external { - _installValidation(); - + function testFuzz_Pass_DirectCallFromModuleCallback(bool validationType) + external + randomizedValidationType(validationType) + { bytes memory encodedCall = abi.encodeCall(DirectCallModule.directCall, ()); vm.prank(address(entryPoint)); @@ -88,11 +100,12 @@ contract DirectCallsFromModuleTest is AccountTestBase { assertEq(abi.decode(result, (bytes)), abi.encode(_module.getData())); } - function test_Flow_DirectCallFromModuleSequence() external { + function testFuzz_Flow_DirectCallFromModuleSequence(bool validationType) + external + randomizedValidationType(validationType) + { // Install => Succeesfully call => uninstall => fail to call - _installValidation(); - vm.prank(address(_module)); account1.execute(address(0), 0, ""); @@ -129,7 +142,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { /* Internals */ /* -------------------------------------------------------------------------- */ - function _installValidation() internal { + function _installValidationSelector() internal { bytes4[] memory selectors = new bytes4[](1); selectors[0] = IModularAccount.execute.selector; @@ -146,6 +159,20 @@ contract DirectCallsFromModuleTest is AccountTestBase { account1.installValidation(validationConfig, selectors, "", hooks); } + function _installValidationGlobal() internal { + bytes[] memory hooks = new bytes[](1); + hooks[0] = abi.encodePacked( + HookConfigLib.packExecHook({_hookFunction: _moduleEntity, _hasPre: true, _hasPost: true}), + hex"00" // onInstall data + ); + + vm.prank(address(entryPoint)); + + ValidationConfig validationConfig = ValidationConfigLib.pack(_moduleEntity, true, false); + + account1.installValidation(validationConfig, new bytes4[](0), "", hooks); + } + function _uninstallValidation() internal { (address module, uint32 entityId) = ModuleEntityLib.unpack(_moduleEntity); vm.prank(address(entryPoint));