Skip to content

Commit

Permalink
feat: add permission hooks to manifest and install flow and cache req…
Browse files Browse the repository at this point in the history
…uired context count
  • Loading branch information
howydev committed Jun 10, 2024
1 parent bbc7d6a commit b183d5a
Show file tree
Hide file tree
Showing 11 changed files with 432 additions and 72 deletions.
36 changes: 25 additions & 11 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@ import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet
import {IAccountLoupe, ExecutionHook} from "../interfaces/IAccountLoupe.sol";
import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.sol";
import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol";
import {
AccountStorage,
getAccountStorage,
SelectorData,
toFunctionReferenceArray,
toExecutionHook
} from "./AccountStorage.sol";
import {AccountStorage, getAccountStorage, toFunctionReferenceArray, toExecutionHook} from "./AccountStorage.sol";

abstract contract AccountLoupe is IAccountLoupe {
using EnumerableSet for EnumerableSet.Bytes32Set;
Expand Down Expand Up @@ -47,15 +41,35 @@ abstract contract AccountLoupe is IAccountLoupe {
override
returns (ExecutionHook[] memory execHooks)
{
SelectorData storage selectorData = getAccountStorage().selectorData[selector];
uint256 executionHooksLength = selectorData.executionHooks.length();
EnumerableSet.Bytes32Set storage hooks = getAccountStorage().selectorData[selector].executionHooks;
uint256 executionHooksLength = hooks.length();

execHooks = new ExecutionHook[](executionHooksLength);

for (uint256 i = 0; i < executionHooksLength; ++i) {
bytes32 key = selectorData.executionHooks.at(i);
bytes32 key = hooks.at(i);
ExecutionHook memory execHook = execHooks[i];
(execHook.hookFunction, execHook.isPreHook, execHook.isPostHook) = toExecutionHook(key);
(execHook.hookFunction, execHook.isPreHook, execHook.isPostHook, execHook.requireUOContext) =
toExecutionHook(key);
}
}

/// @inheritdoc IAccountLoupe
function getPermissionHooks(FunctionReference validationFunction)
external
view
override
returns (ExecutionHook[] memory permissionHooks)
{
EnumerableSet.Bytes32Set storage hooks =
getAccountStorage().validationData[validationFunction].permissionHooks;
uint256 executionHooksLength = hooks.length();
permissionHooks = new ExecutionHook[](executionHooksLength);
for (uint256 i = 0; i < executionHooksLength; ++i) {
bytes32 key = hooks.at(i);
ExecutionHook memory execHook = permissionHooks[i];
(execHook.hookFunction, execHook.isPreHook, execHook.isPostHook, execHook.requireUOContext) =
toExecutionHook(key);
}
}

Expand Down
10 changes: 8 additions & 2 deletions src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ struct ValidationData {
bool isShared;
// Whether or not this validation is a signature validator.
bool isSignatureValidation;
// How many execution hooks require the UO context.
uint8 requireUOHookCount;
// The pre validation hooks for this function selector.
EnumerableSet.Bytes32Set preValidationHooks;
// Permission hooks for this validation function.
EnumerableSet.Bytes32Set permissionHooks;
}

struct AccountStorage {
Expand Down Expand Up @@ -83,16 +87,18 @@ function toFunctionReference(bytes32 setValue) pure returns (FunctionReference)
function toSetValue(ExecutionHook memory executionHook) pure returns (bytes32) {
return bytes32(FunctionReference.unwrap(executionHook.hookFunction))
| bytes32(executionHook.isPreHook ? uint256(1) << 80 : 0)
| bytes32(executionHook.isPostHook ? uint256(1) << 72 : 0);
| bytes32(executionHook.isPostHook ? uint256(1) << 72 : 0)
| bytes32(executionHook.requireUOContext ? uint256(1) << 64 : 0);
}

function toExecutionHook(bytes32 setValue)
pure
returns (FunctionReference hookFunction, bool isPreHook, bool isPostHook)
returns (FunctionReference hookFunction, bool isPreHook, bool isPostHook, bool requireUOContext)
{
hookFunction = FunctionReference.wrap(bytes21(setValue));
isPreHook = (uint256(setValue) >> 80) & 0xFF == 1;
isPostHook = (uint256(setValue) >> 72) & 0xFF == 1;
requireUOContext = (uint256(setValue) >> 64) & 0xFF == 1;
}

/// @dev Helper function to get all elements of a set into memory.
Expand Down
60 changes: 49 additions & 11 deletions src/account/PluginManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol";
import {
IPlugin,
ManifestExecutionHook,
ManifestPermissionHook,
ManifestFunction,
ManifestAssociatedFunctionType,
ManifestAssociatedFunction,
Expand Down Expand Up @@ -100,27 +101,39 @@ abstract contract PluginManagerInternals is IPluginManager {
}

function _addExecHooks(
bytes4 selector,
EnumerableSet.Bytes32Set storage hooks,
FunctionReference hookFunction,
bool isPreExecHook,
bool isPostExecHook
bool isPostExecHook,
bool requireUOContext
) internal {
getAccountStorage().selectorData[selector].executionHooks.add(
hooks.add(
toSetValue(
ExecutionHook({hookFunction: hookFunction, isPreHook: isPreExecHook, isPostHook: isPostExecHook})
ExecutionHook({
hookFunction: hookFunction,
isPreHook: isPreExecHook,
isPostHook: isPostExecHook,
requireUOContext: requireUOContext
})
)
);
}

function _removeExecHooks(
bytes4 selector,
EnumerableSet.Bytes32Set storage hooks,
FunctionReference hookFunction,
bool isPreExecHook,
bool isPostExecHook
bool isPostExecHook,
bool requireUOContext
) internal {
getAccountStorage().selectorData[selector].executionHooks.remove(
hooks.remove(
toSetValue(
ExecutionHook({hookFunction: hookFunction, isPreHook: isPreExecHook, isPostHook: isPostExecHook})
ExecutionHook({
hookFunction: hookFunction,
isPreHook: isPreExecHook,
isPostHook: isPostExecHook,
requireUOContext: requireUOContext
})
)
);
}
Expand Down Expand Up @@ -205,8 +218,21 @@ abstract contract PluginManagerInternals is IPluginManager {
length = manifest.executionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestExecutionHook memory mh = manifest.executionHooks[i];
EnumerableSet.Bytes32Set storage execHooks = _storage.selectorData[mh.executionSelector].executionHooks;
FunctionReference hookFunction = FunctionReferenceLib.pack(plugin, mh.functionId);
_addExecHooks(mh.executionSelector, hookFunction, mh.isPreHook, mh.isPostHook);
_addExecHooks(execHooks, hookFunction, mh.isPreHook, mh.isPostHook, mh.requireUOContext);
}

length = manifest.permissionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestPermissionHook memory mh = manifest.permissionHooks[i];
EnumerableSet.Bytes32Set storage permissionHooks =
_storage.validationData[mh.validationFunction].permissionHooks;
FunctionReference hookFunction = FunctionReferenceLib.pack(plugin, mh.functionId);
_addExecHooks(permissionHooks, hookFunction, mh.isPreHook, mh.isPostHook, mh.requireUOContext);
if (mh.requireUOContext) {
_storage.validationData[mh.validationFunction].requireUOHookCount += 1;
}
}

length = manifest.interfaceIds.length;
Expand Down Expand Up @@ -257,12 +283,24 @@ abstract contract PluginManagerInternals is IPluginManager {
}

// Remove components according to the manifest, in reverse order (by component type) of their installation.

length = manifest.executionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestExecutionHook memory mh = manifest.executionHooks[i];
FunctionReference hookFunction = FunctionReferenceLib.pack(plugin, mh.functionId);
_removeExecHooks(mh.executionSelector, hookFunction, mh.isPreHook, mh.isPostHook);
EnumerableSet.Bytes32Set storage execHooks = _storage.selectorData[mh.executionSelector].executionHooks;
_removeExecHooks(execHooks, hookFunction, mh.isPreHook, mh.isPostHook, mh.requireUOContext);
}

length = manifest.permissionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestPermissionHook memory mh = manifest.permissionHooks[i];
FunctionReference hookFunction = FunctionReferenceLib.pack(plugin, mh.functionId);
EnumerableSet.Bytes32Set storage permissionHooks =
_storage.validationData[mh.validationFunction].permissionHooks;
_removeExecHooks(permissionHooks, hookFunction, mh.isPreHook, mh.isPostHook, mh.requireUOContext);
if (mh.requireUOContext) {
_storage.validationData[mh.validationFunction].requireUOHookCount -= 1;
}
}

length = manifest.signatureValidationFunctions.length;
Expand Down
56 changes: 37 additions & 19 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import {AccountLoupe} from "./AccountLoupe.sol";
import {
AccountStorage,
getAccountStorage,
SelectorData,
toSetValue,
toFunctionReference,
toExecutionHook
Expand Down Expand Up @@ -72,6 +71,7 @@ contract UpgradeableModularAccount is
error PostExecHookReverted(address plugin, uint8 functionId, bytes revertReason);
error PreExecHookReverted(address plugin, uint8 functionId, bytes revertReason);
error PreRuntimeValidationHookFailed(address plugin, uint8 functionId, bytes revertReason);
error RequireUserOperationContext();
error RuntimeValidationFunctionMissing(bytes4 selector);
error RuntimeValidationFunctionReverted(address plugin, uint8 functionId, bytes revertReason);
error SignatureValidationInvalid(address plugin, uint8 functionId);
Expand All @@ -85,7 +85,8 @@ contract UpgradeableModularAccount is
modifier wrapNativeFunction() {
_checkPermittedCallerIfNotFromEP();

PostExecToRun[] memory postExecHooks = _doPreExecHooks(msg.sig, msg.data);
PostExecToRun[] memory postExecHooks =
_doPreExecHooks(getAccountStorage().selectorData[msg.sig].executionHooks, msg.data);

_;

Expand Down Expand Up @@ -139,7 +140,7 @@ contract UpgradeableModularAccount is

PostExecToRun[] memory postExecHooks;
// Cache post-exec hooks in memory
postExecHooks = _doPreExecHooks(msg.sig, msg.data);
postExecHooks = _doPreExecHooks(getAccountStorage().selectorData[msg.sig].executionHooks, msg.data);

// execute the function, bubbling up any reverts
(bool execSuccess, bytes memory execReturnData) = execPlugin.call(msg.data);
Expand Down Expand Up @@ -195,7 +196,8 @@ contract UpgradeableModularAccount is
revert UnrecognizedFunction(selector);
}

PostExecToRun[] memory postExecHooks = _doPreExecHooks(selector, data);
PostExecToRun[] memory postExecHooks =
_doPreExecHooks(_storage.selectorData[selector].executionHooks, data);

(bool success, bytes memory returnData) = execFunctionPlugin.call(data);

Expand All @@ -222,8 +224,10 @@ contract UpgradeableModularAccount is
}

// Run any pre exec hooks for this selector
PostExecToRun[] memory postExecHooks =
_doPreExecHooks(IPluginExecutor.executeFromPluginExternal.selector, msg.data);
PostExecToRun[] memory postExecHooks = _doPreExecHooks(
getAccountStorage().selectorData[IPluginExecutor.executeFromPluginExternal.selector].executionHooks,
msg.data
);

// Perform the external call
bytes memory returnData = _exec(target, value, data);
Expand Down Expand Up @@ -382,7 +386,7 @@ contract UpgradeableModularAccount is
internal
virtual
override
returns (uint256 validationData)
returns (uint256 returnValData)
{
if (userOp.callData.length < 4) {
revert UnrecognizedFunction(bytes4(userOp.callData));
Expand All @@ -394,7 +398,18 @@ contract UpgradeableModularAccount is

_checkIfValidationApplies(selector, userOpValidationFunction, isSharedValidation);

validationData =
// Check if there are exec hooks associated with the validator that require UO context, and revert if the
// call isn't to `executeUserOp`
// This check must be here because if context isn't passed, we wouldn't be able to get the exec hooks
// associated with the validator
if (getAccountStorage().validationData[userOpValidationFunction].requireUOHookCount > 0) {
/**
* && msg.sig != this.executeUserOp.selector
*/
revert RequireUserOperationContext();
}

returnValData =
_doUserOpValidation(selector, userOpValidationFunction, userOp, userOp.signature[22:], userOpHash);
}

Expand Down Expand Up @@ -485,22 +500,26 @@ contract UpgradeableModularAccount is
}
}

function _doPreExecHooks(bytes4 selector, bytes calldata data)
function _doPreExecHooks(EnumerableSet.Bytes32Set storage executionHooks, bytes calldata data)
internal
returns (PostExecToRun[] memory postHooksToRun)
{
SelectorData storage selectorData = getAccountStorage().selectorData[selector];

uint256 hooksLength = selectorData.executionHooks.length();
uint256 hooksLength = executionHooks.length();

// Overallocate on length - not all of this may get filled up. We set the correct length later.
postHooksToRun = new PostExecToRun[](hooksLength);

// Copy all post hooks to the array. This happens before any pre hooks are run, so we can
// be sure that the set of hooks to run will not be affected by state changes mid-execution.
for (uint256 i = 0; i < hooksLength; ++i) {
bytes32 key = selectorData.executionHooks.at(i);
(FunctionReference hookFunction,, bool isPostHook) = toExecutionHook(key);
bytes32 key = executionHooks.at(i);
(FunctionReference hookFunction,, bool isPostHook, bool requireUOContext) = toExecutionHook(key);
if (requireUOContext) {
/**
* && msg.sig != this.executeUserOp.selector
*/
revert RequireUserOperationContext();
}
if (isPostHook) {
postHooksToRun[i].postExecHook = hookFunction;
}
Expand All @@ -509,8 +528,8 @@ contract UpgradeableModularAccount is
// Run the pre hooks and copy their return data to the post hooks array, if an associated post-exec hook
// exists.
for (uint256 i = 0; i < hooksLength; ++i) {
bytes32 key = selectorData.executionHooks.at(i);
(FunctionReference hookFunction, bool isPreHook, bool isPostHook) = toExecutionHook(key);
bytes32 key = executionHooks.at(i);
(FunctionReference hookFunction, bool isPreHook, bool isPostHook,) = toExecutionHook(key);

if (isPreHook) {
bytes memory preExecHookReturnData = _runPreExecHook(hookFunction, data);
Expand All @@ -528,9 +547,8 @@ contract UpgradeableModularAccount is
returns (bytes memory preExecHookReturnData)
{
(address plugin, uint8 functionId) = preExecHook.unpack();
try IExecutionHook(plugin).preExecutionHook(functionId, msg.sender, msg.value, data) returns (
bytes memory returnData
) {
try IExecutionHook(plugin).preExecutionHook(functionId, abi.encodePacked(msg.sender, msg.value, data))
returns (bytes memory returnData) {
preExecHookReturnData = returnData;
} catch (bytes memory revertReason) {
revert PreExecHookReverted(plugin, functionId, revertReason);
Expand Down
9 changes: 9 additions & 0 deletions src/interfaces/IAccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ struct ExecutionHook {
FunctionReference hookFunction;
bool isPreHook;
bool isPostHook;
bool requireUOContext;
}

interface IAccountLoupe {
Expand All @@ -28,6 +29,14 @@ interface IAccountLoupe {
/// @return The pre and post execution hooks for this selector.
function getExecutionHooks(bytes4 selector) external view returns (ExecutionHook[] memory);

/// @notice Get the pre and post execution hooks for a validation function.
/// @param validationFunction The validation function to get the hooks for.
/// @return The pre and post execution hooks for this validation function.
function getPermissionHooks(FunctionReference validationFunction)
external
view
returns (ExecutionHook[] memory);

/// @notice Get the pre user op and runtime validation hooks associated with a selector.
/// @param validationFunction The validation function to get the hooks for.
/// @return preValidationHooks The pre validation hooks for this selector.
Expand Down
9 changes: 3 additions & 6 deletions src/interfaces/IExecutionHook.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@ interface IExecutionHook is IPlugin {
/// @dev To indicate the entire call should revert, the function MUST revert.
/// @param functionId An identifier that routes the call to different internal implementations, should there be
/// more than one.
/// @param sender The caller address.
/// @param value The call value.
/// @param data The calldata sent.
/// @param data If hook requires UO context, data is abi.encode(PackedUserOperation), else its
/// abi.encodePacked(sender, value, calldata)
/// @return Context to pass to a post execution hook, if present. An empty bytes array MAY be returned.
function preExecutionHook(uint8 functionId, address sender, uint256 value, bytes calldata data)
external
returns (bytes memory);
function preExecutionHook(uint8 functionId, bytes calldata data) external returns (bytes memory);

/// @notice Run the post execution hook specified by the `functionId`.
/// @dev To indicate the entire call should revert, the function MUST revert.
Expand Down
Loading

0 comments on commit b183d5a

Please sign in to comment.