Skip to content

Commit

Permalink
refactor: merge segment collection logic
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed committed Aug 20, 2024
1 parent 98507f8 commit d042a7d
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 63 deletions.
68 changes: 9 additions & 59 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol"
import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol";
import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationResHelpers.sol";

import {DIRECT_CALL_VALIDATION_ENTITYID, RESERVED_VALIDATION_DATA_INDEX} from "../helpers/Constants.sol";
import {DIRECT_CALL_VALIDATION_ENTITYID} from "../helpers/Constants.sol";

import {IExecutionHookModule} from "../interfaces/IExecutionHookModule.sol";
import {ExecutionManifest} from "../interfaces/IExecutionModule.sol";
Expand Down Expand Up @@ -67,7 +67,6 @@ contract UpgradeableModularAccount is
bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e;
bytes4 internal constant _1271_INVALID = 0xffffffff;

error NonCanonicalEncoding();
error NotEntryPoint();
error PostExecHookReverted(address module, uint32 entityId, bytes revertReason);
error PreExecHookReverted(address module, uint32 entityId, bytes revertReason);
Expand All @@ -79,8 +78,6 @@ contract UpgradeableModularAccount is
error UnexpectedAggregator(address module, uint32 entityId, address aggregator);
error UnrecognizedFunction(bytes4 selector);
error ValidationFunctionMissing(bytes4 selector);
error ValidationSignatureSegmentMissing();
error SignatureSegmentOutOfOrder();

// Wraps execution of a native function with runtime validation and hooks
// Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installExecution, uninstallExecution
Expand Down Expand Up @@ -347,36 +344,14 @@ contract UpgradeableModularAccount is
bytes calldata signature,
bytes32 userOpHash
) internal returns (uint256) {
// Set up the per-hook data tracking fields
bytes calldata signatureSegment;
(signatureSegment, signature) = signature.getNextSegment();

uint256 validationRes;

// Do preUserOpValidation hooks
ModuleEntity[] memory preUserOpValidationHooks =
getAccountStorage().validationData[userOpValidationFunction].preValidationHooks;

for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) {
// Load per-hook data, if any is present
// The segment index is the first byte of the signature
if (signatureSegment.getIndex() == i) {
// Use the current segment
userOp.signature = signatureSegment.getBody();

if (userOp.signature.length == 0) {
revert NonCanonicalEncoding();
}

// Load the next per-hook data segment
(signatureSegment, signature) = signature.getNextSegment();

if (signatureSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
userOp.signature = "";
}
(userOp.signature, signature) = signature.advanceSegmentIfAtIndex(uint8(i));

(address module, uint32 entityId) = preUserOpValidationHooks[i].unpack();
uint256 currentValidationRes =
Expand All @@ -389,13 +364,9 @@ contract UpgradeableModularAccount is
validationRes = _coalescePreValidation(validationRes, currentValidationRes);
}

// Run the user op validationFunction
// Run the user op validation function
{
if (signatureSegment.getIndex() != RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

userOp.signature = signatureSegment.getBody();
userOp.signature = signature.getFinalSegment();

uint256 currentValidationRes = _execUserOpValidation(userOpValidationFunction, userOp, userOpHash);

Expand All @@ -415,42 +386,21 @@ contract UpgradeableModularAccount is
bytes calldata callData,
bytes calldata authorizationData
) internal {
// Set up the per-hook data tracking fields
bytes calldata authSegment;
(authSegment, authorizationData) = authorizationData.getNextSegment();

// run all preRuntimeValidation hooks
ModuleEntity[] memory preRuntimeValidationHooks =
getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks;

for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) {
bytes memory currentAuthData;

if (authSegment.getIndex() == i) {
// Use the current segment
currentAuthData = authSegment.getBody();

if (currentAuthData.length == 0) {
revert NonCanonicalEncoding();
}
bytes memory currentAuthSegment;

// Load the next per-hook data segment
(authSegment, authorizationData) = authorizationData.getNextSegment();
(currentAuthSegment, authorizationData) = authorizationData.advanceSegmentIfAtIndex(uint8(i));

if (authSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
currentAuthData = "";
}
_doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthData);
_doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthSegment);
}

if (authSegment.getIndex() != RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}
authorizationData = authorizationData.getFinalSegment();

_execRuntimeValidation(runtimeValidationFunction, callData, authSegment.getBody());
_execRuntimeValidation(runtimeValidationFunction, callData, authorizationData);
}

function _doPreHooks(EnumerableSet.Bytes32Set storage executionHooks, bytes memory data)
Expand Down
61 changes: 61 additions & 0 deletions src/helpers/SparseCalldataSegmentLib.sol
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

import {RESERVED_VALIDATION_DATA_INDEX} from "./Constants.sol";

/// @title Sparse Calldata Segment Library
/// @notice Library for working with sparsely-packed calldata segments, identified with an index.
/// @dev The first byte of each segment is the index of the segment.
/// To prevent accidental stack-to-deep errors, the body and index of the segment are extracted separately, rather
/// than inline as part of the tuple returned by `getNextSegment`.
library SparseCalldataSegmentLib {
error NonCanonicalEncoding();
error SegmentOutOfOrder();
error ValidationSignatureSegmentMissing();

/// @notice Splits out a segment of calldata, sparsely-packed.
/// The expected format is:
/// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN]
Expand All @@ -33,6 +39,61 @@ library SparseCalldataSegmentLib {
remainder = source[remainderOffset:];
}

/// @notice If the index of the next segment in the source equals the provided index, return the next body and
/// advance the source by one segment.
/// @dev Reverts if the index of the next segment is less than the provided index, or if the extracted segment
/// has length 0.
/// @param source The calldata to extract the segment from.
/// @param index The index of the segment to extract.
/// @return A tuple containing the extracted segment's body, or an empty buffer if the index is not found, and
/// the remaining calldata.
function advanceSegmentIfAtIndex(bytes calldata source, uint8 index)
internal
pure
returns (bytes memory, bytes calldata)
{
uint8 nextIndex = peekIndex(source);

if (nextIndex < index) {
revert SegmentOutOfOrder();
}

if (nextIndex == index) {
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);

segment = getBody(segment);

if (segment.length == 0) {
revert NonCanonicalEncoding();
}

return (segment, remainder);
}

return ("", source);
}

function getFinalSegment(bytes calldata source) internal pure returns (bytes calldata) {
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);

if (getIndex(segment) != RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

if (remainder.length != 0) {
revert NonCanonicalEncoding();
}

return getBody(segment);
}

/// @notice Returns the index of the next segment in the source.
/// @param source The calldata to extract the index from.
/// @return The index of the next segment.
function peekIndex(bytes calldata source) internal pure returns (uint8) {
return uint8(source[4]);
}

/// @notice Extracts the index from a segment.
/// @dev The first byte of the segment is the index.
/// @param segment The segment to extract the index from
Expand Down
54 changes: 50 additions & 4 deletions test/account/PerHookData.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAcc

import {HookConfigLib} from "../../src/helpers/HookConfigLib.sol";
import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol";
import {SparseCalldataSegmentLib} from "../../src/helpers/SparseCalldataSegmentLib.sol";

import {Counter} from "../mocks/Counter.sol";
import {MockAccessControlHookModule} from "../mocks/modules/MockAccessControlHookModule.sol";
Expand Down Expand Up @@ -123,7 +124,7 @@ contract PerHookDataTest is CustomValidationTestBase {
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector)
abi.encodeWithSelector(SparseCalldataSegmentLib.ValidationSignatureSegmentMissing.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
Expand Down Expand Up @@ -187,7 +188,35 @@ contract PerHookDataTest is CustomValidationTestBase {
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector)
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
}

function test_failPerHookData_excessData_userOp() public {
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});

userOp.signature = abi.encodePacked(
_encodeSignature(
_signerValidation, GLOBAL_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)
),
"extra data"
);

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

vm.expectRevert(
abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
Expand Down Expand Up @@ -262,7 +291,7 @@ contract PerHookDataTest is CustomValidationTestBase {

vm.prank(owner1);
vm.expectRevert(
abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector)
abi.encodeWithSelector(SparseCalldataSegmentLib.ValidationSignatureSegmentMissing.selector)
);
account1.executeWithAuthorization(
abi.encodeCall(
Expand Down Expand Up @@ -299,7 +328,7 @@ contract PerHookDataTest is CustomValidationTestBase {
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: ""});

vm.prank(owner1);
vm.expectRevert(abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector));
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
account1.executeWithAuthorization(
abi.encodeCall(
UpgradeableModularAccount.execute,
Expand All @@ -309,6 +338,23 @@ contract PerHookDataTest is CustomValidationTestBase {
);
}

function test_failPerHookData_excessData_runtime() public {
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});

vm.prank(owner1);
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
account1.executeWithAuthorization(
abi.encodeCall(
UpgradeableModularAccount.execute,
(address(_counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
abi.encodePacked(
_encodeSignature(_signerValidation, GLOBAL_VALIDATION, preValidationHookData, ""), "extra data"
)
);
}

function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) {
PackedUserOperation memory userOp = PackedUserOperation({
sender: address(account1),
Expand Down

0 comments on commit d042a7d

Please sign in to comment.