Skip to content

Commit

Permalink
[5/n permissions] feat: add erc20 token spend limit plugin (#80)
Browse files Browse the repository at this point in the history
Co-authored-by: fangting-alchemy <[email protected]>
  • Loading branch information
howydev and fangting-alchemy authored Jul 16, 2024
1 parent e3b9d2a commit 3be43b7
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "lib/forge-std"]
path = lib/forge-std
url = https://github.com/foundry-rs/forge-std
[submodule "lib/modular-account-libs"]
path = lib/modular-account-libs
url = https://github.com/erc6900/modular-account-libs
1 change: 1 addition & 0 deletions lib/modular-account-libs
Submodule modular-account-libs added at 5d9d0e
1 change: 1 addition & 0 deletions remappings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ ds-test/=lib/forge-std/lib/ds-test/src/
forge-std/=lib/forge-std/src/
@eth-infinitism/account-abstraction/=lib/account-abstraction/contracts/
@openzeppelin/=lib/openzeppelin-contracts/
@modular-account-libs/=lib/modular-account-libs/src/
154 changes: 154 additions & 0 deletions src/plugins/ERC20TokenLimitPlugin.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {UserOperationLib} from "@eth-infinitism/account-abstraction/core/UserOperationLib.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import {
SetValue,
AssociatedLinkedListSet,
AssociatedLinkedListSetLib
} from "@modular-account-libs/libraries/AssociatedLinkedListSetLib.sol";

import {PluginManifest, PluginMetadata} from "../interfaces/IPlugin.sol";
import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol";
import {IPlugin} from "../interfaces/IPlugin.sol";
import {IExecutionHook} from "../interfaces/IExecutionHook.sol";
import {BasePlugin, IERC165} from "./BasePlugin.sol";

/// @title ERC20 Token Limit Plugin
/// @author ERC-6900 Authors
/// @notice This plugin supports an ERC20 token spend limit. This should be combined with a contract whitelist
/// plugin to make sure that token transfers not tracked by the plugin don't happen.
/// Note: this plugin is opinionated on what selectors can be called for token contracts to guard against weird
/// edge cases like DAI. You wouldn't be able to use uni v2 pairs directly as the pair contract is also the LP
/// token contract
contract ERC20TokenLimitPlugin is BasePlugin, IExecutionHook {
using UserOperationLib for PackedUserOperation;
using EnumerableSet for EnumerableSet.AddressSet;
using AssociatedLinkedListSetLib for AssociatedLinkedListSet;

struct ERC20SpendLimit {
address token;
uint256[] limits;
}

string internal constant _NAME = "ERC20 Token Limit Plugin";
string internal constant _VERSION = "1.0.0";
string internal constant _AUTHOR = "ERC-6900 Authors";

mapping(uint8 functionId => mapping(address token => mapping(address account => uint256 limit))) public limits;
AssociatedLinkedListSet internal _tokenList;

error ExceededTokenLimit();
error ExceededNumberOfEntities();
error SelectorNotAllowed();

function updateLimits(uint8 functionId, address token, uint256 newLimit) external {
_tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(token))));
limits[functionId][token][msg.sender] = newLimit;
}

/// @inheritdoc IExecutionHook
function preExecutionHook(uint8 functionId, address, uint256, bytes calldata data)
external
override
returns (bytes memory)
{
(bytes4 selector, bytes memory callData) = _getSelectorAndCalldata(data);

if (selector == IStandardExecutor.execute.selector) {
(address token,, bytes memory innerCalldata) = abi.decode(callData, (address, uint256, bytes));
if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(token))))) {
_decrementLimit(functionId, token, innerCalldata);
}
} else if (selector == IStandardExecutor.executeBatch.selector) {
Call[] memory calls = abi.decode(callData, (Call[]));
for (uint256 i = 0; i < calls.length; i++) {
if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(calls[i].target))))) {
_decrementLimit(functionId, calls[i].target, calls[i].data);
}
}
}

return "";
}

/// @inheritdoc IPlugin
function onInstall(bytes calldata data) external override {
(uint8 startFunctionId, ERC20SpendLimit[] memory spendLimits) =
abi.decode(data, (uint8, ERC20SpendLimit[]));

if (startFunctionId + spendLimits.length > type(uint8).max) {
revert ExceededNumberOfEntities();
}

for (uint8 i = 0; i < spendLimits.length; i++) {
_tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(spendLimits[i].token))));
for (uint256 j = 0; j < spendLimits[i].limits.length; j++) {
limits[i + startFunctionId][spendLimits[i].token][msg.sender] = spendLimits[i].limits[j];
}
}
}

/// @inheritdoc IPlugin
function onUninstall(bytes calldata data) external override {
(address token, uint8 functionId) = abi.decode(data, (address, uint8));
delete limits[functionId][token][msg.sender];
}

function getTokensForAccount(address account) external view returns (address[] memory tokens) {
SetValue[] memory set = _tokenList.getAll(account);
tokens = new address[](set.length);
for (uint256 i = 0; i < tokens.length; i++) {
tokens[i] = address(bytes20(bytes32(SetValue.unwrap(set[i]))));
}
return tokens;
}

/// @inheritdoc IExecutionHook
function postExecutionHook(uint8, bytes calldata) external pure override {
revert NotImplemented();
}

/// @inheritdoc IPlugin
// solhint-disable-next-line no-empty-blocks
function pluginManifest() external pure override returns (PluginManifest memory) {}

/// @inheritdoc IPlugin
function pluginMetadata() external pure virtual override returns (PluginMetadata memory) {
PluginMetadata memory metadata;
metadata.name = _NAME;
metadata.version = _VERSION;
metadata.author = _AUTHOR;

metadata.permissionRequest = new string[](1);
metadata.permissionRequest[0] = "erc20-token-limit";
return metadata;
}

/// @inheritdoc BasePlugin
function supportsInterface(bytes4 interfaceId) public view override(BasePlugin, IERC165) returns (bool) {
return super.supportsInterface(interfaceId);
}

function _decrementLimit(uint8 functionId, address token, bytes memory innerCalldata) internal {
bytes4 selector;
uint256 spend;
assembly {
selector := mload(add(innerCalldata, 32)) // 0:32 is arr len, 32:36 is selector
spend := mload(add(innerCalldata, 68)) // 36:68 is recipient, 68:100 is spend
}
if (selector == IERC20.transfer.selector || selector == IERC20.approve.selector) {
uint256 limit = limits[functionId][token][msg.sender];
if (spend > limit) {
revert ExceededTokenLimit();
}
// solhint-disable-next-line reentrancy
limits[functionId][token][msg.sender] = limit - spend;
} else {
revert SelectorNotAllowed();
}
}
}
12 changes: 12 additions & 0 deletions test/mocks/MockERC20.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol";

contract MockERC20 is ERC20 {
constructor() ERC20("MockERC20", "MERC") {}

function mint(address to, uint256 amount) external {
_mint(to, amount);
}
}
184 changes: 184 additions & 0 deletions test/plugin/ERC20TokenLimitPlugin.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {MockERC20} from "../mocks/MockERC20.sol";
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";

import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol";
import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol";
import {ERC20TokenLimitPlugin} from "../../src/plugins/ERC20TokenLimitPlugin.sol";
import {MockPlugin} from "../mocks/MockPlugin.sol";
import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol";
import {FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol";
import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol";
import {PluginManifest} from "../../src/interfaces/IPlugin.sol";
import {ValidationConfigLib} from "../../src/helpers/ValidationConfigLib.sol";

import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol";
import {AccountTestBase} from "../utils/AccountTestBase.sol";

contract ERC20TokenLimitPluginTest is AccountTestBase {
address public recipient = address(1);
MockERC20 public erc20;
address payable public bundler = payable(address(2));
PluginManifest internal _m;
MockPlugin public validationPlugin = new MockPlugin(_m);
FunctionReference public validationFunction;

UpgradeableModularAccount public acct;
ERC20TokenLimitPlugin public plugin = new ERC20TokenLimitPlugin();
uint256 public spendLimit = 10 ether;

function setUp() public {
// Set up a validator with hooks from the erc20 spend limit plugin attached
MSCAFactoryFixture factory = new MSCAFactoryFixture(entryPoint, _deploySingleOwnerPlugin());

acct = factory.createAccount(address(this), 0);

erc20 = new MockERC20();
erc20.mint(address(acct), 10 ether);

ExecutionHook[] memory permissionHooks = new ExecutionHook[](1);
permissionHooks[0] = ExecutionHook({
hookFunction: FunctionReferenceLib.pack(address(plugin), 0),
isPreHook: true,
isPostHook: false
});

// arr idx 0 => functionId of 0 has that spend
uint256[] memory limits = new uint256[](1);
limits[0] = spendLimit;

ERC20TokenLimitPlugin.ERC20SpendLimit[] memory limit = new ERC20TokenLimitPlugin.ERC20SpendLimit[](1);
limit[0] = ERC20TokenLimitPlugin.ERC20SpendLimit({token: address(erc20), limits: limits});

bytes[] memory permissionInitDatas = new bytes[](1);
permissionInitDatas[0] = abi.encode(uint8(0), limit);

vm.prank(address(acct));
acct.installValidation(
ValidationConfigLib.pack(address(validationPlugin), 0, true, true),
new bytes4[](0),
new bytes(0),
new bytes(0),
abi.encode(permissionHooks, permissionInitDatas)
);

validationFunction = FunctionReferenceLib.pack(address(validationPlugin), 0);
}

function _getPackedUO(bytes memory callData) internal view returns (PackedUserOperation memory uo) {
uo = PackedUserOperation({
sender: address(acct),
nonce: 0,
initCode: "",
callData: abi.encodePacked(UpgradeableModularAccount.executeUserOp.selector, callData),
accountGasLimits: bytes32(bytes16(uint128(200000))) | bytes32(uint256(200000)),
preVerificationGas: 200000,
gasFees: bytes32(uint256(uint128(0))),
paymasterAndData: "",
signature: _encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "")
});
}

function _getExecuteWithSpend(uint256 value) internal view returns (bytes memory) {
return abi.encodeCall(
UpgradeableModularAccount.execute,
(address(erc20), 0, abi.encodeCall(IERC20.transfer, (recipient, value)))
);
}

function test_userOp_executeLimit() public {
vm.startPrank(address(entryPoint));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeUserOp(_getPackedUO(_getExecuteWithSpend(5 ether)), bytes32(0));
assertEq(plugin.limits(0, address(erc20), address(acct)), 5 ether);
}

function test_userOp_executeBatchLimit() public {
Call[] memory calls = new Call[](3);
calls[0] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 wei))});
calls[1] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))});
calls[2] = Call({
target: address(erc20),
value: 0,
data: abi.encodeCall(IERC20.transfer, (recipient, 5 ether + 100000))
});

vm.startPrank(address(entryPoint));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeUserOp(_getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001);
}

function test_userOp_executeBatch_approveAndTransferLimit() public {
Call[] memory calls = new Call[](3);
calls[0] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))});
calls[1] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))});
calls[2] = Call({
target: address(erc20),
value: 0,
data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100000))
});

vm.startPrank(address(entryPoint));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeUserOp(_getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001);
}

function test_userOp_executeBatch_approveAndTransferLimit_fail() public {
Call[] memory calls = new Call[](3);
calls[0] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))});
calls[1] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))});
calls[2] = Call({
target: address(erc20),
value: 0,
data: abi.encodeCall(IERC20.approve, (recipient, 9 ether + 100000))
});

vm.startPrank(address(entryPoint));
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
PackedUserOperation[] memory uos = new PackedUserOperation[](1);
uos[0] = _getPackedUO(abi.encodeCall(IStandardExecutor.executeBatch, (calls)));
entryPoint.handleOps(uos, bundler);
// no spend consumed
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
}

function test_runtime_executeLimit() public {
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeWithAuthorization(
_getExecuteWithSpend(5 ether),
_encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "")
);
assertEq(plugin.limits(0, address(erc20), address(acct)), 5 ether);
}

function test_runtime_executeBatchLimit() public {
Call[] memory calls = new Call[](3);
calls[0] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.approve, (recipient, 1 wei))});
calls[1] =
Call({target: address(erc20), value: 0, data: abi.encodeCall(IERC20.transfer, (recipient, 1 ether))});
calls[2] = Call({
target: address(erc20),
value: 0,
data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100000))
});

assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether);
acct.executeWithAuthorization(
abi.encodeCall(IStandardExecutor.executeBatch, (calls)),
_encodeSignature(FunctionReferenceLib.pack(address(validationPlugin), 0), 1, "")
);
assertEq(plugin.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100001);
}
}

0 comments on commit 3be43b7

Please sign in to comment.