Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support global direct call validation #164

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions src/account/ReferenceModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ contract ReferenceModularAccount is
ModuleEntity postExecHook;
}

enum ValidationCheckingType {
GLOBAL,
SELECTOR,
EITHER
}

IEntryPoint private immutable _ENTRY_POINT;

// As per the EIP-165 spec, no interface should ever match 0xffffffff
Expand Down Expand Up @@ -187,7 +193,11 @@ contract ReferenceModularAccount is

// Check if the runtime validation function is allowed to be called
bool isGlobalValidation = uint8(authorization[24]) == 1;
_checkIfValidationAppliesCallData(data, runtimeValidationFunction, isGlobalValidation);
_checkIfValidationAppliesCallData(
data,
runtimeValidationFunction,
isGlobalValidation ? ValidationCheckingType.GLOBAL : ValidationCheckingType.SELECTOR
);

_doRuntimeValidation(runtimeValidationFunction, data, authorization[25:]);

Expand Down Expand Up @@ -343,7 +353,11 @@ contract ReferenceModularAccount is
ModuleEntity userOpValidationFunction = ModuleEntity.wrap(bytes24(userOp.signature[:24]));
bool isGlobalValidation = uint8(userOp.signature[24]) == 1;

_checkIfValidationAppliesCallData(userOp.callData, userOpValidationFunction, isGlobalValidation);
_checkIfValidationAppliesCallData(
userOp.callData,
userOpValidationFunction,
isGlobalValidation ? ValidationCheckingType.GLOBAL : ValidationCheckingType.SELECTOR
);

// Check if there are permission hooks associated with the validator, and revert if the call isn't to
// `executeUserOp`
Expand Down Expand Up @@ -549,7 +563,7 @@ contract ReferenceModularAccount is
ModuleEntity directCallValidationKey =
ModuleEntityLib.pack(msg.sender, DIRECT_CALL_VALIDATION_ENTITYID);

_checkIfValidationAppliesCallData(msg.data, directCallValidationKey, false);
_checkIfValidationAppliesCallData(msg.data, directCallValidationKey, ValidationCheckingType.EITHER);

// Direct call is allowed, run associated permission & validation hooks

Expand Down Expand Up @@ -645,7 +659,7 @@ contract ReferenceModularAccount is
function _checkIfValidationAppliesCallData(
bytes calldata callData,
ModuleEntity validationFunction,
bool isGlobal
ValidationCheckingType checkingType
) internal view {
bytes4 outerSelector = bytes4(callData[:4]);
if (outerSelector == this.executeUserOp.selector) {
Expand All @@ -655,7 +669,7 @@ contract ReferenceModularAccount is
outerSelector = bytes4(callData[:4]);
}

_checkIfValidationAppliesSelector(outerSelector, validationFunction, isGlobal);
_checkIfValidationAppliesSelector(outerSelector, validationFunction, checkingType);

if (outerSelector == IModularAccount.execute.selector) {
(address target,,) = abi.decode(callData[4:], (address, uint256, bytes));
Expand Down Expand Up @@ -689,26 +703,50 @@ contract ReferenceModularAccount is
revert SelfCallRecursionDepthExceeded();
}

_checkIfValidationAppliesSelector(nestedSelector, validationFunction, isGlobal);
_checkIfValidationAppliesSelector(nestedSelector, validationFunction, checkingType);
}
}
}
}

function _checkIfValidationAppliesSelector(bytes4 selector, ModuleEntity validationFunction, bool isGlobal)
internal
view
{
function _checkIfValidationAppliesSelector(
bytes4 selector,
ModuleEntity validationFunction,
ValidationCheckingType checkingType
) internal view {
// Check that the provided validation function is applicable to the selector
if (isGlobal) {
if (!_globalValidationAllowed(selector) || !_isValidationGlobal(validationFunction)) {

if (checkingType == ValidationCheckingType.GLOBAL) {
if (!_globalValidationApplies(selector, validationFunction)) {
revert ValidationFunctionMissing(selector);
}
} else if (checkingType == ValidationCheckingType.SELECTOR) {
if (!_selectorValidationApplies(selector, validationFunction)) {
revert ValidationFunctionMissing(selector);
}
} else {
// Not global validation, but per-selector
if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) {
if (
!_globalValidationApplies(selector, validationFunction)
&& !_selectorValidationApplies(selector, validationFunction)
) {
revert ValidationFunctionMissing(selector);
}
}
}

function _globalValidationApplies(bytes4 selector, ModuleEntity validationFunction)
internal
view
returns (bool)
{
return _globalValidationAllowed(selector) && _isValidationGlobal(validationFunction);
}

function _selectorValidationApplies(bytes4 selector, ModuleEntity validationFunction)
internal
view
returns (bool)
{
return getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector));
}
}
55 changes: 41 additions & 14 deletions test/account/DirectCallsFromModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ contract DirectCallsFromModuleTest is AccountTestBase {

event ValidationUninstalled(address indexed module, uint32 indexed entityId, bool onUninstallSucceeded);

modifier randomizedValidationType(bool selectorValidation) {
if (selectorValidation) {
_installValidationSelector();
} else {
_installValidationGlobal();
}
_;
}

function setUp() public {
_module = new DirectCallModule();
assertFalse(_module.preHookRan());
Expand All @@ -38,9 +47,10 @@ contract DirectCallsFromModuleTest is AccountTestBase {
account1.execute(address(0), 0, "");
}

function test_Fail_DirectCallModuleUninstalled() external {
_installValidation();

function testFuzz_Fail_DirectCallModuleUninstalled(bool validationType)
external
randomizedValidationType(validationType)
{
_uninstallValidation();

vm.prank(address(_module));
Expand All @@ -49,7 +59,7 @@ contract DirectCallsFromModuleTest is AccountTestBase {
}

function test_Fail_DirectCallModuleCallOtherSelector() external {
_installValidation();
_installValidationSelector();

Call[] memory calls = new Call[](0);

Expand All @@ -62,19 +72,21 @@ contract DirectCallsFromModuleTest is AccountTestBase {
/* Positives */
/* -------------------------------------------------------------------------- */

function test_Pass_DirectCallFromModulePrank() external {
_installValidation();

function testFuzz_Pass_DirectCallFromModulePrank(bool validationType)
external
randomizedValidationType(validationType)
{
vm.prank(address(_module));
account1.execute(address(0), 0, "");

assertTrue(_module.preHookRan());
assertTrue(_module.postHookRan());
}

function test_Pass_DirectCallFromModuleCallback() external {
_installValidation();

function testFuzz_Pass_DirectCallFromModuleCallback(bool validationType)
external
randomizedValidationType(validationType)
{
bytes memory encodedCall = abi.encodeCall(DirectCallModule.directCall, ());

vm.prank(address(entryPoint));
Expand All @@ -88,11 +100,12 @@ contract DirectCallsFromModuleTest is AccountTestBase {
assertEq(abi.decode(result, (bytes)), abi.encode(_module.getData()));
}

function test_Flow_DirectCallFromModuleSequence() external {
function testFuzz_Flow_DirectCallFromModuleSequence(bool validationType)
external
randomizedValidationType(validationType)
{
// Install => Succeesfully call => uninstall => fail to call

_installValidation();

vm.prank(address(_module));
account1.execute(address(0), 0, "");

Expand Down Expand Up @@ -129,7 +142,7 @@ contract DirectCallsFromModuleTest is AccountTestBase {
/* Internals */
/* -------------------------------------------------------------------------- */

function _installValidation() internal {
function _installValidationSelector() internal {
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = IModularAccount.execute.selector;

Expand All @@ -146,6 +159,20 @@ contract DirectCallsFromModuleTest is AccountTestBase {
account1.installValidation(validationConfig, selectors, "", hooks);
}

function _installValidationGlobal() internal {
bytes[] memory hooks = new bytes[](1);
hooks[0] = abi.encodePacked(
HookConfigLib.packExecHook({_hookFunction: _moduleEntity, _hasPre: true, _hasPost: true}),
hex"00" // onInstall data
);

vm.prank(address(entryPoint));

ValidationConfig validationConfig = ValidationConfigLib.pack(_moduleEntity, true, false);

account1.installValidation(validationConfig, new bytes4[](0), "", hooks);
}

function _uninstallValidation() internal {
(address module, uint32 entityId) = ModuleEntityLib.unpack(_moduleEntity);
vm.prank(address(entryPoint));
Expand Down
Loading