Skip to content

Commit

Permalink
feat: support global direct call validation
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed committed Aug 27, 2024
1 parent 7c85a43 commit 76f4270
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 28 deletions.
66 changes: 52 additions & 14 deletions src/account/ReferenceModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ contract ReferenceModularAccount is
ModuleEntity postExecHook;
}

enum ValidationCheckingType {
GLOBAL,
SELECTOR,
EITHER
}

IEntryPoint private immutable _ENTRY_POINT;

// As per the EIP-165 spec, no interface should ever match 0xffffffff
Expand Down Expand Up @@ -187,7 +193,11 @@ contract ReferenceModularAccount is

// Check if the runtime validation function is allowed to be called
bool isGlobalValidation = uint8(authorization[24]) == 1;
_checkIfValidationAppliesCallData(data, runtimeValidationFunction, isGlobalValidation);
_checkIfValidationAppliesCallData(
data,
runtimeValidationFunction,
isGlobalValidation ? ValidationCheckingType.GLOBAL : ValidationCheckingType.SELECTOR
);

_doRuntimeValidation(runtimeValidationFunction, data, authorization[25:]);

Expand Down Expand Up @@ -343,7 +353,11 @@ contract ReferenceModularAccount is
ModuleEntity userOpValidationFunction = ModuleEntity.wrap(bytes24(userOp.signature[:24]));
bool isGlobalValidation = uint8(userOp.signature[24]) == 1;

_checkIfValidationAppliesCallData(userOp.callData, userOpValidationFunction, isGlobalValidation);
_checkIfValidationAppliesCallData(
userOp.callData,
userOpValidationFunction,
isGlobalValidation ? ValidationCheckingType.GLOBAL : ValidationCheckingType.SELECTOR
);

// Check if there are permission hooks associated with the validator, and revert if the call isn't to
// `executeUserOp`
Expand Down Expand Up @@ -549,7 +563,7 @@ contract ReferenceModularAccount is
ModuleEntity directCallValidationKey =
ModuleEntityLib.pack(msg.sender, DIRECT_CALL_VALIDATION_ENTITYID);

_checkIfValidationAppliesCallData(msg.data, directCallValidationKey, false);
_checkIfValidationAppliesCallData(msg.data, directCallValidationKey, ValidationCheckingType.EITHER);

// Direct call is allowed, run associated permission & validation hooks

Expand Down Expand Up @@ -645,7 +659,7 @@ contract ReferenceModularAccount is
function _checkIfValidationAppliesCallData(
bytes calldata callData,
ModuleEntity validationFunction,
bool isGlobal
ValidationCheckingType checkingType
) internal view {
bytes4 outerSelector = bytes4(callData[:4]);
if (outerSelector == this.executeUserOp.selector) {
Expand All @@ -655,7 +669,7 @@ contract ReferenceModularAccount is
outerSelector = bytes4(callData[:4]);
}

_checkIfValidationAppliesSelector(outerSelector, validationFunction, isGlobal);
_checkIfValidationAppliesSelector(outerSelector, validationFunction, checkingType);

if (outerSelector == IModularAccount.execute.selector) {
(address target,,) = abi.decode(callData[4:], (address, uint256, bytes));
Expand Down Expand Up @@ -689,26 +703,50 @@ contract ReferenceModularAccount is
revert SelfCallRecursionDepthExceeded();
}

_checkIfValidationAppliesSelector(nestedSelector, validationFunction, isGlobal);
_checkIfValidationAppliesSelector(nestedSelector, validationFunction, checkingType);
}
}
}
}

function _checkIfValidationAppliesSelector(bytes4 selector, ModuleEntity validationFunction, bool isGlobal)
internal
view
{
function _checkIfValidationAppliesSelector(
bytes4 selector,
ModuleEntity validationFunction,
ValidationCheckingType checkingType
) internal view {
// Check that the provided validation function is applicable to the selector
if (isGlobal) {
if (!_globalValidationAllowed(selector) || !_isValidationGlobal(validationFunction)) {

if (checkingType == ValidationCheckingType.GLOBAL) {
if (!_globalValidationApplies(selector, validationFunction)) {
revert ValidationFunctionMissing(selector);
}
} else if (checkingType == ValidationCheckingType.SELECTOR) {
if (!_selectorValidationApplies(selector, validationFunction)) {
revert ValidationFunctionMissing(selector);
}
} else {
// Not global validation, but per-selector
if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) {
if (
!_globalValidationApplies(selector, validationFunction)
&& !_selectorValidationApplies(selector, validationFunction)
) {
revert ValidationFunctionMissing(selector);
}
}
}

function _globalValidationApplies(bytes4 selector, ModuleEntity validationFunction)
internal
view
returns (bool)
{
return _globalValidationAllowed(selector) && _isValidationGlobal(validationFunction);
}

function _selectorValidationApplies(bytes4 selector, ModuleEntity validationFunction)
internal
view
returns (bool)
{
return getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector));
}
}
55 changes: 41 additions & 14 deletions test/account/DirectCallsFromModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ contract DirectCallsFromModuleTest is AccountTestBase {

event ValidationUninstalled(address indexed module, uint32 indexed entityId, bool onUninstallSucceeded);

modifier randomizedValidationType(bool selectorValidation) {
if (selectorValidation) {
_installValidationSelector();
} else {
_installValidationGlobal();
}
_;
}

function setUp() public {
_module = new DirectCallModule();
assertFalse(_module.preHookRan());
Expand All @@ -38,9 +47,10 @@ contract DirectCallsFromModuleTest is AccountTestBase {
account1.execute(address(0), 0, "");
}

function test_Fail_DirectCallModuleUninstalled() external {
_installValidation();

function testFuzz_Fail_DirectCallModuleUninstalled(bool validationType)
external
randomizedValidationType(validationType)
{
_uninstallValidation();

vm.prank(address(_module));
Expand All @@ -49,7 +59,7 @@ contract DirectCallsFromModuleTest is AccountTestBase {
}

function test_Fail_DirectCallModuleCallOtherSelector() external {
_installValidation();
_installValidationSelector();

Call[] memory calls = new Call[](0);

Expand All @@ -62,19 +72,21 @@ contract DirectCallsFromModuleTest is AccountTestBase {
/* Positives */
/* -------------------------------------------------------------------------- */

function test_Pass_DirectCallFromModulePrank() external {
_installValidation();

function testFuzz_Pass_DirectCallFromModulePrank(bool validationType)
external
randomizedValidationType(validationType)
{
vm.prank(address(_module));
account1.execute(address(0), 0, "");

assertTrue(_module.preHookRan());
assertTrue(_module.postHookRan());
}

function test_Pass_DirectCallFromModuleCallback() external {
_installValidation();

function testFuzz_Pass_DirectCallFromModuleCallback(bool validationType)
external
randomizedValidationType(validationType)
{
bytes memory encodedCall = abi.encodeCall(DirectCallModule.directCall, ());

vm.prank(address(entryPoint));
Expand All @@ -88,11 +100,12 @@ contract DirectCallsFromModuleTest is AccountTestBase {
assertEq(abi.decode(result, (bytes)), abi.encode(_module.getData()));
}

function test_Flow_DirectCallFromModuleSequence() external {
function testFuzz_Flow_DirectCallFromModuleSequence(bool validationType)
external
randomizedValidationType(validationType)
{
// Install => Succeesfully call => uninstall => fail to call

_installValidation();

vm.prank(address(_module));
account1.execute(address(0), 0, "");

Expand Down Expand Up @@ -129,7 +142,7 @@ contract DirectCallsFromModuleTest is AccountTestBase {
/* Internals */
/* -------------------------------------------------------------------------- */

function _installValidation() internal {
function _installValidationSelector() internal {
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = IModularAccount.execute.selector;

Expand All @@ -146,6 +159,20 @@ contract DirectCallsFromModuleTest is AccountTestBase {
account1.installValidation(validationConfig, selectors, "", hooks);
}

function _installValidationGlobal() internal {
bytes[] memory hooks = new bytes[](1);
hooks[0] = abi.encodePacked(
HookConfigLib.packExecHook({_hookFunction: _moduleEntity, _hasPre: true, _hasPost: true}),
hex"00" // onInstall data
);

vm.prank(address(entryPoint));

ValidationConfig validationConfig = ValidationConfigLib.pack(_moduleEntity, true, false);

account1.installValidation(validationConfig, new bytes4[](0), "", hooks);
}

function _uninstallValidation() internal {
(address module, uint32 entityId) = ModuleEntityLib.unpack(_moduleEntity);
vm.prank(address(entryPoint));
Expand Down

0 comments on commit 76f4270

Please sign in to comment.