From 1e8344f07313b1b9631e9cd6815f18a2ac63e1df Mon Sep 17 00:00:00 2001 From: zer0dot Date: Wed, 10 Jul 2024 20:22:33 +0800 Subject: [PATCH] feat: initial impl of simple plugin direct calls with validation hooks --- src/account/AccountStorage.sol | 6 +++++ src/account/UpgradeableModularAccount.sol | 31 ++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index ffdaff26..b8558589 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -42,6 +42,11 @@ struct ValidationData { FunctionReference[] preValidationHooks; } +struct DirectCallValidationData { + bool allowed; // Whether or not this direct call is allowed. + FunctionReference[] preValidationHooks; // The set of pre validation hooks for this direct call. +} + struct AccountStorage { // AccountStorageInitializable variables uint8 initialized; @@ -55,6 +60,7 @@ struct AccountStorage { mapping(address caller => mapping(bytes4 selector => bool)) callPermitted; // For ERC165 introspection mapping(bytes4 => uint256) supportedIfaces; + mapping(address caller => mapping(bytes4 selector => DirectCallValidationData)) directCallData; } function getAccountStorage() pure returns (AccountStorage storage _storage) { diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index a642aec7..330b97d1 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -8,6 +8,7 @@ import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeab import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol"; import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol"; @@ -38,6 +39,7 @@ contract UpgradeableModularAccount is UUPSUpgradeable { using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSet for EnumerableSet.AddressSet; using FunctionReferenceLib for FunctionReference; using SparseCalldataSegmentLib for bytes; @@ -75,6 +77,7 @@ contract UpgradeableModularAccount is error ValidationDoesNotApply(bytes4 selector, address plugin, uint8 functionId, bool isDefault); error ValidationSignatureSegmentMissing(); error SignatureSegmentOutOfOrder(); + error DirectCallDisallowed(); // Wraps execution of a native function with runtime validation and hooks // Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin @@ -596,7 +599,7 @@ contract UpgradeableModularAccount is return getAccountStorage().selectorData[selector].allowDefaultValidation; } - function _checkPermittedCallerIfNotFromEP() internal view { + function _checkPermittedCallerIfNotFromEP() internal { AccountStorage storage _storage = getAccountStorage(); if ( @@ -607,5 +610,31 @@ contract UpgradeableModularAccount is if (!_storage.callPermitted[msg.sender][msg.sig]) { revert ExecFromPluginNotPermitted(msg.sender, msg.sig); } + + // If direct calling isn't allowed OR direct calling is allowed, but the plugin is no longer installed, + // revert. TBD if there's a better way to do this, e.g. deleting this storage or segmenting per + // installation ID. + if ( + !_storage.directCallData[msg.sender][msg.sig].allowed + || !_storage.plugins.contains(msg.sender) + ) { + revert DirectCallDisallowed(); + } + + FunctionReference[] storage hooks = _storage.directCallData[msg.sender][msg.sig].preValidationHooks; + + uint256 hookLen = hooks.length; + for (uint256 i = 0; i < hookLen; ++i) { + (address hookPlugin, uint8 hookFunctionId) = hooks[i].unpack(); + try IValidationHook(hookPlugin).preRuntimeValidationHook( + hookFunctionId, msg.sender, msg.value, msg.data, "" + ) + // forgefmt: disable-start + // solhint-disable-next-line no-empty-blocks + {} catch (bytes memory revertReason) { + // forgefmt: disable-end + revert PreRuntimeValidationHookFailed(hookPlugin, hookFunctionId, revertReason); + } + } } }