Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: merge segment collection logic [1/2] #143

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 9 additions & 59 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol"
import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol";
import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationResHelpers.sol";

import {DIRECT_CALL_VALIDATION_ENTITYID, RESERVED_VALIDATION_DATA_INDEX} from "../helpers/Constants.sol";
import {DIRECT_CALL_VALIDATION_ENTITYID} from "../helpers/Constants.sol";

import {IExecutionHookModule} from "../interfaces/IExecutionHookModule.sol";
import {ExecutionManifest} from "../interfaces/IExecutionModule.sol";
Expand Down Expand Up @@ -67,7 +67,6 @@ contract UpgradeableModularAccount is
bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e;
bytes4 internal constant _1271_INVALID = 0xffffffff;

error NonCanonicalEncoding();
error NotEntryPoint();
error PostExecHookReverted(address module, uint32 entityId, bytes revertReason);
error PreExecHookReverted(address module, uint32 entityId, bytes revertReason);
Expand All @@ -79,8 +78,6 @@ contract UpgradeableModularAccount is
error UnexpectedAggregator(address module, uint32 entityId, address aggregator);
error UnrecognizedFunction(bytes4 selector);
error ValidationFunctionMissing(bytes4 selector);
error ValidationSignatureSegmentMissing();
error SignatureSegmentOutOfOrder();

// Wraps execution of a native function with runtime validation and hooks
// Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installExecution, uninstallExecution
Expand Down Expand Up @@ -347,36 +344,14 @@ contract UpgradeableModularAccount is
bytes calldata signature,
bytes32 userOpHash
) internal returns (uint256) {
// Set up the per-hook data tracking fields
bytes calldata signatureSegment;
(signatureSegment, signature) = signature.getNextSegment();

uint256 validationRes;

// Do preUserOpValidation hooks
ModuleEntity[] memory preUserOpValidationHooks =
getAccountStorage().validationData[userOpValidationFunction].preValidationHooks;

for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) {
// Load per-hook data, if any is present
// The segment index is the first byte of the signature
if (signatureSegment.getIndex() == i) {
// Use the current segment
userOp.signature = signatureSegment.getBody();

if (userOp.signature.length == 0) {
revert NonCanonicalEncoding();
}

// Load the next per-hook data segment
(signatureSegment, signature) = signature.getNextSegment();

if (signatureSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
userOp.signature = "";
}
(userOp.signature, signature) = signature.advanceSegmentIfAtIndex(uint8(i));

(address module, uint32 entityId) = preUserOpValidationHooks[i].unpack();
uint256 currentValidationRes =
Expand All @@ -389,13 +364,9 @@ contract UpgradeableModularAccount is
validationRes = _coalescePreValidation(validationRes, currentValidationRes);
}

// Run the user op validationFunction
// Run the user op validation function
{
if (signatureSegment.getIndex() != RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

userOp.signature = signatureSegment.getBody();
userOp.signature = signature.getFinalSegment();

uint256 currentValidationRes = _execUserOpValidation(userOpValidationFunction, userOp, userOpHash);

Expand All @@ -415,42 +386,21 @@ contract UpgradeableModularAccount is
bytes calldata callData,
bytes calldata authorizationData
) internal {
// Set up the per-hook data tracking fields
bytes calldata authSegment;
(authSegment, authorizationData) = authorizationData.getNextSegment();

// run all preRuntimeValidation hooks
ModuleEntity[] memory preRuntimeValidationHooks =
getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks;

for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) {
bytes memory currentAuthData;

if (authSegment.getIndex() == i) {
// Use the current segment
currentAuthData = authSegment.getBody();

if (currentAuthData.length == 0) {
revert NonCanonicalEncoding();
}
bytes memory currentAuthSegment;

// Load the next per-hook data segment
(authSegment, authorizationData) = authorizationData.getNextSegment();
(currentAuthSegment, authorizationData) = authorizationData.advanceSegmentIfAtIndex(uint8(i));

if (authSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
currentAuthData = "";
}
_doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthData);
_doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthSegment);
}

if (authSegment.getIndex() != RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}
authorizationData = authorizationData.getFinalSegment();

_execRuntimeValidation(runtimeValidationFunction, callData, authSegment.getBody());
_execRuntimeValidation(runtimeValidationFunction, callData, authorizationData);
}

function _doPreHooks(EnumerableSet.Bytes32Set storage executionHooks, bytes memory data)
Expand Down
61 changes: 61 additions & 0 deletions src/helpers/SparseCalldataSegmentLib.sol
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

import {RESERVED_VALIDATION_DATA_INDEX} from "./Constants.sol";

/// @title Sparse Calldata Segment Library
/// @notice Library for working with sparsely-packed calldata segments, identified with an index.
/// @dev The first byte of each segment is the index of the segment.
/// To prevent accidental stack-to-deep errors, the body and index of the segment are extracted separately, rather
/// than inline as part of the tuple returned by `getNextSegment`.
library SparseCalldataSegmentLib {
error NonCanonicalEncoding();
error SegmentOutOfOrder();
error ValidationSignatureSegmentMissing();

/// @notice Splits out a segment of calldata, sparsely-packed.
/// The expected format is:
/// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN]
Expand All @@ -33,6 +39,61 @@ library SparseCalldataSegmentLib {
remainder = source[remainderOffset:];
}

/// @notice If the index of the next segment in the source equals the provided index, return the next body and
/// advance the source by one segment.
/// @dev Reverts if the index of the next segment is less than the provided index, or if the extracted segment
/// has length 0.
/// @param source The calldata to extract the segment from.
/// @param index The index of the segment to extract.
/// @return A tuple containing the extracted segment's body, or an empty buffer if the index is not found, and
/// the remaining calldata.
function advanceSegmentIfAtIndex(bytes calldata source, uint8 index)
internal
pure
returns (bytes memory, bytes calldata)
{
uint8 nextIndex = peekIndex(source);

if (nextIndex < index) {
revert SegmentOutOfOrder();
}

if (nextIndex == index) {
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);

segment = getBody(segment);

if (segment.length == 0) {
revert NonCanonicalEncoding();
}

return (segment, remainder);
}

return ("", source);
}

function getFinalSegment(bytes calldata source) internal pure returns (bytes calldata) {
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);

if (getIndex(segment) != RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

if (remainder.length != 0) {
revert NonCanonicalEncoding();
}

return getBody(segment);
}

/// @notice Returns the index of the next segment in the source.
/// @param source The calldata to extract the index from.
/// @return The index of the next segment.
function peekIndex(bytes calldata source) internal pure returns (uint8) {
return uint8(source[4]);
}

/// @notice Extracts the index from a segment.
/// @dev The first byte of the segment is the index.
/// @param segment The segment to extract the index from
Expand Down
54 changes: 50 additions & 4 deletions test/account/PerHookData.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAcc

import {HookConfigLib} from "../../src/helpers/HookConfigLib.sol";
import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol";
import {SparseCalldataSegmentLib} from "../../src/helpers/SparseCalldataSegmentLib.sol";

import {Counter} from "../mocks/Counter.sol";
import {MockAccessControlHookModule} from "../mocks/modules/MockAccessControlHookModule.sol";
Expand Down Expand Up @@ -123,7 +124,7 @@ contract PerHookDataTest is CustomValidationTestBase {
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector)
abi.encodeWithSelector(SparseCalldataSegmentLib.ValidationSignatureSegmentMissing.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
Expand Down Expand Up @@ -187,7 +188,35 @@ contract PerHookDataTest is CustomValidationTestBase {
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector)
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
}

function test_failPerHookData_excessData_userOp() public {
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());

PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});

userOp.signature = abi.encodePacked(
_encodeSignature(
_signerValidation, GLOBAL_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)
),
"extra data"
);

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;

vm.expectRevert(
abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
0,
"AA23 reverted",
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
)
);
entryPoint.handleOps(userOps, beneficiary);
Expand Down Expand Up @@ -262,7 +291,7 @@ contract PerHookDataTest is CustomValidationTestBase {

vm.prank(owner1);
vm.expectRevert(
abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector)
abi.encodeWithSelector(SparseCalldataSegmentLib.ValidationSignatureSegmentMissing.selector)
);
account1.executeWithAuthorization(
abi.encodeCall(
Expand Down Expand Up @@ -299,7 +328,7 @@ contract PerHookDataTest is CustomValidationTestBase {
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: ""});

vm.prank(owner1);
vm.expectRevert(abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector));
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
account1.executeWithAuthorization(
abi.encodeCall(
UpgradeableModularAccount.execute,
Expand All @@ -309,6 +338,23 @@ contract PerHookDataTest is CustomValidationTestBase {
);
}

function test_failPerHookData_excessData_runtime() public {
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});

vm.prank(owner1);
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
account1.executeWithAuthorization(
abi.encodeCall(
UpgradeableModularAccount.execute,
(address(_counter), 0 wei, abi.encodeCall(Counter.increment, ()))
),
abi.encodePacked(
_encodeSignature(_signerValidation, GLOBAL_VALIDATION, preValidationHookData, ""), "extra data"
)
);
}

function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) {
PackedUserOperation memory userOp = PackedUserOperation({
sender: address(account1),
Expand Down
Loading