diff --git a/.env.example b/.env.example index 4d35e4c1..7412b0ad 100644 --- a/.env.example +++ b/.env.example @@ -7,8 +7,8 @@ ENTRYPOINT= # Create2 expected addresses of the contracts. # When running for the first time, the error message will contain the expected addresses. ACCOUNT_IMPL= -FACTORY= SINGLE_SIGNER_VALIDATION= +FACTORY= # Optional, defaults to bytes32(0) ACCOUNT_IMPL_SALT= @@ -18,3 +18,7 @@ SINGLE_SIGNER_VALIDATION_SALT= # Optional, defaults to 0.1 ether and 1 day, respectively STAKE_AMOUNT= UNSTAKE_DELAY= + +# Allowlist Module +ALLOWLIST_MODULE= +ALLOWLIST_MODULE_SALT= diff --git a/.solhint-script.json b/.solhint-script.json new file mode 100644 index 00000000..307eb64c --- /dev/null +++ b/.solhint-script.json @@ -0,0 +1,21 @@ +{ + "extends": "solhint:recommended", + "rules": { + "func-name-mixedcase": "off", + "immutable-vars-naming": ["error"], + "no-unused-import": ["error"], + "compiler-version": ["error", ">=0.8.19"], + "custom-errors": "off", + "no-console": "off", + "func-visibility": ["error", { "ignoreConstructors": true }], + "max-line-length": ["error", 120], + "max-states-count": ["warn", 30], + "modifier-name-mixedcase": ["error"], + "private-vars-leading-underscore": ["error"], + "no-inline-assembly": "warn", + "avoid-low-level-calls": "off", + "one-contract-per-file": "off", + "no-empty-blocks": "off", + "reason-string": "off" + } +} diff --git a/package.json b/package.json index 1c1540e7..0f8b51ea 100644 --- a/package.json +++ b/package.json @@ -4,8 +4,9 @@ "solhint": "^3.6.2" }, "scripts": { - "lint": "pnpm lint:src && pnpm lint:test", + "lint": "pnpm lint:src && pnpm lint:test && pnpm lint:script", "lint:src": "solhint --max-warnings 0 -c .solhint-src.json './src/**/*.sol'", - "lint:test": "solhint --max-warnings 0 -c .solhint-test.json './test/**/*.sol'" + "lint:test": "solhint --max-warnings 0 -c .solhint-test.json './test/**/*.sol'", + "lint:script": "solhint --max-warnings 0 -c .solhint-script.json './script/**/*.sol'" } } diff --git a/script/Deploy.s.sol b/script/Deploy.s.sol index 042a0ab8..6c6afb95 100644 --- a/script/Deploy.s.sol +++ b/script/Deploy.s.sol @@ -2,8 +2,7 @@ pragma solidity ^0.8.25; import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; -import {Script} from "forge-std/Script.sol"; -import {console2} from "forge-std/Test.sol"; +import {Script, console} from "forge-std/Script.sol"; import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; @@ -28,10 +27,10 @@ contract DeployScript is Script { uint256 public requiredUnstakeDelay = vm.envOr("UNSTAKE_DELAY", uint256(1 days)); function run() public { - console2.log("******** Deploying ERC-6900 Reference Implementation ********"); - console2.log("Chain: ", block.chainid); - console2.log("EP: ", address(entryPoint)); - console2.log("Factory owner: ", owner); + console.log("******** Deploying ERC-6900 Reference Implementation ********"); + console.log("Chain: ", block.chainid); + console.log("EP: ", address(entryPoint)); + console.log("Factory owner: ", owner); vm.startBroadcast(); _deployAccountImpl(accountImplSalt, accountImpl); @@ -42,7 +41,7 @@ contract DeployScript is Script { } function _deployAccountImpl(bytes32 salt, address expected) internal { - console2.log(string.concat("Deploying AccountImpl with salt: ", vm.toString(salt))); + console.log(string.concat("Deploying AccountImpl with salt: ", vm.toString(salt))); address addr = Create2.computeAddress( salt, @@ -50,61 +49,61 @@ contract DeployScript is Script { CREATE2_FACTORY ); if (addr != expected) { - console2.log("Expected address mismatch"); - console2.log("Expected: ", expected); - console2.log("Actual: ", addr); + console.log("Expected address mismatch"); + console.log("Expected: ", expected); + console.log("Actual: ", addr); revert(); } if (addr.code.length == 0) { - console2.log("No code found at expected address, deploying..."); + console.log("No code found at expected address, deploying..."); UpgradeableModularAccount deployed = new UpgradeableModularAccount{salt: salt}(entryPoint); if (address(deployed) != expected) { - console2.log("Deployed address mismatch"); - console2.log("Expected: ", expected); - console2.log("Deployed: ", address(deployed)); + console.log("Deployed address mismatch"); + console.log("Expected: ", expected); + console.log("Deployed: ", address(deployed)); revert(); } - console2.log("Deployed AccountImpl at: ", address(deployed)); + console.log("Deployed AccountImpl at: ", address(deployed)); } else { - console2.log("Code found at expected address, skipping deployment"); + console.log("Code found at expected address, skipping deployment"); } } function _deploySingleSignerValidation(bytes32 salt, address expected) internal { - console2.log(string.concat("Deploying SingleSignerValidation with salt: ", vm.toString(salt))); + console.log(string.concat("Deploying SingleSignerValidation with salt: ", vm.toString(salt))); address addr = Create2.computeAddress( salt, keccak256(abi.encodePacked(type(SingleSignerValidation).creationCode)), CREATE2_FACTORY ); if (addr != expected) { - console2.log("Expected address mismatch"); - console2.log("Expected: ", expected); - console2.log("Actual: ", addr); + console.log("Expected address mismatch"); + console.log("Expected: ", expected); + console.log("Actual: ", addr); revert(); } if (addr.code.length == 0) { - console2.log("No code found at expected address, deploying..."); + console.log("No code found at expected address, deploying..."); SingleSignerValidation deployed = new SingleSignerValidation{salt: salt}(); if (address(deployed) != expected) { - console2.log("Deployed address mismatch"); - console2.log("Expected: ", expected); - console2.log("Deployed: ", address(deployed)); + console.log("Deployed address mismatch"); + console.log("Expected: ", expected); + console.log("Deployed: ", address(deployed)); revert(); } - console2.log("Deployed SingleSignerValidation at: ", address(deployed)); + console.log("Deployed SingleSignerValidation at: ", address(deployed)); } else { - console2.log("Code found at expected address, skipping deployment"); + console.log("Code found at expected address, skipping deployment"); } } function _deployAccountFactory(bytes32 salt, address expected) internal { - console2.log(string.concat("Deploying AccountFactory with salt: ", vm.toString(salt))); + console.log(string.concat("Deploying AccountFactory with salt: ", vm.toString(salt))); address addr = Create2.computeAddress( salt, @@ -117,46 +116,46 @@ contract DeployScript is Script { CREATE2_FACTORY ); if (addr != expected) { - console2.log("Expected address mismatch"); - console2.log("Expected: ", expected); - console2.log("Actual: ", addr); + console.log("Expected address mismatch"); + console.log("Expected: ", expected); + console.log("Actual: ", addr); revert(); } if (addr.code.length == 0) { - console2.log("No code found at expected address, deploying..."); + console.log("No code found at expected address, deploying..."); AccountFactory deployed = new AccountFactory{salt: salt}( entryPoint, UpgradeableModularAccount(payable(accountImpl)), singleSignerValidation, owner ); if (address(deployed) != expected) { - console2.log("Deployed address mismatch"); - console2.log("Expected: ", expected); - console2.log("Deployed: ", address(deployed)); + console.log("Deployed address mismatch"); + console.log("Expected: ", expected); + console.log("Deployed: ", address(deployed)); revert(); } - console2.log("Deployed AccountFactory at: ", address(deployed)); + console.log("Deployed AccountFactory at: ", address(deployed)); } else { - console2.log("Code found at expected address, skipping deployment"); + console.log("Code found at expected address, skipping deployment"); } } function _addStakeForFactory(uint32 unstakeDelay, uint256 stakeAmount) internal { - console2.log("Adding stake to factory"); + console.log("Adding stake to factory"); uint256 currentStake = entryPoint.getDepositInfo(address(factory)).stake; - console2.log("Current stake: ", currentStake); + console.log("Current stake: ", currentStake); uint256 stakeToAdd = stakeAmount - currentStake; if (stakeToAdd > 0) { - console2.log("Adding stake: ", stakeToAdd); - entryPoint.addStake{value: stakeToAdd}(unstakeDelay); - console2.log("Staked factory: ", address(factory)); - console2.log("Total stake amount: ", entryPoint.getDepositInfo(address(factory)).stake); - console2.log("Unstake delay: ", entryPoint.getDepositInfo(address(factory)).unstakeDelaySec); + console.log("Adding stake: ", stakeToAdd); + AccountFactory(factory).addStake{value: stakeToAdd}(unstakeDelay); + console.log("Staked factory: ", address(factory)); + console.log("Total stake amount: ", entryPoint.getDepositInfo(address(factory)).stake); + console.log("Unstake delay: ", entryPoint.getDepositInfo(address(factory)).unstakeDelaySec); } else { - console2.log("No stake to add"); + console.log("No stake to add"); } } } diff --git a/script/DeployAllowlistModule.s.sol b/script/DeployAllowlistModule.s.sol new file mode 100644 index 00000000..42cde52f --- /dev/null +++ b/script/DeployAllowlistModule.s.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {Script, console} from "forge-std/Script.sol"; + +import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; + +import {AllowlistModule} from "../src/modules/permissionhooks/AllowlistModule.sol"; + +contract DeployAllowlistModuleScript is Script { + address public allowlistModule = vm.envOr("ALLOWLIST_MODULE", address(0)); + + bytes32 public allowlistModuleSalt = bytes32(vm.envOr("ALLOWLIST_MODULE_SALT", uint256(0))); + + function run() public { + console.log("******** Deploying AllowlistModule ********"); + console.log("Chain: ", block.chainid); + + vm.startBroadcast(); + _deployAllowlistModule(allowlistModuleSalt, allowlistModule); + vm.stopBroadcast(); + } + + function _deployAllowlistModule(bytes32 salt, address expected) internal { + console.log(string.concat("Deploying AllowlistModule with salt: ", vm.toString(salt))); + + address addr = Create2.computeAddress(salt, keccak256(type(AllowlistModule).creationCode), CREATE2_FACTORY); + if (addr != expected) { + console.log("Expected address mismatch"); + console.log("Expected: ", expected); + console.log("Actual: ", addr); + revert(); + } + + if (addr.code.length == 0) { + console.log("No code found at expected address, deploying..."); + AllowlistModule deployed = new AllowlistModule{salt: salt}(); + + if (address(deployed) != expected) { + console.log("Deployed address mismatch"); + console.log("Expected: ", expected); + console.log("Actual: ", address(deployed)); + revert(); + } + + console.log("Deployed AllowlistModule at: ", address(deployed)); + } else { + console.log("Code found at expected address, skipping deployment"); + } + } +} diff --git a/src/account/AccountFactory.sol b/src/account/AccountFactory.sol index 07a74d09..4119461f 100644 --- a/src/account/AccountFactory.sol +++ b/src/account/AccountFactory.sol @@ -13,7 +13,6 @@ import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol"; contract AccountFactory is Ownable { UpgradeableModularAccount public immutable ACCOUNT_IMPL; bytes32 private immutable _PROXY_BYTECODE_HASH; - uint32 public constant UNSTAKE_DELAY = 1 weeks; IEntryPoint public immutable ENTRY_POINT; address public immutable SINGLE_SIGNER_VALIDATION; @@ -61,8 +60,8 @@ contract AccountFactory is Ownable { return UpgradeableModularAccount(payable(addr)); } - function addStake() external payable onlyOwner { - ENTRY_POINT.addStake{value: msg.value}(UNSTAKE_DELAY); + function addStake(uint32 unstakeDelay) external payable onlyOwner { + ENTRY_POINT.addStake{value: msg.value}(unstakeDelay); } function unlockStake() external onlyOwner { diff --git a/src/modules/permissionhooks/AllowlistModule.sol b/src/modules/permissionhooks/AllowlistModule.sol index 2fdf1362..0ef310ab 100644 --- a/src/modules/permissionhooks/AllowlistModule.sol +++ b/src/modules/permissionhooks/AllowlistModule.sol @@ -10,10 +10,6 @@ import {IValidationHook} from "../../interfaces/IValidationHook.sol"; import {BaseModule} from "../../modules/BaseModule.sol"; contract AllowlistModule is IValidationHook, BaseModule { - enum EntityId { - PRE_VALIDATION_HOOK - } - struct AllowlistInit { address target; bool hasSelectorAllowlist; @@ -25,61 +21,59 @@ contract AllowlistModule is IValidationHook, BaseModule { bool hasSelectorAllowlist; } - mapping(address target => mapping(address account => AllowlistEntry)) public targetAllowlist; - mapping(address target => mapping(bytes4 selector => mapping(address account => bool))) public - selectorAllowlist; + mapping(uint32 entityId => mapping(address target => mapping(address account => AllowlistEntry))) public + targetAllowlist; + mapping( + uint32 entityId => mapping(address target => mapping(bytes4 selector => mapping(address account => bool))) + ) public selectorAllowlist; + + event AllowlistTargetUpdated( + uint32 indexed entityId, address indexed account, address indexed target, AllowlistEntry entry + ); + event AllowlistSelectorUpdated( + uint32 indexed entityId, address indexed account, bytes24 indexed targetAndSelector, bool allowed + ); error TargetNotAllowed(); error SelectorNotAllowed(); error NoSelectorSpecified(); function onInstall(bytes calldata data) external override { - AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[])); + (uint32 entityId, AllowlistInit[] memory init) = abi.decode(data, (uint32, AllowlistInit[])); for (uint256 i = 0; i < init.length; i++) { - targetAllowlist[init[i].target][msg.sender] = AllowlistEntry(true, init[i].hasSelectorAllowlist); + setAllowlistTarget(entityId, init[i].target, true, init[i].hasSelectorAllowlist); if (init[i].hasSelectorAllowlist) { for (uint256 j = 0; j < init[i].selectors.length; j++) { - selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender] = true; + setAllowlistSelector(entityId, init[i].target, init[i].selectors[j], true); } } } } function onUninstall(bytes calldata data) external override { - AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[])); + (uint32 entityId, AllowlistInit[] memory init) = abi.decode(data, (uint32, AllowlistInit[])); for (uint256 i = 0; i < init.length; i++) { - delete targetAllowlist[init[i].target][msg.sender]; + setAllowlistTarget(entityId, init[i].target, false, false); if (init[i].hasSelectorAllowlist) { for (uint256 j = 0; j < init[i].selectors.length; j++) { - delete selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender]; + setAllowlistSelector(entityId, init[i].target, init[i].selectors[j], false); } } } } - function setAllowlistTarget(address target, bool allowed, bool hasSelectorAllowlist) external { - targetAllowlist[target][msg.sender] = AllowlistEntry(allowed, hasSelectorAllowlist); - } - - function setAllowlistSelector(address target, bytes4 selector, bool allowed) external { - selectorAllowlist[target][selector][msg.sender] = allowed; - } - function preUserOpValidationHook(uint32 entityId, PackedUserOperation calldata userOp, bytes32) external view override returns (uint256) { - if (entityId == uint32(EntityId.PRE_VALIDATION_HOOK)) { - _checkAllowlistCalldata(userOp.callData); - return 0; - } - revert NotImplemented(); + checkAllowlistCalldata(entityId, userOp.callData); + return 0; } function preRuntimeValidationHook(uint32 entityId, address, uint256, bytes calldata data, bytes calldata) @@ -87,12 +81,8 @@ contract AllowlistModule is IValidationHook, BaseModule { view override { - if (entityId == uint32(EntityId.PRE_VALIDATION_HOOK)) { - _checkAllowlistCalldata(data); - return; - } - - revert NotImplemented(); + checkAllowlistCalldata(entityId, data); + return; } function moduleMetadata() external pure override returns (ModuleMetadata memory) { @@ -104,21 +94,36 @@ contract AllowlistModule is IValidationHook, BaseModule { return metadata; } - function _checkAllowlistCalldata(bytes calldata callData) internal view { + function setAllowlistTarget(uint32 entityId, address target, bool allowed, bool hasSelectorAllowlist) public { + AllowlistEntry memory entry = AllowlistEntry(allowed, hasSelectorAllowlist); + targetAllowlist[entityId][target][msg.sender] = entry; + emit AllowlistTargetUpdated(entityId, msg.sender, target, entry); + } + + function setAllowlistSelector(uint32 entityId, address target, bytes4 selector, bool allowed) public { + selectorAllowlist[entityId][target][selector][msg.sender] = allowed; + bytes24 targetAndSelector = bytes24(bytes24(bytes20(target)) | (bytes24(selector) >> 160)); + emit AllowlistSelectorUpdated(entityId, msg.sender, targetAndSelector, allowed); + } + + function checkAllowlistCalldata(uint32 entityId, bytes calldata callData) public view { if (bytes4(callData[:4]) == IStandardExecutor.execute.selector) { (address target,, bytes memory data) = abi.decode(callData[4:], (address, uint256, bytes)); - _checkCallPermission(msg.sender, target, data); + _checkCallPermission(entityId, msg.sender, target, data); } else if (bytes4(callData[:4]) == IStandardExecutor.executeBatch.selector) { Call[] memory calls = abi.decode(callData[4:], (Call[])); for (uint256 i = 0; i < calls.length; i++) { - _checkCallPermission(msg.sender, calls[i].target, calls[i].data); + _checkCallPermission(entityId, msg.sender, calls[i].target, calls[i].data); } } } - function _checkCallPermission(address account, address target, bytes memory data) internal view { - AllowlistEntry storage entry = targetAllowlist[target][account]; + function _checkCallPermission(uint32 entityId, address account, address target, bytes memory data) + internal + view + { + AllowlistEntry storage entry = targetAllowlist[entityId][target][account]; (bool allowed, bool hasSelectorAllowlist) = (entry.allowed, entry.hasSelectorAllowlist); if (!allowed) { @@ -132,7 +137,7 @@ contract AllowlistModule is IValidationHook, BaseModule { bytes4 selector = bytes4(data); - if (!selectorAllowlist[target][selector][account]) { + if (!selectorAllowlist[entityId][target][selector][account]) { revert SelectorNotAllowed(); } } diff --git a/test/module/AllowlistModule.t.sol b/test/module/AllowlistModule.t.sol index b0e09c90..5f0f243c 100644 --- a/test/module/AllowlistModule.t.sol +++ b/test/module/AllowlistModule.t.sol @@ -21,6 +21,18 @@ contract AllowlistModuleTest is CustomValidationTestBase { Counter[] public counters; + uint32 public constant HOOK_ENTITY_ID = 0; + + event AllowlistTargetUpdated( + uint32 indexed entityId, + address indexed account, + address indexed target, + AllowlistModule.AllowlistEntry entry + ); + event AllowlistSelectorUpdated( + uint32 indexed entityId, address indexed account, bytes24 indexed targetAndSelector, bool allowed + ); + function setUp() public { allowlistModule = new AllowlistModule(); @@ -97,6 +109,34 @@ contract AllowlistModuleTest is CustomValidationTestBase { } } + function _beforeInstallStep(address accountImpl) internal override { + // Expect events to be emitted from onInstall + for (uint256 i = 0; i < allowlistInit.length; i++) { + vm.expectEmit(address(allowlistModule)); + emit AllowlistTargetUpdated( + HOOK_ENTITY_ID, + accountImpl, + allowlistInit[i].target, + AllowlistModule.AllowlistEntry({ + allowed: true, + hasSelectorAllowlist: allowlistInit[i].hasSelectorAllowlist + }) + ); + + if (!allowlistInit[i].hasSelectorAllowlist) { + continue; + } + + for (uint256 j = 0; j < allowlistInit[i].selectors.length; j++) { + bytes24 targetAndSelector = bytes24( + bytes24(bytes20(allowlistInit[i].target)) | (bytes24(allowlistInit[i].selectors[j]) >> 160) + ); + vm.expectEmit(address(allowlistModule)); + emit AllowlistSelectorUpdated(HOOK_ENTITY_ID, accountImpl, targetAndSelector, true); + } + } + } + function _generateRandomCalls(uint256 seed) internal view returns (Call[] memory, uint256) { uint256 length = seed % 10; seed = _next(seed); @@ -147,11 +187,13 @@ contract AllowlistModuleTest is CustomValidationTestBase { Call memory call = calls[i]; (bool allowed, bool hasSelectorAllowlist) = - allowlistModule.targetAllowlist(call.target, address(account1)); + allowlistModule.targetAllowlist(HOOK_ENTITY_ID, call.target, address(account1)); if (allowed) { if ( hasSelectorAllowlist - && !allowlistModule.selectorAllowlist(call.target, bytes4(call.data), address(account1)) + && !allowlistModule.selectorAllowlist( + HOOK_ENTITY_ID, call.target, bytes4(call.data), address(account1) + ) ) { return abi.encodeWithSelector( IEntryPoint.FailedOpWithRevert.selector, @@ -178,16 +220,18 @@ contract AllowlistModuleTest is CustomValidationTestBase { Call memory call = calls[i]; (bool allowed, bool hasSelectorAllowlist) = - allowlistModule.targetAllowlist(call.target, address(account1)); + allowlistModule.targetAllowlist(HOOK_ENTITY_ID, call.target, address(account1)); if (allowed) { if ( hasSelectorAllowlist - && !allowlistModule.selectorAllowlist(call.target, bytes4(call.data), address(account1)) + && !allowlistModule.selectorAllowlist( + HOOK_ENTITY_ID, call.target, bytes4(call.data), address(account1) + ) ) { return abi.encodeWithSelector( UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, address(allowlistModule), - uint32(AllowlistModule.EntityId.PRE_VALIDATION_HOOK), + HOOK_ENTITY_ID, abi.encodeWithSelector(AllowlistModule.SelectorNotAllowed.selector) ); } @@ -195,7 +239,7 @@ contract AllowlistModuleTest is CustomValidationTestBase { return abi.encodeWithSelector( UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, address(allowlistModule), - uint32(AllowlistModule.EntityId.PRE_VALIDATION_HOOK), + HOOK_ENTITY_ID, abi.encodeWithSelector(AllowlistModule.TargetNotAllowed.selector) ); } @@ -279,12 +323,6 @@ contract AllowlistModuleTest is CustomValidationTestBase { return (init, seed); } - // todo: runtime paths - - // fuzz targets, fuzz target selectors. - - // Maybe pull out the helper function for running user ops and possibly expect a failure? - function _next(uint256 seed) internal pure returns (uint256) { return uint256(keccak256(abi.encodePacked(seed))); } @@ -297,10 +335,8 @@ contract AllowlistModuleTest is CustomValidationTestBase { { bytes[] memory hooks = new bytes[](1); hooks[0] = abi.encodePacked( - HookConfigLib.packValidationHook( - address(allowlistModule), uint32(AllowlistModule.EntityId.PRE_VALIDATION_HOOK) - ), - abi.encode(allowlistInit) + HookConfigLib.packValidationHook(address(allowlistModule), HOOK_ENTITY_ID), + abi.encode(HOOK_ENTITY_ID, allowlistInit) ); return ( diff --git a/test/script/Deploy.s.t.sol b/test/script/Deploy.s.t.sol index 105ce9b8..b25bf7fe 100644 --- a/test/script/Deploy.s.t.sol +++ b/test/script/Deploy.s.t.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.25; import {Test} from "forge-std/Test.sol"; import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {IStakeManager} from "@eth-infinitism/account-abstraction/interfaces/IStakeManager.sol"; import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; import {DeployScript} from "../../script/Deploy.s.sol"; @@ -25,7 +26,10 @@ contract DeployTest is Test { function setUp() public { _entryPoint = new EntryPoint(); - _owner = makeAddr("OWNER"); + + // Set the owner to the foundry default sender, as this is what will be used as the sender within the + // `startBroadcast` segment of the script. + _owner = DEFAULT_SENDER; vm.setEnv("ENTRYPOINT", vm.toString(address(_entryPoint))); vm.setEnv("OWNER", vm.toString(_owner)); @@ -72,5 +76,34 @@ contract DeployTest is Test { assertTrue(_accountImpl.code.length > 0); assertTrue(_factory.code.length > 0); assertTrue(_singleSignerValidation.code.length > 0); + + assertEq( + _singleSignerValidation.code, + type(SingleSignerValidation).runtimeCode, + "SingleSignerValidation runtime code mismatch" + ); + + // Check factory stake + IStakeManager.DepositInfo memory depositInfo = _entryPoint.getDepositInfo(_factory); + + assertTrue(depositInfo.staked, "Factory not staked"); + assertEq(depositInfo.stake, 0.1 ether, "Unexpected factory stake amount"); + assertEq(depositInfo.unstakeDelaySec, 1 days, "Unexpected factory unstake delay"); + } + + function test_deployScript_addStake() public { + test_deployScript_run(); + + vm.setEnv("STAKE_AMOUNT", vm.toString(uint256(0.3 ether))); + + // Refresh script's env vars + + _deployScript = new DeployScript(); + + _deployScript.run(); + + IStakeManager.DepositInfo memory depositInfo = _entryPoint.getDepositInfo(_factory); + + assertEq(depositInfo.stake, 0.3 ether, "Unexpected factory stake amount"); } } diff --git a/test/script/DeployAllowlistModule.s.t.sol b/test/script/DeployAllowlistModule.s.t.sol new file mode 100644 index 00000000..27d4666a --- /dev/null +++ b/test/script/DeployAllowlistModule.s.t.sol @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.25; + +import {Test} from "forge-std/Test.sol"; + +import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; + +import {DeployAllowlistModuleScript} from "../../script/DeployAllowlistModule.s.sol"; + +import {AllowlistModule} from "../../src/modules/permissionhooks/AllowlistModule.sol"; + +contract DeployAllowlistModuleTest is Test { + DeployAllowlistModuleScript internal _deployScript; + + address internal _allowlistModule; + + function setUp() public { + _allowlistModule = + Create2.computeAddress(bytes32(0), keccak256(type(AllowlistModule).creationCode), CREATE2_FACTORY); + + vm.setEnv("ALLOWLIST_MODULE", vm.toString(address(_allowlistModule))); + + _deployScript = new DeployAllowlistModuleScript(); + } + + function test_deployAllowlistModuleScript_run() public { + _deployScript.run(); + + assertTrue(_allowlistModule.code.length > 0, "AllowlistModule not deployed"); + assertEq(_allowlistModule.code, type(AllowlistModule).runtimeCode, "AllowlistModule runtime code mismatch"); + } +} diff --git a/test/utils/CustomValidationTestBase.sol b/test/utils/CustomValidationTestBase.sol index 3b313039..a7920623 100644 --- a/test/utils/CustomValidationTestBase.sol +++ b/test/utils/CustomValidationTestBase.sol @@ -28,6 +28,8 @@ abstract contract CustomValidationTestBase is AccountTestBase { account1 = UpgradeableModularAccount(payable(new ERC1967Proxy{salt: 0}(accountImplementation, ""))); + _beforeInstallStep(address(account1)); + account1.initializeWithValidation( ValidationConfigLib.pack(validationFunction, isGlobal, isSignatureValidation), selectors, @@ -49,4 +51,12 @@ abstract contract CustomValidationTestBase is AccountTestBase { bytes memory installData, bytes[] memory hooks ); + + // If the test needs to perform any setup or checks after the account is created, but before the call to + // `initializeWithValidation`, + // it should override this function. + function _beforeInstallStep(address accountImpl) internal virtual { + // Does nothing by default + (accountImpl); + } }