Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[5/n permissions] feat: add erc20 token spend limit plugin #80

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "lib/forge-std"]
path = lib/forge-std
url = https://github.com/foundry-rs/forge-std
[submodule "lib/modular-account-libs"]
path = lib/modular-account-libs
url = https://github.com/erc6900/modular-account-libs
1 change: 1 addition & 0 deletions lib/modular-account-libs
Submodule modular-account-libs added at 5d9d0e
1 change: 1 addition & 0 deletions remappings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ ds-test/=lib/forge-std/lib/ds-test/src/
forge-std/=lib/forge-std/src/
@eth-infinitism/account-abstraction/=lib/account-abstraction/contracts/
@openzeppelin/=lib/openzeppelin-contracts/
@modular-account-libs/=lib/modular-account-libs/src/
154 changes: 154 additions & 0 deletions src/plugins/ERC20TokenLimitPlugin.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {UserOperationLib} from "@eth-infinitism/account-abstraction/core/UserOperationLib.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import {
SetValue,
AssociatedLinkedListSet,
AssociatedLinkedListSetLib
} from "@modular-account-libs/libraries/AssociatedLinkedListSetLib.sol";

import {PluginManifest, PluginMetadata} from "../interfaces/IPlugin.sol";
import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol";
import {IPlugin} from "../interfaces/IPlugin.sol";
import {IExecutionHook} from "../interfaces/IExecutionHook.sol";
import {BasePlugin, IERC165} from "./BasePlugin.sol";

/// @title ERC20 Token Limit Plugin
/// @author ERC-6900 Authors
/// @notice This plugin supports an ERC20 token spend limit. This should be combined with a contract whitelist
/// plugin to make sure that token transfers not tracked by the plugin don't happen.
/// Note: this plugin is opinionated on what selectors can be called for token contracts to guard against weird
/// edge cases like DAI. You wouldn't be able to use uni v2 pairs directly as the pair contract is also the LP
/// token contract
contract ERC20TokenLimitPlugin is BasePlugin, IExecutionHook {
using UserOperationLib for PackedUserOperation;
using EnumerableSet for EnumerableSet.AddressSet;
using AssociatedLinkedListSetLib for AssociatedLinkedListSet;

struct ERC20SpendLimit {
address token;
uint256[] limits;
}

string internal constant _NAME = "ERC20 Token Limit Plugin";
string internal constant _VERSION = "1.0.0";
string internal constant _AUTHOR = "ERC-6900 Authors";

mapping(uint8 functionId => mapping(address token => mapping(address account => uint256 limit))) public limits;
AssociatedLinkedListSet internal _tokenList;

error ExceededTokenLimit();
error ExceededNumberOfEntities();
error SelectorNotAllowed();

function updateLimits(uint8 functionId, address token, uint256 newLimit) external {
_tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(token))));
limits[functionId][token][msg.sender] = newLimit;
}

/// @inheritdoc IExecutionHook
function preExecutionHook(uint8 functionId, address, uint256, bytes calldata data)
external
override
returns (bytes memory)
{
(bytes4 selector, bytes memory callData) = _getSelectorAndCalldata(data);

if (selector == IStandardExecutor.execute.selector) {
(address token,, bytes memory innerCalldata) = abi.decode(callData, (address, uint256, bytes));
if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(token))))) {
_decrementLimit(functionId, token, innerCalldata);
}
} else if (selector == IStandardExecutor.executeBatch.selector) {
Call[] memory calls = abi.decode(callData, (Call[]));
for (uint256 i = 0; i < calls.length; i++) {
if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(calls[i].target))))) {
_decrementLimit(functionId, calls[i].target, calls[i].data);
}
}
}

return "";
}

/// @inheritdoc IPlugin
function onInstall(bytes calldata data) external override {
(uint8 startFunctionId, ERC20SpendLimit[] memory spendLimits) =
abi.decode(data, (uint8, ERC20SpendLimit[]));

if (startFunctionId + spendLimits.length > type(uint8).max) {
revert ExceededNumberOfEntities();
}

for (uint8 i = 0; i < spendLimits.length; i++) {
_tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(spendLimits[i].token))));
for (uint256 j = 0; j < spendLimits[i].limits.length; j++) {
limits[i + startFunctionId][spendLimits[i].token][msg.sender] = spendLimits[i].limits[j];
}
}
}

/// @inheritdoc IPlugin
function onUninstall(bytes calldata data) external override {
(address token, uint8 functionId) = abi.decode(data, (address, uint8));
delete limits[functionId][token][msg.sender];
}

function getTokensForAccount(address account) external view returns (address[] memory tokens) {
SetValue[] memory set = _tokenList.getAll(account);
tokens = new address[](set.length);
for (uint256 i = 0; i < tokens.length; i++) {
tokens[i] = address(bytes20(bytes32(SetValue.unwrap(set[i]))));
}
return tokens;
}

/// @inheritdoc IExecutionHook
function postExecutionHook(uint8, bytes calldata) external pure override {
revert NotImplemented();
}

/// @inheritdoc IPlugin
// solhint-disable-next-line no-empty-blocks
function pluginManifest() external pure override returns (PluginManifest memory) {}

/// @inheritdoc IPlugin
function pluginMetadata() external pure virtual override returns (PluginMetadata memory) {
PluginMetadata memory metadata;
metadata.name = _NAME;
metadata.version = _VERSION;
metadata.author = _AUTHOR;

metadata.permissionRequest = new string[](1);
metadata.permissionRequest[0] = "erc20-token-limit";
return metadata;
}

/// @inheritdoc BasePlugin
function supportsInterface(bytes4 interfaceId) public view override(BasePlugin, IERC165) returns (bool) {
return super.supportsInterface(interfaceId);
}

function _decrementLimit(uint8 functionId, address token, bytes memory innerCalldata) internal {
bytes4 selector;
uint256 spend;
assembly {
selector := mload(add(innerCalldata, 32)) // 0:32 is arr len, 32:36 is selector
spend := mload(add(innerCalldata, 68)) // 36:68 is recipient, 68:100 is spend
}
if (selector == IERC20.transfer.selector || selector == IERC20.approve.selector) {
uint256 limit = limits[functionId][token][msg.sender];
if (spend > limit) {
revert ExceededTokenLimit();
}
// solhint-disable-next-line reentrancy
limits[functionId][token][msg.sender] = limit - spend;
} else {
revert SelectorNotAllowed();
}
}
}
12 changes: 12 additions & 0 deletions test/mocks/MockERC20.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol";

contract MockERC20 is ERC20 {
constructor() ERC20("MockERC20", "MERC") {}

function mint(address to, uint256 amount) external {
_mint(to, amount);
}
}
184 changes: 184 additions & 0 deletions test/plugin/ERC20TokenLimitPlugin.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {MockERC20} from "../mocks/MockERC20.sol";
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";

import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol";
import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol";
import {ERC20TokenLimitPlugin} from "../../src/plugins/ERC20TokenLimitPlugin.sol";
import {MockPlugin} from "../mocks/MockPlugin.sol";
import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol";
import {FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol";
import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol";
import {PluginManifest} from "../../src/interfaces/IPlugin.sol";
import {ValidationConfigLib} from "../../src/helpers/ValidationConfigLib.sol";

import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol";
import {AccountTestBase} from "../utils/AccountTestBase.sol";

contract ERC20TokenLimitPluginTest is AccountTestBase {
address public recipient = address(1);
MockERC20 public erc20;
address payable public bundler = payable(address(2));
PluginManifest internal _m;
MockPlugin public validationPlugin = new MockPlugin(_m);
FunctionReference public validationFunction;

UpgradeableModularAccount public acct;
ERC20TokenLimitPlugin public plugin = new ERC20TokenLimitPlugin();
uint256 public spendLimit = 10 ether;

function setUp() public {
// Set up a validator with hooks from the erc20 spend limit plugin attached
MSCAFactoryFixture factory = new MSCAFactoryFixture(entryPoint, _deploySingleOwnerPlugin());

acct = factory.createAccount(address(this), 0);

erc20 = new MockERC20();
erc20.mint(address(acct), 10 ether);

ExecutionHook[] memory permissionHooks = new ExecutionHook[](1);
permissionHooks[0] = ExecutionHook({
hookFunction: FunctionReferenceLib.pack(address(plugin), 0),
isPreHook: true,
isPostHook: false
});

// arr idx 0 => functionId of 0 has that spend
uint256[] memory limits = new uint256[](1);
limits[0] = spendLimit;

ERC20TokenLimitPlugin.ERC20SpendLimit[] memory limit = new ERC20TokenLimitPlugin.ERC20SpendLimit[](1);
limit[0] = ERC20TokenLimitPlugin.ERC20SpendLimit({token: address(erc20), limits: limits});

bytes[] memory permissionInitDatas = new bytes[](1);
permissionInitDatas[0] = abi.encode(uint8(0), limit);

vm.prank(address(acct));
acct.installValidation(
ValidationConfigLib.pack(address(validationPlugin), 0, true, true),
new bytes4[](0),
new bytes(0),
new bytes(0),
abi.encode(permissionHooks, permissionInitDatas)
);

validationFunction = FunctionReferenceLib.pack(address(validationPlugin), 0);
}

function _getPackedUO(bytes memory callData) internal view returns (PackedUserOperation memory uo) {
uo = PackedUserOperation({
sender: address(acct),
nonce: 0,
initCode: "",
callData: abi.encodePacked(UpgradeableModularAccount.executeUserOp.selector, callData),
accountGasLimits: bytes32(bytes16(uint128(200000))) | bytes32(uint256(200000)),
preVerificationGas: 200000,
gasFees: bytes32(uint256(uint128(0))),
paymasterAndData: "",
signature: _encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "")
});
}

function _getExecuteWithSpend(uint256 value) internal view returns (bytes memory) {
return abi.encodeCall(
UpgradeableModularAccount.execute,
(address(erc20), 0, abi.encodeCall(IERC20.transfer, (recipient, value)))
);
}

function test_userOp_executeLimit() public {
vm.startPrank(address(entryPoint));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeUserOp(_getPackedUO(_getExecuteWithSpend(5 ether)), bytes32(0));
assertEq(plugin.limits(0, address(erc20), address(acct)), 5 ether);
}

function test_userOp_executeBatchLimit() public {
Call[] memory calls = new Call[](3);
calls[0] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 wei))});
calls[1] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))});
calls[2] = Call({
target: address(erc20),
value: 0,
data: abi.encodeCall(IERC20.transfer, (recipient, 5 ether + 100000))
});

vm.startPrank(address(entryPoint));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeUserOp(_getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001);
}

function test_userOp_executeBatch_approveAndTransferLimit() public {
Call[] memory calls = new Call[](3);
calls[0] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))});
calls[1] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))});
calls[2] = Call({
target: address(erc20),
value: 0,
data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100000))
});

vm.startPrank(address(entryPoint));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeUserOp(_getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001);
}

function test_userOp_executeBatch_approveAndTransferLimit_fail() public {
Call[] memory calls = new Call[](3);
calls[0] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))});
calls[1] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))});
calls[2] = Call({
target: address(erc20),
value: 0,
data: abi.encodeCall(IERC20.approve, (recipient, 9 ether + 100000))
});

vm.startPrank(address(entryPoint));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
PackedUserOperation[] memory uos = new PackedUserOperation[](1);
uos[0] = _getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls)));
entryPoint.handleOps(uos, bundler);
// no spend consumed
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
}

function test_runtime_executeLimit() public {
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeWithAuthorization(
_getExecuteWithSpend(5 ether),
_encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "")
);
assertEq(plugin.limits(0, address(erc20), address(acct)), 5 ether);
}

function test_runtime_executeBatchLimit() public {
Call[] memory calls = new Call[](3);
calls[0] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))});
calls[1] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))});
calls[2] = Call({
target: address(erc20),
value: 0,
data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100000))
});

assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeWithAuthorization(
abi.encodeCall(IStandardExecutor.executeBatch, (calls)),
_encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "")
);
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001);
}
}
Loading