From 4c147bfca9328f18d0e2a7d47f277eaf277cb22d Mon Sep 17 00:00:00 2001 From: adam Date: Tue, 11 Jun 2024 14:32:35 -0400 Subject: [PATCH] Add Allowlist sample plugin, refactor test base --- src/account/UpgradeableModularAccount.sol | 13 +- .../permissionhooks/AllowlistPlugin.sol | 142 ++++++++ test/account/AccountLoupe.t.sol | 7 - test/account/DefaultValidationTest.t.sol | 11 - test/account/MultiValidation.t.sol | 3 - test/account/PerHookData.t.sol | 65 ++-- test/account/UpgradeableModularAccount.t.sol | 11 +- .../mocks/DefaultValidationFactoryFixture.sol | 7 +- test/samples/AllowlistPlugin.t.sol | 318 ++++++++++++++++++ test/utils/AccountTestBase.sol | 134 ++++++++ test/utils/CustomValidationTestBase.sol | 44 +++ 11 files changed, 679 insertions(+), 76 deletions(-) create mode 100644 src/samples/permissionhooks/AllowlistPlugin.sol create mode 100644 test/samples/AllowlistPlugin.t.sol create mode 100644 test/utils/CustomValidationTestBase.sol diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 3e1917aa..a09ac8c7 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -240,11 +240,14 @@ contract UpgradeableModularAccount is /// with user install configs. /// @dev This function is only callable once, and only by the EntryPoint. - function initializeDefaultValidation(FunctionReference validationFunction, bytes calldata installData) - external - initializer - { - _installValidation(validationFunction, true, new bytes4[](0), installData, bytes("")); + function initializeWithValidation( + FunctionReference validationFunction, + bool shared, + bytes4[] memory selectors, + bytes calldata installData, + bytes calldata preValidationHooks + ) external initializer { + _installValidation(validationFunction, shared, selectors, installData, preValidationHooks); emit ModularAccountInitialized(_ENTRY_POINT); } diff --git a/src/samples/permissionhooks/AllowlistPlugin.sol b/src/samples/permissionhooks/AllowlistPlugin.sol new file mode 100644 index 00000000..209d8370 --- /dev/null +++ b/src/samples/permissionhooks/AllowlistPlugin.sol @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {PluginMetadata, PluginManifest} from "../../interfaces/IPlugin.sol"; +import {IValidationHook} from "../../interfaces/IValidationHook.sol"; +import {IStandardExecutor, Call} from "../../interfaces/IStandardExecutor.sol"; +import {BasePlugin} from "../../plugins/BasePlugin.sol"; + +contract AllowlistPlugin is IValidationHook, BasePlugin { + enum FunctionId { + PRE_VALIDATION_HOOK + } + + struct AllowlistInit { + address target; + bool hasSelectorAllowlist; + bytes4[] selectors; + } + + struct AllowlistEntry { + bool allowed; + bool hasSelectorAllowlist; + } + + mapping(address target => mapping(address account => AllowlistEntry)) public targetAllowlist; + mapping(address target => mapping(bytes4 selector => mapping(address account => bool))) public + selectorAllowlist; + + error TargetNotAllowed(); + error SelectorNotAllowed(); + error NoSelectorSpecified(); + + function onInstall(bytes calldata data) external override { + AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[])); + + for (uint256 i = 0; i < init.length; i++) { + targetAllowlist[init[i].target][msg.sender] = AllowlistEntry(true, init[i].hasSelectorAllowlist); + + if (init[i].hasSelectorAllowlist) { + for (uint256 j = 0; j < init[i].selectors.length; j++) { + selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender] = true; + } + } + } + } + + function onUninstall(bytes calldata data) external override { + AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[])); + + for (uint256 i = 0; i < init.length; i++) { + delete targetAllowlist[init[i].target][msg.sender]; + + if (init[i].hasSelectorAllowlist) { + for (uint256 j = 0; j < init[i].selectors.length; j++) { + delete selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender]; + } + } + } + } + + function setAllowlistTarget(address target, bool allowed, bool hasSelectorAllowlist) external { + targetAllowlist[target][msg.sender] = AllowlistEntry(allowed, hasSelectorAllowlist); + } + + function setAllowlistSelector(address target, bytes4 selector, bool allowed) external { + selectorAllowlist[target][selector][msg.sender] = allowed; + } + + function preUserOpValidationHook(uint8 functionId, PackedUserOperation calldata userOp, bytes32) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + _checkAllowlistCalldata(userOp.callData); + return 0; + } + revert NotImplemented(); + } + + function preRuntimeValidationHook(uint8 functionId, address, uint256, bytes calldata data, bytes calldata) + external + view + override + { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + _checkAllowlistCalldata(data); + return; + } + + revert NotImplemented(); + } + + function pluginMetadata() external pure override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = "Allowlist Plugin"; + metadata.version = "v0.0.1"; + metadata.author = "ERC-6900 Working Group"; + + return metadata; + } + + // solhint-disable-next-line no-empty-blocks + function pluginManifest() external pure override returns (PluginManifest memory) {} + + function _checkAllowlistCalldata(bytes calldata callData) internal view { + if (bytes4(callData[:4]) == IStandardExecutor.execute.selector) { + (address target,, bytes memory data) = abi.decode(callData[4:], (address, uint256, bytes)); + _checkCallPermission(msg.sender, target, data); + } else if (bytes4(callData[:4]) == IStandardExecutor.executeBatch.selector) { + Call[] memory calls = abi.decode(callData[4:], (Call[])); + + for (uint256 i = 0; i < calls.length; i++) { + _checkCallPermission(msg.sender, calls[i].target, calls[i].data); + } + } + } + + function _checkCallPermission(address account, address target, bytes memory data) internal view { + AllowlistEntry storage entry = targetAllowlist[target][account]; + (bool allowed, bool hasSelectorAllowlist) = (entry.allowed, entry.hasSelectorAllowlist); + + if (!allowed) { + revert TargetNotAllowed(); + } + + if (hasSelectorAllowlist) { + if (data.length < 4) { + revert NoSelectorSpecified(); + } + + bytes4 selector = bytes4(data); + + if (!selectorAllowlist[target][selector][account]) { + revert SelectorNotAllowed(); + } + } + } +} diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol index fa92ab00..8d05c647 100644 --- a/test/account/AccountLoupe.t.sol +++ b/test/account/AccountLoupe.t.sol @@ -7,7 +7,6 @@ import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/Functio import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol"; import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; -import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {ComprehensivePlugin} from "../mocks/plugins/ComprehensivePlugin.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; @@ -15,8 +14,6 @@ import {AccountTestBase} from "../utils/AccountTestBase.sol"; contract AccountLoupeTest is AccountTestBase { ComprehensivePlugin public comprehensivePlugin; - FunctionReference public ownerValidation; - event ReceivedCall(bytes msgData, uint256 msgValue); function setUp() public { @@ -28,10 +25,6 @@ contract AccountLoupeTest is AccountTestBase { vm.prank(address(entryPoint)); account1.installPlugin(address(comprehensivePlugin), manifestHash, "", new FunctionReference[](0)); - ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) - ); - FunctionReference[] memory preValidationHooks = new FunctionReference[](2); preValidationHooks[0] = FunctionReferenceLib.pack( address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.PRE_VALIDATION_HOOK_1) diff --git a/test/account/DefaultValidationTest.t.sol b/test/account/DefaultValidationTest.t.sol index c2f118de..604de650 100644 --- a/test/account/DefaultValidationTest.t.sol +++ b/test/account/DefaultValidationTest.t.sol @@ -5,8 +5,6 @@ import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interface import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; -import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; import {DefaultValidationFactoryFixture} from "../mocks/DefaultValidationFactoryFixture.sol"; @@ -16,11 +14,6 @@ contract DefaultValidationTest is AccountTestBase { DefaultValidationFactoryFixture public defaultValidationFactoryFixture; - uint256 public constant CALL_GAS_LIMIT = 50000; - uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; - - FunctionReference public ownerValidation; - address public ethRecipient; function setUp() public { @@ -32,10 +25,6 @@ contract DefaultValidationTest is AccountTestBase { ethRecipient = makeAddr("ethRecipient"); vm.deal(ethRecipient, 1 wei); - - ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) - ); } function test_defaultValidation_userOp_simple() public { diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index e80d022c..78867f55 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -25,9 +25,6 @@ contract MultiValidationTest is AccountTestBase { address public owner2; uint256 public owner2Key; - uint256 public constant CALL_GAS_LIMIT = 50000; - uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; - function setUp() public { validator2 = new SingleOwnerPlugin(); diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol index a7542c0c..708f4fde 100644 --- a/test/account/PerHookData.t.sol +++ b/test/account/PerHookData.t.sol @@ -3,63 +3,28 @@ pragma solidity ^0.8.25; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; -import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; import {MockAccessControlHookPlugin} from "../mocks/plugins/MockAccessControlHookPlugin.sol"; import {Counter} from "../mocks/Counter.sol"; -import {AccountTestBase} from "../utils/AccountTestBase.sol"; +import {CustomValidationTestBase} from "../utils/CustomValidationTestBase.sol"; -contract PerHookDataTest is AccountTestBase { +contract PerHookDataTest is CustomValidationTestBase { using MessageHashUtils for bytes32; MockAccessControlHookPlugin internal accessControlHookPlugin; Counter internal counter; - FunctionReference internal ownerValidation; - - uint256 public constant CALL_GAS_LIMIT = 50000; - uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; - function setUp() public { counter = new Counter(); accessControlHookPlugin = new MockAccessControlHookPlugin(); - // Write over `account1` with a new account proxy, with different initialization. - - address accountImplementation = address(factory.accountImplementation()); - - account1 = UpgradeableModularAccount(payable(new ERC1967Proxy(accountImplementation, ""))); - - ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) - ); - - FunctionReference accessControlHook = FunctionReferenceLib.pack( - address(accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK) - ); - - FunctionReference[] memory preValidationHooks = new FunctionReference[](1); - preValidationHooks[0] = accessControlHook; - - bytes[] memory preValidationHookData = new bytes[](1); - // Access control is restricted to only the counter - preValidationHookData[0] = abi.encode(counter); - - bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); - - vm.prank(address(entryPoint)); - account1.installValidation( - ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks - ); - - vm.deal(address(account1), 100 ether); + _customValidationSetup(); } function test_passAccessControl_userOp() public { @@ -310,4 +275,28 @@ contract PerHookDataTest is AccountTestBase { return (userOp, userOpHash); } + + // Test config + + function _initialValidationConfig() + internal + virtual + override + returns (FunctionReference, bool, bytes4[] memory, bytes memory, bytes memory) + { + FunctionReference accessControlHook = FunctionReferenceLib.pack( + address(accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK) + ); + + FunctionReference[] memory preValidationHooks = new FunctionReference[](1); + preValidationHooks[0] = accessControlHook; + + bytes[] memory preValidationHookData = new bytes[](1); + // Access control is restricted to only the counter + preValidationHookData[0] = abi.encode(counter); + + bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); + + return (ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks); + } } diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index bea72dc6..8b7c4f59 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -10,7 +10,7 @@ import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import {PluginManagerInternals} from "../../src/account/PluginManagerInternals.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; import {IPlugin, PluginManifest} from "../../src/interfaces/IPlugin.sol"; import {IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; @@ -39,11 +39,6 @@ contract UpgradeableModularAccountTest is AccountTestBase { Counter public counter; PluginManifest internal manifest; - FunctionReference public ownerValidation; - - uint256 public constant CALL_GAS_LIMIT = 50000; - uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; - event PluginInstalled(address indexed plugin, bytes32 manifestHash, FunctionReference[] dependencies); event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); event ReceivedCall(bytes msgData, uint256 msgValue); @@ -61,10 +56,6 @@ contract UpgradeableModularAccountTest is AccountTestBase { vm.deal(ethRecipient, 1 wei); counter = new Counter(); counter.increment(); // amoritze away gas cost of zero->nonzero transition - - ownerValidation = FunctionReferenceLib.pack( - address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) - ); } function test_deployAccount() public { diff --git a/test/mocks/DefaultValidationFactoryFixture.sol b/test/mocks/DefaultValidationFactoryFixture.sol index a4836ad8..54663a7c 100644 --- a/test/mocks/DefaultValidationFactoryFixture.sol +++ b/test/mocks/DefaultValidationFactoryFixture.sol @@ -55,11 +55,14 @@ contract DefaultValidationFactoryFixture is OptimizedTest { new ERC1967Proxy{salt: getSalt(owner, salt)}(address(accountImplementation), ""); // point proxy to actual implementation and init plugins - UpgradeableModularAccount(payable(addr)).initializeDefaultValidation( + UpgradeableModularAccount(payable(addr)).initializeWithValidation( FunctionReferenceLib.pack( address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) ), - pluginInstallData + true, + new bytes4[](0), + pluginInstallData, + "" ); } diff --git a/test/samples/AllowlistPlugin.t.sol b/test/samples/AllowlistPlugin.t.sol new file mode 100644 index 00000000..d3564d87 --- /dev/null +++ b/test/samples/AllowlistPlugin.t.sol @@ -0,0 +1,318 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; + +import {Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {AllowlistPlugin} from "../../src/samples/permissionhooks/AllowlistPlugin.sol"; + +import {CustomValidationTestBase} from "../utils/CustomValidationTestBase.sol"; +import {Counter} from "../mocks/Counter.sol"; + +contract AllowlistPluginTest is CustomValidationTestBase { + AllowlistPlugin public allowlistPlugin; + + AllowlistPlugin.AllowlistInit[] internal allowlistInit; + + Counter[] public counters; + + function setUp() public { + allowlistPlugin = new AllowlistPlugin(); + + counters = new Counter[](10); + + for (uint256 i = 0; i < counters.length; i++) { + counters[i] = new Counter(); + } + + // Don't call `_customValidationSetup` here, as we want to test various configurations of install data. + } + + function testFuzz_allowlistHook_userOp_single(uint256 seed) public { + AllowlistPlugin.AllowlistInit[] memory inits; + (inits, seed) = _generateRandomizedAllowlistInit(seed); + + _copyInitToStorage(inits); + _customValidationSetup(); + + Call[] memory calls = new Call[](1); + (calls[0], seed) = _generateRandomCall(seed); + bytes memory expectedError = _getExpectedUserOpError(calls); + + _runExecUserOp(calls[0].target, calls[0].data, expectedError); + } + + function testFuzz_allowlistHook_userOp_batch(uint256 seed) public { + AllowlistPlugin.AllowlistInit[] memory inits; + (inits, seed) = _generateRandomizedAllowlistInit(seed); + + _copyInitToStorage(inits); + _customValidationSetup(); + + Call[] memory calls; + (calls, seed) = _generateRandomCalls(seed); + bytes memory expectedError = _getExpectedUserOpError(calls); + + _runExecBatchUserOp(calls, expectedError); + } + + function testFuzz_allowlistHook_runtime_single(uint256 seed) public { + AllowlistPlugin.AllowlistInit[] memory inits; + (inits, seed) = _generateRandomizedAllowlistInit(seed); + + _copyInitToStorage(inits); + _customValidationSetup(); + + Call[] memory calls = new Call[](1); + (calls[0], seed) = _generateRandomCall(seed); + bytes memory expectedError = _getExpectedRuntimeError(calls); + + if (keccak256(expectedError) == keccak256("emptyrevert")) { + _runtimeExecExpFail(calls[0].target, calls[0].data, ""); + } else { + _runtimeExec(calls[0].target, calls[0].data, expectedError); + } + } + + function testFuzz_allowlistHook_runtime_batch(uint256 seed) public { + AllowlistPlugin.AllowlistInit[] memory inits; + (inits, seed) = _generateRandomizedAllowlistInit(seed); + + _copyInitToStorage(inits); + _customValidationSetup(); + + Call[] memory calls; + (calls, seed) = _generateRandomCalls(seed); + bytes memory expectedError = _getExpectedRuntimeError(calls); + + if (keccak256(expectedError) == keccak256("emptyrevert")) { + _runtimeExecBatchExpFail(calls, ""); + } else { + _runtimeExecBatch(calls, expectedError); + } + } + + function _generateRandomCalls(uint256 seed) internal view returns (Call[] memory, uint256) { + uint256 length = seed % 10; + seed = _next(seed); + + Call[] memory calls = new Call[](length); + + for (uint256 i = 0; i < length; i++) { + (calls[i], seed) = _generateRandomCall(seed); + } + + return (calls, seed); + } + + function _generateRandomCall(uint256 seed) internal view returns (Call memory call, uint256 newSeed) { + // Half of the time, the target is a random counter, the other half, it's a random address. + bool isCounter = seed % 2 == 0; + seed = _next(seed); + + call.target = isCounter ? address(counters[seed % counters.length]) : address(uint160(uint256(seed))); + seed = _next(seed); + + bool validSelector = seed % 2 == 0; + seed = _next(seed); + + if (validSelector) { + uint256 selectorIndex = seed % 3; + seed = _next(seed); + + if (selectorIndex == 0) { + call.data = abi.encodeCall(Counter.setNumber, (seed % 100)); + } else if (selectorIndex == 1) { + call.data = abi.encodeCall(Counter.increment, ()); + } else { + call.data = abi.encodeWithSignature("number()"); + } + + seed = _next(seed); + } else { + call.data = abi.encodePacked(bytes4(uint32(uint256(seed)))); + seed = _next(seed); + } + + return (call, seed); + } + + function _getExpectedUserOpError(Call[] memory calls) internal view returns (bytes memory) { + for (uint256 i = 0; i < calls.length; i++) { + Call memory call = calls[i]; + + (bool allowed, bool hasSelectorAllowlist) = + allowlistPlugin.targetAllowlist(call.target, address(account1)); + if (allowed) { + if ( + hasSelectorAllowlist + && !allowlistPlugin.selectorAllowlist(call.target, bytes4(call.data), address(account1)) + ) { + return abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(AllowlistPlugin.SelectorNotAllowed.selector) + ); + } + } else { + return abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(AllowlistPlugin.TargetNotAllowed.selector) + ); + } + } + + return ""; + } + + function _getExpectedRuntimeError(Call[] memory calls) internal view returns (bytes memory) { + for (uint256 i = 0; i < calls.length; i++) { + Call memory call = calls[i]; + + (bool allowed, bool hasSelectorAllowlist) = + allowlistPlugin.targetAllowlist(call.target, address(account1)); + if (allowed) { + if ( + hasSelectorAllowlist + && !allowlistPlugin.selectorAllowlist(call.target, bytes4(call.data), address(account1)) + ) { + return abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + address(allowlistPlugin), + uint8(AllowlistPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSelector(AllowlistPlugin.SelectorNotAllowed.selector) + ); + } + } else { + return abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + address(allowlistPlugin), + uint8(AllowlistPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSelector(AllowlistPlugin.TargetNotAllowed.selector) + ); + } + } + + // At this point, we have returned any error that would come from the AllowlistPlugin. + // But, because this is in the runtime path, the Counter itself may throw if it is not a valid selector. + + for (uint256 i = 0; i < calls.length; i++) { + Call memory call = calls[i]; + bytes4 selector = bytes4(call.data); + + if ( + selector != Counter.setNumber.selector && selector != Counter.increment.selector + && selector != bytes4(abi.encodeWithSignature("number()")) + ) { + //todo: better define a way to handle empty reverts. + return "emptyrevert"; + } + } + + return ""; + } + + function _generateRandomizedAllowlistInit(uint256 seed) + internal + view + returns (AllowlistPlugin.AllowlistInit[] memory, uint256) + { + uint256 length = seed % 10; + seed = _next(seed); + + AllowlistPlugin.AllowlistInit[] memory init = new AllowlistPlugin.AllowlistInit[](length); + + for (uint256 i = 0; i < length; i++) { + // Half the time, the target is a random counter, the other half, it's a random address. + bool isCounter = seed % 2 == 0; + seed = _next(seed); + + address target = + isCounter ? address(counters[seed % counters.length]) : address(uint160(uint256(seed))); + + bool hasSelectorAllowlist = seed % 2 == 0; + seed = _next(seed); + + uint256 selectorLength = seed % 10; + seed = _next(seed); + + bytes4[] memory selectors = new bytes4[](selectorLength); + + for (uint256 j = 0; j < selectorLength; j++) { + // half of the time, the selector is a valid selector on counter, the other half it's a random + // selector + + bool isCounterSelector = seed % 2 == 0; + seed = _next(seed); + + if (isCounterSelector) { + uint256 selectorIndex = seed % 3; + seed = _next(seed); + + if (selectorIndex == 0) { + selectors[j] = Counter.setNumber.selector; + } else if (selectorIndex == 1) { + selectors[j] = Counter.increment.selector; + } else { + selectors[j] = bytes4(abi.encodeWithSignature("number()")); + } + } else { + selectors[j] = bytes4(uint32(uint256(seed))); + seed = _next(seed); + } + + selectors[j] = bytes4(uint32(uint256(keccak256(abi.encodePacked(seed, j))))); + seed = _next(seed); + } + + init[i] = AllowlistPlugin.AllowlistInit(target, hasSelectorAllowlist, selectors); + } + + return (init, seed); + } + + // todo: runtime paths + + // fuzz targets, fuzz target selectors. + + // Maybe pull out the helper function for running user ops and possibly expect a failure? + + function _next(uint256 seed) internal pure returns (uint256) { + return uint256(keccak256(abi.encodePacked(seed))); + } + + function _initialValidationConfig() + internal + virtual + override + returns (FunctionReference, bool, bytes4[] memory, bytes memory, bytes memory) + { + FunctionReference accessControlHook = FunctionReferenceLib.pack( + address(allowlistPlugin), uint8(AllowlistPlugin.FunctionId.PRE_VALIDATION_HOOK) + ); + + FunctionReference[] memory preValidationHooks = new FunctionReference[](1); + preValidationHooks[0] = accessControlHook; + + bytes[] memory preValidationHookData = new bytes[](1); + // Access control is restricted to only the counter + preValidationHookData[0] = abi.encode(allowlistInit); + + bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); + + return (ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks); + } + + // Unfortunately, this is a feature that solidity has only implemented in via-ir, so we need to do it manually + // to be able to run the tests in lite mode. + function _copyInitToStorage(AllowlistPlugin.AllowlistInit[] memory init) internal { + for (uint256 i = 0; i < init.length; i++) { + allowlistInit.push(init[i]); + } + } +} diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 2d48c6b4..60a5d53f 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -2,8 +2,11 @@ pragma solidity ^0.8.19; import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; +import {IStandardExecutor, Call} from "../../src/interfaces/IStandardExecutor.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; @@ -16,6 +19,7 @@ import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; /// SingleOwnerPlugin. abstract contract AccountTestBase is OptimizedTest { using FunctionReferenceLib for FunctionReference; + using MessageHashUtils for bytes32; EntryPoint public entryPoint; address payable public beneficiary; @@ -26,9 +30,14 @@ abstract contract AccountTestBase is OptimizedTest { uint256 public owner1Key; UpgradeableModularAccount public account1; + FunctionReference internal ownerValidation; + uint8 public constant SELECTOR_ASSOCIATED_VALIDATION = 0; uint8 public constant DEFAULT_VALIDATION = 1; + uint256 public constant CALL_GAS_LIMIT = 50000; + uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; + struct PreValidationHookData { uint8 index; bytes validationData; @@ -44,6 +53,131 @@ abstract contract AccountTestBase is OptimizedTest { account1 = factory.createAccount(owner1, 0); vm.deal(address(account1), 100 ether); + + ownerValidation = FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ); + } + + function _runExecUserOp(address target, bytes memory callData) internal { + _runUserOp(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData))); + } + + function _runExecUserOp(address target, bytes memory callData, bytes memory revertReason) internal { + _runUserOp(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData)), revertReason); + } + + function _runExecBatchUserOp(Call[] memory calls) internal { + _runUserOp(abi.encodeCall(IStandardExecutor.executeBatch, (calls))); + } + + function _runExecBatchUserOp(Call[] memory calls, bytes memory revertReason) internal { + _runUserOp(abi.encodeCall(IStandardExecutor.executeBatch, (calls)), revertReason); + } + + function _runUserOp(bytes memory callData) internal { + // Run user op without expecting a revert + _runUserOp(callData, hex""); + } + + function _runUserOp(bytes memory callData, bytes memory expectedRevertData) internal { + uint256 nonce = entryPoint.getNonce(address(account1), 0); + + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: nonce, + initCode: hex"", + callData: callData, + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: hex"", + signature: hex"" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + userOp.signature = _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + DEFAULT_VALIDATION, + abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + if (expectedRevertData.length > 0) { + vm.expectRevert(expectedRevertData); + } + entryPoint.handleOps(userOps, beneficiary); + } + + function _runtimeExec(address target, bytes memory callData) internal { + _runtimeCall(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData))); + } + + function _runtimeExec(address target, bytes memory callData, bytes memory expectedRevertData) internal { + _runtimeCall(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData)), expectedRevertData); + } + + function _runtimeExecExpFail(address target, bytes memory callData, bytes memory expectedRevertData) + internal + { + _runtimeCallExpFail(abi.encodeCall(IStandardExecutor.execute, (target, 0, callData)), expectedRevertData); + } + + function _runtimeExecBatch(Call[] memory calls) internal { + _runtimeCall(abi.encodeCall(IStandardExecutor.executeBatch, (calls))); + } + + function _runtimeExecBatch(Call[] memory calls, bytes memory expectedRevertData) internal { + _runtimeCall(abi.encodeCall(IStandardExecutor.executeBatch, (calls)), expectedRevertData); + } + + function _runtimeExecBatchExpFail(Call[] memory calls, bytes memory expectedRevertData) internal { + _runtimeCallExpFail(abi.encodeCall(IStandardExecutor.executeBatch, (calls)), expectedRevertData); + } + + function _runtimeCall(bytes memory callData) internal { + _runtimeCall(callData, ""); + } + + function _runtimeCall(bytes memory callData, bytes memory expectedRevertData) internal { + if (expectedRevertData.length > 0) { + vm.expectRevert(expectedRevertData); + } + + vm.prank(owner1); + account1.executeWithAuthorization( + callData, + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + DEFAULT_VALIDATION, + "" + ) + ); + } + + // Always expects a revert, even if the revert data is zero-length. + function _runtimeCallExpFail(bytes memory callData, bytes memory expectedRevertData) internal { + vm.expectRevert(expectedRevertData); + + vm.prank(owner1); + account1.executeWithAuthorization( + callData, + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + DEFAULT_VALIDATION, + "" + ) + ); } function _transferOwnershipToTest() internal { diff --git a/test/utils/CustomValidationTestBase.sol b/test/utils/CustomValidationTestBase.sol new file mode 100644 index 00000000..8bcdd406 --- /dev/null +++ b/test/utils/CustomValidationTestBase.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; + +import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; + +import {AccountTestBase} from "./AccountTestBase.sol"; + +/// @dev This test contract base is used to test custom validation logic. +/// To use this, override the _initialValidationConfig function to return the desired validation configuration. +/// Then, call _customValidationSetup in the test setup. +/// Make sure to do so after any state variables that `_initialValidationConfig` relies on are set. +abstract contract CustomValidationTestBase is AccountTestBase { + function _customValidationSetup() internal { + ( + FunctionReference validationFunction, + bool shared, + bytes4[] memory selectors, + bytes memory installData, + bytes memory preValidationHooks + ) = _initialValidationConfig(); + + address accountImplementation = address(factory.accountImplementation()); + + account1 = UpgradeableModularAccount(payable(new ERC1967Proxy{salt: 0}(accountImplementation, ""))); + + account1.initializeWithValidation(validationFunction, shared, selectors, installData, preValidationHooks); + + vm.deal(address(account1), 100 ether); + } + + function _initialValidationConfig() + internal + virtual + returns ( + FunctionReference validationFunction, + bool shared, + bytes4[] memory selectors, + bytes memory installData, + bytes memory preValidationHooks + ); +}