diff --git a/src/plugins/ERC20TokenLimitPlugin.sol b/src/plugins/ERC20TokenLimitPlugin.sol index d8e7ded0..c3ddfed8 100644 --- a/src/plugins/ERC20TokenLimitPlugin.sol +++ b/src/plugins/ERC20TokenLimitPlugin.sol @@ -4,13 +4,13 @@ pragma solidity ^0.8.25; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {UserOperationLib} from "@eth-infinitism/account-abstraction/core/UserOperationLib.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; -import {IAccountExecute} from "@eth-infinitism/account-abstraction/interfaces/IAccountExecute.sol"; import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import {PluginManifest, PluginMetadata} from "../interfaces/IPlugin.sol"; import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol"; import {IPlugin} from "../interfaces/IPlugin.sol"; import {IExecutionHook} from "../interfaces/IExecutionHook.sol"; +import {IPermissionHook} from "../interfaces/IPermissionHook.sol"; import {BasePlugin, IERC165} from "./BasePlugin.sol"; /// @title ERC20 Token Limit Plugin @@ -20,7 +20,7 @@ import {BasePlugin, IERC165} from "./BasePlugin.sol"; /// Note: this plugin is opinionated on what selectors can be called for token contracts to guard against weird /// edge cases like DAI. You wouldn't be able to use uni v2 pairs directly as the pair contract is also the LP /// token contract -contract ERC20TokenLimitPlugin is BasePlugin, IExecutionHook { +contract ERC20TokenLimitPlugin is BasePlugin, IExecutionHook, IPermissionHook { using UserOperationLib for PackedUserOperation; using EnumerableSet for EnumerableSet.AddressSet; @@ -53,7 +53,13 @@ contract ERC20TokenLimitPlugin is BasePlugin, IExecutionHook { limits[msg.sender][token][functionId] = newLimit; } - function _checkAndDecrementLimit(uint8 functionId, bytes4 selector, uint256 spend, address token) internal { + function _decrementLimit(uint8 functionId, address token, bytes memory innerCalldata) internal { + bytes4 selector; + uint256 spend; + assembly { + selector := mload(add(innerCalldata, 32)) // 0:32 is arr len, 32:36 is selector + spend := mload(add(innerCalldata, 68)) // 36:68 is recipient, 68:100 is spend + } if (selector == IERC20.transfer.selector || selector == IERC20.approve.selector) { uint256 limit = limits[msg.sender][token][functionId]; if (spend > limit) { @@ -65,48 +71,40 @@ contract ERC20TokenLimitPlugin is BasePlugin, IExecutionHook { } } - /// @inheritdoc IExecutionHook - function preExecutionHook(uint8 functionId, bytes calldata data) external override returns (bytes memory) { - bytes calldata topLevelCallData; - bytes4 topLevelSelector; - - // TODO: plugins should never have to do these gymnastics - topLevelSelector = bytes4(data[52:56]); - if (topLevelSelector == IAccountExecute.executeUserOp.selector) { - topLevelCallData = data[56:]; - topLevelSelector = bytes4(topLevelCallData); - } else { - topLevelCallData = data[52:]; - } + /// @inheritdoc IPermissionHook + function preUserOpExecutionHook(uint8 functionId, PackedUserOperation calldata uo) + external + override + returns (bytes memory) + { + return _checkSelectorAndDecrementLimit(functionId, uo.callData); + } - if (topLevelSelector == IStandardExecutor.execute.selector) { - address token = address(uint160(uint256(bytes32(topLevelCallData[4:36])))); - if (_tokenList[msg.sender].contains(token)) { - bytes calldata executeCalldata; - uint256 offset = uint256(bytes32(topLevelCallData[68:100])); + /// @inheritdoc IExecutionHook + function preExecutionHook(uint8 functionId, address, uint256, bytes calldata data) + external + override + returns (bytes memory) + { + return _checkSelectorAndDecrementLimit(functionId, data); + } - assembly { - let relativeOffset := add(add(topLevelCallData.offset, offset), 4) - executeCalldata.offset := add(relativeOffset, 32) - executeCalldata.length := calldataload(relativeOffset) - } + function _checkSelectorAndDecrementLimit(uint8 functionId, bytes calldata data) + internal + returns (bytes memory) + { + bytes4 selector = bytes4(data[:4]); - _checkAndDecrementLimit( - functionId, bytes4(executeCalldata[:4]), uint256(bytes32(executeCalldata[36:68])), token - ); + if (selector == IStandardExecutor.execute.selector) { + (address token,, bytes memory innerCalldata) = abi.decode(data[4:], (address, uint256, bytes)); + if (_tokenList[msg.sender].contains(token)) { + _decrementLimit(functionId, token, innerCalldata); } - } else if (topLevelSelector == IStandardExecutor.executeBatch.selector) { - Call[] memory calls = abi.decode(topLevelCallData[4:], (Call[])); + } else if (selector == IStandardExecutor.executeBatch.selector) { + Call[] memory calls = abi.decode(data[4:], (Call[])); for (uint256 i = 0; i < calls.length; i++) { if (_tokenList[msg.sender].contains(calls[i].target)) { - bytes memory tokenContractCallData = calls[i].data; - bytes4 selector; - uint256 spend; - assembly { - selector := mload(add(tokenContractCallData, 32)) // 0:32 is arr len, 32:36 is selector - spend := mload(add(tokenContractCallData, 68)) // 36:68 is recipient, 68:100 is spend - } - _checkAndDecrementLimit(functionId, selector, spend, calls[i].target); + _decrementLimit(functionId, calls[i].target, calls[i].data); } } } diff --git a/test/plugin/ERC20TokenLimitPlugin.t.sol b/test/plugin/ERC20TokenLimitPlugin.t.sol index c8dd926f..c96d467c 100644 --- a/test/plugin/ERC20TokenLimitPlugin.t.sol +++ b/test/plugin/ERC20TokenLimitPlugin.t.sol @@ -44,8 +44,7 @@ contract ERC20TokenLimitPluginTest is OptimizedTest { permissionHooks[0] = ExecutionHook({ hookFunction: FunctionReferenceLib.pack(address(plugin), 0), isPreHook: true, - isPostHook: false, - requireUOContext: false + isPostHook: false }); // arr idx 0 => functionId of 0 has that spend