Skip to content

Commit

Permalink
chore: add exec batch and test
Browse files Browse the repository at this point in the history
  • Loading branch information
howydev committed Jun 12, 2024
1 parent 0c46196 commit afe6455
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 17 deletions.
19 changes: 16 additions & 3 deletions src/plugins/NativeTokenLimitPlugin.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {UserOperationLib} from "@eth-infinitism/account-abstraction/core/UserOpe
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {PluginManifest, PluginMetadata} from "../interfaces/IPlugin.sol";
import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol";
import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol";
import {IPlugin} from "../interfaces/IPlugin.sol";
import {IExecutionHook} from "../interfaces/IExecutionHook.sol";
import {IValidationHook} from "../interfaces/IValidationHook.sol";
Expand Down Expand Up @@ -45,6 +45,19 @@ contract NativeTokenLimitPlugin is BasePlugin, IExecutionHook, IValidationHook {
if (execSelector == IStandardExecutor.execute.selector) {
// Get value being sent
uint256 value = uint256(bytes32((_copyBytes(40, 32, uo.callData))));
uint256 limit = limits[msg.sender][functionId];
if (value > limit) {
revert ExceededNativeTokenLimit();
}
limits[msg.sender][functionId] = limit - value;
} else if (execSelector == IStandardExecutor.executeBatch.selector) {
// Get value being sent
uint256 value;
Call[] memory calls = abi.decode(_copyBytes(8, uo.callData.length - 8, uo.callData), (Call[]));
for (uint256 i = 0; i < calls.length; i++) {
value += calls[i].value;
}

uint256 limit = limits[msg.sender][functionId];
if (value > limit) {
revert ExceededNativeTokenLimit();
Expand Down Expand Up @@ -108,15 +121,15 @@ contract NativeTokenLimitPlugin is BasePlugin, IExecutionHook, IValidationHook {

function _copyBytes(uint256 offset, uint256 length, bytes memory data)
internal
view
pure
returns (bytes memory out)
{
assembly {
out := mload(0x40)
mstore(0x40, add(out, add(length, 0x20)))
mstore(out, length)
for { let i := 0 } lt(i, length) { i := add(i, 32) } {
let word := mload(add(add(add(data, offset), 0x20), mul(i, 0x20)))
let word := mload(add(add(add(data, offset), 0x20), i))
mstore(add(add(out, 0x20), i), word)
}
}
Expand Down
59 changes: 45 additions & 14 deletions test/plugin/NativeTokenLimitPlugin.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {NativeTokenLimitPlugin} from "../../src/plugins/NativeTokenLimitPlugin.s
import {MockUserOpValidationPlugin} from "../mocks/plugins/ValidationPluginMocks.sol";
import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol";
import {FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol";
import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol";

import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol";
import {OptimizedTest} from "../utils/OptimizedTest.sol";
Expand All @@ -27,7 +28,7 @@ contract NativeTokenLimitPluginTest is OptimizedTest {

function setUp() public {
// Set up a validator with hooks from the gas spend limit plugin attached

MSCAFactoryFixture factory = new MSCAFactoryFixture(entryPoint, _deploySingleOwnerPlugin());

acct = factory.createAccount(address(this), 0);
Expand Down Expand Up @@ -68,7 +69,11 @@ contract NativeTokenLimitPluginTest is OptimizedTest {
validationFunction = FunctionReferenceLib.pack(address(validationPlugin), 0);
}

function _getPackedUO(uint256 gas1, uint256 gas2, uint256 gas3, uint256 gasPrice, uint256 sendAmount)
function _getExecuteWithValue(uint256 value) internal view returns (bytes memory) {
return abi.encodeCall(UpgradeableModularAccount.execute, (recipient, value, ""));
}

function _getPackedUO(uint256 gas1, uint256 gas2, uint256 gas3, uint256 gasPrice, bytes memory callData)
internal
view
returns (PackedUserOperation memory uo)
Expand All @@ -77,10 +82,7 @@ contract NativeTokenLimitPluginTest is OptimizedTest {
sender: address(acct),
nonce: 0,
initCode: "",
callData: abi.encodePacked(
UpgradeableModularAccount.executeUserOp.selector,
abi.encodeCall(UpgradeableModularAccount.execute, (recipient, sendAmount, ""))
),
callData: abi.encodePacked(UpgradeableModularAccount.executeUserOp.selector, callData),
accountGasLimits: bytes32(bytes16(uint128(gas1))) | bytes32(uint256(gas2)),
preVerificationGas: gas3,
gasFees: bytes32(uint256(uint128(gasPrice))),
Expand All @@ -94,23 +96,25 @@ contract NativeTokenLimitPluginTest is OptimizedTest {

// uses 10e - 200000 of gas
assertEq(plugin.limits(address(acct), 0), 10 ether);
uint256 result = acct.validateUserOp(_getPackedUO(100000, 100000, 10 ether - 400000, 1, 0), bytes32(0), 0);
uint256 result = acct.validateUserOp(
_getPackedUO(100000, 100000, 10 ether - 400000, 1, _getExecuteWithValue(0)), bytes32(0), 0
);
assertEq(plugin.limits(address(acct), 0), 200000);

uint256 expected = uint256(type(uint48).max) << 160;
assertEq(result, expected);

// uses 200k + 1 wei of gas
vm.expectRevert(NativeTokenLimitPlugin.ExceededNativeTokenLimit.selector);
result = acct.validateUserOp(_getPackedUO(100000, 100000, 1, 1, 0), bytes32(0), 0);
result = acct.validateUserOp(_getPackedUO(100000, 100000, 1, 1, _getExecuteWithValue(0)), bytes32(0), 0);
}

function test_executeGasLimit() public {
vm.startPrank(address(entryPoint));

// uses 5e of native tokens
assertEq(plugin.limits(address(acct), 0), 10 ether);
acct.executeUserOp(_getPackedUO(0, 0, 0, 0, 5 ether), bytes32(0));
acct.executeUserOp(_getPackedUO(0, 0, 0, 0, _getExecuteWithValue(5 ether)), bytes32(0));
assertEq(plugin.limits(address(acct), 0), 5 ether);

// uses 5e + 1wei of native tokens
Expand All @@ -124,27 +128,54 @@ contract NativeTokenLimitPluginTest is OptimizedTest {
)
)
);
acct.executeUserOp(_getPackedUO(0, 0, 0, 0, 5 ether + 1), bytes32(0));
acct.executeUserOp(_getPackedUO(0, 0, 0, 0, _getExecuteWithValue(5 ether + 1)), bytes32(0));
}

function test_combinedGasLimit_success() public {
function test_executeBatchGasLimit() public {
Call[] memory calls = new Call[](3);
calls[0] = Call({target: recipient, value: 1, data: ""});
calls[1] = Call({target: recipient, value: 1 ether, data: ""});
calls[2] = Call({target: recipient, value: 5 ether + 100000, data: ""});

vm.startPrank(address(entryPoint));
assertEq(plugin.limits(address(acct), 0), 10 ether);
acct.executeUserOp(
_getPackedUO(0, 0, 0, 0, abi.encodeCall(IStandardExecutor.executeBatch, (calls))), bytes32(0)
);
assertEq(plugin.limits(address(acct), 0), 10 ether - 6 ether - 100001);
assertEq(recipient.balance, 6 ether + 100001);
}

function test_combinedExecGasLimit_success() public {
assertEq(plugin.limits(address(acct), 0), 10 ether);
PackedUserOperation[] memory uos = new PackedUserOperation[](1);
uos[0] = _getPackedUO(100000, 100000, 100000, 1, 5 ether);
uos[0] = _getPackedUO(100000, 100000, 100000, 1, _getExecuteWithValue(5 ether));
entryPoint.handleOps(uos, bundler);

assertEq(plugin.limits(address(acct), 0), 5 ether - 300000);
assertEq(recipient.balance, 5 ether);
}

function test_combinedGasLimit_failExec() public {
function test_combinedExecBatchGasLimit_success() public {
Call[] memory calls = new Call[](3);
calls[0] = Call({target: recipient, value: 1, data: ""});
calls[1] = Call({target: recipient, value: 1 ether, data: ""});
calls[2] = Call({target: recipient, value: 5 ether + 100000, data: ""});

vm.startPrank(address(entryPoint));
assertEq(plugin.limits(address(acct), 0), 10 ether);
PackedUserOperation[] memory uos = new PackedUserOperation[](1);
uos[0] = _getPackedUO(100000, 100000, 100000, 1, abi.encodeCall(IStandardExecutor.executeBatch, (calls)));
entryPoint.handleOps(uos, bundler);

assertEq(plugin.limits(address(acct), 0), 10 ether - 6 ether - 400001);
assertEq(recipient.balance, 6 ether + 100001);
}

function test_combinedExecGasLimit_failExec() public {
assertEq(plugin.limits(address(acct), 0), 10 ether);
PackedUserOperation[] memory uos = new PackedUserOperation[](1);
uos[0] = _getPackedUO(100000, 100000, 100000, 1, 10 ether);
uos[0] = _getPackedUO(100000, 100000, 100000, 1, _getExecuteWithValue(10 ether));
entryPoint.handleOps(uos, bundler);

assertEq(plugin.limits(address(acct), 0), 10 ether - 300000);
Expand Down

0 comments on commit afe6455

Please sign in to comment.