diff --git a/src/account/ModuleManagerInternals.sol b/src/account/ModuleManagerInternals.sol index d11c8c6e..f8cd2f2f 100644 --- a/src/account/ModuleManagerInternals.sol +++ b/src/account/ModuleManagerInternals.sol @@ -10,9 +10,13 @@ import {HookConfigLib} from "../helpers/HookConfigLib.sol"; import {KnownSelectors} from "../helpers/KnownSelectors.sol"; import {ModuleEntityLib} from "../helpers/ModuleEntityLib.sol"; import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol"; +import {IExecutionHookModule} from "../interfaces/IExecutionHookModule.sol"; import {ExecutionManifest, ManifestExecutionHook} from "../interfaces/IExecutionModule.sol"; import {IModule} from "../interfaces/IModule.sol"; import {HookConfig, IModuleManager, ModuleEntity, ValidationConfig} from "../interfaces/IModuleManager.sol"; +import {IValidationHookModule} from "../interfaces/IValidationHookModule.sol"; +import {IValidationModule} from "../interfaces/IValidationModule.sol"; + import { AccountStorage, ExecutionData, @@ -32,11 +36,11 @@ abstract contract ModuleManagerInternals is IModuleManager { error Erc4337FunctionNotAllowed(bytes4 selector); error ExecutionFunctionAlreadySet(bytes4 selector); error IModuleFunctionNotAllowed(bytes4 selector); + error InterfaceNotSupported(address module); error NativeFunctionNotAllowed(bytes4 selector); error NullModule(); error PermissionAlreadySet(ModuleEntity validationFunction, HookConfig hookConfig); error ModuleInstallCallbackFailed(address module, bytes revertReason); - error ModuleInterfaceNotSupported(address module); error ModuleNotInstalled(address module); error PreValidationHookLimitExceeded(); error ValidationAlreadySet(bytes4 selector, ModuleEntity validationFunction); @@ -134,10 +138,8 @@ abstract contract ModuleManagerInternals is IModuleManager { revert NullModule(); } - // TODO: do we need this check? Or switch to a non-165 checking function? - // Check that the module supports the IModule interface. if (!ERC165Checker.supportsInterface(module, type(IModule).interfaceId)) { - revert ModuleInterfaceNotSupported(module); + revert InterfaceNotSupported(module); } // Update components according to the manifest. @@ -227,6 +229,16 @@ abstract contract ModuleManagerInternals is IModuleManager { } } + function _onInstall(address module, bytes calldata data, bytes4 interfaceId) internal { + if (data.length > 0) { + if (!ERC165Checker.supportsInterface(module, interfaceId)) { + revert InterfaceNotSupported(module); + } + + IModule(module).onInstall(data); + } + } + function _onUninstall(address module, bytes calldata data) internal returns (bool onUninstallSuccess) { onUninstallSuccess = true; if (data.length > 0) { @@ -260,12 +272,17 @@ abstract contract ModuleManagerInternals is IModuleManager { if (_validationData.preValidationHooks.length > MAX_PRE_VALIDATION_HOOKS) { revert PreValidationHookLimitExceeded(); } - } // Hook is an execution hook - else if (!_validationData.permissionHooks.add(toSetValue(hookConfig))) { + + _onInstall(hookConfig.module(), hookData, type(IValidationHookModule).interfaceId); + + continue; + } + // Hook is a permission hook + if (!_validationData.permissionHooks.add(toSetValue(hookConfig))) { revert PermissionAlreadySet(moduleEntity, hookConfig); } - _onInstall(hookConfig.module(), hookData); + _onInstall(hookConfig.module(), hookData, type(IExecutionHookModule).interfaceId); } for (uint256 i = 0; i < selectors.length; ++i) { @@ -278,7 +295,7 @@ abstract contract ModuleManagerInternals is IModuleManager { _validationData.isGlobal = validationConfig.isGlobal(); _validationData.isSignatureValidation = validationConfig.isSignatureValidation(); - _onInstall(validationConfig.module(), installData); + _onInstall(validationConfig.module(), installData, type(IValidationModule).interfaceId); emit ValidationInstalled(validationConfig.module(), validationConfig.entityId()); } diff --git a/src/modules/ERC20TokenLimitModule.sol b/src/modules/ERC20TokenLimitModule.sol index 6eb93ea8..77a19752 100644 --- a/src/modules/ERC20TokenLimitModule.sol +++ b/src/modules/ERC20TokenLimitModule.sol @@ -127,7 +127,7 @@ contract ERC20TokenLimitModule is BaseModule, IExecutionHookModule { /// @inheritdoc BaseModule function supportsInterface(bytes4 interfaceId) public view override(BaseModule, IERC165) returns (bool) { - return super.supportsInterface(interfaceId); + return interfaceId == type(IExecutionHookModule).interfaceId || super.supportsInterface(interfaceId); } function _decrementLimit(uint32 entityId, address token, bytes memory innerCalldata) internal { diff --git a/src/modules/permissionhooks/AllowlistModule.sol b/src/modules/permissionhooks/AllowlistModule.sol index dd23a6c9..6f158d50 100644 --- a/src/modules/permissionhooks/AllowlistModule.sol +++ b/src/modules/permissionhooks/AllowlistModule.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.25; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol"; import {ModuleMetadata} from "../../interfaces/IModule.sol"; @@ -119,6 +120,16 @@ contract AllowlistModule is IValidationHookModule, BaseModule { } } + function supportsInterface(bytes4 interfaceId) + public + view + virtual + override(BaseModule, IERC165) + returns (bool) + { + return interfaceId == type(IValidationHookModule).interfaceId || super.supportsInterface(interfaceId); + } + function _checkCallPermission(uint32 entityId, address account, address target, bytes memory data) internal view diff --git a/src/modules/validation/SingleSignerValidationModule.sol b/src/modules/validation/SingleSignerValidationModule.sol index 7ad1c768..de490971 100644 --- a/src/modules/validation/SingleSignerValidationModule.sol +++ b/src/modules/validation/SingleSignerValidationModule.sol @@ -2,6 +2,8 @@ pragma solidity ^0.8.25; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol"; import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; @@ -120,6 +122,16 @@ contract SingleSignerValidationModule is ISingleSignerValidationModule, BaseModu return metadata; } + function supportsInterface(bytes4 interfaceId) + public + view + virtual + override(BaseModule, IERC165) + returns (bool) + { + return (interfaceId == type(IValidationModule).interfaceId || super.supportsInterface(interfaceId)); + } + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ // ┃ Internal / Private functions ┃ // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index e4394f34..32747cf3 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -273,7 +273,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { address badModule = address(1); vm.expectRevert( - abi.encodeWithSelector(ModuleManagerInternals.ModuleInterfaceNotSupported.selector, address(badModule)) + abi.encodeWithSelector(ModuleManagerInternals.InterfaceNotSupported.selector, address(badModule)) ); ExecutionManifest memory m; diff --git a/test/mocks/modules/DirectCallModule.sol b/test/mocks/modules/DirectCallModule.sol index c5ad271c..53acff8e 100644 --- a/test/mocks/modules/DirectCallModule.sol +++ b/test/mocks/modules/DirectCallModule.sol @@ -1,6 +1,8 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; +import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol"; + import {IExecutionHookModule} from "../../../src/interfaces/IExecutionHookModule.sol"; import {ModuleMetadata} from "../../../src/interfaces/IModule.sol"; import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; @@ -42,4 +44,14 @@ contract DirectCallModule is BaseModule, IExecutionHookModule { ); postHookRan = true; } + + function supportsInterface(bytes4 interfaceId) + public + view + virtual + override(BaseModule, IERC165) + returns (bool) + { + return interfaceId == type(IExecutionHookModule).interfaceId || super.supportsInterface(interfaceId); + } } diff --git a/test/mocks/modules/MockAccessControlHookModule.sol b/test/mocks/modules/MockAccessControlHookModule.sol index 4ef33181..23c75a9c 100644 --- a/test/mocks/modules/MockAccessControlHookModule.sol +++ b/test/mocks/modules/MockAccessControlHookModule.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.25; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol"; import {ModuleMetadata} from "../../../src/interfaces/IModule.sol"; import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; @@ -73,4 +74,14 @@ contract MockAccessControlHookModule is IValidationHookModule, BaseModule { } function moduleMetadata() external pure override returns (ModuleMetadata memory) {} + + function supportsInterface(bytes4 interfaceId) + public + view + virtual + override(BaseModule, IERC165) + returns (bool) + { + return interfaceId == type(IValidationHookModule).interfaceId || super.supportsInterface(interfaceId); + } }