Skip to content

Commit

Permalink
Allow direct plugin calls with validation & permission hooks (#90)
Browse files Browse the repository at this point in the history
Zer0dot authored and adamegyed committed Aug 5, 2024
1 parent a1377b3 commit 44f80f4
Showing 9 changed files with 315 additions and 93 deletions.
36 changes: 18 additions & 18 deletions .solhint-test.json
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
{
"extends": "solhint:recommended",
"rules": {
"func-name-mixedcase": "off",
"immutable-vars-naming": ["error"],
"no-unused-import": ["error"],
"compiler-version": ["error", ">=0.8.19"],
"custom-errors": "off",
"func-visibility": ["error", { "ignoreConstructors": true }],
"max-line-length": ["error", 120],
"max-states-count": ["warn", 30],
"modifier-name-mixedcase": ["error"],
"private-vars-leading-underscore": ["error"],
"no-inline-assembly": "off",
"avoid-low-level-calls": "off",
"one-contract-per-file": "off",
"no-empty-blocks": "off"
}
"extends": "solhint:recommended",
"rules": {
"func-name-mixedcase": "off",
"immutable-vars-naming": ["error"],
"no-unused-import": ["error"],
"compiler-version": ["error", ">=0.8.19"],
"custom-errors": "off",
"func-visibility": ["error", { "ignoreConstructors": true }],
"max-line-length": ["error", 120],
"max-states-count": ["warn", 30],
"modifier-name-mixedcase": ["error"],
"private-vars-leading-underscore": ["error"],
"no-inline-assembly": "off",
"avoid-low-level-calls": "off",
"one-contract-per-file": "off",
"no-empty-blocks": "off",
"reason-string": ["warn", { "maxLength": 64 }]
}
}
2 changes: 1 addition & 1 deletion src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ struct ValidationData {
bool isGlobal;
// Whether or not this validation is a signature validator.
bool isSignatureValidation;
// The pre validation hooks for this function selector.
// The pre validation hooks for this validation function.
PluginEntity[] preValidationHooks;
// Permission hooks for this validation function.
EnumerableSet.Bytes32Set permissionHooks;
35 changes: 19 additions & 16 deletions src/account/PluginManager2.sol
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ import {IPlugin} from "../interfaces/IPlugin.sol";
import {PluginEntity, ValidationConfig} from "../interfaces/IPluginManager.sol";
import {PluginEntityLib} from "../helpers/PluginEntityLib.sol";
import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol";
import {ValidationData, getAccountStorage, toSetValue, toPluginEntity} from "./AccountStorage.sol";
import {ValidationData, getAccountStorage, toSetValue} from "./AccountStorage.sol";
import {ExecutionHook} from "../interfaces/IAccountLoupe.sol";

// Temporary additional functions for a user-controlled install flow for validation functions.
@@ -17,6 +17,7 @@ abstract contract PluginManager2 {

// Index marking the start of the data for the validation function.
uint8 internal constant _RESERVED_VALIDATION_DATA_INDEX = 255;
uint32 internal constant _SELF_PERMIT_VALIDATION_FUNCTIONID = type(uint32).max;

error PreValidationAlreadySet(PluginEntity validationFunction, PluginEntity preValidationFunction);
error ValidationAlreadySet(bytes4 selector, PluginEntity validationFunction);
@@ -32,7 +33,7 @@ abstract contract PluginManager2 {
bytes memory permissionHooks
) internal {
ValidationData storage _validationData =
getAccountStorage().validationData[validationConfig.functionReference()];
getAccountStorage().validationData[validationConfig.pluginEntity()];

if (preValidationHooks.length > 0) {
(PluginEntity[] memory preValidationFunctions, bytes[] memory initDatas) =
@@ -63,7 +64,7 @@ abstract contract PluginManager2 {
ExecutionHook memory permissionFunction = permissionFunctions[i];

if (!_validationData.permissionHooks.add(toSetValue(permissionFunction))) {
revert PermissionAlreadySet(validationConfig.functionReference(), permissionFunction);
revert PermissionAlreadySet(validationConfig.pluginEntity(), permissionFunction);
}

if (initDatas[i].length > 0) {
@@ -73,19 +74,21 @@ abstract contract PluginManager2 {
}
}

_validationData.isGlobal = validationConfig.isGlobal();
_validationData.isSignatureValidation = validationConfig.isSignatureValidation();

for (uint256 i = 0; i < selectors.length; ++i) {
bytes4 selector = selectors[i];
if (!_validationData.selectors.add(toSetValue(selector))) {
revert ValidationAlreadySet(selector, validationConfig.functionReference());
revert ValidationAlreadySet(selector, validationConfig.pluginEntity());
}
}

if (installData.length > 0) {
address plugin = validationConfig.plugin();
IPlugin(plugin).onInstall(installData);
if (validationConfig.entityId() != _SELF_PERMIT_VALIDATION_FUNCTIONID) {
// Only allow global validations and signature validations if they're not direct-call validations.

_validationData.isGlobal = validationConfig.isGlobal();
_validationData.isSignatureValidation = validationConfig.isSignatureValidation();
if (installData.length > 0) {
IPlugin(validationConfig.plugin()).onInstall(installData);
}
}
}

@@ -120,12 +123,12 @@ abstract contract PluginManager2 {

// Clear permission hooks
EnumerableSet.Bytes32Set storage permissionHooks = _validationData.permissionHooks;
uint256 i = 0;
while (permissionHooks.length() > 0) {
PluginEntity permissionHook = toPluginEntity(permissionHooks.at(0));
permissionHooks.remove(toSetValue(permissionHook));
(address permissionHookPlugin,) = PluginEntityLib.unpack(permissionHook);
IPlugin(permissionHookPlugin).onUninstall(permissionHookUninstallDatas[i++]);
uint256 len = permissionHooks.length();
for (uint256 i = 0; i < len; ++i) {
bytes32 permissionHook = permissionHooks.at(0);
permissionHooks.remove(permissionHook);
address permissionHookPlugin = address(uint160(bytes20(permissionHook)));
IPlugin(permissionHookPlugin).onUninstall(permissionHookUninstallDatas[i]);
}
}
delete _validationData.preValidationHooks;
131 changes: 89 additions & 42 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
@@ -78,22 +78,21 @@ contract UpgradeableModularAccount is
error SignatureValidationInvalid(address plugin, uint32 entityId);
error UnexpectedAggregator(address plugin, uint32 entityId, address aggregator);
error UnrecognizedFunction(bytes4 selector);
error UserOpValidationFunctionMissing(bytes4 selector);
error ValidationFunctionMissing(bytes4 selector);
error ValidationDoesNotApply(bytes4 selector, address plugin, uint32 entityId, bool isGlobal);
error ValidationSignatureSegmentMissing();
error SignatureSegmentOutOfOrder();

// Wraps execution of a native function with runtime validation and hooks
// Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin
modifier wrapNativeFunction() {
_checkPermittedCallerIfNotFromEP();

PostExecToRun[] memory postExecHooks =
_doPreHooks(getAccountStorage().selectorData[msg.sig].executionHooks, msg.data);
(PostExecToRun[] memory postPermissionHooks, PostExecToRun[] memory postExecHooks) =
_checkPermittedCallerAndAssociatedHooks();

_;

_doCachedPostExecHooks(postExecHooks);
_doCachedPostExecHooks(postPermissionHooks);
}

constructor(IEntryPoint anEntryPoint) {
@@ -136,7 +135,7 @@ contract UpgradeableModularAccount is
revert UnrecognizedFunction(msg.sig);
}

_checkPermittedCallerIfNotFromEP();
_checkPermittedCallerAndAssociatedHooks();

PostExecToRun[] memory postExecHooks;
// Cache post-exec hooks in memory
@@ -500,17 +499,7 @@ contract UpgradeableModularAccount is
} else {
currentAuthData = "";
}

(address hookPlugin, uint32 hookEntityId) = preRuntimeValidationHooks[i].unpack();
try IValidationHook(hookPlugin).preRuntimeValidationHook(
hookEntityId, msg.sender, msg.value, callData, currentAuthData
)
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
{} catch (bytes memory revertReason) {
// forgefmt: disable-end
revert PreRuntimeValidationHookFailed(hookPlugin, hookEntityId, revertReason);
}
_doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthData);
}

if (authSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) {
@@ -605,9 +594,78 @@ contract UpgradeableModularAccount is
}
}

function _doPreRuntimeValidationHook(
PluginEntity validationHook,
bytes memory callData,
bytes memory currentAuthData
) internal {
(address hookPlugin, uint32 hookEntityId) = validationHook.unpack();
try IValidationHook(hookPlugin).preRuntimeValidationHook(
hookEntityId, msg.sender, msg.value, callData, currentAuthData
)
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
{} catch (bytes memory revertReason) {
// forgefmt: disable-end
revert PreRuntimeValidationHookFailed(hookPlugin, hookEntityId, revertReason);
}
}

// solhint-disable-next-line no-empty-blocks
function _authorizeUpgrade(address newImplementation) internal override {}

/**
* Order of operations:
* 1. Check if the sender is the entry point, the account itself, or the selector called is public.
* - Yes: Return an empty array, there are no post-permissionHooks.
* - No: Continue
* 2. Check if the called selector (msg.sig) is included in the set of selectors the msg.sender can
* directly call.
* - Yes: Continue
* - No: Revert, the caller is not allowed to call this selector
* 3. If there are runtime validation hooks associated with this caller-sig combination, run them.
* 4. Run the pre-permissionHooks associated with this caller-sig combination, and return the
* post-permissionHooks to run later.
*/
function _checkPermittedCallerAndAssociatedHooks()
internal
returns (PostExecToRun[] memory, PostExecToRun[] memory)
{
AccountStorage storage _storage = getAccountStorage();

if (
msg.sender == address(_ENTRY_POINT) || msg.sender == address(this)
|| _storage.selectorData[msg.sig].isPublic
) {
return (new PostExecToRun[](0), new PostExecToRun[](0));
}

PluginEntity directCallValidationKey = PluginEntityLib.pack(msg.sender, _SELF_PERMIT_VALIDATION_FUNCTIONID);

_checkIfValidationAppliesCallData(msg.data, directCallValidationKey, false);

// Direct call is allowed, run associated permission & validation hooks

// Validation hooks
PluginEntity[] memory preRuntimeValidationHooks =
_storage.validationData[directCallValidationKey].preValidationHooks;

uint256 hookLen = preRuntimeValidationHooks.length;
for (uint256 i = 0; i < hookLen; ++i) {
_doPreRuntimeValidationHook(preRuntimeValidationHooks[i], msg.data, "");
}

// Permission hooks
PostExecToRun[] memory postPermissionHooks =
_doPreHooks(_storage.validationData[directCallValidationKey].permissionHooks, msg.data);

// Exec hooks
PostExecToRun[] memory postExecutionHooks =
_doPreHooks(_storage.selectorData[msg.sig].executionHooks, msg.data);

return (postPermissionHooks, postExecutionHooks);
}

function _checkIfValidationAppliesCallData(
bytes calldata callData,
PluginEntity validationFunction,
@@ -661,25 +719,6 @@ contract UpgradeableModularAccount is
}
}

function _checkIfValidationAppliesSelector(bytes4 selector, PluginEntity validationFunction, bool isGlobal)
internal
view
{
AccountStorage storage _storage = getAccountStorage();

// Check that the provided validation function is applicable to the selector
if (isGlobal) {
if (!_globalValidationAllowed(selector) || !_storage.validationData[validationFunction].isGlobal) {
revert UserOpValidationFunctionMissing(selector);
}
} else {
// Not global validation, but per-selector
if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) {
revert UserOpValidationFunctionMissing(selector);
}
}
}

function _globalValidationAllowed(bytes4 selector) internal view returns (bool) {
if (
selector == this.execute.selector || selector == this.executeBatch.selector
@@ -693,14 +732,22 @@ contract UpgradeableModularAccount is
return getAccountStorage().selectorData[selector].allowGlobalValidation;
}

function _checkPermittedCallerIfNotFromEP() internal view {
function _checkIfValidationAppliesSelector(bytes4 selector, PluginEntity validationFunction, bool isGlobal)
internal
view
{
AccountStorage storage _storage = getAccountStorage();

if (
msg.sender != address(_ENTRY_POINT) && msg.sender != address(this)
&& !_storage.selectorData[msg.sig].isPublic
) {
revert ExecFromPluginNotPermitted(msg.sender, msg.sig);
// Check that the provided validation function is applicable to the selector
if (isGlobal) {
if (!_globalValidationAllowed(selector) || !_storage.validationData[validationFunction].isGlobal) {
revert ValidationFunctionMissing(selector);
}
} else {
// Not global validation, but per-selector
if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) {
revert ValidationFunctionMissing(selector);
}
}
}
}
2 changes: 1 addition & 1 deletion src/helpers/ValidationConfigLib.sol
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ library ValidationConfigLib {
return uint32(bytes4(ValidationConfig.unwrap(config) << 160));
}

function functionReference(ValidationConfig config) internal pure returns (PluginEntity) {
function pluginEntity(ValidationConfig config) internal pure returns (PluginEntity) {
return PluginEntity.wrap(bytes24(ValidationConfig.unwrap(config)));
}

133 changes: 133 additions & 0 deletions test/account/DirectCallsFromPlugin.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
pragma solidity ^0.8.19;

import {DirectCallPlugin} from "../mocks/plugins/DirectCallPlugin.sol";
import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol";
import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol";
import {PluginEntityLib, PluginEntity} from "../../src/helpers/PluginEntityLib.sol";
import {ValidationConfig, ValidationConfigLib} from "../../src/helpers/ValidationConfigLib.sol";
import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol";

import {AccountTestBase} from "../utils/AccountTestBase.sol";

contract DirectCallsFromPluginTest is AccountTestBase {
using ValidationConfigLib for ValidationConfig;

DirectCallPlugin internal _plugin;
PluginEntity internal _pluginEntity;

function setUp() public {
_plugin = new DirectCallPlugin();
assertFalse(_plugin.preHookRan());
assertFalse(_plugin.postHookRan());
_pluginEntity = PluginEntityLib.pack(address(_plugin), type(uint32).max);
}

/* -------------------------------------------------------------------------- */
/* Negatives */
/* -------------------------------------------------------------------------- */

function test_Fail_DirectCallPluginNotInstalled() external {
vm.prank(address(_plugin));
vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.execute.selector));
account1.execute(address(0), 0, "");
}

function test_Fail_DirectCallPluginUninstalled() external {
_installPlugin();

_uninstallPlugin();

vm.prank(address(_plugin));
vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.execute.selector));
account1.execute(address(0), 0, "");
}

function test_Fail_DirectCallPluginCallOtherSelector() external {
_installPlugin();

Call[] memory calls = new Call[](0);

vm.prank(address(_plugin));
vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.executeBatch.selector));
account1.executeBatch(calls);
}

/* -------------------------------------------------------------------------- */
/* Positives */
/* -------------------------------------------------------------------------- */

function test_Pass_DirectCallFromPluginPrank() external {
_installPlugin();

vm.prank(address(_plugin));
account1.execute(address(0), 0, "");

assertTrue(_plugin.preHookRan());
assertTrue(_plugin.postHookRan());
}

function test_Pass_DirectCallFromPluginCallback() external {
_installPlugin();

bytes memory encodedCall = abi.encodeCall(DirectCallPlugin.directCall, ());

vm.prank(address(entryPoint));
bytes memory result = account1.execute(address(_plugin), 0, encodedCall);

assertTrue(_plugin.preHookRan());
assertTrue(_plugin.postHookRan());

// the directCall() function in the _plugin calls back into `execute()` with an encoded call back into the
// _plugin's getData() function.
assertEq(abi.decode(result, (bytes)), abi.encode(_plugin.getData()));
}

function test_Flow_DirectCallFromPluginSequence() external {
// Install => Succeesfully call => uninstall => fail to call

_installPlugin();

vm.prank(address(_plugin));
account1.execute(address(0), 0, "");

assertTrue(_plugin.preHookRan());
assertTrue(_plugin.postHookRan());

_uninstallPlugin();

vm.prank(address(_plugin));
vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.execute.selector));
account1.execute(address(0), 0, "");
}

/* -------------------------------------------------------------------------- */
/* Internals */
/* -------------------------------------------------------------------------- */

function _installPlugin() internal {
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = IStandardExecutor.execute.selector;

ExecutionHook[] memory permissionHooks = new ExecutionHook[](1);
bytes[] memory permissionHookInitDatas = new bytes[](1);

permissionHooks[0] = ExecutionHook({hookFunction: _pluginEntity, isPreHook: true, isPostHook: true});

bytes memory encodedPermissionHooks = abi.encode(permissionHooks, permissionHookInitDatas);

vm.prank(address(entryPoint));

ValidationConfig validationConfig = ValidationConfigLib.pack(_pluginEntity, false, false);

account1.installValidation(validationConfig, selectors, "", "", encodedPermissionHooks);
}

function _uninstallPlugin() internal {
vm.prank(address(entryPoint));
account1.uninstallValidation(_pluginEntity, "", abi.encode(new bytes[](0)), abi.encode(new bytes[](1)));
}

function _buildDirectCallDisallowedError(bytes4 selector) internal pure returns (bytes memory) {
return abi.encodeWithSelector(UpgradeableModularAccount.ValidationFunctionMissing.selector, selector);
}
}
4 changes: 1 addition & 3 deletions test/account/PermittedCallPermissions.t.sol
Original file line number Diff line number Diff line change
@@ -48,9 +48,7 @@ contract PermittedCallPermissionsTest is AccountTestBase {
function test_permittedCall_NotAllowed() public {
vm.expectRevert(
abi.encodeWithSelector(
UpgradeableModularAccount.ExecFromPluginNotPermitted.selector,
address(permittedCallerPlugin),
ResultCreatorPlugin.bar.selector
UpgradeableModularAccount.ValidationFunctionMissing.selector, ResultCreatorPlugin.bar.selector
)
);
PermittedCallerPlugin(address(account1)).usePermittedCallNotAllowed();
18 changes: 6 additions & 12 deletions test/account/SelfCallAuthorization.t.sol
Original file line number Diff line number Diff line change
@@ -40,8 +40,7 @@ contract SelfCallAuthorizationTest is AccountTestBase {
0,
"AA23 reverted",
abi.encodeWithSelector(
UpgradeableModularAccount.UserOpValidationFunctionMissing.selector,
ComprehensivePlugin.foo.selector
UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector
)
)
);
@@ -56,8 +55,7 @@ contract SelfCallAuthorizationTest is AccountTestBase {
0,
"AA23 reverted",
abi.encodeWithSelector(
UpgradeableModularAccount.UserOpValidationFunctionMissing.selector,
ComprehensivePlugin.foo.selector
UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector
)
)
);
@@ -68,8 +66,7 @@ contract SelfCallAuthorizationTest is AccountTestBase {
_runtimeCall(
abi.encodeCall(ComprehensivePlugin.foo, ()),
abi.encodeWithSelector(
UpgradeableModularAccount.UserOpValidationFunctionMissing.selector,
ComprehensivePlugin.foo.selector
UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector
)
);
}
@@ -99,8 +96,7 @@ contract SelfCallAuthorizationTest is AccountTestBase {
0,
"AA23 reverted",
abi.encodeWithSelector(
UpgradeableModularAccount.UserOpValidationFunctionMissing.selector,
ComprehensivePlugin.foo.selector
UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector
)
)
);
@@ -136,8 +132,7 @@ contract SelfCallAuthorizationTest is AccountTestBase {
0,
"AA23 reverted",
abi.encodeWithSelector(
UpgradeableModularAccount.UserOpValidationFunctionMissing.selector,
ComprehensivePlugin.foo.selector
UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector
)
)
);
@@ -159,8 +154,7 @@ contract SelfCallAuthorizationTest is AccountTestBase {
_runtimeExecBatchExpFail(
calls,
abi.encodeWithSelector(
UpgradeableModularAccount.UserOpValidationFunctionMissing.selector,
ComprehensivePlugin.foo.selector
UpgradeableModularAccount.ValidationFunctionMissing.selector, ComprehensivePlugin.foo.selector
)
);
}
47 changes: 47 additions & 0 deletions test/mocks/plugins/DirectCallPlugin.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import {PluginManifest, PluginMetadata} from "../../../src/interfaces/IPlugin.sol";
import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol";
import {IExecutionHook} from "../../../src/interfaces/IExecutionHook.sol";

import {BasePlugin} from "../../../src/plugins/BasePlugin.sol";

contract DirectCallPlugin is BasePlugin, IExecutionHook {
bool public preHookRan = false;
bool public postHookRan = false;

function onInstall(bytes calldata) external override {}

function onUninstall(bytes calldata) external override {}

function pluginManifest() external pure override returns (PluginManifest memory) {}

function directCall() external returns (bytes memory) {
return IStandardExecutor(msg.sender).execute(address(this), 0, abi.encodeCall(this.getData, ()));
}

function getData() external pure returns (bytes memory) {
return hex"04546b";
}

function pluginMetadata() external pure override returns (PluginMetadata memory) {}

function preExecutionHook(uint32, address sender, uint256, bytes calldata)
external
override
returns (bytes memory)
{
require(sender == address(this), "mock direct call pre permission hook failed");
preHookRan = true;
return abi.encode(keccak256(hex"04546b"));
}

function postExecutionHook(uint32, bytes calldata preExecHookData) external override {
require(
abi.decode(preExecHookData, (bytes32)) == keccak256(hex"04546b"),
"mock direct call post permission hook failed"
);
postHookRan = true;
}
}

0 comments on commit 44f80f4

Please sign in to comment.