From fe5af61741df5ff3741704710590b13c4cedfc25 Mon Sep 17 00:00:00 2001 From: zer0dot Date: Wed, 10 Jul 2024 20:22:33 +0800 Subject: [PATCH] feat: initial impl of simple plugin direct calls with validation hooks --- src/account/AccountStorage.sol | 6 +++++ src/account/UpgradeableModularAccount.sol | 31 ++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index 78c06259..1c3d426a 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -40,6 +40,11 @@ struct ValidationData { EnumerableSet.Bytes32Set selectors; } +struct DirectCallValidationData { + bool allowed; // Whether or not this direct call is allowed. + FunctionReference[] preValidationHooks; // The set of pre validation hooks for this direct call. +} + struct AccountStorage { // AccountStorageInitializable variables uint8 initialized; @@ -51,6 +56,7 @@ struct AccountStorage { mapping(FunctionReference validationFunction => ValidationData) validationData; // For ERC165 introspection mapping(bytes4 => uint256) supportedIfaces; + mapping(address caller => mapping(bytes4 selector => DirectCallValidationData)) directCallData; } function getAccountStorage() pure returns (AccountStorage storage _storage) { diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 42db6967..7c1180ed 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -9,6 +9,7 @@ import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeab import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol"; import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol"; @@ -40,6 +41,7 @@ contract UpgradeableModularAccount is UUPSUpgradeable { using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSet for EnumerableSet.AddressSet; using FunctionReferenceLib for FunctionReference; using SparseCalldataSegmentLib for bytes; @@ -80,6 +82,7 @@ contract UpgradeableModularAccount is error ValidationDoesNotApply(bytes4 selector, address plugin, uint8 functionId, bool isGlobal); error ValidationSignatureSegmentMissing(); error SignatureSegmentOutOfOrder(); + error DirectCallDisallowed(); // Wraps execution of a native function with runtime validation and hooks // Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin @@ -696,7 +699,7 @@ contract UpgradeableModularAccount is return getAccountStorage().selectorData[selector].allowGlobalValidation; } - function _checkPermittedCallerIfNotFromEP() internal view { + function _checkPermittedCallerIfNotFromEP() internal { AccountStorage storage _storage = getAccountStorage(); if ( @@ -705,5 +708,31 @@ contract UpgradeableModularAccount is ) { revert ExecFromPluginNotPermitted(msg.sender, msg.sig); } + + // If direct calling isn't allowed OR direct calling is allowed, but the plugin is no longer installed, + // revert. TBD if there's a better way to do this, e.g. deleting this storage or segmenting per + // installation ID. + if ( + !_storage.directCallData[msg.sender][msg.sig].allowed + || !_storage.plugins.contains(msg.sender) + ) { + revert DirectCallDisallowed(); + } + + FunctionReference[] storage hooks = _storage.directCallData[msg.sender][msg.sig].preValidationHooks; + + uint256 hookLen = hooks.length; + for (uint256 i = 0; i < hookLen; ++i) { + (address hookPlugin, uint8 hookFunctionId) = hooks[i].unpack(); + try IValidationHook(hookPlugin).preRuntimeValidationHook( + hookFunctionId, msg.sender, msg.value, msg.data, "" + ) + // forgefmt: disable-start + // solhint-disable-next-line no-empty-blocks + {} catch (bytes memory revertReason) { + // forgefmt: disable-end + revert PreRuntimeValidationHookFailed(hookPlugin, hookFunctionId, revertReason); + } + } } }