Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post-only hooks #16

Merged
merged 7 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ abstract contract AccountLoupe is IAccountLoupe {
AccountStorage storage _storage = getAccountStorage();

FunctionReference[] memory preExecHooks =
toFunctionReferenceArray(_storage.selectorData[selector].preExecHooks);
toFunctionReferenceArray(_storage.selectorData[selector].executionHooks.preHooks);

uint256 numHooks = preExecHooks.length;
execHooks = new ExecutionHooks[](numHooks);

for (uint256 i = 0; i < numHooks;) {
execHooks[i].preExecHook = preExecHooks[i];
execHooks[i].postExecHook = _storage.selectorData[selector].associatedPostExecHooks[preExecHooks[i]];
execHooks[i].postExecHook =
_storage.selectorData[selector].executionHooks.associatedPostHooks[preExecHooks[i]];

unchecked {
++i;
Expand All @@ -76,15 +77,15 @@ abstract contract AccountLoupe is IAccountLoupe {
bytes24 key = getPermittedCallKey(callingPlugin, selector);

FunctionReference[] memory prePermittedCallHooks =
toFunctionReferenceArray(_storage.permittedCalls[key].prePermittedCallHooks);
toFunctionReferenceArray(_storage.permittedCalls[key].permittedCallHooks.preHooks);

uint256 numHooks = prePermittedCallHooks.length;
execHooks = new ExecutionHooks[](numHooks);

for (uint256 i = 0; i < numHooks;) {
execHooks[i].preExecHook = prePermittedCallHooks[i];
execHooks[i].postExecHook =
_storage.permittedCalls[key].associatedPostPermittedCallHooks[prePermittedCallHooks[i]];
_storage.permittedCalls[key].permittedCallHooks.associatedPostHooks[prePermittedCallHooks[i]];

unchecked {
++i;
Expand Down
87 changes: 51 additions & 36 deletions src/account/PluginManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
SelectorData,
PermittedCallData,
getPermittedCallKey,
HookGroup,
PermittedExternalCallData,
StoredInjectedHook
} from "../libraries/AccountStorage.sol";
Expand Down Expand Up @@ -129,34 +130,18 @@ abstract contract PluginManagerInternals is IPluginManager {

function _addExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook)
internal
notNullFunction(preExecHook)
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

if (!_selectorData.preExecHooks.add(_toSetValue(preExecHook))) {
// Treat the pre-exec and post-exec hook as a single unit, identified by the pre-exec hook.
// If the pre-exec hook exists, revert.
revert ExecutionHookAlreadySet(selector, preExecHook);
}

if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
_selectorData.associatedPostExecHooks[preExecHook] = postExecHook;
}
_addHooks(_selectorData.executionHooks, selector, preExecHook, postExecHook);
}

function _removeExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook)
internal
notNullFunction(preExecHook)
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
_selectorData.preExecHooks.remove(_toSetValue(preExecHook));

// If the post exec hook is set, clear it.
if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
_selectorData.associatedPostExecHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE;
}
_removeHooks(_selectorData.executionHooks, preExecHook, postExecHook);
}

function _enableExecFromPlugin(bytes4 selector, address plugin, AccountStorage storage accountStorage)
Expand All @@ -181,37 +166,67 @@ abstract contract PluginManagerInternals is IPluginManager {
address plugin,
FunctionReference preExecHook,
FunctionReference postExecHook
) internal notNullPlugin(plugin) notNullFunction(preExecHook) {
) internal notNullPlugin(plugin) {
bytes24 permittedCallKey = getPermittedCallKey(plugin, selector);
PermittedCallData storage _permittedCalldata = getAccountStorage().permittedCalls[permittedCallKey];

if (!_permittedCalldata.prePermittedCallHooks.add(_toSetValue(preExecHook))) {
// Treat the pre-exec and post-exec hook as a single unit, identified by the pre-exec hook.
// If the pre-exec hook exists, revert.
revert PermittedCallHookAlreadySet(selector, plugin, preExecHook);
}

if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
_permittedCalldata.associatedPostPermittedCallHooks[preExecHook] = postExecHook;
}
_addHooks(_permittedCalldata.permittedCallHooks, selector, preExecHook, postExecHook);
}

function _removePermittedCallHooks(
bytes4 selector,
address plugin,
FunctionReference preExecHook,
FunctionReference postExecHook
) internal notNullPlugin(plugin) notNullFunction(preExecHook) {
) internal notNullPlugin(plugin) {
bytes24 permittedCallKey = getPermittedCallKey(plugin, selector);
PermittedCallData storage _permittedCalldata = getAccountStorage().permittedCalls[permittedCallKey];
PermittedCallData storage _permittedCallData = getAccountStorage().permittedCalls[permittedCallKey];

// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
_permittedCalldata.prePermittedCallHooks.remove(_toSetValue(preExecHook));
_removeHooks(_permittedCallData.permittedCallHooks, preExecHook, postExecHook);
}

function _addHooks(
HookGroup storage hooks,
bytes4 selector,
FunctionReference preExecHook,
FunctionReference postExecHook
) internal {
if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
// add pre or pre/post pair of exec hooks
if (!hooks.preHooks.add(_toSetValue(preExecHook))) {
revert ExecutionHookAlreadySet(selector, preExecHook);
}

if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
hooks.associatedPostHooks[preExecHook] = postExecHook;
}
} else {
if (postExecHook == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
// both pre and post hooks cannot be null
revert NullFunctionReference();
}

hooks.postOnlyHooks.add(_toSetValue(postExecHook));
}
}

function _removeHooks(HookGroup storage hooks, FunctionReference preExecHook, FunctionReference postExecHook)
internal
{
if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
// remove pre or pre/post pair of exec hooks

// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
hooks.preHooks.remove(_toSetValue(preExecHook));

if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
hooks.associatedPostHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE;
}
} else {
// THe case where both pre and post hooks are null was checked during installation.

// If the post permitted call exec hook is set, clear it.
if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
_permittedCalldata.associatedPostPermittedCallHooks[preExecHook] =
FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE;
// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
hooks.postOnlyHooks.remove(_toSetValue(postExecHook));
}
}

Expand Down
104 changes: 41 additions & 63 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeab

import {AccountExecutor} from "./AccountExecutor.sol";
import {AccountLoupe} from "./AccountLoupe.sol";
import {AccountStorage, getAccountStorage, getPermittedCallKey} from "../libraries/AccountStorage.sol";
import {AccountStorage, HookGroup, getAccountStorage, getPermittedCallKey} from "../libraries/AccountStorage.sol";
import {AccountStorageInitializable} from "./AccountStorageInitializable.sol";
import {FunctionReference, FunctionReferenceLib} from "../libraries/FunctionReferenceLib.sol";
import {IPlugin, PluginManifest} from "../interfaces/IPlugin.sol";
Expand Down Expand Up @@ -66,7 +66,7 @@ contract UpgradeableModularAccount is
modifier wrapNativeFunction() {
_doRuntimeValidationIfNotFromEP();

PostExecToRun[] memory postExecHooks = _doPreExecHooks(msg.sig);
PostExecToRun[] memory postExecHooks = _doPreExecHooks(msg.sig, msg.data);

_;

Expand Down Expand Up @@ -127,7 +127,7 @@ contract UpgradeableModularAccount is

PostExecToRun[] memory postExecHooks;
// Cache post-exec hooks in memory
postExecHooks = _doPreExecHooks(msg.sig);
postExecHooks = _doPreExecHooks(msg.sig, msg.data);

// execute the function, bubbling up any reverts
(bool execSuccess, bytes memory execReturnData) = execPlugin.call(msg.data);
Expand Down Expand Up @@ -188,15 +188,16 @@ contract UpgradeableModularAccount is
revert ExecFromPluginNotPermitted(callingPlugin, selector);
}

PostExecToRun[] memory postPermittedCallHooks = _doPrePermittedCallHooks(selector, callingPlugin);
PostExecToRun[] memory postPermittedCallHooks =
_doPrePermittedCallHooks(getPermittedCallKey(callingPlugin, selector), data);

address execFunctionPlugin = _storage.selectorData[selector].plugin;

if (execFunctionPlugin == address(0)) {
revert UnrecognizedFunction(selector);
}

PostExecToRun[] memory postExecHooks = _doPreExecHooks(selector);
PostExecToRun[] memory postExecHooks = _doPreExecHooks(selector, data);

(bool success, bytes memory returnData) = execFunctionPlugin.call(data);

Expand Down Expand Up @@ -250,11 +251,13 @@ contract UpgradeableModularAccount is

// Run any pre plugin exec specific to this caller and the `executeFromPluginExternal` selector

PostExecToRun[] memory postPermittedCallHooks =
_doPrePermittedCallHooks(IPluginExecutor.executeFromPluginExternal.selector, msg.sender);
PostExecToRun[] memory postPermittedCallHooks = _doPrePermittedCallHooks(
getPermittedCallKey(msg.sender, IPluginExecutor.executeFromPluginExternal.selector), msg.data
);

// Run any pre exec hooks for this selector
PostExecToRun[] memory postExecHooks = _doPreExecHooks(IPluginExecutor.executeFromPluginExternal.selector);
PostExecToRun[] memory postExecHooks =
_doPreExecHooks(IPluginExecutor.executeFromPluginExternal.selector, msg.data);

// Perform the external call
bytes memory returnData = _exec(target, value, data);
Expand Down Expand Up @@ -476,68 +479,34 @@ contract UpgradeableModularAccount is
}
}

function _doPreExecHooks(bytes4 selector) internal returns (PostExecToRun[] memory postHooksToRun) {
EnumerableSet.Bytes32Set storage preExecHooks = getAccountStorage().selectorData[selector].preExecHooks;

uint256 postExecHooksLength = 0;
uint256 preExecHooksLength = preExecHooks.length();

// Over-allocate on length, but not all of this may get filled up.
postHooksToRun = new PostExecToRun[](preExecHooksLength);
for (uint256 i = 0; i < preExecHooksLength;) {
FunctionReference preExecHook = _toFunctionReference(preExecHooks.at(i));

if (preExecHook.isEmptyOrMagicValue()) {
if (preExecHook == FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY) {
revert AlwaysDenyRule();
}
// Function reference cannot be 0. If _RUNTIME_VALIDATION_ALWAYS_ALLOW, revert since it's an
// invalid configuration.
revert InvalidConfiguration();
}

(address plugin, uint8 functionId) = preExecHook.unpack();
bytes memory preExecHookReturnData;
try IPlugin(plugin).preExecutionHook(functionId, msg.sender, msg.value, msg.data) returns (
bytes memory returnData
) {
preExecHookReturnData = returnData;
} catch (bytes memory revertReason) {
revert PreExecHookReverted(plugin, functionId, revertReason);
}

// Check to see if there is a postExec hook set for this preExec hook
FunctionReference postExecHook =
getAccountStorage().selectorData[selector].associatedPostExecHooks[preExecHook];
if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
postHooksToRun[postExecHooksLength].postExecHook = postExecHook;
postHooksToRun[postExecHooksLength].preExecHookReturnData = preExecHookReturnData;
unchecked {
++postExecHooksLength;
}
}
function _doPreExecHooks(bytes4 selector, bytes calldata data)
internal
returns (PostExecToRun[] memory postHooksToRun)
{
HookGroup storage hooks = getAccountStorage().selectorData[selector].executionHooks;

unchecked {
++i;
}
}
return _doPreHooks(hooks, data);
}

function _doPrePermittedCallHooks(bytes4 executionSelector, address callerPlugin)
function _doPrePermittedCallHooks(bytes24 permittedCallKey, bytes calldata data)
internal
returns (PostExecToRun[] memory postHooksToRun)
{
bytes24 permittedCallKey = getPermittedCallKey(callerPlugin, executionSelector);
HookGroup storage hooks = getAccountStorage().permittedCalls[permittedCallKey].permittedCallHooks;

EnumerableSet.Bytes32Set storage preExecHooks =
getAccountStorage().permittedCalls[permittedCallKey].prePermittedCallHooks;
return _doPreHooks(hooks, data);
}

function _doPreHooks(HookGroup storage hooks, bytes calldata data)
internal
returns (PostExecToRun[] memory postHooksToRun)
{
uint256 postExecHooksLength = 0;
uint256 preExecHooksLength = preExecHooks.length();
postHooksToRun = new PostExecToRun[](preExecHooksLength); // Over-allocate on length, but not all of this
// may get filled up.
uint256 preExecHooksLength = hooks.preHooks.length();
// Over-allocate on length, but not all of this may get filled up.
postHooksToRun = new PostExecToRun[](preExecHooksLength + hooks.postOnlyHooks.length());
for (uint256 i = 0; i < preExecHooksLength;) {
FunctionReference preExecHook = _toFunctionReference(preExecHooks.at(i));
FunctionReference preExecHook = _toFunctionReference(hooks.preHooks.at(i));

if (preExecHook.isEmptyOrMagicValue()) {
if (preExecHook == FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY) {
Expand All @@ -550,7 +519,7 @@ contract UpgradeableModularAccount is

(address plugin, uint8 functionId) = preExecHook.unpack();
bytes memory preExecHookReturnData;
try IPlugin(plugin).preExecutionHook(functionId, msg.sender, msg.value, msg.data) returns (
try IPlugin(plugin).preExecutionHook(functionId, msg.sender, msg.value, data) returns (
bytes memory returnData
) {
preExecHookReturnData = returnData;
Expand All @@ -559,8 +528,7 @@ contract UpgradeableModularAccount is
}

// Check to see if there is a postExec hook set for this preExec hook
FunctionReference postExecHook =
getAccountStorage().permittedCalls[permittedCallKey].associatedPostPermittedCallHooks[preExecHook];
FunctionReference postExecHook = hooks.associatedPostHooks[preExecHook];
if (FunctionReference.unwrap(postExecHook) != 0) {
postHooksToRun[postExecHooksLength].postExecHook = postExecHook;
postHooksToRun[postExecHooksLength].preExecHookReturnData = preExecHookReturnData;
Expand All @@ -573,6 +541,16 @@ contract UpgradeableModularAccount is
++i;
}
}

// Copy post-only hooks to the end of the array
uint256 postOnlyHooksLength = hooks.postOnlyHooks.length();
for (uint256 i = 0; i < postOnlyHooksLength;) {
postHooksToRun[postExecHooksLength].postExecHook = _toFunctionReference(hooks.postOnlyHooks.at(i));
unchecked {
++postExecHooksLength;
++i;
}
}
}

function _doCachedPostExecHooks(PostExecToRun[] memory postHooksToRun) internal {
Expand Down
Loading