Skip to content

Commit

Permalink
feat: update allowlist module to use entity id
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed committed Aug 1, 2024
1 parent f9b20ad commit d3922e2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 44 deletions.
65 changes: 31 additions & 34 deletions src/modules/permissionhooks/AllowlistModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ import {IValidationHook} from "../../interfaces/IValidationHook.sol";
import {BaseModule} from "../../modules/BaseModule.sol";

contract AllowlistModule is IValidationHook, BaseModule {
enum EntityId {
PRE_VALIDATION_HOOK
}

struct AllowlistInit {
address target;
bool hasSelectorAllowlist;
Expand All @@ -25,48 +21,53 @@ contract AllowlistModule is IValidationHook, BaseModule {
bool hasSelectorAllowlist;
}

mapping(address target => mapping(address account => AllowlistEntry)) public targetAllowlist;
mapping(address target => mapping(bytes4 selector => mapping(address account => bool))) public
selectorAllowlist;
mapping(uint32 entityId => mapping(address target => mapping(address account => AllowlistEntry))) public
targetAllowlist;
mapping(
uint32 entityId => mapping(address target => mapping(bytes4 selector => mapping(address account => bool)))
) public selectorAllowlist;

error TargetNotAllowed();
error SelectorNotAllowed();
error NoSelectorSpecified();

function onInstall(bytes calldata data) external override {
AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[]));
(uint32 entityId, AllowlistInit[] memory init) = abi.decode(data, (uint32, AllowlistInit[]));

for (uint256 i = 0; i < init.length; i++) {
targetAllowlist[init[i].target][msg.sender] = AllowlistEntry(true, init[i].hasSelectorAllowlist);
targetAllowlist[entityId][init[i].target][msg.sender] =
AllowlistEntry(true, init[i].hasSelectorAllowlist);

if (init[i].hasSelectorAllowlist) {
for (uint256 j = 0; j < init[i].selectors.length; j++) {
selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender] = true;
selectorAllowlist[entityId][init[i].target][init[i].selectors[j]][msg.sender] = true;
}
}
}
}

function onUninstall(bytes calldata data) external override {
AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[]));
(uint32 entityId, AllowlistInit[] memory init) = abi.decode(data, (uint32, AllowlistInit[]));

for (uint256 i = 0; i < init.length; i++) {
delete targetAllowlist[init[i].target][msg.sender];
delete targetAllowlist[entityId][init[i].target][msg.sender];

if (init[i].hasSelectorAllowlist) {
for (uint256 j = 0; j < init[i].selectors.length; j++) {
delete selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender];
delete selectorAllowlist[entityId][init[i].target][init[i].selectors[j]][msg.sender];
}
}
}
}

function setAllowlistTarget(address target, bool allowed, bool hasSelectorAllowlist) external {
targetAllowlist[target][msg.sender] = AllowlistEntry(allowed, hasSelectorAllowlist);
function setAllowlistTarget(uint32 entityId, address target, bool allowed, bool hasSelectorAllowlist)
external
{
targetAllowlist[entityId][target][msg.sender] = AllowlistEntry(allowed, hasSelectorAllowlist);
}

function setAllowlistSelector(address target, bytes4 selector, bool allowed) external {
selectorAllowlist[target][selector][msg.sender] = allowed;
function setAllowlistSelector(uint32 entityId, address target, bytes4 selector, bool allowed) external {
selectorAllowlist[entityId][target][selector][msg.sender] = allowed;
}

function preUserOpValidationHook(uint32 entityId, PackedUserOperation calldata userOp, bytes32)
Expand All @@ -75,24 +76,17 @@ contract AllowlistModule is IValidationHook, BaseModule {
override
returns (uint256)
{
if (entityId == uint32(EntityId.PRE_VALIDATION_HOOK)) {
_checkAllowlistCalldata(userOp.callData);
return 0;
}
revert NotImplemented();
checkAllowlistCalldata(entityId, userOp.callData);
return 0;
}

function preRuntimeValidationHook(uint32 entityId, address, uint256, bytes calldata data, bytes calldata)
external
view
override
{
if (entityId == uint32(EntityId.PRE_VALIDATION_HOOK)) {
_checkAllowlistCalldata(data);
return;
}

revert NotImplemented();
checkAllowlistCalldata(entityId, data);
return;
}

function moduleMetadata() external pure override returns (ModuleMetadata memory) {
Expand All @@ -104,21 +98,24 @@ contract AllowlistModule is IValidationHook, BaseModule {
return metadata;
}

function _checkAllowlistCalldata(bytes calldata callData) internal view {
function checkAllowlistCalldata(uint32 entityId, bytes calldata callData) public view {
if (bytes4(callData[:4]) == IStandardExecutor.execute.selector) {
(address target,, bytes memory data) = abi.decode(callData[4:], (address, uint256, bytes));
_checkCallPermission(msg.sender, target, data);
_checkCallPermission(entityId, msg.sender, target, data);
} else if (bytes4(callData[:4]) == IStandardExecutor.executeBatch.selector) {
Call[] memory calls = abi.decode(callData[4:], (Call[]));

for (uint256 i = 0; i < calls.length; i++) {
_checkCallPermission(msg.sender, calls[i].target, calls[i].data);
_checkCallPermission(entityId, msg.sender, calls[i].target, calls[i].data);
}
}
}

function _checkCallPermission(address account, address target, bytes memory data) internal view {
AllowlistEntry storage entry = targetAllowlist[target][account];
function _checkCallPermission(uint32 entityId, address account, address target, bytes memory data)
internal
view
{
AllowlistEntry storage entry = targetAllowlist[entityId][target][account];
(bool allowed, bool hasSelectorAllowlist) = (entry.allowed, entry.hasSelectorAllowlist);

if (!allowed) {
Expand All @@ -132,7 +129,7 @@ contract AllowlistModule is IValidationHook, BaseModule {

bytes4 selector = bytes4(data);

if (!selectorAllowlist[target][selector][account]) {
if (!selectorAllowlist[entityId][target][selector][account]) {
revert SelectorNotAllowed();
}
}
Expand Down
24 changes: 14 additions & 10 deletions test/module/AllowlistModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ contract AllowlistModuleTest is CustomValidationTestBase {

Counter[] public counters;

uint32 public constant HOOK_ENTITY_ID = 0;

function setUp() public {
allowlistModule = new AllowlistModule();

Expand Down Expand Up @@ -147,11 +149,13 @@ contract AllowlistModuleTest is CustomValidationTestBase {
Call memory call = calls[i];

(bool allowed, bool hasSelectorAllowlist) =
allowlistModule.targetAllowlist(call.target, address(account1));
allowlistModule.targetAllowlist(HOOK_ENTITY_ID, call.target, address(account1));
if (allowed) {
if (
hasSelectorAllowlist
&& !allowlistModule.selectorAllowlist(call.target, bytes4(call.data), address(account1))
&& !allowlistModule.selectorAllowlist(
HOOK_ENTITY_ID, call.target, bytes4(call.data), address(account1)
)
) {
return abi.encodeWithSelector(
IEntryPoint.FailedOpWithRevert.selector,
Expand All @@ -178,24 +182,26 @@ contract AllowlistModuleTest is CustomValidationTestBase {
Call memory call = calls[i];

(bool allowed, bool hasSelectorAllowlist) =
allowlistModule.targetAllowlist(call.target, address(account1));
allowlistModule.targetAllowlist(HOOK_ENTITY_ID, call.target, address(account1));
if (allowed) {
if (
hasSelectorAllowlist
&& !allowlistModule.selectorAllowlist(call.target, bytes4(call.data), address(account1))
&& !allowlistModule.selectorAllowlist(
HOOK_ENTITY_ID, call.target, bytes4(call.data), address(account1)
)
) {
return abi.encodeWithSelector(
UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector,
address(allowlistModule),
uint32(AllowlistModule.EntityId.PRE_VALIDATION_HOOK),
HOOK_ENTITY_ID,
abi.encodeWithSelector(AllowlistModule.SelectorNotAllowed.selector)
);
}
} else {
return abi.encodeWithSelector(
UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector,
address(allowlistModule),
uint32(AllowlistModule.EntityId.PRE_VALIDATION_HOOK),
HOOK_ENTITY_ID,
abi.encodeWithSelector(AllowlistModule.TargetNotAllowed.selector)
);
}
Expand Down Expand Up @@ -297,10 +303,8 @@ contract AllowlistModuleTest is CustomValidationTestBase {
{
bytes[] memory hooks = new bytes[](1);
hooks[0] = abi.encodePacked(
HookConfigLib.packValidationHook(
address(allowlistModule), uint32(AllowlistModule.EntityId.PRE_VALIDATION_HOOK)
),
abi.encode(allowlistInit)
HookConfigLib.packValidationHook(address(allowlistModule), HOOK_ENTITY_ID),
abi.encode(HOOK_ENTITY_ID, allowlistInit)
);

return (
Expand Down

0 comments on commit d3922e2

Please sign in to comment.