Skip to content

Commit

Permalink
feat: [v0.8-develop] per validation hook data (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed authored Jul 9, 2024
1 parent f63137b commit de28c60
Show file tree
Hide file tree
Showing 20 changed files with 839 additions and 100 deletions.
1 change: 1 addition & 0 deletions .solhint-test.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"immutable-vars-naming": ["error"],
"no-unused-import": ["error"],
"compiler-version": ["error", ">=0.8.19"],
"custom-errors": "off",
"func-visibility": ["error", { "ignoreConstructors": true }],
"max-line-length": ["error", 120],
"max-states-count": ["warn", 30],
Expand Down
3 changes: 1 addition & 2 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ abstract contract AccountLoupe is IAccountLoupe {
override
returns (FunctionReference[] memory preValidationHooks)
{
preValidationHooks =
toFunctionReferenceArray(getAccountStorage().validationData[validationFunction].preValidationHooks);
preValidationHooks = getAccountStorage().validationData[validationFunction].preValidationHooks;
}

/// @inheritdoc IAccountLoupe
Expand Down
2 changes: 1 addition & 1 deletion src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct ValidationData {
// How many execution hooks require the UO context.
uint8 requireUOHookCount;
// The pre validation hooks for this function selector.
EnumerableSet.Bytes32Set preValidationHooks;
FunctionReference[] preValidationHooks;
// Permission hooks for this validation function.
EnumerableSet.Bytes32Set permissionHooks;
}
Expand Down
36 changes: 22 additions & 14 deletions src/account/PluginManager2.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@ import {ExecutionHook} from "../interfaces/IAccountLoupe.sol";
abstract contract PluginManager2 {
using EnumerableSet for EnumerableSet.Bytes32Set;

// Index marking the start of the data for the validation function.
uint8 internal constant _RESERVED_VALIDATION_DATA_INDEX = 255;

error DefaultValidationAlreadySet(FunctionReference validationFunction);
error PreValidationAlreadySet(FunctionReference validationFunction, FunctionReference preValidationFunction);
error ValidationAlreadySet(bytes4 selector, FunctionReference validationFunction);
error ValidationNotSet(bytes4 selector, FunctionReference validationFunction);
error PermissionAlreadySet(FunctionReference validationFunction, ExecutionHook hook);
error PreValidationHookLimitExceeded();

function _installValidation(
FunctionReference validationFunction,
Expand All @@ -39,19 +43,21 @@ abstract contract PluginManager2 {
for (uint256 i = 0; i < preValidationFunctions.length; ++i) {
FunctionReference preValidationFunction = preValidationFunctions[i];

if (
!_storage.validationData[validationFunction].preValidationHooks.add(
toSetValue(preValidationFunction)
)
) {
revert PreValidationAlreadySet(validationFunction, preValidationFunction);
}
_storage.validationData[validationFunction].preValidationHooks.push(preValidationFunction);

if (initDatas[i].length > 0) {
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
IPlugin(preValidationPlugin).onInstall(initDatas[i]);
}
}

// Avoid collision between reserved index and actual indices
if (
_storage.validationData[validationFunction].preValidationHooks.length
> _RESERVED_VALIDATION_DATA_INDEX
) {
revert PreValidationHookLimitExceeded();
}
}

if (permissionHooks.length > 0) {
Expand Down Expand Up @@ -110,15 +116,16 @@ abstract contract PluginManager2 {
bytes[] memory preValidationHookUninstallDatas = abi.decode(preValidationHookUninstallData, (bytes[]));

// Clear pre validation hooks
EnumerableSet.Bytes32Set storage preValidationHooks =
FunctionReference[] storage preValidationHooks =
_storage.validationData[validationFunction].preValidationHooks;
uint256 i = 0;
while (preValidationHooks.length() > 0) {
FunctionReference preValidationFunction = toFunctionReference(preValidationHooks.at(0));
preValidationHooks.remove(toSetValue(preValidationFunction));
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
IPlugin(preValidationPlugin).onUninstall(preValidationHookUninstallDatas[i++]);
for (uint256 i = 0; i < preValidationHooks.length; ++i) {
FunctionReference preValidationFunction = preValidationHooks[i];
if (preValidationHookUninstallDatas[0].length > 0) {
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
IPlugin(preValidationPlugin).onUninstall(preValidationHookUninstallDatas[0]);
}
}
delete _storage.validationData[validationFunction].preValidationHooks;
}

{
Expand All @@ -135,6 +142,7 @@ abstract contract PluginManager2 {
IPlugin(permissionHookPlugin).onUninstall(permissionHookUninstallDatas[i++]);
}
}
delete _storage.validationData[validationFunction].preValidationHooks;

// Because this function also calls `onUninstall`, and removes the default flag from validation, we must
// assume these selectors passed in to be exhaustive.
Expand Down
113 changes: 78 additions & 35 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol";
import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol";
import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationDataHelpers.sol";
import {IPlugin, PluginManifest} from "../interfaces/IPlugin.sol";
import {IValidation} from "../interfaces/IValidation.sol";
Expand All @@ -20,13 +21,7 @@ import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.so
import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol";
import {AccountExecutor} from "./AccountExecutor.sol";
import {AccountLoupe} from "./AccountLoupe.sol";
import {
AccountStorage,
getAccountStorage,
toSetValue,
toFunctionReference,
toExecutionHook
} from "./AccountStorage.sol";
import {AccountStorage, getAccountStorage, toSetValue, toExecutionHook} from "./AccountStorage.sol";
import {AccountStorageInitializable} from "./AccountStorageInitializable.sol";
import {PluginManagerInternals} from "./PluginManagerInternals.sol";
import {PluginManager2} from "./PluginManager2.sol";
Expand All @@ -46,6 +41,7 @@ contract UpgradeableModularAccount is
{
using EnumerableSet for EnumerableSet.Bytes32Set;
using FunctionReferenceLib for FunctionReference;
using SparseCalldataSegmentLib for bytes;

struct PostExecToRun {
bytes preExecHookReturnData;
Expand All @@ -68,6 +64,7 @@ contract UpgradeableModularAccount is
error ExecFromPluginNotPermitted(address plugin, bytes4 selector);
error ExecFromPluginExternalNotPermitted(address plugin, address target, uint256 value, bytes data);
error NativeTokenSpendingNotPermitted(address plugin);
error NonCanonicalEncoding();
error NotEntryPoint();
error PostExecHookReverted(address plugin, uint8 functionId, bytes revertReason);
error PreExecHookReverted(address plugin, uint8 functionId, bytes revertReason);
Expand All @@ -80,6 +77,8 @@ contract UpgradeableModularAccount is
error UnrecognizedFunction(bytes4 selector);
error UserOpValidationFunctionMissing(bytes4 selector);
error ValidationDoesNotApply(bytes4 selector, address plugin, uint8 functionId, bool isDefault);
error ValidationSignatureSegmentMissing();
error SignatureSegmentOutOfOrder();

// Wraps execution of a native function with runtime validation and hooks
// Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin
Expand Down Expand Up @@ -407,38 +406,50 @@ contract UpgradeableModularAccount is
revert RequireUserOperationContext();
}

validationData =
_doUserOpValidation(selector, userOpValidationFunction, userOp, userOp.signature[22:], userOpHash);
validationData = _doUserOpValidation(userOpValidationFunction, userOp, userOp.signature[22:], userOpHash);
}

// To support gas estimation, we don't fail early when the failure is caused by a signature failure
function _doUserOpValidation(
bytes4 selector,
FunctionReference userOpValidationFunction,
PackedUserOperation memory userOp,
bytes calldata signature,
bytes32 userOpHash
) internal returns (uint256 validationData) {
userOp.signature = signature;
) internal returns (uint256) {
// Set up the per-hook data tracking fields
bytes calldata signatureSegment;
(signatureSegment, signature) = signature.getNextSegment();

if (userOpValidationFunction.isEmpty()) {
// If the validation function is empty, then the call cannot proceed.
revert UserOpValidationFunctionMissing(selector);
}

uint256 currentValidationData;
uint256 validationData;

// Do preUserOpValidation hooks
EnumerableSet.Bytes32Set storage preUserOpValidationHooks =
FunctionReference[] memory preUserOpValidationHooks =
getAccountStorage().validationData[userOpValidationFunction].preValidationHooks;

uint256 preUserOpValidationHooksLength = preUserOpValidationHooks.length();
for (uint256 i = 0; i < preUserOpValidationHooksLength; ++i) {
bytes32 key = preUserOpValidationHooks.at(i);
FunctionReference preUserOpValidationHook = toFunctionReference(key);
for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) {
// Load per-hook data, if any is present
// The segment index is the first byte of the signature
if (signatureSegment.getIndex() == i) {
// Use the current segment
userOp.signature = signatureSegment.getBody();

if (userOp.signature.length == 0) {
revert NonCanonicalEncoding();
}

// Load the next per-hook data segment
(signatureSegment, signature) = signature.getNextSegment();

if (signatureSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
userOp.signature = "";
}

(address plugin, uint8 functionId) = preUserOpValidationHook.unpack();
currentValidationData = IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash);
(address plugin, uint8 functionId) = preUserOpValidationHooks[i].unpack();
uint256 currentValidationData =
IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash);

if (uint160(currentValidationData) > 1) {
// If the aggregator is not 0 or 1, it is an unexpected value
Expand All @@ -449,35 +460,63 @@ contract UpgradeableModularAccount is

// Run the user op validationFunction
{
if (signatureSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

userOp.signature = signatureSegment.getBody();

(address plugin, uint8 functionId) = userOpValidationFunction.unpack();
currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash);
uint256 currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash);

if (preUserOpValidationHooksLength != 0) {
if (preUserOpValidationHooks.length != 0) {
// If we have other validation data we need to coalesce with
validationData = _coalesceValidation(validationData, currentValidationData);
} else {
validationData = currentValidationData;
}
}

return validationData;
}

function _doRuntimeValidation(
FunctionReference runtimeValidationFunction,
bytes calldata callData,
bytes calldata authorizationData
) internal {
// Set up the per-hook data tracking fields
bytes calldata authSegment;
(authSegment, authorizationData) = authorizationData.getNextSegment();

// run all preRuntimeValidation hooks
EnumerableSet.Bytes32Set storage preRuntimeValidationHooks =
FunctionReference[] memory preRuntimeValidationHooks =
getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks;

uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length();
for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) {
bytes32 key = preRuntimeValidationHooks.at(i);
FunctionReference preRuntimeValidationHook = toFunctionReference(key);
for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) {
bytes memory currentAuthData;

if (authSegment.getIndex() == i) {
// Use the current segment
currentAuthData = authSegment.getBody();

if (currentAuthData.length == 0) {
revert NonCanonicalEncoding();
}

// Load the next per-hook data segment
(authSegment, authorizationData) = authorizationData.getNextSegment();

(address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHook.unpack();
if (authSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
currentAuthData = "";
}

(address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHooks[i].unpack();
try IValidationHook(hookPlugin).preRuntimeValidationHook(
hookFunctionId, msg.sender, msg.value, callData
hookFunctionId, msg.sender, msg.value, callData, currentAuthData
)
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
Expand All @@ -487,9 +526,13 @@ contract UpgradeableModularAccount is
}
}

if (authSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

(address plugin, uint8 functionId) = runtimeValidationFunction.unpack();

try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authorizationData)
try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authSegment.getBody())
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
{} catch (bytes memory revertReason) {
Expand Down
51 changes: 51 additions & 0 deletions src/helpers/SparseCalldataSegmentLib.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

/// @title Sparse Calldata Segment Library
/// @notice Library for working with sparsely-packed calldata segments, identified with an index.
/// @dev The first byte of each segment is the index of the segment.
/// To prevent accidental stack-to-deep errors, the body and index of the segment are extracted separately, rather
/// than inline as part of the tuple returned by `getNextSegment`.
library SparseCalldataSegmentLib {
/// @notice Splits out a segment of calldata, sparsely-packed.
/// The expected format is:
/// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN]
/// @param source The calldata to extract the segment from.
/// @return segment The extracted segment. Using the above example, this would be segment0.
/// @return remainder The remaining calldata. Using the above example,
/// this would start at uint32(len(segment1)) and continue to the end at segmentN.
function getNextSegment(bytes calldata source)
internal
pure
returns (bytes calldata segment, bytes calldata remainder)
{
// The first 4 bytes hold the length of the segment, excluding the index.
uint32 length = uint32(bytes4(source[:4]));

// The offset of the remainder of the calldata.
uint256 remainderOffset = 4 + length;

// The segment is the next `length` + 1 bytes, to account for the index.
// By convention, the first byte of each segment is the index of the segment.
segment = source[4:remainderOffset];

// The remainder is the rest of the calldata.
remainder = source[remainderOffset:];
}

/// @notice Extracts the index from a segment.
/// @dev The first byte of the segment is the index.
/// @param segment The segment to extract the index from
/// @return The index of the segment
function getIndex(bytes calldata segment) internal pure returns (uint8) {
return uint8(segment[0]);
}

/// @notice Extracts the body from a segment.
/// @dev The body is the segment without the index.
/// @param segment The segment to extract the body from
/// @return The body of the segment.
function getBody(bytes calldata segment) internal pure returns (bytes calldata) {
return segment[1:];
}
}
1 change: 1 addition & 0 deletions src/interfaces/IValidation.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ interface IValidation is IPlugin {
/// @param sender The caller address.
/// @param value The call value.
/// @param data The calldata sent.
/// @param authorization Additional data for the validation function to use.
function validateRuntime(
uint8 functionId,
address sender,
Expand Down
9 changes: 7 additions & 2 deletions src/interfaces/IValidationHook.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ interface IValidationHook is IPlugin {
/// @param sender The caller address.
/// @param value The call value.
/// @param data The calldata sent.
function preRuntimeValidationHook(uint8 functionId, address sender, uint256 value, bytes calldata data)
external;
function preRuntimeValidationHook(
uint8 functionId,
address sender,
uint256 value,
bytes calldata data,
bytes calldata authorization
) external;

// TODO: support this hook type within the account & in the manifest

Expand Down
Loading

0 comments on commit de28c60

Please sign in to comment.