diff --git a/.gitmodules b/.gitmodules index 813d955e..05bd137f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "lib/forge-std"] path = lib/forge-std url = https://github.com/foundry-rs/forge-std +[submodule "lib/modular-account-libs"] + path = lib/modular-account-libs + url = https://github.com/erc6900/modular-account-libs diff --git a/lib/modular-account-libs b/lib/modular-account-libs new file mode 160000 index 00000000..5d9d0e40 --- /dev/null +++ b/lib/modular-account-libs @@ -0,0 +1 @@ +Subproject commit 5d9d0e403332251045eee2954c2a8b7ea0bae953 diff --git a/remappings.txt b/remappings.txt index 3d0ee0df..bc2ce0be 100644 --- a/remappings.txt +++ b/remappings.txt @@ -2,3 +2,4 @@ ds-test/=lib/forge-std/lib/ds-test/src/ forge-std/=lib/forge-std/src/ @eth-infinitism/account-abstraction/=lib/account-abstraction/contracts/ @openzeppelin/=lib/openzeppelin-contracts/ +@modular-account-libs/=lib/modular-account-libs/src/ \ No newline at end of file diff --git a/src/plugins/ERC20TokenLimitPlugin.sol b/src/plugins/ERC20TokenLimitPlugin.sol new file mode 100644 index 00000000..1df5bcfd --- /dev/null +++ b/src/plugins/ERC20TokenLimitPlugin.sol @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {UserOperationLib} from "@eth-infinitism/account-abstraction/core/UserOperationLib.sol"; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { + SetValue, + AssociatedLinkedListSet, + AssociatedLinkedListSetLib +} from "@modular-account-libs/libraries/AssociatedLinkedListSetLib.sol"; + +import {PluginManifest, PluginMetadata} from "../interfaces/IPlugin.sol"; +import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol"; +import {IPlugin} from "../interfaces/IPlugin.sol"; +import {IExecutionHook} from "../interfaces/IExecutionHook.sol"; +import {BasePlugin, IERC165} from "./BasePlugin.sol"; + +/// @title ERC20 Token Limit Plugin +/// @author ERC-6900 Authors +/// @notice This plugin supports an ERC20 token spend limit. This should be combined with a contract whitelist +/// plugin to make sure that token transfers not tracked by the plugin don't happen. +/// Note: this plugin is opinionated on what selectors can be called for token contracts to guard against weird +/// edge cases like DAI. You wouldn't be able to use uni v2 pairs directly as the pair contract is also the LP +/// token contract +contract ERC20TokenLimitPlugin is BasePlugin, IExecutionHook { + using UserOperationLib for PackedUserOperation; + using EnumerableSet for EnumerableSet.AddressSet; + using AssociatedLinkedListSetLib for AssociatedLinkedListSet; + + struct ERC20SpendLimit { + address token; + uint256[] limits; + } + + string internal constant _NAME = "ERC20 Token Limit Plugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "ERC-6900 Authors"; + + mapping(uint8 functionId => mapping(address token => mapping(address account => uint256 limit))) public limits; + AssociatedLinkedListSet internal _tokenList; + + error ExceededTokenLimit(); + error ExceededNumberOfEntities(); + error SelectorNotAllowed(); + + function updateLimits(uint8 functionId, address token, uint256 newLimit) external { + _tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(token)))); + limits[functionId][token][msg.sender] = newLimit; + } + + /// @inheritdoc IExecutionHook + function preExecutionHook(uint8 functionId, address, uint256, bytes calldata data) + external + override + returns (bytes memory) + { + (bytes4 selector, bytes memory callData) = _getSelectorAndCalldata(data); + + if (selector == IStandardExecutor.execute.selector) { + (address token,, bytes memory innerCalldata) = abi.decode(callData, (address, uint256, bytes)); + if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(token))))) { + _decrementLimit(functionId, token, innerCalldata); + } + } else if (selector == IStandardExecutor.executeBatch.selector) { + Call[] memory calls = abi.decode(callData, (Call[])); + for (uint256 i = 0; i < calls.length; i++) { + if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(calls[i].target))))) { + _decrementLimit(functionId, calls[i].target, calls[i].data); + } + } + } + + return ""; + } + + /// @inheritdoc IPlugin + function onInstall(bytes calldata data) external override { + (uint8 startFunctionId, ERC20SpendLimit[] memory spendLimits) = + abi.decode(data, (uint8, ERC20SpendLimit[])); + + if (startFunctionId + spendLimits.length > type(uint8).max) { + revert ExceededNumberOfEntities(); + } + + for (uint8 i = 0; i < spendLimits.length; i++) { + _tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(spendLimits[i].token)))); + for (uint256 j = 0; j < spendLimits[i].limits.length; j++) { + limits[i + startFunctionId][spendLimits[i].token][msg.sender] = spendLimits[i].limits[j]; + } + } + } + + /// @inheritdoc IPlugin + function onUninstall(bytes calldata data) external override { + (address token, uint8 functionId) = abi.decode(data, (address, uint8)); + delete limits[functionId][token][msg.sender]; + } + + function getTokensForAccount(address account) external view returns (address[] memory tokens) { + SetValue[] memory set = _tokenList.getAll(account); + tokens = new address[](set.length); + for (uint256 i = 0; i < tokens.length; i++) { + tokens[i] = address(bytes20(bytes32(SetValue.unwrap(set[i])))); + } + return tokens; + } + + /// @inheritdoc IExecutionHook + function postExecutionHook(uint8, bytes calldata) external pure override { + revert NotImplemented(); + } + + /// @inheritdoc IPlugin + // solhint-disable-next-line no-empty-blocks + function pluginManifest() external pure override returns (PluginManifest memory) {} + + /// @inheritdoc IPlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + + metadata.permissionRequest = new string[](1); + metadata.permissionRequest[0] = "erc20-token-limit"; + return metadata; + } + + /// @inheritdoc BasePlugin + function supportsInterface(bytes4 interfaceId) public view override(BasePlugin, IERC165) returns (bool) { + return super.supportsInterface(interfaceId); + } + + function _decrementLimit(uint8 functionId, address token, bytes memory innerCalldata) internal { + bytes4 selector; + uint256 spend; + assembly { + selector := mload(add(innerCalldata, 32)) // 0:32 is arr len, 32:36 is selector + spend := mload(add(innerCalldata, 68)) // 36:68 is recipient, 68:100 is spend + } + if (selector == IERC20.transfer.selector || selector == IERC20.approve.selector) { + uint256 limit = limits[functionId][token][msg.sender]; + if (spend > limit) { + revert ExceededTokenLimit(); + } + // solhint-disable-next-line reentrancy + limits[functionId][token][msg.sender] = limit - spend; + } else { + revert SelectorNotAllowed(); + } + } +} diff --git a/test/mocks/MockERC20.sol b/test/mocks/MockERC20.sol new file mode 100644 index 00000000..131e0d1a --- /dev/null +++ b/test/mocks/MockERC20.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract MockERC20 is ERC20 { + constructor() ERC20("MockERC20", "MERC") {} + + function mint(address to, uint256 amount) external { + _mint(to, amount); + } +} diff --git a/test/plugin/ERC20TokenLimitPlugin.t.sol b/test/plugin/ERC20TokenLimitPlugin.t.sol new file mode 100644 index 00000000..96a18c20 --- /dev/null +++ b/test/plugin/ERC20TokenLimitPlugin.t.sol @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MockERC20} from "../mocks/MockERC20.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; +import {ERC20TokenLimitPlugin} from "../../src/plugins/ERC20TokenLimitPlugin.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; +import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol"; +import {FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {PluginManifest} from "../../src/interfaces/IPlugin.sol"; +import {ValidationConfigLib} from "../../src/helpers/ValidationConfigLib.sol"; + +import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; +import {AccountTestBase} from "../utils/AccountTestBase.sol"; + +contract ERC20TokenLimitPluginTest is AccountTestBase { + address public recipient = address(1); + MockERC20 public erc20; + address payable public bundler = payable(address(2)); + PluginManifest internal _m; + MockPlugin public validationPlugin = new MockPlugin(_m); + FunctionReference public validationFunction; + + UpgradeableModularAccount public acct; + ERC20TokenLimitPlugin public plugin = new ERC20TokenLimitPlugin(); + uint256 public spendLimit = 10 ether; + + function setUp() public { + // Set up a validator with hooks from the erc20 spend limit plugin attached + MSCAFactoryFixture factory = new MSCAFactoryFixture(entryPoint, _deploySingleOwnerPlugin()); + + acct = factory.createAccount(address(this), 0); + + erc20 = new MockERC20(); + erc20.mint(address(acct), 10 ether); + + ExecutionHook[] memory permissionHooks = new ExecutionHook[](1); + permissionHooks[0] = ExecutionHook({ + hookFunction: FunctionReferenceLib.pack(address(plugin), 0), + isPreHook: true, + isPostHook: false + }); + + // arr idx 0 => functionId of 0 has that spend + uint256[] memory limits = new uint256[](1); + limits[0] = spendLimit; + + ERC20TokenLimitPlugin.ERC20SpendLimit[] memory limit = new ERC20TokenLimitPlugin.ERC20SpendLimit[](1); + limit[0] = ERC20TokenLimitPlugin.ERC20SpendLimit({token: address(erc20), limits: limits}); + + bytes[] memory permissionInitDatas = new bytes[](1); + permissionInitDatas[0] = abi.encode(uint8(0), limit); + + vm.prank(address(acct)); + acct.installValidation( + ValidationConfigLib.pack(address(validationPlugin), 0, true, true), + new bytes4[](0), + new bytes(0), + new bytes(0), + abi.encode(permissionHooks, permissionInitDatas) + ); + + validationFunction = FunctionReferenceLib.pack(address(validationPlugin), 0); + } + + function _getPackedUO(bytes memory callData) internal view returns (PackedUserOperation memory uo) { + uo = PackedUserOperation({ + sender: address(acct), + nonce: 0, + initCode: "", + callData: abi.encodePacked(UpgradeableModularAccount.executeUserOp.selector, callData), + accountGasLimits: bytes32(bytes16(uint128(200000))) | bytes32(uint256(200000)), + preVerificationGas: 200000, + gasFees: bytes32(uint256(uint128(0))), + paymasterAndData: "", + signature: _encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "") + }); + } + + function _getExecuteWithSpend(uint256 value) internal view returns (bytes memory) { + return abi.encodeCall( + UpgradeableModularAccount.execute, + (address(erc20), 0, abi.encodeCall(IERC20.transfer, (recipient, value))) + ); + } + + function test_userOp_executeLimit() public { + vm.startPrank(address(entryPoint)); + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether); + acct.executeUserOp(_getPackedUO(_getExecuteWithSpend(5 ether)), bytes32(0)); + assertEq(plugin.limits(0, address(erc20), address(acct)), 5 ether); + } + + function test_userOp_executeBatchLimit() public { + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 wei))}); + calls[1] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))}); + calls[2] = Call({ + target: address(erc20), + value: 0, + data: abi.encodeCall(IERC20.transfer, (recipient, 5 ether + 100000)) + }); + + vm.startPrank(address(entryPoint)); + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether); + acct.executeUserOp(_getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0)); + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001); + } + + function test_userOp_executeBatch_approveAndTransferLimit() public { + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))}); + calls[1] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))}); + calls[2] = Call({ + target: address(erc20), + value: 0, + data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100000)) + }); + + vm.startPrank(address(entryPoint)); + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether); + acct.executeUserOp(_getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0)); + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001); + } + + function test_userOp_executeBatch_approveAndTransferLimit_fail() public { + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))}); + calls[1] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))}); + calls[2] = Call({ + target: address(erc20), + value: 0, + data: abi.encodeCall(IERC20.approve, (recipient, 9 ether + 100000)) + }); + + vm.startPrank(address(entryPoint)); + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether); + PackedUserOperation[] memory uos = new PackedUserOperation[](1); + uos[0] = _getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))); + entryPoint.handleOps(uos, bundler); + // no spend consumed + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether); + } + + function test_runtime_executeLimit() public { + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether); + acct.executeWithAuthorization( + _getExecuteWithSpend(5 ether), + _encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "") + ); + assertEq(plugin.limits(0, address(erc20), address(acct)), 5 ether); + } + + function test_runtime_executeBatchLimit() public { + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))}); + calls[1] = + Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))}); + calls[2] = Call({ + target: address(erc20), + value: 0, + data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100000)) + }); + + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether); + acct.executeWithAuthorization( + abi.encodeCall(IStandardExecutor.executeBatch, (calls)), + _encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "") + ); + assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001); + } +}