diff --git a/.solhint-test.json b/.solhint-test.json index fd2b1007..3224b9d0 100644 --- a/.solhint-test.json +++ b/.solhint-test.json @@ -1,20 +1,20 @@ { - "extends": "solhint:recommended", - "rules": { - "func-name-mixedcase": "off", - "immutable-vars-naming": ["error"], - "no-unused-import": ["error"], - "compiler-version": ["error", ">=0.8.19"], - "custom-errors": "off", - "func-visibility": ["error", { "ignoreConstructors": true }], - "max-line-length": ["error", 120], - "max-states-count": ["warn", 30], - "modifier-name-mixedcase": ["error"], - "private-vars-leading-underscore": ["error"], - "no-inline-assembly": "off", - "avoid-low-level-calls": "off", - "one-contract-per-file": "off", - "no-empty-blocks": "off" - } + "extends": "solhint:recommended", + "rules": { + "func-name-mixedcase": "off", + "immutable-vars-naming": ["error"], + "no-unused-import": ["error"], + "compiler-version": ["error", ">=0.8.19"], + "custom-errors": "off", + "func-visibility": ["error", { "ignoreConstructors": true }], + "max-line-length": ["error", 120], + "max-states-count": ["warn", 30], + "modifier-name-mixedcase": ["error"], + "private-vars-leading-underscore": ["error"], + "no-inline-assembly": "off", + "avoid-low-level-calls": "off", + "one-contract-per-file": "off", + "no-empty-blocks": "off", + "reason-string": ["warn", { "maxLength": 64 }] } - \ No newline at end of file +} diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index df081bd0..aa56f75d 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -30,7 +30,7 @@ struct ValidationData { bool isGlobal; // Whether or not this validation is a signature validator. bool isSignatureValidation; - // The pre validation hooks for this function selector. + // The pre validation hooks for this validation function. PluginEntity[] preValidationHooks; // Permission hooks for this validation function. EnumerableSet.Bytes32Set permissionHooks; diff --git a/src/account/PluginManager2.sol b/src/account/PluginManager2.sol index 1f121880..2ff6a9d1 100644 --- a/src/account/PluginManager2.sol +++ b/src/account/PluginManager2.sol @@ -7,7 +7,7 @@ import {IPlugin} from "../interfaces/IPlugin.sol"; import {PluginEntity, ValidationConfig} from "../interfaces/IPluginManager.sol"; import {PluginEntityLib} from "../helpers/PluginEntityLib.sol"; import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol"; -import {ValidationData, getAccountStorage, toSetValue, toPluginEntity} from "./AccountStorage.sol"; +import {ValidationData, getAccountStorage, toSetValue} from "./AccountStorage.sol"; import {ExecutionHook} from "../interfaces/IAccountLoupe.sol"; // Temporary additional functions for a user-controlled install flow for validation functions. @@ -17,6 +17,7 @@ abstract contract PluginManager2 { // Index marking the start of the data for the validation function. uint8 internal constant _RESERVED_VALIDATION_DATA_INDEX = 255; + uint32 internal constant _SELF_PERMIT_VALIDATION_FUNCTIONID = type(uint32).max; error PreValidationAlreadySet(PluginEntity validationFunction, PluginEntity preValidationFunction); error ValidationAlreadySet(bytes4 selector, PluginEntity validationFunction); @@ -32,7 +33,7 @@ abstract contract PluginManager2 { bytes memory permissionHooks ) internal { ValidationData storage _validationData = - getAccountStorage().validationData[validationConfig.functionReference()]; + getAccountStorage().validationData[validationConfig.pluginEntity()]; if (preValidationHooks.length > 0) { (PluginEntity[] memory preValidationFunctions, bytes[] memory initDatas) = @@ -63,7 +64,7 @@ abstract contract PluginManager2 { ExecutionHook memory permissionFunction = permissionFunctions[i]; if (!_validationData.permissionHooks.add(toSetValue(permissionFunction))) { - revert PermissionAlreadySet(validationConfig.functionReference(), permissionFunction); + revert PermissionAlreadySet(validationConfig.pluginEntity(), permissionFunction); } if (initDatas[i].length > 0) { @@ -73,19 +74,21 @@ abstract contract PluginManager2 { } } - _validationData.isGlobal = validationConfig.isGlobal(); - _validationData.isSignatureValidation = validationConfig.isSignatureValidation(); - for (uint256 i = 0; i < selectors.length; ++i) { bytes4 selector = selectors[i]; if (!_validationData.selectors.add(toSetValue(selector))) { - revert ValidationAlreadySet(selector, validationConfig.functionReference()); + revert ValidationAlreadySet(selector, validationConfig.pluginEntity()); } } - if (installData.length > 0) { - address plugin = validationConfig.plugin(); - IPlugin(plugin).onInstall(installData); + if (validationConfig.entityId() != _SELF_PERMIT_VALIDATION_FUNCTIONID) { + // Only allow global validations and signature validations if they're not direct-call validations. + + _validationData.isGlobal = validationConfig.isGlobal(); + _validationData.isSignatureValidation = validationConfig.isSignatureValidation(); + if (installData.length > 0) { + IPlugin(validationConfig.plugin()).onInstall(installData); + } } } @@ -120,12 +123,12 @@ abstract contract PluginManager2 { // Clear permission hooks EnumerableSet.Bytes32Set storage permissionHooks = _validationData.permissionHooks; - uint256 i = 0; - while (permissionHooks.length() > 0) { - PluginEntity permissionHook = toPluginEntity(permissionHooks.at(0)); - permissionHooks.remove(toSetValue(permissionHook)); - (address permissionHookPlugin,) = PluginEntityLib.unpack(permissionHook); - IPlugin(permissionHookPlugin).onUninstall(permissionHookUninstallDatas[i++]); + uint256 len = permissionHooks.length(); + for (uint256 i = 0; i < len; ++i) { + bytes32 permissionHook = permissionHooks.at(0); + permissionHooks.remove(permissionHook); + address permissionHookPlugin = address(uint160(bytes20(permissionHook))); + IPlugin(permissionHookPlugin).onUninstall(permissionHookUninstallDatas[i]); } } delete _validationData.preValidationHooks; diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index acbd6de5..6c577ed2 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -78,7 +78,7 @@ contract UpgradeableModularAccount is error SignatureValidationInvalid(address plugin, uint32 entityId); error UnexpectedAggregator(address plugin, uint32 entityId, address aggregator); error UnrecognizedFunction(bytes4 selector); - error UserOpValidationFunctionMissing(bytes4 selector); + error ValidationFunctionMissing(bytes4 selector); error ValidationDoesNotApply(bytes4 selector, address plugin, uint32 entityId, bool isGlobal); error ValidationSignatureSegmentMissing(); error SignatureSegmentOutOfOrder(); @@ -86,14 +86,13 @@ contract UpgradeableModularAccount is // Wraps execution of a native function with runtime validation and hooks // Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin modifier wrapNativeFunction() { - _checkPermittedCallerIfNotFromEP(); - - PostExecToRun[] memory postExecHooks = - _doPreHooks(getAccountStorage().selectorData[msg.sig].executionHooks, msg.data); + (PostExecToRun[] memory postPermissionHooks, PostExecToRun[] memory postExecHooks) = + _checkPermittedCallerAndAssociatedHooks(); _; _doCachedPostExecHooks(postExecHooks); + _doCachedPostExecHooks(postPermissionHooks); } constructor(IEntryPoint anEntryPoint) { @@ -136,7 +135,7 @@ contract UpgradeableModularAccount is revert UnrecognizedFunction(msg.sig); } - _checkPermittedCallerIfNotFromEP(); + _checkPermittedCallerAndAssociatedHooks(); PostExecToRun[] memory postExecHooks; // Cache post-exec hooks in memory @@ -500,17 +499,7 @@ contract UpgradeableModularAccount is } else { currentAuthData = ""; } - - (address hookPlugin, uint32 hookEntityId) = preRuntimeValidationHooks[i].unpack(); - try IValidationHook(hookPlugin).preRuntimeValidationHook( - hookEntityId, msg.sender, msg.value, callData, currentAuthData - ) - // forgefmt: disable-start - // solhint-disable-next-line no-empty-blocks - {} catch (bytes memory revertReason) { - // forgefmt: disable-end - revert PreRuntimeValidationHookFailed(hookPlugin, hookEntityId, revertReason); - } + _doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthData); } if (authSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) { @@ -605,9 +594,78 @@ contract UpgradeableModularAccount is } } + function _doPreRuntimeValidationHook( + PluginEntity validationHook, + bytes memory callData, + bytes memory currentAuthData + ) internal { + (address hookPlugin, uint32 hookEntityId) = validationHook.unpack(); + try IValidationHook(hookPlugin).preRuntimeValidationHook( + hookEntityId, msg.sender, msg.value, callData, currentAuthData + ) + // forgefmt: disable-start + // solhint-disable-next-line no-empty-blocks + {} catch (bytes memory revertReason) { + // forgefmt: disable-end + revert PreRuntimeValidationHookFailed(hookPlugin, hookEntityId, revertReason); + } + } + // solhint-disable-next-line no-empty-blocks function _authorizeUpgrade(address newImplementation) internal override {} + /** + * Order of operations: + * 1. Check if the sender is the entry point, the account itself, or the selector called is public. + * - Yes: Return an empty array, there are no post-permissionHooks. + * - No: Continue + * 2. Check if the called selector (msg.sig) is included in the set of selectors the msg.sender can + * directly call. + * - Yes: Continue + * - No: Revert, the caller is not allowed to call this selector + * 3. If there are runtime validation hooks associated with this caller-sig combination, run them. + * 4. Run the pre-permissionHooks associated with this caller-sig combination, and return the + * post-permissionHooks to run later. + */ + function _checkPermittedCallerAndAssociatedHooks() + internal + returns (PostExecToRun[] memory, PostExecToRun[] memory) + { + AccountStorage storage _storage = getAccountStorage(); + + if ( + msg.sender == address(_ENTRY_POINT) || msg.sender == address(this) + || _storage.selectorData[msg.sig].isPublic + ) { + return (new PostExecToRun[](0), new PostExecToRun[](0)); + } + + PluginEntity directCallValidationKey = PluginEntityLib.pack(msg.sender, _SELF_PERMIT_VALIDATION_FUNCTIONID); + + _checkIfValidationAppliesCallData(msg.data, directCallValidationKey, false); + + // Direct call is allowed, run associated permission & validation hooks + + // Validation hooks + PluginEntity[] memory preRuntimeValidationHooks = + _storage.validationData[directCallValidationKey].preValidationHooks; + + uint256 hookLen = preRuntimeValidationHooks.length; + for (uint256 i = 0; i < hookLen; ++i) { + _doPreRuntimeValidationHook(preRuntimeValidationHooks[i], msg.data, ""); + } + + // Permission hooks + PostExecToRun[] memory postPermissionHooks = + _doPreHooks(_storage.validationData[directCallValidationKey].permissionHooks, msg.data); + + // Exec hooks + PostExecToRun[] memory postExecutionHooks = + _doPreHooks(_storage.selectorData[msg.sig].executionHooks, msg.data); + + return (postPermissionHooks, postExecutionHooks); + } + function _checkIfValidationAppliesCallData( bytes calldata callData, PluginEntity validationFunction, @@ -661,25 +719,6 @@ contract UpgradeableModularAccount is } } - function _checkIfValidationAppliesSelector(bytes4 selector, PluginEntity validationFunction, bool isGlobal) - internal - view - { - AccountStorage storage _storage = getAccountStorage(); - - // Check that the provided validation function is applicable to the selector - if (isGlobal) { - if (!_globalValidationAllowed(selector) || !_storage.validationData[validationFunction].isGlobal) { - revert UserOpValidationFunctionMissing(selector); - } - } else { - // Not global validation, but per-selector - if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) { - revert UserOpValidationFunctionMissing(selector); - } - } - } - function _globalValidationAllowed(bytes4 selector) internal view returns (bool) { if ( selector == this.execute.selector || selector == this.executeBatch.selector @@ -693,14 +732,22 @@ contract UpgradeableModularAccount is return getAccountStorage().selectorData[selector].allowGlobalValidation; } - function _checkPermittedCallerIfNotFromEP() internal view { + function _checkIfValidationAppliesSelector(bytes4 selector, PluginEntity validationFunction, bool isGlobal) + internal + view + { AccountStorage storage _storage = getAccountStorage(); - if ( - msg.sender != address(_ENTRY_POINT) && msg.sender != address(this) - && !_storage.selectorData[msg.sig].isPublic - ) { - revert ExecFromPluginNotPermitted(msg.sender, msg.sig); + // Check that the provided validation function is applicable to the selector + if (isGlobal) { + if (!_globalValidationAllowed(selector) || !_storage.validationData[validationFunction].isGlobal) { + revert ValidationFunctionMissing(selector); + } + } else { + // Not global validation, but per-selector + if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) { + revert ValidationFunctionMissing(selector); + } } } } diff --git a/src/helpers/ValidationConfigLib.sol b/src/helpers/ValidationConfigLib.sol index 95e8ea90..6d27b907 100644 --- a/src/helpers/ValidationConfigLib.sol +++ b/src/helpers/ValidationConfigLib.sol @@ -78,7 +78,7 @@ library ValidationConfigLib { return uint32(bytes4(ValidationConfig.unwrap(config) << 160)); } - function functionReference(ValidationConfig config) internal pure returns (PluginEntity) { + function pluginEntity(ValidationConfig config) internal pure returns (PluginEntity) { return PluginEntity.wrap(bytes24(ValidationConfig.unwrap(config))); } diff --git a/test/account/DirectCallsFromPlugin.t.sol b/test/account/DirectCallsFromPlugin.t.sol new file mode 100644 index 00000000..1c0fcca8 --- /dev/null +++ b/test/account/DirectCallsFromPlugin.t.sol @@ -0,0 +1,133 @@ +pragma solidity ^0.8.19; + +import {DirectCallPlugin} from "../mocks/plugins/DirectCallPlugin.sol"; +import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol"; +import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {PluginEntityLib, PluginEntity} from "../../src/helpers/PluginEntityLib.sol"; +import {ValidationConfig, ValidationConfigLib} from "../../src/helpers/ValidationConfigLib.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; + +import {AccountTestBase} from "../utils/AccountTestBase.sol"; + +contract DirectCallsFromPluginTest is AccountTestBase { + using ValidationConfigLib for ValidationConfig; + + DirectCallPlugin internal _plugin; + PluginEntity internal _pluginEntity; + + function setUp() public { + _plugin = new DirectCallPlugin(); + assertFalse(_plugin.preHookRan()); + assertFalse(_plugin.postHookRan()); + _pluginEntity = PluginEntityLib.pack(address(_plugin), type(uint32).max); + } + + /* -------------------------------------------------------------------------- */ + /* Negatives */ + /* -------------------------------------------------------------------------- */ + + function test_Fail_DirectCallPluginNotInstalled() external { + vm.prank(address(_plugin)); + vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.execute.selector)); + account1.execute(address(0), 0, ""); + } + + function test_Fail_DirectCallPluginUninstalled() external { + _installPlugin(); + + _uninstallPlugin(); + + vm.prank(address(_plugin)); + vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.execute.selector)); + account1.execute(address(0), 0, ""); + } + + function test_Fail_DirectCallPluginCallOtherSelector() external { + _installPlugin(); + + Call[] memory calls = new Call[](0); + + vm.prank(address(_plugin)); + vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.executeBatch.selector)); + account1.executeBatch(calls); + } + + /* -------------------------------------------------------------------------- */ + /* Positives */ + /* -------------------------------------------------------------------------- */ + + function test_Pass_DirectCallFromPluginPrank() external { + _installPlugin(); + + vm.prank(address(_plugin)); + account1.execute(address(0), 0, ""); + + assertTrue(_plugin.preHookRan()); + assertTrue(_plugin.postHookRan()); + } + + function test_Pass_DirectCallFromPluginCallback() external { + _installPlugin(); + + bytes memory encodedCall = abi.encodeCall(DirectCallPlugin.directCall, ()); + + vm.prank(address(entryPoint)); + bytes memory result = account1.execute(address(_plugin), 0, encodedCall); + + assertTrue(_plugin.preHookRan()); + assertTrue(_plugin.postHookRan()); + + // the directCall() function in the _plugin calls back into `execute()` with an encoded call back into the + // _plugin's getData() function. + assertEq(abi.decode(result, (bytes)), abi.encode(_plugin.getData())); + } + + function test_Flow_DirectCallFromPluginSequence() external { + // Install => Succeesfully call => uninstall => fail to call + + _installPlugin(); + + vm.prank(address(_plugin)); + account1.execute(address(0), 0, ""); + + assertTrue(_plugin.preHookRan()); + assertTrue(_plugin.postHookRan()); + + _uninstallPlugin(); + + vm.prank(address(_plugin)); + vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.execute.selector)); + account1.execute(address(0), 0, ""); + } + + /* -------------------------------------------------------------------------- */ + /* Internals */ + /* -------------------------------------------------------------------------- */ + + function _installPlugin() internal { + bytes4[] memory selectors = new bytes4[](1); + selectors[0] = IStandardExecutor.execute.selector; + + ExecutionHook[] memory permissionHooks = new ExecutionHook[](1); + bytes[] memory permissionHookInitDatas = new bytes[](1); + + permissionHooks[0] = ExecutionHook({hookFunction: _pluginEntity, isPreHook: true, isPostHook: true}); + + bytes memory encodedPermissionHooks = abi.encode(permissionHooks, permissionHookInitDatas); + + vm.prank(address(entryPoint)); + + ValidationConfig validationConfig = ValidationConfigLib.pack(_pluginEntity, false, false); + + account1.installValidation(validationConfig, selectors, "", "", encodedPermissionHooks); + } + + function _uninstallPlugin() internal { + vm.prank(address(entryPoint)); + account1.uninstallValidation(_pluginEntity, "", abi.encode(new bytes[](0)), abi.encode(new bytes[](1))); + } + + function _buildDirectCallDisallowedError(bytes4 selector) internal pure returns (bytes memory) { + return abi.encodeWithSelector(UpgradeableModularAccount.ValidationFunctionMissing.selector, selector); + } +} diff --git a/test/account/PermittedCallPermissions.t.sol b/test/account/PermittedCallPermissions.t.sol index 18257955..885015b0 100644 --- a/test/account/PermittedCallPermissions.t.sol +++ b/test/account/PermittedCallPermissions.t.sol @@ -48,9 +48,7 @@ contract PermittedCallPermissionsTest is AccountTestBase { function test_permittedCall_NotAllowed() public { vm.expectRevert( abi.encodeWithSelector( - UpgradeableModularAccount.ExecFromPluginNotPermitted.selector, - address(permittedCallerPlugin), - ResultCreatorPlugin.bar.selector + UpgradeableModularAccount.ValidationFunctionMissing.selector, ResultCreatorPlugin.bar.selector ) ); PermittedCallerPlugin(address(account1)).usePermittedCallNotAllowed(); diff --git a/test/account/SelfCallAuthorization.t.sol b/test/account/SelfCallAuthorization.t.sol index c490eea4..4b6d04c7 100644 --- a/test/account/SelfCallAuthorization.t.sol +++ b/test/account/SelfCallAuthorization.t.sol @@ -40,8 +40,7 @@ contract SelfCallAuthorizationTest is AccountTestBase { 0, "AA23 reverted", abi.encodeWithSelector( - UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, - ComprehensivePlugin.foo.selector + UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector ) ) ); @@ -56,8 +55,7 @@ contract SelfCallAuthorizationTest is AccountTestBase { 0, "AA23 reverted", abi.encodeWithSelector( - UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, - ComprehensivePlugin.foo.selector + UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector ) ) ); @@ -68,8 +66,7 @@ contract SelfCallAuthorizationTest is AccountTestBase { _runtimeCall( abi.encodeCall(ComprehensivePlugin.foo, ()), abi.encodeWithSelector( - UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, - ComprehensivePlugin.foo.selector + UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector ) ); } @@ -99,8 +96,7 @@ contract SelfCallAuthorizationTest is AccountTestBase { 0, "AA23 reverted", abi.encodeWithSelector( - UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, - ComprehensivePlugin.foo.selector + UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector ) ) ); @@ -136,8 +132,7 @@ contract SelfCallAuthorizationTest is AccountTestBase { 0, "AA23 reverted", abi.encodeWithSelector( - UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, - ComprehensivePlugin.foo.selector + UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector ) ) ); @@ -159,8 +154,7 @@ contract SelfCallAuthorizationTest is AccountTestBase { _runtimeExecBatchExpFail( calls, abi.encodeWithSelector( - UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, - ComprehensivePlugin.foo.selector + UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector ) ); } diff --git a/test/mocks/plugins/DirectCallPlugin.sol b/test/mocks/plugins/DirectCallPlugin.sol new file mode 100644 index 00000000..8ab5dd42 --- /dev/null +++ b/test/mocks/plugins/DirectCallPlugin.sol @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {PluginManifest, PluginMetadata} from "../../../src/interfaces/IPlugin.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {IExecutionHook} from "../../../src/interfaces/IExecutionHook.sol"; + +import {BasePlugin} from "../../../src/plugins/BasePlugin.sol"; + +contract DirectCallPlugin is BasePlugin, IExecutionHook { + bool public preHookRan = false; + bool public postHookRan = false; + + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory) {} + + function directCall() external returns (bytes memory) { + return IStandardExecutor(msg.sender).execute(address(this), 0, abi.encodeCall(this.getData, ())); + } + + function getData() external pure returns (bytes memory) { + return hex"04546b"; + } + + function pluginMetadata() external pure override returns (PluginMetadata memory) {} + + function preExecutionHook(uint32, address sender, uint256, bytes calldata) + external + override + returns (bytes memory) + { + require(sender == address(this), "mock direct call pre permission hook failed"); + preHookRan = true; + return abi.encode(keccak256(hex"04546b")); + } + + function postExecutionHook(uint32, bytes calldata preExecHookData) external override { + require( + abi.decode(preExecHookData, (bytes32)) == keccak256(hex"04546b"), + "mock direct call post permission hook failed" + ); + postHookRan = true; + } +}