Skip to content

Commit

Permalink
feat: add interface checks for validations and hooks (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed authored Aug 23, 2024
1 parent a1510cd commit 094605e
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 41 deletions.
62 changes: 32 additions & 30 deletions src/account/ModuleManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ import {HookConfigLib} from "../helpers/HookConfigLib.sol";
import {KnownSelectors} from "../helpers/KnownSelectors.sol";
import {ModuleEntityLib} from "../helpers/ModuleEntityLib.sol";
import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol";
import {IExecutionHookModule} from "../interfaces/IExecutionHookModule.sol";
import {ExecutionManifest, ManifestExecutionHook} from "../interfaces/IExecutionModule.sol";
import {HookConfig, IModularAccount, ModuleEntity, ValidationConfig} from "../interfaces/IModularAccount.sol";
import {IModule} from "../interfaces/IModule.sol";
import {IValidationHookModule} from "../interfaces/IValidationHookModule.sol";
import {IValidationModule} from "../interfaces/IValidationModule.sol";

import {
AccountStorage,
ExecutionData,
Expand All @@ -32,11 +36,11 @@ abstract contract ModuleManagerInternals is IModularAccount {
error Erc4337FunctionNotAllowed(bytes4 selector);
error ExecutionFunctionAlreadySet(bytes4 selector);
error IModuleFunctionNotAllowed(bytes4 selector);
error InterfaceNotSupported(address module);
error NativeFunctionNotAllowed(bytes4 selector);
error NullModule();
error PermissionAlreadySet(ModuleEntity validationFunction, HookConfig hookConfig);
error ModuleInstallCallbackFailed(address module, bytes revertReason);
error ModuleInterfaceNotSupported(address module);
error ModuleNotInstalled(address module);
error PreValidationHookLimitExceeded();
error ValidationAlreadySet(bytes4 selector, ModuleEntity validationFunction);
Expand Down Expand Up @@ -125,21 +129,17 @@ abstract contract ModuleManagerInternals is IModularAccount {
hooks.remove(toSetValue(hookConfig));
}

function _installExecution(address module, ExecutionManifest calldata manifest, bytes memory moduleInstallData)
internal
{
function _installExecution(
address module,
ExecutionManifest calldata manifest,
bytes calldata moduleInstallData
) internal {
AccountStorage storage _storage = getAccountStorage();

if (module == address(0)) {
revert NullModule();
}

// TODO: do we need this check? Or switch to a non-165 checking function?
// Check that the module supports the IModule interface.
if (!ERC165Checker.supportsInterface(module, type(IModule).interfaceId)) {
revert ModuleInterfaceNotSupported(module);
}

// Update components according to the manifest.
uint256 length = manifest.executionFunctions.length;
for (uint256 i = 0; i < length; ++i) {
Expand Down Expand Up @@ -168,18 +168,12 @@ abstract contract ModuleManagerInternals is IModularAccount {
_storage.supportedIfaces[manifest.interfaceIds[i]] += 1;
}

// Initialize the module storage for the account.
// solhint-disable-next-line no-empty-blocks
try IModule(module).onInstall(moduleInstallData) {}
catch {
bytes memory revertReason = collectReturnData();
revert ModuleInstallCallbackFailed(module, revertReason);
}
_onInstall(module, moduleInstallData, type(IModule).interfaceId);

emit ExecutionInstalled(module, manifest);
}

function _uninstallExecution(address module, ExecutionManifest calldata manifest, bytes memory uninstallData)
function _uninstallExecution(address module, ExecutionManifest calldata manifest, bytes calldata uninstallData)
internal
{
AccountStorage storage _storage = getAccountStorage();
Expand Down Expand Up @@ -212,19 +206,22 @@ abstract contract ModuleManagerInternals is IModularAccount {
}

// Clear the module storage for the account.
bool onUninstallSuccess = true;
// solhint-disable-next-line no-empty-blocks
try IModule(module).onUninstall(uninstallData) {}
catch {
onUninstallSuccess = false;
}
bool onUninstallSuccess = _onUninstall(module, uninstallData);

emit ExecutionUninstalled(module, onUninstallSuccess, manifest);
}

function _onInstall(address module, bytes calldata data) internal {
function _onInstall(address module, bytes calldata data, bytes4 interfaceId) internal {
if (data.length > 0) {
IModule(module).onInstall(data);
if (!ERC165Checker.supportsInterface(module, interfaceId)) {
revert InterfaceNotSupported(module);
}
// solhint-disable-next-line no-empty-blocks
try IModule(module).onInstall(data) {}
catch {
bytes memory revertReason = collectReturnData();
revert ModuleInstallCallbackFailed(module, revertReason);
}
}
}

Expand Down Expand Up @@ -261,12 +258,17 @@ abstract contract ModuleManagerInternals is IModularAccount {
if (_validationData.preValidationHooks.length > MAX_PRE_VALIDATION_HOOKS) {
revert PreValidationHookLimitExceeded();
}
} // Hook is an execution hook
else if (!_validationData.permissionHooks.add(toSetValue(hookConfig))) {

_onInstall(hookConfig.module(), hookData, type(IValidationHookModule).interfaceId);

continue;
}
// Hook is a permission hook
if (!_validationData.permissionHooks.add(toSetValue(hookConfig))) {
revert PermissionAlreadySet(moduleEntity, hookConfig);
}

_onInstall(hookConfig.module(), hookData);
_onInstall(hookConfig.module(), hookData, type(IExecutionHookModule).interfaceId);
}

for (uint256 i = 0; i < selectors.length; ++i) {
Expand All @@ -279,7 +281,7 @@ abstract contract ModuleManagerInternals is IModularAccount {
_validationData.isGlobal = validationConfig.isGlobal();
_validationData.isSignatureValidation = validationConfig.isSignatureValidation();

_onInstall(validationConfig.module(), installData);
_onInstall(validationConfig.module(), installData, type(IValidationModule).interfaceId);
emit ValidationInstalled(validationConfig.module(), validationConfig.entityId());
}

Expand Down
2 changes: 1 addition & 1 deletion src/modules/ERC20TokenLimitModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ contract ERC20TokenLimitModule is BaseModule, IExecutionHookModule {

/// @inheritdoc BaseModule
function supportsInterface(bytes4 interfaceId) public view override(BaseModule, IERC165) returns (bool) {
return super.supportsInterface(interfaceId);
return interfaceId == type(IExecutionHookModule).interfaceId || super.supportsInterface(interfaceId);
}

function _decrementLimit(uint32 entityId, address token, bytes memory innerCalldata) internal {
Expand Down
11 changes: 11 additions & 0 deletions src/modules/permissionhooks/AllowlistModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pragma solidity ^0.8.25;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol";

import {IModule} from "../../interfaces/IModule.sol";

Expand Down Expand Up @@ -118,6 +119,16 @@ contract AllowlistModule is IValidationHookModule, BaseModule {
}
}

function supportsInterface(bytes4 interfaceId)
public
view
virtual
override(BaseModule, IERC165)
returns (bool)
{
return interfaceId == type(IValidationHookModule).interfaceId || super.supportsInterface(interfaceId);
}

function _checkCallPermission(uint32 entityId, address account, address target, bytes memory data)
internal
view
Expand Down
18 changes: 15 additions & 3 deletions src/modules/validation/SingleSignerValidationModule.sol
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol";
import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";
import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol";

import {IModule} from "../../interfaces/IModule.sol";
import {IValidationModule} from "../../interfaces/IValidationModule.sol";
import {BaseModule} from "../BaseModule.sol";

import {ReplaySafeWrapper} from "../ReplaySafeWrapper.sol";
import {ISingleSignerValidationModule} from "./ISingleSignerValidationModule.sol";
import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";
import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol";

/// @title ECSDA Validation
/// @author ERC-6900 Authors
Expand Down Expand Up @@ -115,6 +117,16 @@ contract SingleSignerValidationModule is ISingleSignerValidationModule, ReplaySa
return "erc6900/single-signer-validation-module/1.0.0";
}

function supportsInterface(bytes4 interfaceId)
public
view
virtual
override(BaseModule, IERC165)
returns (bool)
{
return (interfaceId == type(IValidationModule).interfaceId || super.supportsInterface(interfaceId));
}

// ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
// ┃ Internal / Private functions ┃
// ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛
Expand Down
8 changes: 4 additions & 4 deletions test/account/AccountExecHooks.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -162,27 +162,27 @@ contract AccountExecHooksTest is AccountTestBase {
mockModule1 = new MockModule(_m1);

vm.expectEmit(true, true, true, true);
emit ReceivedCall(abi.encodeCall(IModule.onInstall, (bytes(""))), 0);
emit ReceivedCall(abi.encodeCall(IModule.onInstall, (bytes("a"))), 0);
vm.expectEmit(true, true, true, true);
emit ExecutionInstalled(address(mockModule1), _m1);

vm.startPrank(address(entryPoint));
account1.installExecution({
module: address(mockModule1),
manifest: mockModule1.executionManifest(),
moduleInstallData: bytes("")
moduleInstallData: bytes("a")
});
vm.stopPrank();
}

function _uninstallExecution(MockModule module) internal {
vm.expectEmit(true, true, true, true);
emit ReceivedCall(abi.encodeCall(IModule.onUninstall, (bytes(""))), 0);
emit ReceivedCall(abi.encodeCall(IModule.onUninstall, (bytes("b"))), 0);
vm.expectEmit(true, true, true, true);
emit ExecutionUninstalled(address(module), true, module.executionManifest());

vm.startPrank(address(entryPoint));
account1.uninstallExecution(address(module), module.executionManifest(), bytes(""));
account1.uninstallExecution(address(module), module.executionManifest(), bytes("b"));
vm.stopPrank();
}
}
23 changes: 22 additions & 1 deletion test/account/DirectCallsFromModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import {DirectCallModule} from "../mocks/modules/DirectCallModule.sol";

import {AccountTestBase} from "../utils/AccountTestBase.sol";

import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../src/helpers/Constants.sol";

contract DirectCallsFromModuleTest is AccountTestBase {
using ValidationConfigLib for ValidationConfig;

Expand All @@ -23,7 +25,7 @@ contract DirectCallsFromModuleTest is AccountTestBase {
_module = new DirectCallModule();
assertFalse(_module.preHookRan());
assertFalse(_module.postHookRan());
_moduleEntity = ModuleEntityLib.pack(address(_module), type(uint32).max);
_moduleEntity = ModuleEntityLib.pack(address(_module), DIRECT_CALL_VALIDATION_ENTITYID);
}

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -104,6 +106,25 @@ contract DirectCallsFromModuleTest is AccountTestBase {
account1.execute(address(0), 0, "");
}

function test_directCallsFromEOA() external {
address extraOwner = makeAddr("extraOwner");

bytes4[] memory selectors = new bytes4[](1);
selectors[0] = IModularAccount.execute.selector;

vm.prank(address(entryPoint));

account1.installValidation(
ValidationConfigLib.pack(extraOwner, DIRECT_CALL_VALIDATION_ENTITYID, false, false),
selectors,
"",
new bytes[](0)
);

vm.prank(extraOwner);
account1.execute(makeAddr("dead"), 0, "");
}

/* -------------------------------------------------------------------------- */
/* Internals */
/* -------------------------------------------------------------------------- */
Expand Down
4 changes: 2 additions & 2 deletions test/account/UpgradeableModularAccount.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ contract UpgradeableModularAccountTest is AccountTestBase {

address badModule = address(1);
vm.expectRevert(
abi.encodeWithSelector(ModuleManagerInternals.ModuleInterfaceNotSupported.selector, address(badModule))
abi.encodeWithSelector(ModuleManagerInternals.InterfaceNotSupported.selector, address(badModule))
);

ExecutionManifest memory m;

account1.installExecution({module: address(badModule), manifest: m, moduleInstallData: ""});
account1.installExecution({module: address(badModule), manifest: m, moduleInstallData: "a"});
}

function test_installExecution_alreadyInstalled() public {
Expand Down
12 changes: 12 additions & 0 deletions test/mocks/modules/DirectCallModule.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol";

import {IExecutionHookModule} from "../../../src/interfaces/IExecutionHookModule.sol";
import {IModularAccount} from "../../../src/interfaces/IModularAccount.sol";
import {BaseModule} from "../../../src/modules/BaseModule.sol";
Expand Down Expand Up @@ -42,4 +44,14 @@ contract DirectCallModule is BaseModule, IExecutionHookModule {
);
postHookRan = true;
}

function supportsInterface(bytes4 interfaceId)
public
view
virtual
override(BaseModule, IERC165)
returns (bool)
{
return interfaceId == type(IExecutionHookModule).interfaceId || super.supportsInterface(interfaceId);
}
}
11 changes: 11 additions & 0 deletions test/mocks/modules/MockAccessControlHookModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pragma solidity ^0.8.25;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol";

import {IModularAccount} from "../../../src/interfaces/IModularAccount.sol";
import {IValidationHookModule} from "../../../src/interfaces/IValidationHookModule.sol";
Expand Down Expand Up @@ -86,4 +87,14 @@ contract MockAccessControlHookModule is IValidationHookModule, BaseModule {
function moduleId() external pure returns (string memory) {
return "erc6900/mock-access-control-hook-module/1.0.0";
}

function supportsInterface(bytes4 interfaceId)
public
view
virtual
override(BaseModule, IERC165)
returns (bool)
{
return interfaceId == type(IValidationHookModule).interfaceId || super.supportsInterface(interfaceId);
}
}

0 comments on commit 094605e

Please sign in to comment.