Skip to content

Commit

Permalink
Update internal structures to use hookConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed committed Jul 23, 2024
1 parent e849987 commit 167bfcb
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 68 deletions.
20 changes: 15 additions & 5 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeab
import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {HookConfigLib} from "../helpers/HookConfigLib.sol";
import {ExecutionHook, IAccountLoupe} from "../interfaces/IAccountLoupe.sol";
import {IModuleManager, ModuleEntity} from "../interfaces/IModuleManager.sol";
import {HookConfig, IModuleManager, ModuleEntity} from "../interfaces/IModuleManager.sol";
import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol";
import {getAccountStorage, toExecutionHook, toSelector} from "./AccountStorage.sol";

abstract contract AccountLoupe is IAccountLoupe {
using EnumerableSet for EnumerableSet.Bytes32Set;
using EnumerableMap for EnumerableMap.AddressToUintMap;
using HookConfigLib for HookConfig;

/// @inheritdoc IAccountLoupe
function getExecutionFunctionHandler(bytes4 selector) external view override returns (address module) {
Expand Down Expand Up @@ -56,8 +58,12 @@ abstract contract AccountLoupe is IAccountLoupe {

for (uint256 i = 0; i < executionHooksLength; ++i) {
bytes32 key = hooks.at(i);
ExecutionHook memory execHook = execHooks[i];
(execHook.hookFunction, execHook.isPreHook, execHook.isPostHook) = toExecutionHook(key);
HookConfig hook = toExecutionHook(key);
execHooks[i] = ExecutionHook({
hookFunction: hook.moduleEntity(),
isPreHook: hook.hasPreHook(),
isPostHook: hook.hasPostHook()
});
}
}

Expand All @@ -74,8 +80,12 @@ abstract contract AccountLoupe is IAccountLoupe {
permissionHooks = new ExecutionHook[](executionHooksLength);
for (uint256 i = 0; i < executionHooksLength; ++i) {
bytes32 key = hooks.at(i);
ExecutionHook memory execHook = permissionHooks[i];
(execHook.hookFunction, execHook.isPreHook, execHook.isPostHook) = toExecutionHook(key);
HookConfig hook = toExecutionHook(key);
permissionHooks[i] = ExecutionHook({
hookFunction: hook.moduleEntity(),
isPreHook: hook.hasPreHook(),
isPostHook: hook.hasPostHook()
});
}
}

Expand Down
18 changes: 5 additions & 13 deletions src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ pragma solidity ^0.8.25;

import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {ExecutionHook} from "../interfaces/IAccountLoupe.sol";
import {ModuleEntity} from "../interfaces/IModuleManager.sol";
import {HookConfig, ModuleEntity} from "../interfaces/IModuleManager.sol";

// bytes = keccak256("ERC6900.UpgradeableModularAccount.Storage")
bytes32 constant _ACCOUNT_STORAGE_SLOT = 0x9f09680beaa4e5c9f38841db2460c401499164f368baef687948c315d9073e40;
Expand Down Expand Up @@ -70,19 +69,12 @@ function toModuleEntity(bytes32 setValue) pure returns (ModuleEntity) {
// 0x________________________________________________AA____________________ is pre hook
// 0x__________________________________________________BB__________________ is post hook

function toSetValue(ExecutionHook memory executionHook) pure returns (bytes32) {
return bytes32(ModuleEntity.unwrap(executionHook.hookFunction))
| bytes32(executionHook.isPreHook ? uint256(1) << 56 : 0)
| bytes32(executionHook.isPostHook ? uint256(1) << 48 : 0);
function toSetValue(HookConfig hookConfig) pure returns (bytes32) {
return bytes32(HookConfig.unwrap(hookConfig));
}

function toExecutionHook(bytes32 setValue)
pure
returns (ModuleEntity hookFunction, bool isPreHook, bool isPostHook)
{
hookFunction = ModuleEntity.wrap(bytes24(setValue));
isPreHook = (uint256(setValue) >> 56) & 0xFF == 1;
isPostHook = (uint256(setValue) >> 48) & 0xFF == 1;
function toExecutionHook(bytes32 setValue) pure returns (HookConfig) {
return HookConfig.wrap(bytes26(setValue));
}

function toSetValue(bytes4 selector) pure returns (bytes32) {
Expand Down
58 changes: 22 additions & 36 deletions src/account/ModuleManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import {HookConfigLib} from "../helpers/HookConfigLib.sol";
import {KnownSelectors} from "../helpers/KnownSelectors.sol";
import {ModuleEntityLib} from "../helpers/ModuleEntityLib.sol";
import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol";
import {ExecutionHook} from "../interfaces/IAccountLoupe.sol";
import {IModule, ManifestExecutionHook, ModuleManifest} from "../interfaces/IModule.sol";
import {HookConfig, IModuleManager, ModuleEntity, ValidationConfig} from "../interfaces/IModuleManager.sol";
import {
Expand All @@ -34,7 +33,7 @@ abstract contract ModuleManagerInternals is IModuleManager {
error IModuleFunctionNotAllowed(bytes4 selector);
error NativeFunctionNotAllowed(bytes4 selector);
error NullModule();
error PermissionAlreadySet(ModuleEntity validationFunction, ExecutionHook hook);
error PermissionAlreadySet(ModuleEntity validationFunction, HookConfig hookFunction);
error ModuleInstallCallbackFailed(address module, bytes revertReason);
error ModuleInterfaceNotSupported(address module);
error ModuleNotInstalled(address module);
Expand Down Expand Up @@ -117,30 +116,12 @@ abstract contract ModuleManagerInternals is IModuleManager {
}
}

function _addExecHooks(
EnumerableSet.Bytes32Set storage hooks,
ModuleEntity hookFunction,
bool isPreExecHook,
bool isPostExecHook
) internal {
hooks.add(
toSetValue(
ExecutionHook({hookFunction: hookFunction, isPreHook: isPreExecHook, isPostHook: isPostExecHook})
)
);
function _addExecHooks(EnumerableSet.Bytes32Set storage hooks, HookConfig hookFunction) internal {
hooks.add(toSetValue(hookFunction));
}

function _removeExecHooks(
EnumerableSet.Bytes32Set storage hooks,
ModuleEntity hookFunction,
bool isPreExecHook,
bool isPostExecHook
) internal {
hooks.remove(
toSetValue(
ExecutionHook({hookFunction: hookFunction, isPreHook: isPreExecHook, isPostHook: isPostExecHook})
)
);
function _removeExecHooks(EnumerableSet.Bytes32Set storage hooks, HookConfig hookFunction) internal {
hooks.remove(toSetValue(hookFunction));
}

function _installModule(address module, ModuleManifest calldata manifest, bytes memory moduleInstallData)
Expand Down Expand Up @@ -171,8 +152,13 @@ abstract contract ModuleManagerInternals is IModuleManager {
for (uint256 i = 0; i < length; ++i) {
ManifestExecutionHook memory mh = manifest.executionHooks[i];
EnumerableSet.Bytes32Set storage execHooks = _storage.selectorData[mh.executionSelector].executionHooks;
ModuleEntity hookFunction = ModuleEntityLib.pack(module, mh.entityId);
_addExecHooks(execHooks, hookFunction, mh.isPreHook, mh.isPostHook);
HookConfig hookFunction = HookConfigLib.packExecHook({
_module: module,
_entityId: mh.entityId,
_hasPre: mh.isPreHook,
_hasPost: mh.isPostHook
});
_addExecHooks(execHooks, hookFunction);
}

length = manifest.interfaceIds.length;
Expand Down Expand Up @@ -200,9 +186,14 @@ abstract contract ModuleManagerInternals is IModuleManager {
uint256 length = manifest.executionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestExecutionHook memory mh = manifest.executionHooks[i];
ModuleEntity hookFunction = ModuleEntityLib.pack(module, mh.entityId);
EnumerableSet.Bytes32Set storage execHooks = _storage.selectorData[mh.executionSelector].executionHooks;
_removeExecHooks(execHooks, hookFunction, mh.isPreHook, mh.isPostHook);
HookConfig hookFunction = HookConfigLib.packExecHook({
_module: module,
_entityId: mh.entityId,
_hasPre: mh.isPreHook,
_hasPost: mh.isPostHook
});
_removeExecHooks(execHooks, hookFunction);
}

length = manifest.executionFunctions.length;
Expand Down Expand Up @@ -241,7 +232,7 @@ abstract contract ModuleManagerInternals is IModuleManager {

function _installValidation(
ValidationConfig validationConfig,
bytes4[] memory selectors,
bytes4[] calldata selectors,
bytes calldata installData,
bytes[] calldata hooks
) internal {
Expand All @@ -263,13 +254,8 @@ abstract contract ModuleManagerInternals is IModuleManager {
_onInstall(hook.module(), hookData);
} else {
// Hook is an execution hook
(ModuleEntity hookFunction, bool hasPre, bool hasPost) = hook.unpackExecHook();

ExecutionHook memory executionHook =
ExecutionHook({hookFunction: hookFunction, isPreHook: hasPre, isPostHook: hasPost});

if (!_validationData.permissionHooks.add(toSetValue(executionHook))) {
revert PermissionAlreadySet(validationConfig.moduleEntity(), executionHook);
if (!_validationData.permissionHooks.add(toSetValue(hook))) {
revert PermissionAlreadySet(validationConfig.moduleEntity(), hook);
}

_onInstall(hook.module(), hookData);
Expand Down
22 changes: 11 additions & 11 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeab
import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {HookConfig, HookConfigLib} from "../helpers/HookConfigLib.sol";
import {ModuleEntityLib} from "../helpers/ModuleEntityLib.sol";

import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol";
Expand Down Expand Up @@ -46,6 +47,7 @@ contract UpgradeableModularAccount is
using EnumerableSet for EnumerableSet.Bytes32Set;
using ModuleEntityLib for ModuleEntity;
using ValidationConfigLib for ValidationConfig;
using HookConfigLib for HookConfig;
using SparseCalldataSegmentLib for bytes;

struct PostExecToRun {
Expand Down Expand Up @@ -249,7 +251,7 @@ contract UpgradeableModularAccount is
/// @dev This function is only callable once, and only by the EntryPoint.
function initializeWithValidation(
ValidationConfig validationConfig,
bytes4[] memory selectors,
bytes4[] calldata selectors,
bytes calldata installData,
bytes[] calldata hooks
) external initializer {
Expand All @@ -261,7 +263,7 @@ contract UpgradeableModularAccount is
/// @notice May be validated by a global validation.
function installValidation(
ValidationConfig validationConfig,
bytes4[] memory selectors,
bytes4[] calldata selectors,
bytes calldata installData,
bytes[] calldata hooks
) external wrapNativeFunction {
Expand Down Expand Up @@ -498,26 +500,24 @@ contract UpgradeableModularAccount is
// Copy all post hooks to the array. This happens before any pre hooks are run, so we can
// be sure that the set of hooks to run will not be affected by state changes mid-execution.
for (uint256 i = 0; i < hooksLength; ++i) {
bytes32 key = executionHooks.at(i);
(ModuleEntity hookFunction,, bool isPostHook) = toExecutionHook(key);
if (isPostHook) {
postHooksToRun[i].postExecHook = hookFunction;
HookConfig hookFunction = toExecutionHook(executionHooks.at(i));
if (hookFunction.hasPostHook()) {
postHooksToRun[i].postExecHook = hookFunction.moduleEntity();
}
}

// Run the pre hooks and copy their return data to the post hooks array, if an associated post-exec hook
// exists.
for (uint256 i = 0; i < hooksLength; ++i) {
bytes32 key = executionHooks.at(i);
(ModuleEntity hookFunction, bool isPreHook, bool isPostHook) = toExecutionHook(key);
HookConfig hookFunction = toExecutionHook(executionHooks.at(i));

if (isPreHook) {
if (hookFunction.hasPreHook()) {
bytes memory preExecHookReturnData;

preExecHookReturnData = _runPreExecHook(hookFunction, data);
preExecHookReturnData = _runPreExecHook(hookFunction.moduleEntity(), data);

// If there is an associated post-exec hook, save the return data.
if (isPostHook) {
if (hookFunction.hasPostHook()) {
postHooksToRun[i].preExecHookReturnData = preExecHookReturnData;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/interfaces/IModuleManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ interface IModuleManager {
/// install data, if any.
function installValidation(
ValidationConfig validationConfig,
bytes4[] memory selectors,
bytes4[] calldata selectors,
bytes calldata installData,
bytes[] calldata hooks
) external;
Expand Down
2 changes: 1 addition & 1 deletion test/module/ERC20TokenLimitModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ contract ERC20TokenLimitModuleTest is AccountTestBase {
bytes[] memory hooks = new bytes[](1);
hooks[0] = abi.encodePacked(
HookConfigLib.packExecHook({_module: address(module), _entityId: 0, _hasPre: true, _hasPost: false}),
abi.encode(uint8(0), limit) // TODO: should this still be uint8?
abi.encode(uint32(0), limit)
);

vm.prank(address(acct));
Expand Down
2 changes: 1 addition & 1 deletion test/module/NativeTokenLimitModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ contract NativeTokenLimitModuleTest is AccountTestBase {
UpgradeableModularAccount.PreExecHookReverted.selector,
abi.encode(
address(module),
uint8(0),
uint32(0),
abi.encodePacked(NativeTokenLimitModule.ExceededNativeTokenLimit.selector)
)
)
Expand Down

0 comments on commit 167bfcb

Please sign in to comment.