diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 410c8610..04c133e3 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -19,7 +19,7 @@ import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol" import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol"; import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationResHelpers.sol"; -import {DIRECT_CALL_VALIDATION_ENTITYID, RESERVED_VALIDATION_DATA_INDEX} from "../helpers/Constants.sol"; +import {DIRECT_CALL_VALIDATION_ENTITYID} from "../helpers/Constants.sol"; import {IExecutionHookModule} from "../interfaces/IExecutionHookModule.sol"; import {ExecutionManifest} from "../interfaces/IExecutionModule.sol"; @@ -67,7 +67,6 @@ contract UpgradeableModularAccount is bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e; bytes4 internal constant _1271_INVALID = 0xffffffff; - error NonCanonicalEncoding(); error NotEntryPoint(); error PostExecHookReverted(address module, uint32 entityId, bytes revertReason); error PreExecHookReverted(address module, uint32 entityId, bytes revertReason); @@ -79,8 +78,6 @@ contract UpgradeableModularAccount is error UnexpectedAggregator(address module, uint32 entityId, address aggregator); error UnrecognizedFunction(bytes4 selector); error ValidationFunctionMissing(bytes4 selector); - error ValidationSignatureSegmentMissing(); - error SignatureSegmentOutOfOrder(); // Wraps execution of a native function with runtime validation and hooks // Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installExecution, uninstallExecution @@ -347,10 +344,6 @@ contract UpgradeableModularAccount is bytes calldata signature, bytes32 userOpHash ) internal returns (uint256) { - // Set up the per-hook data tracking fields - bytes calldata signatureSegment; - (signatureSegment, signature) = signature.getNextSegment(); - uint256 validationRes; // Do preUserOpValidation hooks @@ -358,25 +351,7 @@ contract UpgradeableModularAccount is getAccountStorage().validationData[userOpValidationFunction].preValidationHooks; for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) { - // Load per-hook data, if any is present - // The segment index is the first byte of the signature - if (signatureSegment.getIndex() == i) { - // Use the current segment - userOp.signature = signatureSegment.getBody(); - - if (userOp.signature.length == 0) { - revert NonCanonicalEncoding(); - } - - // Load the next per-hook data segment - (signatureSegment, signature) = signature.getNextSegment(); - - if (signatureSegment.getIndex() <= i) { - revert SignatureSegmentOutOfOrder(); - } - } else { - userOp.signature = ""; - } + (userOp.signature, signature) = signature.advanceSegmentIfAtIndex(uint8(i)); (address module, uint32 entityId) = preUserOpValidationHooks[i].unpack(); uint256 currentValidationRes = @@ -389,13 +364,9 @@ contract UpgradeableModularAccount is validationRes = _coalescePreValidation(validationRes, currentValidationRes); } - // Run the user op validationFunction + // Run the user op validation function { - if (signatureSegment.getIndex() != RESERVED_VALIDATION_DATA_INDEX) { - revert ValidationSignatureSegmentMissing(); - } - - userOp.signature = signatureSegment.getBody(); + userOp.signature = signature.getFinalSegment(); uint256 currentValidationRes = _execUserOpValidation(userOpValidationFunction, userOp, userOpHash); @@ -415,42 +386,21 @@ contract UpgradeableModularAccount is bytes calldata callData, bytes calldata authorizationData ) internal { - // Set up the per-hook data tracking fields - bytes calldata authSegment; - (authSegment, authorizationData) = authorizationData.getNextSegment(); - // run all preRuntimeValidation hooks ModuleEntity[] memory preRuntimeValidationHooks = getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks; for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) { - bytes memory currentAuthData; - - if (authSegment.getIndex() == i) { - // Use the current segment - currentAuthData = authSegment.getBody(); - - if (currentAuthData.length == 0) { - revert NonCanonicalEncoding(); - } + bytes memory currentAuthSegment; - // Load the next per-hook data segment - (authSegment, authorizationData) = authorizationData.getNextSegment(); + (currentAuthSegment, authorizationData) = authorizationData.advanceSegmentIfAtIndex(uint8(i)); - if (authSegment.getIndex() <= i) { - revert SignatureSegmentOutOfOrder(); - } - } else { - currentAuthData = ""; - } - _doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthData); + _doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthSegment); } - if (authSegment.getIndex() != RESERVED_VALIDATION_DATA_INDEX) { - revert ValidationSignatureSegmentMissing(); - } + authorizationData = authorizationData.getFinalSegment(); - _execRuntimeValidation(runtimeValidationFunction, callData, authSegment.getBody()); + _execRuntimeValidation(runtimeValidationFunction, callData, authorizationData); } function _doPreHooks(EnumerableSet.Bytes32Set storage executionHooks, bytes memory data) diff --git a/src/helpers/SparseCalldataSegmentLib.sol b/src/helpers/SparseCalldataSegmentLib.sol index 0a6cc541..e6711c86 100644 --- a/src/helpers/SparseCalldataSegmentLib.sol +++ b/src/helpers/SparseCalldataSegmentLib.sol @@ -1,12 +1,18 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity ^0.8.25; +import {RESERVED_VALIDATION_DATA_INDEX} from "./Constants.sol"; + /// @title Sparse Calldata Segment Library /// @notice Library for working with sparsely-packed calldata segments, identified with an index. /// @dev The first byte of each segment is the index of the segment. /// To prevent accidental stack-to-deep errors, the body and index of the segment are extracted separately, rather /// than inline as part of the tuple returned by `getNextSegment`. library SparseCalldataSegmentLib { + error NonCanonicalEncoding(); + error SegmentOutOfOrder(); + error ValidationSignatureSegmentMissing(); + /// @notice Splits out a segment of calldata, sparsely-packed. /// The expected format is: /// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN] @@ -33,6 +39,61 @@ library SparseCalldataSegmentLib { remainder = source[remainderOffset:]; } + /// @notice If the index of the next segment in the source equals the provided index, return the next body and + /// advance the source by one segment. + /// @dev Reverts if the index of the next segment is less than the provided index, or if the extracted segment + /// has length 0. + /// @param source The calldata to extract the segment from. + /// @param index The index of the segment to extract. + /// @return A tuple containing the extracted segment's body, or an empty buffer if the index is not found, and + /// the remaining calldata. + function advanceSegmentIfAtIndex(bytes calldata source, uint8 index) + internal + pure + returns (bytes memory, bytes calldata) + { + uint8 nextIndex = peekIndex(source); + + if (nextIndex < index) { + revert SegmentOutOfOrder(); + } + + if (nextIndex == index) { + (bytes calldata segment, bytes calldata remainder) = getNextSegment(source); + + segment = getBody(segment); + + if (segment.length == 0) { + revert NonCanonicalEncoding(); + } + + return (segment, remainder); + } + + return ("", source); + } + + function getFinalSegment(bytes calldata source) internal pure returns (bytes calldata) { + (bytes calldata segment, bytes calldata remainder) = getNextSegment(source); + + if (getIndex(segment) != RESERVED_VALIDATION_DATA_INDEX) { + revert ValidationSignatureSegmentMissing(); + } + + if (remainder.length != 0) { + revert NonCanonicalEncoding(); + } + + return getBody(segment); + } + + /// @notice Returns the index of the next segment in the source. + /// @param source The calldata to extract the index from. + /// @return The index of the next segment. + function peekIndex(bytes calldata source) internal pure returns (uint8) { + return uint8(source[4]); + } + /// @notice Extracts the index from a segment. /// @dev The first byte of the segment is the index. /// @param segment The segment to extract the index from diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol index 2be42f86..83941df9 100644 --- a/test/account/PerHookData.t.sol +++ b/test/account/PerHookData.t.sol @@ -9,6 +9,7 @@ import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAcc import {HookConfigLib} from "../../src/helpers/HookConfigLib.sol"; import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; +import {SparseCalldataSegmentLib} from "../../src/helpers/SparseCalldataSegmentLib.sol"; import {Counter} from "../mocks/Counter.sol"; import {MockAccessControlHookModule} from "../mocks/modules/MockAccessControlHookModule.sol"; @@ -123,7 +124,7 @@ contract PerHookDataTest is CustomValidationTestBase { IEntryPoint.FailedOpWithRevert.selector, 0, "AA23 reverted", - abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector) + abi.encodeWithSelector(SparseCalldataSegmentLib.ValidationSignatureSegmentMissing.selector) ) ); entryPoint.handleOps(userOps, beneficiary); @@ -187,7 +188,35 @@ contract PerHookDataTest is CustomValidationTestBase { IEntryPoint.FailedOpWithRevert.selector, 0, "AA23 reverted", - abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector) + abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector) + ) + ); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_failPerHookData_excessData_userOp() public { + (PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP(); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + + userOp.signature = abi.encodePacked( + _encodeSignature( + _signerValidation, GLOBAL_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v) + ), + "extra data" + ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert( + abi.encodeWithSelector( + IEntryPoint.FailedOpWithRevert.selector, + 0, + "AA23 reverted", + abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector) ) ); entryPoint.handleOps(userOps, beneficiary); @@ -262,7 +291,7 @@ contract PerHookDataTest is CustomValidationTestBase { vm.prank(owner1); vm.expectRevert( - abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector) + abi.encodeWithSelector(SparseCalldataSegmentLib.ValidationSignatureSegmentMissing.selector) ); account1.executeWithAuthorization( abi.encodeCall( @@ -299,7 +328,7 @@ contract PerHookDataTest is CustomValidationTestBase { preValidationHookData[0] = PreValidationHookData({index: 0, validationData: ""}); vm.prank(owner1); - vm.expectRevert(abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector)); + vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)); account1.executeWithAuthorization( abi.encodeCall( UpgradeableModularAccount.execute, @@ -309,6 +338,23 @@ contract PerHookDataTest is CustomValidationTestBase { ); } + function test_failPerHookData_excessData_runtime() public { + PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); + preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)}); + + vm.prank(owner1); + vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)); + account1.executeWithAuthorization( + abi.encodeCall( + UpgradeableModularAccount.execute, + (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ())) + ), + abi.encodePacked( + _encodeSignature(_signerValidation, GLOBAL_VALIDATION, preValidationHookData, ""), "extra data" + ) + ); + } + function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) { PackedUserOperation memory userOp = PackedUserOperation({ sender: address(account1),