From 547b99a8ce6d30b788a3d541b2f93acb3c0f72f1 Mon Sep 17 00:00:00 2001 From: howydev <132113803+howydev@users.noreply.github.com> Date: Wed, 12 Jun 2024 16:04:58 -0400 Subject: [PATCH] feat: add erc20 token spend limit plugin --- src/plugins/ERC20TokenLimitPlugin.sol | 170 ++++++++++++++++++++++ test/mocks/MockERC20.sol | 12 ++ test/plugin/ERC20TokenLimitPlugin.t.sol | 185 ++++++++++++++++++++++++ 3 files changed, 367 insertions(+) create mode 100644 src/plugins/ERC20TokenLimitPlugin.sol create mode 100644 test/mocks/MockERC20.sol create mode 100644 test/plugin/ERC20TokenLimitPlugin.t.sol diff --git a/src/plugins/ERC20TokenLimitPlugin.sol b/src/plugins/ERC20TokenLimitPlugin.sol new file mode 100644 index 00000000..d8e7ded0 --- /dev/null +++ b/src/plugins/ERC20TokenLimitPlugin.sol @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {UserOperationLib} from "@eth-infinitism/account-abstraction/core/UserOperationLib.sol"; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import {IAccountExecute} from "@eth-infinitism/account-abstraction/interfaces/IAccountExecute.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; + +import {PluginManifest, PluginMetadata} from "../interfaces/IPlugin.sol"; +import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol"; +import {IPlugin} from "../interfaces/IPlugin.sol"; +import {IExecutionHook} from "../interfaces/IExecutionHook.sol"; +import {BasePlugin, IERC165} from "./BasePlugin.sol"; + +/// @title ERC20 Token Limit Plugin +/// @author ERC-6900 Authors +/// @notice This plugin supports an ERC20 token spend limit. This should be combined with a contract whitelist +/// plugin to make sure that token transfers not tracked by the plugin don't happen. +/// Note: this plugin is opinionated on what selectors can be called for token contracts to guard against weird +/// edge cases like DAI. You wouldn't be able to use uni v2 pairs directly as the pair contract is also the LP +/// token contract +contract ERC20TokenLimitPlugin is BasePlugin, IExecutionHook { + using UserOperationLib for PackedUserOperation; + using EnumerableSet for EnumerableSet.AddressSet; + + string public constant NAME = "ERC20 Token Limit Plugin"; + string public constant VERSION = "1.0.0"; + string public constant AUTHOR = "ERC-6900 Authors"; + + struct ERC20SpendLimit { + address token; + uint256[] limits; + } + + mapping(address account => mapping(address => uint256[] limits)) public limits; + mapping(address account => EnumerableSet.AddressSet) internal _tokenList; + + error ExceededNativeTokenLimit(); + error ExceededNumberOfEntities(); + error SelectorNotAllowed(); + + function getTokensForAccount(address account) external view returns (address[] memory tokens) { + tokens = new address[](_tokenList[account].length()); + for (uint256 i = 0; i < _tokenList[account].length(); i++) { + tokens[i] = _tokenList[account].at(i); + } + return tokens; + } + + function updateLimits(uint8 functionId, address token, uint256 newLimit) external { + _tokenList[msg.sender].add(token); + limits[msg.sender][token][functionId] = newLimit; + } + + function _checkAndDecrementLimit(uint8 functionId, bytes4 selector, uint256 spend, address token) internal { + if (selector == IERC20.transfer.selector || selector == IERC20.approve.selector) { + uint256 limit = limits[msg.sender][token][functionId]; + if (spend > limit) { + revert ExceededNativeTokenLimit(); + } + limits[msg.sender][token][functionId] = limit - spend; + } else { + revert SelectorNotAllowed(); + } + } + + /// @inheritdoc IExecutionHook + function preExecutionHook(uint8 functionId, bytes calldata data) external override returns (bytes memory) { + bytes calldata topLevelCallData; + bytes4 topLevelSelector; + + // TODO: plugins should never have to do these gymnastics + topLevelSelector = bytes4(data[52:56]); + if (topLevelSelector == IAccountExecute.executeUserOp.selector) { + topLevelCallData = data[56:]; + topLevelSelector = bytes4(topLevelCallData); + } else { + topLevelCallData = data[52:]; + } + + if (topLevelSelector == IStandardExecutor.execute.selector) { + address token = address(uint160(uint256(bytes32(topLevelCallData[4:36])))); + if (_tokenList[msg.sender].contains(token)) { + bytes calldata executeCalldata; + uint256 offset = uint256(bytes32(topLevelCallData[68:100])); + + assembly { + let relativeOffset := add(add(topLevelCallData.offset, offset), 4) + executeCalldata.offset := add(relativeOffset, 32) + executeCalldata.length := calldataload(relativeOffset) + } + + _checkAndDecrementLimit( + functionId, bytes4(executeCalldata[:4]), uint256(bytes32(executeCalldata[36:68])), token + ); + } + } else if (topLevelSelector == IStandardExecutor.executeBatch.selector) { + Call[] memory calls = abi.decode(topLevelCallData[4:], (Call[])); + for (uint256 i = 0; i < calls.length; i++) { + if (_tokenList[msg.sender].contains(calls[i].target)) { + bytes memory tokenContractCallData = calls[i].data; + bytes4 selector; + uint256 spend; + assembly { + selector := mload(add(tokenContractCallData, 32)) // 0:32 is arr len, 32:36 is selector + spend := mload(add(tokenContractCallData, 68)) // 36:68 is recipient, 68:100 is spend + } + _checkAndDecrementLimit(functionId, selector, spend, calls[i].target); + } + } + } + + return ""; + } + + /// @inheritdoc IExecutionHook + function postExecutionHook(uint8, bytes calldata) external pure override { + revert NotImplemented(); + } + + /// @inheritdoc IPlugin + function onInstall(bytes calldata data) external override { + ERC20SpendLimit[] memory spendLimits = abi.decode(data, (ERC20SpendLimit[])); + + for (uint256 i = 0; i < spendLimits.length; i++) { + _tokenList[msg.sender].add(spendLimits[i].token); + for (uint256 j = 0; j < spendLimits[i].limits.length; j++) { + limits[msg.sender][spendLimits[i].token].push(spendLimits[i].limits[j]); + } + if (limits[msg.sender][spendLimits[i].token].length > type(uint8).max) { + revert ExceededNumberOfEntities(); + } + } + } + + /// @inheritdoc IPlugin + function onUninstall(bytes calldata data) external override { + (address token, uint8 functionId) = abi.decode(data, (address, uint8)); + delete limits[msg.sender][token][functionId]; + } + + /// @inheritdoc IPlugin + function pluginManifest() external pure override returns (PluginManifest memory) { + // silence warnings + PluginManifest memory manifest; + return manifest; + } + + /// @inheritdoc IPlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = NAME; + metadata.version = VERSION; + metadata.author = AUTHOR; + + metadata.permissionRequest = new string[](1); + metadata.permissionRequest[0] = "erc20-token-limit"; + return metadata; + } + + // ┏━━━━━━━━━━━━━━━┓ + // ┃ EIP-165 ┃ + // ┗━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function supportsInterface(bytes4 interfaceId) public view override(BasePlugin, IERC165) returns (bool) { + return super.supportsInterface(interfaceId); + } +} diff --git a/test/mocks/MockERC20.sol b/test/mocks/MockERC20.sol new file mode 100644 index 00000000..131e0d1a --- /dev/null +++ b/test/mocks/MockERC20.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract MockERC20 is ERC20 { + constructor() ERC20("MockERC20", "MERC") {} + + function mint(address to, uint256 amount) external { + _mint(to, amount); + } +} diff --git a/test/plugin/ERC20TokenLimitPlugin.t.sol b/test/plugin/ERC20TokenLimitPlugin.t.sol new file mode 100644 index 00000000..c8dd926f --- /dev/null +++ b/test/plugin/ERC20TokenLimitPlugin.t.sol @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MockERC20} from "../mocks/MockERC20.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; +import {ERC20TokenLimitPlugin} from "../../src/plugins/ERC20TokenLimitPlugin.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; +import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol"; +import {FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {PluginManifest} from "../../src/interfaces/IPlugin.sol"; + +import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; +import {OptimizedTest} from "../utils/OptimizedTest.sol"; + +contract ERC20TokenLimitPluginTest is OptimizedTest { + EntryPoint public entryPoint = new EntryPoint(); + address public recipient = address(1); + MockERC20 public erc20; + address payable public bundler = payable(address(2)); + PluginManifest internal _m; + MockPlugin public validationPlugin = new MockPlugin(_m); + FunctionReference public validationFunction; + + UpgradeableModularAccount public acct; + ERC20TokenLimitPlugin public plugin = new ERC20TokenLimitPlugin(); + uint256 public spendLimit = 10 ether; + + function setUp() public { + // Set up a validator with hooks from the erc20 spend limit plugin attached + MSCAFactoryFixture factory = new MSCAFactoryFixture(entryPoint, _deploySingleOwnerPlugin()); + + acct = factory.createAccount(address(this), 0); + + erc20 = new MockERC20(); + erc20.mint(address(acct), 10 ether); + + ExecutionHook[] memory permissionHooks = new ExecutionHook[](1); + permissionHooks[0] = ExecutionHook({ + hookFunction: FunctionReferenceLib.pack(address(plugin), 0), + isPreHook: true, + isPostHook: false, + requireUOContext: false + }); + + // arr idx 0 => functionId of 0 has that spend + uint256[] memory limits = new uint256[](1); + limits[0] = spendLimit; + + ERC20TokenLimitPlugin.ERC20SpendLimit[] memory limit = new ERC20TokenLimitPlugin.ERC20SpendLimit[](1); + limit[0] = ERC20TokenLimitPlugin.ERC20SpendLimit({token: address(erc20), limits: limits}); + + bytes[] memory permissionInitDatas = new bytes[](1); + permissionInitDatas[0] = abi.encode(limit); + + vm.prank(address(acct)); + acct.installValidation( + FunctionReferenceLib.pack(address(validationPlugin), 0), + true, + new bytes4[](0), + new bytes(0), + new bytes(0), + abi.encode(permissionHooks, permissionInitDatas) + ); + + validationFunction = FunctionReferenceLib.pack(address(validationPlugin), 0); + } + + function _getPackedUO(bytes memory callData) internal view returns (PackedUserOperation memory uo) { + uo = PackedUserOperation({ + sender: address(acct), + nonce: 0, + initCode: "", + callData: abi.encodePacked(UpgradeableModularAccount.executeUserOp.selector, callData), + accountGasLimits: bytes32(bytes16(uint128(200000))) | bytes32(uint256(200000)), + preVerificationGas: 200000, + gasFees: bytes32(uint256(uint128(0))), + paymasterAndData: "", + signature: abi.encodePacked(FunctionReferenceLib.pack(address(validationPlugin), 0), uint8(1)) + }); + } + + function _getExecuteWithSpend(uint256 value) internal view returns (bytes memory) { + return abi.encodeCall( + UpgradeableModularAccount.execute, + (address(erc20), 0, abi.encodeCall(IERC20.transfer, (recipient, value))) + ); + } + + function test_userOp_executeLimit() public { + vm.startPrank(address(entryPoint)); + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether); + acct.executeUserOp(_getPackedUO(_getExecuteWithSpend(5 ether)), bytes32(0)); + assertEq(plugin.limits(address(acct), address(erc20), 0), 5 ether); + } + + function test_userOp_executeBatchLimit() public { + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 wei))}); + calls[1] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))}); + calls[2] = Call({ + target: address(erc20), + value: 0, + data: abi.encodeCall(IERC20.transfer, (recipient, 5 ether + 100000)) + }); + + vm.startPrank(address(entryPoint)); + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether); + acct.executeUserOp(_getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0)); + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether - 6 ether - 100001); + } + + function test_userOp_executeBatch_approveAndTransferLimit() public { + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))}); + calls[1] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))}); + calls[2] = Call({ + target: address(erc20), + value: 0, + data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100000)) + }); + + vm.startPrank(address(entryPoint)); + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether); + acct.executeUserOp(_getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0)); + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether - 6 ether - 100001); + } + + function test_userOp_executeBatch_approveAndTransferLimit_fail() public { + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))}); + calls[1] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))}); + calls[2] = Call({ + target: address(erc20), + value: 0, + data: abi.encodeCall(IERC20.approve, (recipient, 9 ether + 100000)) + }); + + vm.startPrank(address(entryPoint)); + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether); + PackedUserOperation[] memory uos = new PackedUserOperation[](1); + uos[0] = _getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))); + entryPoint.handleOps(uos, bundler); + // no spend consumed + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether); + } + + function test_runtime_executeLimit() public { + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether); + acct.executeWithAuthorization( + _getExecuteWithSpend(5 ether), abi.encodePacked(validationFunction, uint8(1)) + ); + assertEq(plugin.limits(address(acct), address(erc20), 0), 5 ether); + } + + function test_runtime_executeBatchLimit() public { + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))}); + calls[1] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))}); + calls[2] = Call({ + target: address(erc20), + value: 0, + data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100000)) + }); + + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether); + acct.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)), abi.encodePacked(validationFunction, uint8(1)) + ); + assertEq(plugin.limits(address(acct), address(erc20), 0), 10 ether - 6 ether - 100001); + } +}