Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [v0.8-develop] remove hook overlap support #47

Merged
merged 6 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 8 additions & 44 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
pragma solidity ^0.8.25;

import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol";
import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {IAccountLoupe} from "../interfaces/IAccountLoupe.sol";
Expand All @@ -11,7 +10,7 @@ import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol";
import {AccountStorage, getAccountStorage, SelectorData, toFunctionReferenceArray} from "./AccountStorage.sol";

abstract contract AccountLoupe is IAccountLoupe {
using EnumerableMap for EnumerableMap.Bytes32ToUintMap;
using EnumerableSet for EnumerableSet.Bytes32Set;
using EnumerableSet for EnumerableSet.AddressSet;

/// @inheritdoc IAccountLoupe
Expand Down Expand Up @@ -41,56 +40,21 @@ abstract contract AccountLoupe is IAccountLoupe {
SelectorData storage selectorData = getAccountStorage().selectorData[selector];
uint256 preExecHooksLength = selectorData.preHooks.length();
uint256 postOnlyExecHooksLength = selectorData.postOnlyHooks.length();
uint256 maxExecHooksLength = postOnlyExecHooksLength;

// There can only be as many associated post hooks to run as there are pre hooks.
for (uint256 i = 0; i < preExecHooksLength; ++i) {
(, uint256 count) = selectorData.preHooks.at(i);
unchecked {
maxExecHooksLength += (count + 1);
}
}

// Overallocate on length - not all of this may get filled up. We set the correct length later.
execHooks = new ExecutionHooks[](maxExecHooksLength);
uint256 actualExecHooksLength;
execHooks = new ExecutionHooks[](preExecHooksLength + postOnlyExecHooksLength);

for (uint256 i = 0; i < preExecHooksLength; ++i) {
(bytes32 key,) = selectorData.preHooks.at(i);
bytes32 key = selectorData.preHooks.at(i);
FunctionReference preExecHook = FunctionReference.wrap(bytes21(key));
FunctionReference associatedPostExecHook = selectorData.associatedPostHooks[preExecHook];

uint256 associatedPostExecHooksLength = selectorData.associatedPostHooks[preExecHook].length();
if (associatedPostExecHooksLength > 0) {
for (uint256 j = 0; j < associatedPostExecHooksLength; ++j) {
execHooks[actualExecHooksLength].preExecHook = preExecHook;
(key,) = selectorData.associatedPostHooks[preExecHook].at(j);
execHooks[actualExecHooksLength].postExecHook = FunctionReference.wrap(bytes21(key));

unchecked {
++actualExecHooksLength;
}
}
} else {
execHooks[actualExecHooksLength].preExecHook = preExecHook;

unchecked {
++actualExecHooksLength;
}
}
execHooks[i].preExecHook = preExecHook;
execHooks[i].postExecHook = associatedPostExecHook;
}

for (uint256 i = 0; i < postOnlyExecHooksLength; ++i) {
(bytes32 key,) = selectorData.postOnlyHooks.at(i);
execHooks[actualExecHooksLength].postExecHook = FunctionReference.wrap(bytes21(key));

unchecked {
++actualExecHooksLength;
}
}

// Trim the exec hooks array to the actual length, since we may have overallocated.
assembly ("memory-safe") {
mstore(execHooks, actualExecHooksLength)
bytes32 key = selectorData.postOnlyHooks.at(i);
execHooks[preExecHooksLength + i].postExecHook = FunctionReference.wrap(bytes21(key));
}
}

Expand Down
24 changes: 14 additions & 10 deletions src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {IPlugin} from "../interfaces/IPlugin.sol";
Expand Down Expand Up @@ -35,15 +34,20 @@ struct SelectorData {
// The plugin that implements this execution function.
// If this is a native function, the address must remain address(0).
address plugin;
// How many times a `PRE_HOOK_ALWAYS_DENY` has been added for this function.
// Since that is the only type of hook that may overlap, we can use this to track the number of times it has
// been applied, and whether or not the deny should apply. The size `uint48` was chosen somewhat arbitrarily,
// but it packs alongside `plugin` while still leaving some other space in the slot for future packing.
uint48 denyExecutionCount;
adam-alchemy marked this conversation as resolved.
Show resolved Hide resolved
// User operation validation and runtime validation share a function reference.
FunctionReference validation;
// The pre validation hooks for this function selector.
EnumerableMap.Bytes32ToUintMap preValidationHooks;
EnumerableSet.Bytes32Set preValidationHooks;
// The execution hooks for this function selector.
EnumerableMap.Bytes32ToUintMap preHooks;
EnumerableSet.Bytes32Set preHooks;
// bytes21 key = pre hook function reference
mapping(FunctionReference => EnumerableMap.Bytes32ToUintMap) associatedPostHooks;
EnumerableMap.Bytes32ToUintMap postOnlyHooks;
mapping(FunctionReference => FunctionReference) associatedPostHooks;
EnumerableSet.Bytes32Set postOnlyHooks;
}

struct AccountStorage {
Expand Down Expand Up @@ -73,17 +77,17 @@ function getPermittedCallKey(address addr, bytes4 selector) pure returns (bytes2
return bytes24(bytes20(addr)) | (bytes24(selector) >> 160);
}

// Helper function to get all elements of a set into memory.
using EnumerableMap for EnumerableMap.Bytes32ToUintMap;
using EnumerableSet for EnumerableSet.Bytes32Set;

function toFunctionReferenceArray(EnumerableMap.Bytes32ToUintMap storage map)
/// @dev Helper function to get all elements of a set into memory.
function toFunctionReferenceArray(EnumerableSet.Bytes32Set storage set)
view
returns (FunctionReference[] memory)
{
uint256 length = map.length();
uint256 length = set.length();
FunctionReference[] memory result = new FunctionReference[](length);
for (uint256 i = 0; i < length; ++i) {
(bytes32 key,) = map.at(i);
bytes32 key = set.at(i);
result[i] = FunctionReference.wrap(bytes21(key));
}
return result;
Expand Down
85 changes: 36 additions & 49 deletions src/account/PluginManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
pragma solidity ^0.8.25;

import {ERC165Checker} from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol";
import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol";
Expand All @@ -25,7 +24,7 @@ import {
} from "./AccountStorage.sol";

abstract contract PluginManagerInternals is IPluginManager {
using EnumerableMap for EnumerableMap.Bytes32ToUintMap;
using EnumerableSet for EnumerableSet.Bytes32Set;
using EnumerableSet for EnumerableSet.AddressSet;
using FunctionReferenceLib for FunctionReference;

Expand Down Expand Up @@ -103,19 +102,28 @@ abstract contract PluginManagerInternals is IPluginManager {
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

if (!preExecHook.isEmpty()) {
adam-alchemy marked this conversation as resolved.
Show resolved Hide resolved
_addOrIncrement(_selectorData.preHooks, _toSetValue(preExecHook));
if (preExecHook.eq(FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY)) {
// Increment `denyExecutionCount`, because this pre exec hook may be applied multiple times.
_selectorData.denyExecutionCount += 1;
return;
}

// Don't need to check for duplicates, as the hook can be run at most once.
_selectorData.preHooks.add(_toSetValue(preExecHook));

if (!postExecHook.isEmpty()) {
_addOrIncrement(_selectorData.associatedPostHooks[preExecHook], _toSetValue(postExecHook));
}
} else {
if (postExecHook.isEmpty()) {
// both pre and post hooks cannot be null
revert NullFunctionReference();
_selectorData.associatedPostHooks[preExecHook] = postExecHook;
}

_addOrIncrement(_selectorData.postOnlyHooks, _toSetValue(postExecHook));
return;
}

if (postExecHook.isEmpty()) {
// both pre and post hooks cannot be null
revert NullFunctionReference();
}

_selectorData.postOnlyHooks.add(_toSetValue(postExecHook));
}

function _removeExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook)
Expand All @@ -124,36 +132,40 @@ abstract contract PluginManagerInternals is IPluginManager {
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

if (!preExecHook.isEmpty()) {
_removeOrDecrement(_selectorData.preHooks, _toSetValue(preExecHook));
if (preExecHook.eq(FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY)) {
// Decrement `denyExecutionCount`, because this pre exec hook may be applied multiple times.
_selectorData.denyExecutionCount -= 1;
return;
}

_selectorData.preHooks.remove(_toSetValue(preExecHook));

if (!postExecHook.isEmpty()) {
_removeOrDecrement(_selectorData.associatedPostHooks[preExecHook], _toSetValue(postExecHook));
_selectorData.associatedPostHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE;
}
} else {
// The case where both pre and post hooks are null was checked during installation.

// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
_removeOrDecrement(_selectorData.postOnlyHooks, _toSetValue(postExecHook));
return;
}

// The case where both pre and post hooks are null was checked during installation.

// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
_selectorData.postOnlyHooks.remove(_toSetValue(postExecHook));
}

function _addPreValidationHook(bytes4 selector, FunctionReference preValidationHook)
internal
notNullFunction(preValidationHook)
{
_addOrIncrement(
getAccountStorage().selectorData[selector].preValidationHooks, _toSetValue(preValidationHook)
);
getAccountStorage().selectorData[selector].preValidationHooks.add(_toSetValue(preValidationHook));
}

function _removePreValidationHook(bytes4 selector, FunctionReference preValidationHook)
internal
notNullFunction(preValidationHook)
{
// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
_removeOrDecrement(
getAccountStorage().selectorData[selector].preValidationHooks, _toSetValue(preValidationHook)
);
getAccountStorage().selectorData[selector].preValidationHooks.remove(_toSetValue(preValidationHook));
}

function _installPlugin(
Expand Down Expand Up @@ -276,10 +288,7 @@ abstract contract PluginManagerInternals is IPluginManager {
_addPreValidationHook(
mh.executionSelector,
_resolveManifestFunction(
mh.associatedFunction,
plugin,
emptyDependencies,
ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY
mh.associatedFunction, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE
)
);
}
Expand Down Expand Up @@ -370,10 +379,7 @@ abstract contract PluginManagerInternals is IPluginManager {
_removePreValidationHook(
mh.executionSelector,
_resolveManifestFunction(
mh.associatedFunction,
plugin,
emptyDependencies,
ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY
mh.associatedFunction, plugin, emptyDependencies, ManifestAssociatedFunctionType.NONE
)
);
}
Expand Down Expand Up @@ -449,25 +455,6 @@ abstract contract PluginManagerInternals is IPluginManager {
emit PluginUninstalled(plugin, onUninstallSuccess);
}

function _addOrIncrement(EnumerableMap.Bytes32ToUintMap storage map, bytes32 key) internal {
(bool success, uint256 value) = map.tryGet(key);
map.set(key, success ? value + 1 : 0);
}

/// @return True if the key was removed or its value was decremented, false if the key was not found.
function _removeOrDecrement(EnumerableMap.Bytes32ToUintMap storage map, bytes32 key) internal returns (bool) {
(bool success, uint256 value) = map.tryGet(key);
if (!success) {
return false;
}
if (value == 0) {
map.remove(key);
} else {
map.set(key, value - 1);
}
return true;
}

function _toSetValue(FunctionReference functionReference) internal pure returns (bytes32) {
return bytes32(FunctionReference.unwrap(functionReference));
}
Expand Down
Loading
Loading