diff --git a/.solhint-test.json b/.solhint-test.json index cbd7bf02..fd2b1007 100644 --- a/.solhint-test.json +++ b/.solhint-test.json @@ -5,6 +5,7 @@ "immutable-vars-naming": ["error"], "no-unused-import": ["error"], "compiler-version": ["error", ">=0.8.19"], + "custom-errors": "off", "func-visibility": ["error", { "ignoreConstructors": true }], "max-line-length": ["error", 120], "max-states-count": ["warn", 30], diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 2afe6638..3e62fe9a 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -76,8 +76,7 @@ abstract contract AccountLoupe is IAccountLoupe { override returns (FunctionReference[] memory preValidationHooks) { - preValidationHooks = - toFunctionReferenceArray(getAccountStorage().validationData[validationFunction].preValidationHooks); + preValidationHooks = getAccountStorage().validationData[validationFunction].preValidationHooks; } /// @inheritdoc IAccountLoupe diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index 352a3d5b..ddd8f900 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -41,7 +41,7 @@ struct ValidationData { // How many execution hooks require the UO context. uint8 requireUOHookCount; // The pre validation hooks for this function selector. - EnumerableSet.Bytes32Set preValidationHooks; + FunctionReference[] preValidationHooks; // Permission hooks for this validation function. EnumerableSet.Bytes32Set permissionHooks; } diff --git a/src/account/PluginManager2.sol b/src/account/PluginManager2.sol index f733eb60..effa1a15 100644 --- a/src/account/PluginManager2.sol +++ b/src/account/PluginManager2.sol @@ -13,11 +13,15 @@ import {ExecutionHook} from "../interfaces/IAccountLoupe.sol"; abstract contract PluginManager2 { using EnumerableSet for EnumerableSet.Bytes32Set; + // Index marking the start of the data for the validation function. + uint8 internal constant _RESERVED_VALIDATION_DATA_INDEX = 255; + error DefaultValidationAlreadySet(FunctionReference validationFunction); error PreValidationAlreadySet(FunctionReference validationFunction, FunctionReference preValidationFunction); error ValidationAlreadySet(bytes4 selector, FunctionReference validationFunction); error ValidationNotSet(bytes4 selector, FunctionReference validationFunction); error PermissionAlreadySet(FunctionReference validationFunction, ExecutionHook hook); + error PreValidationHookLimitExceeded(); function _installValidation( FunctionReference validationFunction, @@ -39,19 +43,21 @@ abstract contract PluginManager2 { for (uint256 i = 0; i < preValidationFunctions.length; ++i) { FunctionReference preValidationFunction = preValidationFunctions[i]; - if ( - !_storage.validationData[validationFunction].preValidationHooks.add( - toSetValue(preValidationFunction) - ) - ) { - revert PreValidationAlreadySet(validationFunction, preValidationFunction); - } + _storage.validationData[validationFunction].preValidationHooks.push(preValidationFunction); if (initDatas[i].length > 0) { (address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction); IPlugin(preValidationPlugin).onInstall(initDatas[i]); } } + + // Avoid collision between reserved index and actual indices + if ( + _storage.validationData[validationFunction].preValidationHooks.length + > _RESERVED_VALIDATION_DATA_INDEX + ) { + revert PreValidationHookLimitExceeded(); + } } if (permissionHooks.length > 0) { @@ -110,15 +116,16 @@ abstract contract PluginManager2 { bytes[] memory preValidationHookUninstallDatas = abi.decode(preValidationHookUninstallData, (bytes[])); // Clear pre validation hooks - EnumerableSet.Bytes32Set storage preValidationHooks = + FunctionReference[] storage preValidationHooks = _storage.validationData[validationFunction].preValidationHooks; - uint256 i = 0; - while (preValidationHooks.length() > 0) { - FunctionReference preValidationFunction = toFunctionReference(preValidationHooks.at(0)); - preValidationHooks.remove(toSetValue(preValidationFunction)); - (address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction); - IPlugin(preValidationPlugin).onUninstall(preValidationHookUninstallDatas[i++]); + for (uint256 i = 0; i < preValidationHooks.length; ++i) { + FunctionReference preValidationFunction = preValidationHooks[i]; + if (preValidationHookUninstallDatas[0].length > 0) { + (address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction); + IPlugin(preValidationPlugin).onUninstall(preValidationHookUninstallDatas[0]); + } } + delete _storage.validationData[validationFunction].preValidationHooks; } { @@ -135,6 +142,7 @@ abstract contract PluginManager2 { IPlugin(permissionHookPlugin).onUninstall(permissionHookUninstallDatas[i++]); } } + delete _storage.validationData[validationFunction].preValidationHooks; // Because this function also calls `onUninstall`, and removes the default flag from validation, we must // assume these selectors passed in to be exhaustive. diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 4175c5bc..51137641 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -11,6 +11,7 @@ import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol"; +import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol"; import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationDataHelpers.sol"; import {IPlugin, PluginManifest} from "../interfaces/IPlugin.sol"; import {IValidation} from "../interfaces/IValidation.sol"; @@ -20,13 +21,7 @@ import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.so import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol"; import {AccountExecutor} from "./AccountExecutor.sol"; import {AccountLoupe} from "./AccountLoupe.sol"; -import { - AccountStorage, - getAccountStorage, - toSetValue, - toFunctionReference, - toExecutionHook -} from "./AccountStorage.sol"; +import {AccountStorage, getAccountStorage, toSetValue, toExecutionHook} from "./AccountStorage.sol"; import {AccountStorageInitializable} from "./AccountStorageInitializable.sol"; import {PluginManagerInternals} from "./PluginManagerInternals.sol"; import {PluginManager2} from "./PluginManager2.sol"; @@ -46,6 +41,7 @@ contract UpgradeableModularAccount is { using EnumerableSet for EnumerableSet.Bytes32Set; using FunctionReferenceLib for FunctionReference; + using SparseCalldataSegmentLib for bytes; struct PostExecToRun { bytes preExecHookReturnData; @@ -68,6 +64,7 @@ contract UpgradeableModularAccount is error ExecFromPluginNotPermitted(address plugin, bytes4 selector); error ExecFromPluginExternalNotPermitted(address plugin, address target, uint256 value, bytes data); error NativeTokenSpendingNotPermitted(address plugin); + error NonCanonicalEncoding(); error NotEntryPoint(); error PostExecHookReverted(address plugin, uint8 functionId, bytes revertReason); error PreExecHookReverted(address plugin, uint8 functionId, bytes revertReason); @@ -80,6 +77,8 @@ contract UpgradeableModularAccount is error UnrecognizedFunction(bytes4 selector); error UserOpValidationFunctionMissing(bytes4 selector); error ValidationDoesNotApply(bytes4 selector, address plugin, uint8 functionId, bool isDefault); + error ValidationSignatureSegmentMissing(); + error SignatureSegmentOutOfOrder(); // Wraps execution of a native function with runtime validation and hooks // Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin @@ -407,38 +406,50 @@ contract UpgradeableModularAccount is revert RequireUserOperationContext(); } - validationData = - _doUserOpValidation(selector, userOpValidationFunction, userOp, userOp.signature[22:], userOpHash); + validationData = _doUserOpValidation(userOpValidationFunction, userOp, userOp.signature[22:], userOpHash); } // To support gas estimation, we don't fail early when the failure is caused by a signature failure function _doUserOpValidation( - bytes4 selector, FunctionReference userOpValidationFunction, PackedUserOperation memory userOp, bytes calldata signature, bytes32 userOpHash - ) internal returns (uint256 validationData) { - userOp.signature = signature; + ) internal returns (uint256) { + // Set up the per-hook data tracking fields + bytes calldata signatureSegment; + (signatureSegment, signature) = signature.getNextSegment(); - if (userOpValidationFunction.isEmpty()) { - // If the validation function is empty, then the call cannot proceed. - revert UserOpValidationFunctionMissing(selector); - } - - uint256 currentValidationData; + uint256 validationData; // Do preUserOpValidation hooks - EnumerableSet.Bytes32Set storage preUserOpValidationHooks = + FunctionReference[] memory preUserOpValidationHooks = getAccountStorage().validationData[userOpValidationFunction].preValidationHooks; - uint256 preUserOpValidationHooksLength = preUserOpValidationHooks.length(); - for (uint256 i = 0; i < preUserOpValidationHooksLength; ++i) { - bytes32 key = preUserOpValidationHooks.at(i); - FunctionReference preUserOpValidationHook = toFunctionReference(key); + for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) { + // Load per-hook data, if any is present + // The segment index is the first byte of the signature + if (signatureSegment.getIndex() == i) { + // Use the current segment + userOp.signature = signatureSegment.getBody(); + + if (userOp.signature.length == 0) { + revert NonCanonicalEncoding(); + } + + // Load the next per-hook data segment + (signatureSegment, signature) = signature.getNextSegment(); + + if (signatureSegment.getIndex() <= i) { + revert SignatureSegmentOutOfOrder(); + } + } else { + userOp.signature = ""; + } - (address plugin, uint8 functionId) = preUserOpValidationHook.unpack(); - currentValidationData = IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash); + (address plugin, uint8 functionId) = preUserOpValidationHooks[i].unpack(); + uint256 currentValidationData = + IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash); if (uint160(currentValidationData) > 1) { // If the aggregator is not 0 or 1, it is an unexpected value @@ -449,16 +460,24 @@ contract UpgradeableModularAccount is // Run the user op validationFunction { + if (signatureSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) { + revert ValidationSignatureSegmentMissing(); + } + + userOp.signature = signatureSegment.getBody(); + (address plugin, uint8 functionId) = userOpValidationFunction.unpack(); - currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash); + uint256 currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash); - if (preUserOpValidationHooksLength != 0) { + if (preUserOpValidationHooks.length != 0) { // If we have other validation data we need to coalesce with validationData = _coalesceValidation(validationData, currentValidationData); } else { validationData = currentValidationData; } } + + return validationData; } function _doRuntimeValidation( @@ -466,18 +485,38 @@ contract UpgradeableModularAccount is bytes calldata callData, bytes calldata authorizationData ) internal { + // Set up the per-hook data tracking fields + bytes calldata authSegment; + (authSegment, authorizationData) = authorizationData.getNextSegment(); + // run all preRuntimeValidation hooks - EnumerableSet.Bytes32Set storage preRuntimeValidationHooks = + FunctionReference[] memory preRuntimeValidationHooks = getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks; - uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length(); - for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) { - bytes32 key = preRuntimeValidationHooks.at(i); - FunctionReference preRuntimeValidationHook = toFunctionReference(key); + for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) { + bytes memory currentAuthData; + + if (authSegment.getIndex() == i) { + // Use the current segment + currentAuthData = authSegment.getBody(); + + if (currentAuthData.length == 0) { + revert NonCanonicalEncoding(); + } + + // Load the next per-hook data segment + (authSegment, authorizationData) = authorizationData.getNextSegment(); - (address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHook.unpack(); + if (authSegment.getIndex() <= i) { + revert SignatureSegmentOutOfOrder(); + } + } else { + currentAuthData = ""; + } + + (address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHooks[i].unpack(); try IValidationHook(hookPlugin).preRuntimeValidationHook( - hookFunctionId, msg.sender, msg.value, callData + hookFunctionId, msg.sender, msg.value, callData, currentAuthData ) // forgefmt: disable-start // solhint-disable-next-line no-empty-blocks @@ -487,9 +526,13 @@ contract UpgradeableModularAccount is } } + if (authSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) { + revert ValidationSignatureSegmentMissing(); + } + (address plugin, uint8 functionId) = runtimeValidationFunction.unpack(); - try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authorizationData) + try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authSegment.getBody()) // forgefmt: disable-start // solhint-disable-next-line no-empty-blocks {} catch (bytes memory revertReason) { diff --git a/src/helpers/SparseCalldataSegmentLib.sol b/src/helpers/SparseCalldataSegmentLib.sol new file mode 100644 index 00000000..0a6cc541 --- /dev/null +++ b/src/helpers/SparseCalldataSegmentLib.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.25; + +/// @title Sparse Calldata Segment Library +/// @notice Library for working with sparsely-packed calldata segments, identified with an index. +/// @dev The first byte of each segment is the index of the segment. +/// To prevent accidental stack-to-deep errors, the body and index of the segment are extracted separately, rather +/// than inline as part of the tuple returned by `getNextSegment`. +library SparseCalldataSegmentLib { + /// @notice Splits out a segment of calldata, sparsely-packed. + /// The expected format is: + /// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN] + /// @param source The calldata to extract the segment from. + /// @return segment The extracted segment. Using the above example, this would be segment0. + /// @return remainder The remaining calldata. Using the above example, + /// this would start at uint32(len(segment1)) and continue to the end at segmentN. + function getNextSegment(bytes calldata source) + internal + pure + returns (bytes calldata segment, bytes calldata remainder) + { + // The first 4 bytes hold the length of the segment, excluding the index. + uint32 length = uint32(bytes4(source[:4])); + + // The offset of the remainder of the calldata. + uint256 remainderOffset = 4 + length; + + // The segment is the next `length` + 1 bytes, to account for the index. + // By convention, the first byte of each segment is the index of the segment. + segment = source[4:remainderOffset]; + + // The remainder is the rest of the calldata. + remainder = source[remainderOffset:]; + } + + /// @notice Extracts the index from a segment. + /// @dev The first byte of the segment is the index. + /// @param segment The segment to extract the index from + /// @return The index of the segment + function getIndex(bytes calldata segment) internal pure returns (uint8) { + return uint8(segment[0]); + } + + /// @notice Extracts the body from a segment. + /// @dev The body is the segment without the index. + /// @param segment The segment to extract the body from + /// @return The body of the segment. + function getBody(bytes calldata segment) internal pure returns (bytes calldata) { + return segment[1:]; + } +} diff --git a/src/interfaces/IValidation.sol b/src/interfaces/IValidation.sol index b3adcd3d..38c8a139 100644 --- a/src/interfaces/IValidation.sol +++ b/src/interfaces/IValidation.sol @@ -23,6 +23,7 @@ interface IValidation is IPlugin { /// @param sender The caller address. /// @param value The call value. /// @param data The calldata sent. + /// @param authorization Additional data for the validation function to use. function validateRuntime( uint8 functionId, address sender, diff --git a/src/interfaces/IValidationHook.sol b/src/interfaces/IValidationHook.sol index 8eb7a61d..8300bbb8 100644 --- a/src/interfaces/IValidationHook.sol +++ b/src/interfaces/IValidationHook.sol @@ -24,8 +24,13 @@ interface IValidationHook is IPlugin { /// @param sender The caller address. /// @param value The call value. /// @param data The calldata sent. - function preRuntimeValidationHook(uint8 functionId, address sender, uint256 value, bytes calldata data) - external; + function preRuntimeValidationHook( + uint8 functionId, + address sender, + uint256 value, + bytes calldata data, + bytes calldata authorization + ) external; // TODO: support this hook type within the account & in the manifest diff --git a/test/account/AccountReturnData.t.sol b/test/account/AccountReturnData.t.sol index 8e8f3215..fc9fd615 100644 --- a/test/account/AccountReturnData.t.sol +++ b/test/account/AccountReturnData.t.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; -import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; import {Call} from "../../src/interfaces/IStandardExecutor.sol"; import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; @@ -59,8 +59,12 @@ contract AccountReturnDataTest is AccountTestBase { account1.execute, (address(regularResultContract), 0, abi.encodeCall(RegularResultContract.foo, ())) ), - abi.encodePacked( - singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER, SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); @@ -85,8 +89,12 @@ contract AccountReturnDataTest is AccountTestBase { bytes memory retData = account1.executeWithAuthorization( abi.encodeCall(account1.executeBatch, (calls)), - abi.encodePacked( - singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER, SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); diff --git a/test/account/DefaultValidationTest.t.sol b/test/account/DefaultValidationTest.t.sol index fc93060d..c2f118de 100644 --- a/test/account/DefaultValidationTest.t.sol +++ b/test/account/DefaultValidationTest.t.sol @@ -57,7 +57,7 @@ contract DefaultValidationTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, DEFAULT_VALIDATION, r, s, v); + userOp.signature = _encodeSignature(ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -74,7 +74,7 @@ contract DefaultValidationTest is AccountTestBase { vm.prank(owner1); account1.executeWithAuthorization( abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), - abi.encodePacked(ownerValidation, DEFAULT_VALIDATION) + _encodeSignature(ownerValidation, DEFAULT_VALIDATION, "") ); assertEq(ethRecipient.balance, 2 wei); diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index 9b22f5a0..e80d022c 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -67,20 +67,24 @@ contract MultiValidationTest is AccountTestBase { ); account1.executeWithAuthorization( abi.encodeCall(IStandardExecutor.execute, (address(0), 0, "")), - abi.encodePacked( - address(validator2), - uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), - SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); vm.prank(owner2); account1.executeWithAuthorization( abi.encodeCall(IStandardExecutor.execute, (address(0), 0, "")), - abi.encodePacked( - address(validator2), - uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), - SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); } @@ -105,13 +109,10 @@ contract MultiValidationTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked( - address(validator2), + userOp.signature = _encodeSignature( + FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)), SELECTOR_ASSOCIATED_VALIDATION, - uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), - r, - s, - v + abi.encodePacked(r, s, v) ); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); @@ -123,8 +124,11 @@ contract MultiValidationTest is AccountTestBase { userOp.nonce = 1; (v, r, s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = - abi.encodePacked(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER), r, s, v); + userOp.signature = _encodeSignature( + FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)), + SELECTOR_ASSOCIATED_VALIDATION, + abi.encodePacked(r, s, v) + ); userOps[0] = userOp; vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error")); diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol new file mode 100644 index 00000000..b677951f --- /dev/null +++ b/test/account/PerHookData.t.sol @@ -0,0 +1,361 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; + +import {MockAccessControlHookPlugin} from "../mocks/plugins/MockAccessControlHookPlugin.sol"; +import {Counter} from "../mocks/Counter.sol"; +import {AccountTestBase} from "../utils/AccountTestBase.sol"; + +contract PerHookDataTest is AccountTestBase { + using MessageHashUtils for bytes32; + + MockAccessControlHookPlugin internal _accessControlHookPlugin; + + Counter internal _counter; + + FunctionReference internal _ownerValidation; + + uint256 public constant CALL_GAS_LIMIT = 50000; + uint256 public constant VERIFICATION_GAS_LIMIT = 1200000; + + function setUp() public { + _counter = new Counter(); + + _accessControlHookPlugin = new MockAccessControlHookPlugin(); + + // Write over `account1` with a new account proxy, with different initialization. + + address accountImplementation = address(factory.accountImplementation()); + + account1 = UpgradeableModularAccount(payable(new ERC1967Proxy(accountImplementation, ""))); + + _ownerValidation = FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ); + + FunctionReference accessControlHook = FunctionReferenceLib.pack( + address(_accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK) + ); + + FunctionReference[] memory preValidationHooks = new FunctionReference[](1); + preValidationHooks[0] = accessControlHook; + + bytes[] memory preValidationHookData = new bytes[](1); + // Access control is restricted to only the _counter + preValidationHookData[0] = abi.encode(_counter); + + bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData); + + vm.prank(address(entryPoint)); + account1.installValidation( + _ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks, "" + ); + + vm.deal(address(account1), 100 ether); + } + + function test_passAccessControl_userOp() public { + assertEq(_counter.number(), 0); + + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(_counter.number(), 1); + } + + function test_failAccessControl_badSigData_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({ + index: 0, + validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234)) + }); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failAccessControl_noSigData_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + userOp.signature = _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failAccessControl_badIndexProvided_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](2); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + preValidationHookData[1] = PreValidationHookData({index: 1, validationData: abi.encodePacked(_counter)}); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + // todo: index out of order failure case with 2 pre hooks + + function test_failAccessControl_badTarget_userOp() public { + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 1 wei, "")), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(beneficiary)}); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSignature("Error(string)", "Target not allowed") + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failPerHookData_nonCanonicalEncoding_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: ""}); + + userOp.signature = _encodeSignature( + _ownerValidation, DEFAULT_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_passAccessControl_runtime() public { + assertEq(_counter.number(), 0); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + + vm.prank(owner1); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + + assertEq(_counter.number(), 1); + } + + function test_failAccessControl_badSigData_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({ + index: 0, + validationData: abi.encodePacked(address(0x1234123412341234123412341234123412341234)) + }); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + _accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + } + + function test_failAccessControl_noSigData_runtime() public { + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + _accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Proof doesn't match target") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, "") + ); + } + + function test_failAccessControl_badIndexProvided_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](2); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + preValidationHookData[1] = PreValidationHookData({index: 1, validationData: abi.encodePacked(_counter)}); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector) + ); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + } + + //todo: index out of order failure case with 2 pre hooks + + function test_failAccessControl_badTarget_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(beneficiary)}); + + vm.prank(owner1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + _accessControlHookPlugin, + uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK), + abi.encodeWithSignature("Error(string)", "Target not allowed") + ) + ); + account1.executeWithAuthorization( + abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 1 wei, "")), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + } + + function test_failPerHookData_nonCanonicalEncoding_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: ""}); + + vm.prank(owner1); + vm.expectRevert(abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector)); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, preValidationHookData, "") + ); + } + + function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) { + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall( + UpgradeableModularAccount.execute, (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), + preVerificationGas: 0, + gasFees: _encodeGas(1, 1), + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + + return (userOp, userOpHash); + } +} diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index 46ea1d5f..3f386a60 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -87,7 +87,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -116,7 +117,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -142,7 +144,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -168,7 +171,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -196,7 +200,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -227,7 +232,8 @@ contract UpgradeableModularAccountTest is AccountTestBase { // Generate signature bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = abi.encodePacked(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, r, s, v); + userOp.signature = + _encodeSignature(ownerValidation, SELECTOR_ASSOCIATED_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; diff --git a/test/account/ValidationIntersection.t.sol b/test/account/ValidationIntersection.t.sol index 6f451a16..faa24074 100644 --- a/test/account/ValidationIntersection.t.sol +++ b/test/account/ValidationIntersection.t.sol @@ -107,7 +107,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(noHookPlugin.foo.selector); - userOp.signature = abi.encodePacked(noHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(noHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -124,7 +124,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -142,7 +142,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -165,7 +165,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -187,7 +187,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -207,7 +207,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -232,7 +232,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -256,7 +256,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(oneHookPlugin.bar.selector); - userOp.signature = abi.encodePacked(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(oneHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -280,7 +280,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(twoHookPlugin.baz.selector); - userOp.signature = abi.encodePacked(twoHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(twoHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); @@ -299,7 +299,7 @@ contract ValidationIntersectionTest is AccountTestBase { PackedUserOperation memory userOp; userOp.callData = bytes.concat(twoHookPlugin.baz.selector); - userOp.signature = abi.encodePacked(twoHookValidation, SELECTOR_ASSOCIATED_VALIDATION); + userOp.signature = _encodeSignature(twoHookValidation, SELECTOR_ASSOCIATED_VALIDATION, ""); bytes32 uoHash = entryPoint.getUserOpHash(userOp); vm.prank(address(entryPoint)); diff --git a/test/libraries/SparseCalldataSegmentLib.t.sol b/test/libraries/SparseCalldataSegmentLib.t.sol new file mode 100644 index 00000000..7edea4e4 --- /dev/null +++ b/test/libraries/SparseCalldataSegmentLib.t.sol @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {Test} from "forge-std/Test.sol"; + +import {SparseCalldataSegmentLib} from "../../src/helpers/SparseCalldataSegmentLib.sol"; + +contract SparseCalldataSegmentLibTest is Test { + using SparseCalldataSegmentLib for bytes; + + function testFuzz_sparseCalldataSegmentLib_encodeDecode_simple(bytes[] memory segments) public { + bytes memory encoded = _encodeSimple(segments); + bytes[] memory decoded = this.decodeSimple(encoded, segments.length); + + assertEq(decoded.length, segments.length, "decoded.length != segments.length"); + + for (uint256 i = 0; i < segments.length; i++) { + assertEq(decoded[i], segments[i]); + } + } + + function testFuzz_sparseCalldataSegmentLib_encodeDecode_withIndex(bytes[] memory segments, uint256 indexSeed) + public + { + // Generate random indices + uint8[] memory indices = new uint8[](segments.length); + for (uint256 i = 0; i < segments.length; i++) { + uint8 nextIndex = uint8(uint256(keccak256(abi.encodePacked(indexSeed, i)))); + indices[i] = nextIndex; + } + + // Encode + bytes memory encoded = _encodeWithIndex(segments, indices); + + // Decode + (bytes[] memory decodedBodies, uint8[] memory decodedIndices) = + this.decodeWithIndex(encoded, segments.length); + + assertEq(decodedBodies.length, segments.length, "decodedBodies.length != segments.length"); + assertEq(decodedIndices.length, segments.length, "decodedIndices.length != segments.length"); + + for (uint256 i = 0; i < segments.length; i++) { + assertEq(decodedBodies[i], segments[i]); + assertEq(decodedIndices[i], indices[i]); + } + } + + function _encodeSimple(bytes[] memory segments) internal pure returns (bytes memory) { + bytes memory result = ""; + + for (uint256 i = 0; i < segments.length; i++) { + result = abi.encodePacked(result, uint32(segments[i].length), segments[i]); + } + + return result; + } + + function _encodeWithIndex(bytes[] memory segments, uint8[] memory indices) + internal + pure + returns (bytes memory) + { + require(segments.length == indices.length, "segments len != indices len"); + + bytes memory result = ""; + + for (uint256 i = 0; i < segments.length; i++) { + result = abi.encodePacked(result, uint32(segments[i].length + 1), indices[i], segments[i]); + } + + return result; + } + + function decodeSimple(bytes calldata encoded, uint256 capacityHint) external pure returns (bytes[] memory) { + bytes[] memory result = new bytes[](capacityHint); + + bytes calldata remainder = encoded; + + uint256 index = 0; + while (remainder.length > 0) { + bytes calldata segment; + (segment, remainder) = remainder.getNextSegment(); + result[index] = segment; + index++; + } + + return result; + } + + function decodeWithIndex(bytes calldata encoded, uint256 capacityHint) + external + pure + returns (bytes[] memory, uint8[] memory) + { + bytes[] memory bodies = new bytes[](capacityHint); + uint8[] memory indices = new uint8[](capacityHint); + + bytes calldata remainder = encoded; + + uint256 index = 0; + while (remainder.length > 0) { + bytes calldata segment; + (segment, remainder) = remainder.getNextSegment(); + bodies[index] = segment.getBody(); + indices[index] = segment.getIndex(); + index++; + } + + return (bodies, indices); + } +} diff --git a/test/mocks/plugins/ComprehensivePlugin.sol b/test/mocks/plugins/ComprehensivePlugin.sol index 6ef654c7..4062218b 100644 --- a/test/mocks/plugins/ComprehensivePlugin.sol +++ b/test/mocks/plugins/ComprehensivePlugin.sol @@ -74,7 +74,11 @@ contract ComprehensivePlugin is IValidation, IValidationHook, IExecutionHook, Ba revert NotImplemented(); } - function preRuntimeValidationHook(uint8 functionId, address, uint256, bytes calldata) external pure override { + function preRuntimeValidationHook(uint8 functionId, address, uint256, bytes calldata, bytes calldata) + external + pure + override + { if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK_1)) { return; } else if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK_2)) { diff --git a/test/mocks/plugins/MockAccessControlHookPlugin.sol b/test/mocks/plugins/MockAccessControlHookPlugin.sol new file mode 100644 index 00000000..c17868a8 --- /dev/null +++ b/test/mocks/plugins/MockAccessControlHookPlugin.sol @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {PluginMetadata, PluginManifest} from "../../../src/interfaces/IPlugin.sol"; +import {IValidationHook} from "../../../src/interfaces/IValidationHook.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {BasePlugin} from "../../../src/plugins/BasePlugin.sol"; + +// A pre validaiton hook plugin that uses per-hook data. +// This example enforces that the target of an `execute` call must only be the previously specified address. +// This is just a mock - it does not enforce this over `executeBatch` and other methods of making calls, and should +// not be used in production.. +contract MockAccessControlHookPlugin is IValidationHook, BasePlugin { + enum FunctionId { + PRE_VALIDATION_HOOK + } + + mapping(address account => address allowedTarget) public allowedTargets; + + function onInstall(bytes calldata data) external override { + address allowedTarget = abi.decode(data, (address)); + allowedTargets[msg.sender] = allowedTarget; + } + + function onUninstall(bytes calldata) external override { + delete allowedTargets[msg.sender]; + } + + function preUserOpValidationHook(uint8 functionId, PackedUserOperation calldata userOp, bytes32) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + if (bytes4(userOp.callData[:4]) == IStandardExecutor.execute.selector) { + address target = abi.decode(userOp.callData[4:36], (address)); + + // Simulate a merkle proof - require that the target address is also provided in the signature + address proof = address(bytes20(userOp.signature)); + require(proof == target, "Proof doesn't match target"); + require(target == allowedTargets[msg.sender], "Target not allowed"); + return 0; + } + } + revert NotImplemented(); + } + + function preRuntimeValidationHook( + uint8 functionId, + address, + uint256, + bytes calldata data, + bytes calldata authorization + ) external view override { + if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) { + if (bytes4(data[:4]) == IStandardExecutor.execute.selector) { + address target = abi.decode(data[4:36], (address)); + + // Simulate a merkle proof - require that the target address is also provided in the authorization + // data + address proof = address(bytes20(authorization)); + require(proof == target, "Proof doesn't match target"); + require(target == allowedTargets[msg.sender], "Target not allowed"); + + return; + } + } + + revert NotImplemented(); + } + + function pluginMetadata() external pure override returns (PluginMetadata memory) {} + + function pluginManifest() external pure override returns (PluginManifest memory) {} +} diff --git a/test/mocks/plugins/ReturnDataPluginMocks.sol b/test/mocks/plugins/ReturnDataPluginMocks.sol index 485b8456..c0fc0cfe 100644 --- a/test/mocks/plugins/ReturnDataPluginMocks.sol +++ b/test/mocks/plugins/ReturnDataPluginMocks.sol @@ -101,7 +101,8 @@ contract ResultConsumerPlugin is BasePlugin, IValidation { // This result should be allowed based on the manifest permission request bytes memory returnData = IStandardExecutor(msg.sender).executeWithAuthorization( abi.encodeCall(IStandardExecutor.execute, (target, 0, abi.encodeCall(RegularResultContract.foo, ()))), - abi.encodePacked(this, uint8(0), uint8(0)) // Validation function of self, selector-associated + abi.encodePacked(this, uint8(0), uint8(0), uint32(1), uint8(255)) // Validation function of self, + // selector-associated, with no auth data ); bytes32 actual = abi.decode(abi.decode(returnData, (bytes)), (bytes32)); diff --git a/test/mocks/plugins/ValidationPluginMocks.sol b/test/mocks/plugins/ValidationPluginMocks.sol index d5f75e99..f6ed4a5f 100644 --- a/test/mocks/plugins/ValidationPluginMocks.sol +++ b/test/mocks/plugins/ValidationPluginMocks.sol @@ -67,7 +67,11 @@ abstract contract MockBaseUserOpValidationPlugin is IValidation, IValidationHook // Empty stubs function pluginMetadata() external pure override returns (PluginMetadata memory) {} - function preRuntimeValidationHook(uint8, address, uint256, bytes calldata) external pure override { + function preRuntimeValidationHook(uint8, address, uint256, bytes calldata, bytes calldata) + external + pure + override + { revert NotImplemented(); } diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 059e9cac..736b6041 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.19; import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol"; import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; @@ -14,6 +15,8 @@ import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; /// @dev This contract handles common boilerplate setup for tests using UpgradeableModularAccount with /// SingleOwnerPlugin. abstract contract AccountTestBase is OptimizedTest { + using FunctionReferenceLib for FunctionReference; + EntryPoint public entryPoint; address payable public beneficiary; SingleOwnerPlugin public singleOwnerPlugin; @@ -26,6 +29,11 @@ abstract contract AccountTestBase is OptimizedTest { uint8 public constant SELECTOR_ASSOCIATED_VALIDATION = 0; uint8 public constant DEFAULT_VALIDATION = 1; + struct PreValidationHookData { + uint8 index; + bytes validationData; + } + constructor() { entryPoint = new EntryPoint(); (owner1, owner1Key) = makeAddrAndKey("owner1"); @@ -50,10 +58,12 @@ abstract contract AccountTestBase is OptimizedTest { abi.encodeCall(SingleOwnerPlugin.transferOwnership, (address(this))) ) ), - abi.encodePacked( - address(singleOwnerPlugin), - ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER, - SELECTOR_ASSOCIATED_VALIDATION + _encodeSignature( + FunctionReferenceLib.pack( + address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER) + ), + SELECTOR_ASSOCIATED_VALIDATION, + "" ) ); } @@ -62,4 +72,48 @@ abstract contract AccountTestBase is OptimizedTest { function _encodeGas(uint256 g1, uint256 g2) internal pure returns (bytes32) { return bytes32(uint256((g1 << 128) + uint128(g2))); } + + // helper function to encode a signature, according to the per-hook and per-validation data format. + function _encodeSignature( + FunctionReference validationFunction, + uint8 defaultOrNot, + PreValidationHookData[] memory preValidationHookData, + bytes memory validationData + ) internal pure returns (bytes memory) { + bytes memory sig = abi.encodePacked(validationFunction, defaultOrNot); + + for (uint256 i = 0; i < preValidationHookData.length; ++i) { + sig = abi.encodePacked( + sig, + _packValidationDataWithIndex( + preValidationHookData[i].index, preValidationHookData[i].validationData + ) + ); + } + + // Index of the actual validation data is the length of the preValidationHooksRetrieved - aka + // one-past-the-end + sig = abi.encodePacked(sig, _packValidationDataWithIndex(255, validationData)); + + return sig; + } + + // overload for the case where there are no pre-validation hooks + function _encodeSignature( + FunctionReference validationFunction, + uint8 defaultOrNot, + bytes memory validationData + ) internal pure returns (bytes memory) { + PreValidationHookData[] memory emptyPreValidationHookData = new PreValidationHookData[](0); + return _encodeSignature(validationFunction, defaultOrNot, emptyPreValidationHookData, validationData); + } + + // helper function to pack validation data with an index, according to the sparse calldata segment spec. + function _packValidationDataWithIndex(uint8 index, bytes memory validationData) + internal + pure + returns (bytes memory) + { + return abi.encodePacked(uint32(validationData.length + 1), index, validationData); + } }