diff --git a/.env.example b/.env.example index 612c3436..a7bd543e 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,3 @@ - # Factory owner capable only of managing stake OWNER= # EP 0.7 address @@ -22,3 +21,6 @@ UNSTAKE_DELAY= # Allowlist Module ALLOWLIST_MODULE= ALLOWLIST_MODULE_SALT= + +# Whether to test the semi-modular or full modular account +SMA_TEST=false diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dfa5cf41..3187d269 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,7 +66,10 @@ jobs: run: FOUNDRY_PROFILE=optimized-build forge build - name: Run tests - run: FOUNDRY_PROFILE=optimized-test-deep forge test -vvv + run: FOUNDRY_PROFILE=optimized-test-deep SMA_TEST=false forge test -vvv + + - name: Run SMA tests + run: FOUNDRY_PROFILE=optimized-test-deep SMA_TEST=true forge test -vvv test-default: name: Run Forge Tests (default) @@ -88,4 +91,7 @@ jobs: run: forge build - name: Run tests - run: forge test -vvv + run: SMA_TEST=false forge test -vvv + + - name: Run SMA tests + run: SMA_TEST=true forge test -vvv diff --git a/.gitmodules b/.gitmodules index 05bd137f..229aff1b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "lib/modular-account-libs"] path = lib/modular-account-libs url = https://github.com/erc6900/modular-account-libs +[submodule "lib/solady"] + path = lib/solady + url = https://github.com/vectorized/solady diff --git a/foundry.toml b/foundry.toml index 516e1a1f..56131b8f 100644 --- a/foundry.toml +++ b/foundry.toml @@ -25,7 +25,7 @@ depth = 10 [profile.optimized-build] via_ir = true test = 'src' -optimizer_runs = 15000 +optimizer_runs = 10000 out = 'out-optimized' [profile.optimized-test] diff --git a/lib/solady b/lib/solady new file mode 160000 index 00000000..a1f9be98 --- /dev/null +++ b/lib/solady @@ -0,0 +1 @@ +Subproject commit a1f9be988d3c12655692cb8cdfc6864cc393cff6 diff --git a/remappings.txt b/remappings.txt index bc2ce0be..8d9639cf 100644 --- a/remappings.txt +++ b/remappings.txt @@ -2,4 +2,5 @@ ds-test/=lib/forge-std/lib/ds-test/src/ forge-std/=lib/forge-std/src/ @eth-infinitism/account-abstraction/=lib/account-abstraction/contracts/ @openzeppelin/=lib/openzeppelin-contracts/ -@modular-account-libs/=lib/modular-account-libs/src/ \ No newline at end of file +@modular-account-libs/=lib/modular-account-libs/src/ +solady=lib/solady/src/ diff --git a/script/Deploy.s.sol b/script/Deploy.s.sol index 1ad11be5..efbf0f05 100644 --- a/script/Deploy.s.sol +++ b/script/Deploy.s.sol @@ -7,6 +7,8 @@ import {Script, console} from "forge-std/Script.sol"; import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; import {AccountFactory} from "../src/account/AccountFactory.sol"; + +import {SemiModularAccount} from "../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../src/account/UpgradeableModularAccount.sol"; import {SingleSignerValidationModule} from "../src/modules/validation/SingleSignerValidationModule.sol"; @@ -16,10 +18,12 @@ contract DeployScript is Script { address public owner = vm.envAddress("OWNER"); address public accountImpl = vm.envOr("ACCOUNT_IMPL", address(0)); + address public semiModularAccountImpl = vm.envOr("SMA_IMPL", address(0)); address public factory = vm.envOr("FACTORY", address(0)); address public singleSignerValidationModule = vm.envOr("SINGLE_SIGNER_VALIDATION_MODULE", address(0)); bytes32 public accountImplSalt = bytes32(vm.envOr("ACCOUNT_IMPL_SALT", uint256(0))); + bytes32 public semiModularAccountImplSalt = bytes32(vm.envOr("SMA_IMPL_SALT", uint256(0))); bytes32 public factorySalt = bytes32(vm.envOr("FACTORY_SALT", uint256(0))); bytes32 public singleSignerValidationModuleSalt = bytes32(vm.envOr("SINGLE_SIGNER_VALIDATION_MODULE_SALT", uint256(0))); @@ -35,7 +39,8 @@ contract DeployScript is Script { vm.startBroadcast(); _deployAccountImpl(accountImplSalt, accountImpl); - _deploySingleSignerValidationModule(singleSignerValidationModuleSalt, singleSignerValidationModule); + _deploySemiModularAccountImpl(semiModularAccountImplSalt, semiModularAccountImpl); + _deploySingleSignerValidation(singleSignerValidationModuleSalt, singleSignerValidationModule); _deployAccountFactory(factorySalt, factory); _addStakeForFactory(uint32(requiredUnstakeDelay), requiredStakeAmount); vm.stopBroadcast(); @@ -73,7 +78,39 @@ contract DeployScript is Script { } } - function _deploySingleSignerValidationModule(bytes32 salt, address expected) internal { + function _deploySemiModularAccountImpl(bytes32 salt, address expected) internal { + console.log(string.concat("Deploying SemiModularAccountImpl with salt: ", vm.toString(salt))); + + address addr = Create2.computeAddress( + salt, + keccak256(abi.encodePacked(type(SemiModularAccount).creationCode, abi.encode(entryPoint))), + CREATE2_FACTORY + ); + if (addr != expected) { + console.log("Expected address mismatch"); + console.log("Expected: ", expected); + console.log("Actual: ", addr); + revert(); + } + + if (addr.code.length == 0) { + console.log("No code found at expected address, deploying..."); + SemiModularAccount deployed = new SemiModularAccount{salt: salt}(entryPoint); + + if (address(deployed) != expected) { + console.log("Deployed address mismatch"); + console.log("Expected: ", expected); + console.log("Deployed: ", address(deployed)); + revert(); + } + + console.log("Deployed SemiModularAccount at: ", address(deployed)); + } else { + console.log("Code found at expected address, skipping deployment"); + } + } + + function _deploySingleSignerValidation(bytes32 salt, address expected) internal { console.log(string.concat("Deploying SingleSignerValidationModule with salt: ", vm.toString(salt))); address addr = Create2.computeAddress( @@ -111,7 +148,9 @@ contract DeployScript is Script { keccak256( abi.encodePacked( type(AccountFactory).creationCode, - abi.encode(entryPoint, accountImpl, singleSignerValidationModule, owner) + abi.encode( + entryPoint, accountImpl, semiModularAccountImpl, singleSignerValidationModule, owner + ) ) ), CREATE2_FACTORY @@ -126,7 +165,11 @@ contract DeployScript is Script { if (addr.code.length == 0) { console.log("No code found at expected address, deploying..."); AccountFactory deployed = new AccountFactory{salt: salt}( - entryPoint, UpgradeableModularAccount(payable(accountImpl)), singleSignerValidationModule, owner + entryPoint, + UpgradeableModularAccount(payable(accountImpl)), + SemiModularAccount(payable(semiModularAccountImpl)), + singleSignerValidationModule, + owner ); if (address(deployed) != expected) { diff --git a/src/account/AccountFactory.sol b/src/account/AccountFactory.sol index e768dee4..11ad5684 100644 --- a/src/account/AccountFactory.sol +++ b/src/account/AccountFactory.sol @@ -7,20 +7,26 @@ import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; +import {SemiModularAccount} from "../account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../account/UpgradeableModularAccount.sol"; import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol"; +import {LibClone} from "solady/utils/LibClone.sol"; + contract AccountFactory is Ownable { UpgradeableModularAccount public immutable ACCOUNT_IMPL; + SemiModularAccount public immutable SEMI_MODULAR_ACCOUNT_IMPL; bytes32 private immutable _PROXY_BYTECODE_HASH; IEntryPoint public immutable ENTRY_POINT; address public immutable SINGLE_SIGNER_VALIDATION_MODULE; event ModularAccountDeployed(address indexed account, address indexed owner, uint256 salt); + event SemiModularAccountDeployed(address indexed account, address indexed owner, uint256 salt); constructor( IEntryPoint _entryPoint, UpgradeableModularAccount _accountImpl, + SemiModularAccount _semiModularImpl, address _singleSignerValidationModule, address owner ) Ownable(owner) { @@ -28,6 +34,7 @@ contract AccountFactory is Ownable { _PROXY_BYTECODE_HASH = keccak256(abi.encodePacked(type(ERC1967Proxy).creationCode, abi.encode(address(_accountImpl), ""))); ACCOUNT_IMPL = _accountImpl; + SEMI_MODULAR_ACCOUNT_IMPL = _semiModularImpl; SINGLE_SIGNER_VALIDATION_MODULE = _singleSignerValidationModule; } @@ -63,6 +70,23 @@ contract AccountFactory is Ownable { return UpgradeableModularAccount(payable(addr)); } + function createSemiModularAccount(address owner, uint256 salt) external returns (SemiModularAccount) { + // both module address and entityId for fallback validations are hardcoded at the maximum value. + bytes32 fullSalt = getSalt(owner, salt, type(uint32).max); + + bytes memory immutables = _getImmutableArgs(owner); + + // LibClone short-circuits if it's already deployed. + (bool alreadyDeployed, address instance) = + LibClone.createDeterministicERC1967(address(SEMI_MODULAR_ACCOUNT_IMPL), immutables, fullSalt); + + if (!alreadyDeployed) { + emit SemiModularAccountDeployed(instance, owner, salt); + } + + return SemiModularAccount(payable(instance)); + } + function addStake(uint32 unstakeDelay) external payable onlyOwner { ENTRY_POINT.addStake{value: msg.value}(unstakeDelay); } @@ -82,7 +106,21 @@ contract AccountFactory is Ownable { return Create2.computeAddress(getSalt(owner, salt, entityId), _PROXY_BYTECODE_HASH); } + function getAddressSemiModular(address owner, uint256 salt) public view returns (address) { + bytes32 fullSalt = getSalt(owner, salt, type(uint32).max); + bytes memory immutables = _getImmutableArgs(owner); + return _getAddressSemiModular(immutables, fullSalt); + } + function getSalt(address owner, uint256 salt, uint32 entityId) public pure returns (bytes32) { return keccak256(abi.encodePacked(owner, salt, entityId)); } + + function _getAddressSemiModular(bytes memory immutables, bytes32 salt) internal view returns (address) { + return LibClone.predictDeterministicAddressERC1967(address(ACCOUNT_IMPL), immutables, salt, address(this)); + } + + function _getImmutableArgs(address owner) private pure returns (bytes memory) { + return abi.encodePacked(owner); + } } diff --git a/src/account/SemiModularAccount.sol b/src/account/SemiModularAccount.sol new file mode 100644 index 00000000..4a6e918d --- /dev/null +++ b/src/account/SemiModularAccount.sol @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.25; + +import {UpgradeableModularAccount} from "./UpgradeableModularAccount.sol"; +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {ModuleEntityLib} from "../helpers/ModuleEntityLib.sol"; + +import {ModuleEntity, ValidationConfig} from "../interfaces/IModuleManager.sol"; + +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; +import {LibClone} from "solady/utils/LibClone.sol"; + +contract SemiModularAccount is UpgradeableModularAccount { + using MessageHashUtils for bytes32; + using ModuleEntityLib for ModuleEntity; + + struct SemiModularAccountStorage { + address fallbackSigner; + bool fallbackSignerDisabled; + } + + // keccak256("ERC6900.SemiModularAccount.Storage") + uint256 internal constant _SEMI_MODULAR_ACCOUNT_STORAGE_SLOT = + 0x5b9dc9aa943f8fa2653ceceda5e3798f0686455280432166ba472eca0bc17a32; + + ModuleEntity internal constant _FALLBACK_VALIDATION = ModuleEntity.wrap(bytes24(type(uint192).max)); + + uint256 internal constant _SIG_VALIDATION_PASSED = 0; + uint256 internal constant _SIG_VALIDATION_FAILED = 1; + + event FallbackSignerSet(address indexed previousFallbackSigner, address indexed newFallbackSigner); + event FallbackSignerDisabledSet(bool prevDisabled, bool newDisabled); + + error FallbackSignerMismatch(); + error FallbackSignerDisabled(); + error InitializerDisabled(); + + constructor(IEntryPoint anEntryPoint) UpgradeableModularAccount(anEntryPoint) {} + + /// Override reverts on initialization, effectively disabling the initializer. + function initializeWithValidation(ValidationConfig, bytes4[] calldata, bytes calldata, bytes[] calldata) + external + override + initializer + { + revert InitializerDisabled(); + } + + /// @notice Updates the fallback signer address in storage. + /// @dev This function causes the fallback signer getter to ignore the bytecode signer if it is nonzero. It can + /// also be used to revert back to the bytecode signer by setting to zero. + /// @param fallbackSigner The new signer to set. + function updateFallbackSigner(address fallbackSigner) external wrapNativeFunction { + SemiModularAccountStorage storage _storage = _getSemiModularAccountStorage(); + emit FallbackSignerSet(_storage.fallbackSigner, fallbackSigner); + + _storage.fallbackSigner = fallbackSigner; + } + + /// @notice Sets whether the fallback signer validation should be enabled or disabled. + /// @dev Due to being initially zero, we need to store "disabled" rather than "enabled" in storage. + /// @param isDisabled True to disable fallback signer validation, false to enable it. + function setFallbackSignerDisabled(bool isDisabled) external wrapNativeFunction { + SemiModularAccountStorage storage _storage = _getSemiModularAccountStorage(); + emit FallbackSignerDisabledSet(_storage.fallbackSignerDisabled, isDisabled); + + _storage.fallbackSignerDisabled = isDisabled; + } + + /// @notice Returns whether the fallback signer validation is disabled. + /// @return True if the fallback signer validation is disabled, false if it is enabled. + function isFallbackSignerDisabled() external view returns (bool) { + return _getSemiModularAccountStorage().fallbackSignerDisabled; + } + + /// @notice Returns the fallback signer associated with this account, regardless if the fallback signer + /// validation is enabled or not. + /// @return The fallback signer address, either overriden in storage, or read from bytecode. + function getFallbackSigner() external view returns (address) { + return _retrieveFallbackSignerUnchecked(_getSemiModularAccountStorage()); + } + + function _execUserOpValidation( + ModuleEntity userOpValidationFunction, + PackedUserOperation memory userOp, + bytes32 userOpHash + ) internal override returns (uint256) { + if (userOpValidationFunction.eq(_FALLBACK_VALIDATION)) { + address fallbackSigner = _getFallbackSigner(); + + if ( + SignatureChecker.isValidSignatureNow( + fallbackSigner, userOpHash.toEthSignedMessageHash(), userOp.signature + ) + ) { + return _SIG_VALIDATION_PASSED; + } + return _SIG_VALIDATION_FAILED; + } + + return super._execUserOpValidation(userOpValidationFunction, userOp, userOpHash); + } + + function _execRuntimeValidation( + ModuleEntity runtimeValidationFunction, + bytes calldata callData, + bytes calldata authorization + ) internal override { + if (runtimeValidationFunction.eq(_FALLBACK_VALIDATION)) { + address fallbackSigner = _getFallbackSigner(); + + if (msg.sender != fallbackSigner) { + revert FallbackSignerMismatch(); + } + return; + } + super._execRuntimeValidation(runtimeValidationFunction, callData, authorization); + } + + function _exec1271Validation(ModuleEntity sigValidation, bytes32 hash, bytes calldata signature) + internal + view + override + returns (bytes4) + { + if (sigValidation.eq(_FALLBACK_VALIDATION)) { + address fallbackSigner = _getFallbackSigner(); + + if (SignatureChecker.isValidSignatureNow(fallbackSigner, hash, signature)) { + return _1271_MAGIC_VALUE; + } + return _1271_INVALID; + } + return super._exec1271Validation(sigValidation, hash, signature); + } + + function _globalValidationAllowed(bytes4 selector) internal view override returns (bool) { + return selector == this.setFallbackSignerDisabled.selector + || selector == this.updateFallbackSigner.selector || super._globalValidationAllowed(selector); + } + + function _isValidationGlobal(ModuleEntity validationFunction) internal view override returns (bool) { + return validationFunction.eq(_FALLBACK_VALIDATION) || super._isValidationGlobal(validationFunction); + } + + function _getFallbackSigner() internal view returns (address) { + SemiModularAccountStorage storage _storage = _getSemiModularAccountStorage(); + + if (_storage.fallbackSignerDisabled) { + revert FallbackSignerDisabled(); + } + + return _retrieveFallbackSignerUnchecked(_storage); + } + + function _retrieveFallbackSignerUnchecked(SemiModularAccountStorage storage _storage) + internal + view + returns (address) + { + address storageFallbackSigner = _storage.fallbackSigner; + if (storageFallbackSigner != address(0)) { + return storageFallbackSigner; + } + + bytes memory appendedData = LibClone.argsOnERC1967(address(this), 0, 20); + + return address(uint160(bytes20(appendedData))); + } + + function _getSemiModularAccountStorage() internal pure returns (SemiModularAccountStorage storage) { + SemiModularAccountStorage storage _storage; + assembly ("memory-safe") { + _storage.slot := _SEMI_MODULAR_ACCOUNT_STORAGE_SLOT + } + return _storage; + } +} diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 2904e1e6..410c8610 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -238,15 +238,13 @@ contract UpgradeableModularAccount is } /// @notice Initializes the account with a validation function added to the global pool. - /// TODO: remove and merge with regular initialization, after we figure out a better install/uninstall workflow - /// with user install configs. /// @dev This function is only callable once. function initializeWithValidation( ValidationConfig validationConfig, bytes4[] calldata selectors, bytes calldata installData, bytes[] calldata hooks - ) external initializer { + ) external virtual initializer { _installValidation(validationConfig, selectors, installData, hooks); } @@ -299,22 +297,9 @@ contract UpgradeableModularAccount is } function isValidSignature(bytes32 hash, bytes calldata signature) public view override returns (bytes4) { - AccountStorage storage _storage = getAccountStorage(); - ModuleEntity sigValidation = ModuleEntity.wrap(bytes24(signature)); - (address module, uint32 entityId) = sigValidation.unpack(); - if (!_storage.validationData[sigValidation].isSignatureValidation) { - revert SignatureValidationInvalid(module, entityId); - } - - if ( - IValidationModule(module).validateSignature(address(this), entityId, msg.sender, hash, signature[24:]) - == _1271_MAGIC_VALUE - ) { - return _1271_MAGIC_VALUE; - } - return _1271_INVALID; + return _exec1271Validation(sigValidation, hash, signature[24:]); } /// @notice Gets the entry point for this account @@ -328,7 +313,6 @@ contract UpgradeableModularAccount is // Parent function validateUserOp enforces that this call can only be made by the EntryPoint function _validateSignature(PackedUserOperation calldata userOp, bytes32 userOpHash) internal - virtual override returns (uint256 validationData) { @@ -413,8 +397,7 @@ contract UpgradeableModularAccount is userOp.signature = signatureSegment.getBody(); - (address module, uint32 entityId) = userOpValidationFunction.unpack(); - uint256 currentValidationRes = IValidationModule(module).validateUserOp(entityId, userOp, userOpHash); + uint256 currentValidationRes = _execUserOpValidation(userOpValidationFunction, userOp, userOpHash); if (preUserOpValidationHooks.length != 0) { // If we have other validation data we need to coalesce with @@ -467,17 +450,7 @@ contract UpgradeableModularAccount is revert ValidationSignatureSegmentMissing(); } - (address module, uint32 entityId) = runtimeValidationFunction.unpack(); - - try IValidationModule(module).validateRuntime( - address(this), entityId, msg.sender, msg.value, callData, authSegment.getBody() - ) - // forgefmt: disable-start - // solhint-disable-next-line no-empty-blocks - {} catch (bytes memory revertReason){ - // forgefmt: disable-end - revert RuntimeValidationFunctionReverted(module, entityId, revertReason); - } + _execRuntimeValidation(runtimeValidationFunction, callData, authSegment.getBody()); } function _doPreHooks(EnumerableSet.Bytes32Set storage executionHooks, bytes memory data) @@ -627,6 +600,73 @@ contract UpgradeableModularAccount is return (postPermissionHooks, postExecutionHooks); } + function _execUserOpValidation( + ModuleEntity userOpValidationFunction, + PackedUserOperation memory userOp, + bytes32 userOpHash + ) internal virtual returns (uint256) { + (address module, uint32 entityId) = userOpValidationFunction.unpack(); + + return IValidationModule(module).validateUserOp(entityId, userOp, userOpHash); + } + + function _execRuntimeValidation( + ModuleEntity runtimeValidationFunction, + bytes calldata callData, + bytes calldata authorization + ) internal virtual { + (address module, uint32 entityId) = runtimeValidationFunction.unpack(); + + try IValidationModule(module).validateRuntime( + address(this), entityId, msg.sender, msg.value, callData, authorization + ) + // forgefmt: disable-start + // solhint-disable-next-line no-empty-blocks + {} catch (bytes memory revertReason){ + // forgefmt: disable-end + revert RuntimeValidationFunctionReverted(module, entityId, revertReason); + } + } + + function _exec1271Validation(ModuleEntity sigValidation, bytes32 hash, bytes calldata signature) + internal + view + virtual + returns (bytes4) + { + AccountStorage storage _storage = getAccountStorage(); + + (address module, uint32 entityId) = sigValidation.unpack(); + if (!_storage.validationData[sigValidation].isSignatureValidation) { + revert SignatureValidationInvalid(module, entityId); + } + + if ( + IValidationModule(module).validateSignature(address(this), entityId, msg.sender, hash, signature) + == _1271_MAGIC_VALUE + ) { + return _1271_MAGIC_VALUE; + } + return _1271_INVALID; + } + + function _globalValidationAllowed(bytes4 selector) internal view virtual returns (bool) { + if ( + selector == this.execute.selector || selector == this.executeBatch.selector + || selector == this.installExecution.selector || selector == this.uninstallExecution.selector + || selector == this.installValidation.selector || selector == this.uninstallValidation.selector + || selector == this.upgradeToAndCall.selector + ) { + return true; + } + + return getAccountStorage().executionData[selector].allowGlobalValidation; + } + + function _isValidationGlobal(ModuleEntity validationFunction) internal view virtual returns (bool) { + return getAccountStorage().validationData[validationFunction].isGlobal; + } + function _checkIfValidationAppliesCallData( bytes calldata callData, ModuleEntity validationFunction, @@ -680,28 +720,13 @@ contract UpgradeableModularAccount is } } - function _globalValidationAllowed(bytes4 selector) internal view returns (bool) { - if ( - selector == this.execute.selector || selector == this.executeBatch.selector - || selector == this.installExecution.selector || selector == this.uninstallExecution.selector - || selector == this.installValidation.selector || selector == this.uninstallValidation.selector - || selector == this.upgradeToAndCall.selector - ) { - return true; - } - - return getAccountStorage().executionData[selector].allowGlobalValidation; - } - function _checkIfValidationAppliesSelector(bytes4 selector, ModuleEntity validationFunction, bool isGlobal) internal view { - AccountStorage storage _storage = getAccountStorage(); - // Check that the provided validation function is applicable to the selector if (isGlobal) { - if (!_globalValidationAllowed(selector) || !_storage.validationData[validationFunction].isGlobal) { + if (!_globalValidationAllowed(selector) || !_isValidationGlobal(validationFunction)) { revert ValidationFunctionMissing(selector); } } else { diff --git a/src/modules/validation/SingleSignerValidationModule.sol b/src/modules/validation/SingleSignerValidationModule.sol index 7ad1c768..c869af22 100644 --- a/src/modules/validation/SingleSignerValidationModule.sol +++ b/src/modules/validation/SingleSignerValidationModule.sol @@ -1,14 +1,13 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity ^0.8.25; -import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; -import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; -import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; - import {IModule, ModuleMetadata} from "../../interfaces/IModule.sol"; import {IValidationModule} from "../../interfaces/IValidationModule.sol"; import {BaseModule} from "../BaseModule.sol"; import {ISingleSignerValidationModule} from "./ISingleSignerValidationModule.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; /// @title ECSDA Validation /// @author ERC-6900 Authors diff --git a/test/account/AccountFactory.t.sol b/test/account/AccountFactory.t.sol index 25194f1d..92d79752 100644 --- a/test/account/AccountFactory.t.sol +++ b/test/account/AccountFactory.t.sol @@ -2,16 +2,23 @@ pragma solidity ^0.8.19; import {AccountFactory} from "../../src/account/AccountFactory.sol"; + +import {SemiModularAccount} from "../../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; contract AccountFactoryTest is AccountTestBase { AccountFactory internal _factory; UpgradeableModularAccount internal _account; + SemiModularAccount internal _semiModularAccount; function setUp() public { _account = new UpgradeableModularAccount(entryPoint); - _factory = new AccountFactory(entryPoint, _account, address(singleSignerValidationModule), address(this)); + _semiModularAccount = new SemiModularAccount(entryPoint); + + _factory = new AccountFactory( + entryPoint, _account, _semiModularAccount, address(singleSignerValidationModule), address(this) + ); } function test_createAccount() public { diff --git a/test/account/AccountReturnData.t.sol b/test/account/AccountReturnData.t.sol index 5f4cb98d..3542f3db 100644 --- a/test/account/AccountReturnData.t.sol +++ b/test/account/AccountReturnData.t.sol @@ -2,7 +2,6 @@ pragma solidity ^0.8.19; import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../src/helpers/Constants.sol"; -import {ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; import {ValidationConfigLib} from "../../src/helpers/ValidationConfigLib.sol"; import {Call} from "../../src/interfaces/IStandardExecutor.sol"; import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; @@ -13,7 +12,6 @@ import { ResultCreatorModule } from "../mocks/modules/ReturnDataModuleMocks.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; -import {TEST_DEFAULT_VALIDATION_ENTITY_ID} from "../utils/TestConstants.sol"; // Tests all the different ways that return data can be read from modules through an account contract AccountReturnDataTest is AccountTestBase { @@ -67,11 +65,7 @@ contract AccountReturnDataTest is AccountTestBase { account1.execute, (address(regularResultContract), 0, abi.encodeCall(RegularResultContract.foo, ())) ), - _encodeSignature( - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), - GLOBAL_VALIDATION, - "" - ) + _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "") ); bytes32 result = abi.decode(abi.decode(returnData, (bytes)), (bytes32)); @@ -95,11 +89,7 @@ contract AccountReturnDataTest is AccountTestBase { bytes memory retData = account1.executeWithAuthorization( abi.encodeCall(account1.executeBatch, (calls)), - _encodeSignature( - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), - GLOBAL_VALIDATION, - "" - ) + _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "") ); bytes[] memory returnDatas = abi.decode(retData, (bytes[])); diff --git a/test/account/DirectCallsFromModule.t.sol b/test/account/DirectCallsFromModule.t.sol index 5ffeecf3..5f013055 100644 --- a/test/account/DirectCallsFromModule.t.sol +++ b/test/account/DirectCallsFromModule.t.sol @@ -36,9 +36,9 @@ contract DirectCallsFromModuleTest is AccountTestBase { } function test_Fail_DirectCallModuleUninstalled() external { - _installExecution(); + _installValidation(); - _uninstallExecution(); + _uninstallValidation(); vm.prank(address(_module)); vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.execute.selector)); @@ -46,7 +46,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { } function test_Fail_DirectCallModuleCallOtherSelector() external { - _installExecution(); + _installValidation(); Call[] memory calls = new Call[](0); @@ -60,7 +60,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { /* -------------------------------------------------------------------------- */ function test_Pass_DirectCallFromModulePrank() external { - _installExecution(); + _installValidation(); vm.prank(address(_module)); account1.execute(address(0), 0, ""); @@ -70,7 +70,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { } function test_Pass_DirectCallFromModuleCallback() external { - _installExecution(); + _installValidation(); bytes memory encodedCall = abi.encodeCall(DirectCallModule.directCall, ()); @@ -88,7 +88,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { function test_Flow_DirectCallFromModuleSequence() external { // Install => Succeesfully call => uninstall => fail to call - _installExecution(); + _installValidation(); vm.prank(address(_module)); account1.execute(address(0), 0, ""); @@ -96,7 +96,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { assertTrue(_module.preHookRan()); assertTrue(_module.postHookRan()); - _uninstallExecution(); + _uninstallValidation(); vm.prank(address(_module)); vm.expectRevert(_buildDirectCallDisallowedError(IStandardExecutor.execute.selector)); @@ -107,7 +107,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { /* Internals */ /* -------------------------------------------------------------------------- */ - function _installExecution() internal { + function _installValidation() internal { bytes4[] memory selectors = new bytes4[](1); selectors[0] = IStandardExecutor.execute.selector; @@ -124,7 +124,7 @@ contract DirectCallsFromModuleTest is AccountTestBase { account1.installValidation(validationConfig, selectors, "", hooks); } - function _uninstallExecution() internal { + function _uninstallValidation() internal { (address module, uint32 entityId) = ModuleEntityLib.unpack(_moduleEntity); vm.prank(address(entryPoint)); vm.expectEmit(true, true, true, true); diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index 298f8e87..78de0657 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -43,8 +43,7 @@ contract MultiValidationTest is AccountTestBase { ); ModuleEntity[] memory validations = new ModuleEntity[](2); - validations[0] = - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID); + validations[0] = _signerValidation; validations[1] = ModuleEntityLib.pack(address(validator2), TEST_DEFAULT_VALIDATION_ENTITY_ID); bytes4[] memory selectors0 = account1.getValidationData(validations[0]).selectors; @@ -65,7 +64,7 @@ contract MultiValidationTest is AccountTestBase { abi.encodeWithSelector( UpgradeableModularAccount.RuntimeValidationFunctionReverted.selector, address(validator2), - 1, + TEST_DEFAULT_VALIDATION_ENTITY_ID, abi.encodeWithSignature("NotAuthorized()") ) ); diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol index cc3f3415..2be42f86 100644 --- a/test/account/PerHookData.t.sol +++ b/test/account/PerHookData.t.sol @@ -8,7 +8,7 @@ import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/Messa import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {HookConfigLib} from "../../src/helpers/HookConfigLib.sol"; -import {ModuleEntity} from "../../src/helpers/ModuleEntityLib.sol"; +import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; import {Counter} from "../mocks/Counter.sol"; import {MockAccessControlHookModule} from "../mocks/modules/MockAccessControlHookModule.sol"; @@ -22,6 +22,9 @@ contract PerHookDataTest is CustomValidationTestBase { Counter internal _counter; function setUp() public { + _signerValidation = + ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID); + _counter = new Counter(); _accessControlHookModule = new MockAccessControlHookModule(); @@ -341,14 +344,8 @@ contract PerHookDataTest is CustomValidationTestBase { ), abi.encode(_counter) ); - - return ( - _signerValidation, - true, - true, - new bytes4[](0), - abi.encode(TEST_DEFAULT_VALIDATION_ENTITY_ID, owner1), - hooks - ); + // patched to also work during SMA tests by differentiating the validation + _signerValidation = ModuleEntityLib.pack(address(singleSignerValidationModule), type(uint32).max - 1); + return (_signerValidation, true, true, new bytes4[](0), abi.encode(type(uint32).max - 1, owner1), hooks); } } diff --git a/test/account/ReplaceModule.t.sol b/test/account/ReplaceModule.t.sol index 32d4c7d0..9de8247b 100644 --- a/test/account/ReplaceModule.t.sol +++ b/test/account/ReplaceModule.t.sol @@ -21,7 +21,6 @@ import {IValidationHookModule} from "../../src/interfaces/IValidationHookModule. import {SingleSignerValidationModule} from "../../src/modules/validation/SingleSignerValidationModule.sol"; import {MockModule} from "../mocks/MockModule.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; -import {TEST_DEFAULT_VALIDATION_ENTITY_ID} from "../utils/TestConstants.sol"; interface TestModule { function testFunction() external; @@ -86,11 +85,7 @@ contract UpgradeModuleTest is AccountTestBase { }); account1.executeWithAuthorization( abi.encodeCall(account1.executeBatch, (calls)), - _encodeSignature( - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), - GLOBAL_VALIDATION, - "" - ) + _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "") ); // test installed, test if old module still installed @@ -189,11 +184,7 @@ contract UpgradeModuleTest is AccountTestBase { }); account1.executeWithAuthorization( abi.encodeCall(account1.executeBatch, (calls)), - _encodeSignature( - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), - GLOBAL_VALIDATION, - "" - ) + _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "") ); // Test if old validation still works, expect fail diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index e4394f34..e6b9d28a 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -10,6 +10,8 @@ import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {ModuleManagerInternals} from "../../src/account/ModuleManagerInternals.sol"; + +import {SemiModularAccount} from "../../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {ExecutionDataView, IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; @@ -92,11 +94,9 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_basicUserOp_withInitCode() public { - PackedUserOperation memory userOp = PackedUserOperation({ - sender: address(account2), - nonce: 0, - initCode: abi.encodePacked(address(factory), abi.encodeCall(factory.createAccount, (owner2, 0))), - callData: abi.encodeCall( + bytes memory callData = vm.envOr("SMA_TEST", false) + ? abi.encodeCall(SemiModularAccount(payable(account1)).updateFallbackSigner, (owner2)) + : abi.encodeCall( UpgradeableModularAccount.execute, ( address(singleSignerValidationModule), @@ -105,7 +105,13 @@ contract UpgradeableModularAccountTest is AccountTestBase { SingleSignerValidationModule.transferSigner, (TEST_DEFAULT_VALIDATION_ENTITY_ID, owner2) ) ) - ), + ); + + PackedUserOperation memory userOp = PackedUserOperation({ + sender: address(account2), + nonce: 0, + initCode: abi.encodePacked(address(factory), abi.encodeCall(factory.createAccount, (owner2, 0))), + callData: callData, accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), preVerificationGas: 0, gasFees: _encodeGas(1, 2), @@ -355,10 +361,20 @@ contract UpgradeableModularAccountTest is AccountTestBase { assertEq(address(account3), address(uint160(uint256(vm.load(address(account1), slot))))); } + // TODO: Consider if this test belongs here or in the tests specific to the SingleSignerValidationModule function test_transferOwnership() public { - assertEq( - singleSignerValidationModule.signers(TEST_DEFAULT_VALIDATION_ENTITY_ID, address(account1)), owner1 - ); + if (vm.envOr("SMA_TEST", false)) { + // Note: replaced "owner1" with address(0), this doesn't actually affect the account, but allows the + // test to pass by ensuring the signer can be set in the validation. + assertEq( + singleSignerValidationModule.signers(TEST_DEFAULT_VALIDATION_ENTITY_ID, address(account1)), + address(0) + ); + } else { + assertEq( + singleSignerValidationModule.signers(TEST_DEFAULT_VALIDATION_ENTITY_ID, address(account1)), owner1 + ); + } vm.prank(address(entryPoint)); account1.execute( @@ -381,8 +397,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { // singleSignerValidationModule.ownerOf(address(account1)); - bytes memory signature = - abi.encodePacked(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID, r, s, v); + bytes memory signature = abi.encodePacked(_signerValidation, r, s, v); bytes4 validationResult = IERC1271(address(account1)).isValidSignature(message, signature); diff --git a/test/mocks/SingleSignerFactoryFixture.sol b/test/mocks/SingleSignerFactoryFixture.sol index 38453951..1594d29e 100644 --- a/test/mocks/SingleSignerFactoryFixture.sol +++ b/test/mocks/SingleSignerFactoryFixture.sol @@ -12,6 +12,8 @@ import {SingleSignerValidationModule} from "../../src/modules/validation/SingleS import {OptimizedTest} from "../utils/OptimizedTest.sol"; import {TEST_DEFAULT_VALIDATION_ENTITY_ID} from "../utils/TestConstants.sol"; +import {LibClone} from "solady/utils/LibClone.sol"; + contract SingleSignerFactoryFixture is OptimizedTest { UpgradeableModularAccount public accountImplementation; SingleSignerValidationModule public singleSignerValidationModule; @@ -23,9 +25,14 @@ contract SingleSignerFactoryFixture is OptimizedTest { address public self; + error SemiModularAccountAddressMismatch(address expected, address returned); + constructor(IEntryPoint _entryPoint, SingleSignerValidationModule _singleSignerValidationModule) { entryPoint = _entryPoint; - accountImplementation = _deployUpgradeableModularAccount(_entryPoint); + + accountImplementation = vm.envOr("SMA_TEST", false) + ? _deploySemiModularAccount(_entryPoint) + : _deployUpgradeableModularAccount(_entryPoint); _PROXY_BYTECODE_HASH = keccak256( abi.encodePacked(type(ERC1967Proxy).creationCode, abi.encode(address(accountImplementation), "")) ); @@ -41,6 +48,12 @@ contract SingleSignerFactoryFixture is OptimizedTest { * account creation */ function createAccount(address owner, uint256 salt) public returns (UpgradeableModularAccount) { + // We cast the SemiModularAccount to an UpgradeableModularAccount to facilitate equivalence testing. + // However, we don't do this in the actual factory. + if (vm.envOr("SMA_TEST", false)) { + return createSemiModularAccount(owner, salt); + } + address addr = Create2.computeAddress(getSalt(owner, salt), _PROXY_BYTECODE_HASH); // short circuit if exists @@ -63,13 +76,40 @@ contract SingleSignerFactoryFixture is OptimizedTest { return UpgradeableModularAccount(payable(addr)); } + function createSemiModularAccount(address owner, uint256 salt) public returns (UpgradeableModularAccount) { + bytes32 fullSalt = getSalt(owner, salt); + + bytes memory immutables = _getImmutableArgs(owner); + + address addr = _getAddressSemiModular(immutables, fullSalt); + + // LibClone short-circuits if it's already deployed. + (, address instance) = + LibClone.createDeterministicERC1967(address(accountImplementation), immutables, fullSalt); + + if (instance != addr) { + revert SemiModularAccountAddressMismatch(addr, instance); + } + + return UpgradeableModularAccount(payable(addr)); + } + /** * calculate the counterfactual address of this account as it would be returned by createAccount() */ - function getAddress(address owner, uint256 salt) public view returns (address) { + function getAddress(address owner, uint256 salt) public returns (address) { + if (vm.envOr("SMA_TEST", false)) { + return getAddressSemiModular(owner, salt); + } return Create2.computeAddress(getSalt(owner, salt), _PROXY_BYTECODE_HASH); } + function getAddressSemiModular(address owner, uint256 salt) public view returns (address) { + bytes32 fullSalt = getSalt(owner, salt); + bytes memory immutables = _getImmutableArgs(owner); + return _getAddressSemiModular(immutables, fullSalt); + } + function addStake() external payable { entryPoint.addStake{value: msg.value}(UNSTAKE_DELAY); } @@ -77,4 +117,14 @@ contract SingleSignerFactoryFixture is OptimizedTest { function getSalt(address owner, uint256 salt) public pure returns (bytes32) { return keccak256(abi.encodePacked(owner, salt)); } + + function _getAddressSemiModular(bytes memory immutables, bytes32 salt) internal view returns (address) { + return LibClone.predictDeterministicAddressERC1967( + address(accountImplementation), immutables, salt, address(this) + ); + } + + function _getImmutableArgs(address owner) private pure returns (bytes memory) { + return abi.encodePacked(owner); + } } diff --git a/test/module/AllowlistModule.t.sol b/test/module/AllowlistModule.t.sol index 5f0f243c..0d66ed85 100644 --- a/test/module/AllowlistModule.t.sol +++ b/test/module/AllowlistModule.t.sol @@ -6,7 +6,7 @@ import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntry import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {HookConfigLib} from "../../src/helpers/HookConfigLib.sol"; -import {ModuleEntity} from "../../src/helpers/ModuleEntityLib.sol"; +import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; import {Call} from "../../src/interfaces/IStandardExecutor.sol"; import {AllowlistModule} from "../../src/modules/permissionhooks/AllowlistModule.sol"; @@ -34,6 +34,9 @@ contract AllowlistModuleTest is CustomValidationTestBase { ); function setUp() public { + _signerValidation = + ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID); + allowlistModule = new AllowlistModule(); counters = new Counter[](10); @@ -338,15 +341,9 @@ contract AllowlistModuleTest is CustomValidationTestBase { HookConfigLib.packValidationHook(address(allowlistModule), HOOK_ENTITY_ID), abi.encode(HOOK_ENTITY_ID, allowlistInit) ); - - return ( - _signerValidation, - true, - true, - new bytes4[](0), - abi.encode(TEST_DEFAULT_VALIDATION_ENTITY_ID, owner1), - hooks - ); + // patched to also work during SMA tests by differentiating the validation + _signerValidation = ModuleEntityLib.pack(address(singleSignerValidationModule), type(uint32).max - 1); + return (_signerValidation, true, true, new bytes4[](0), abi.encode(type(uint32).max - 1, owner1), hooks); } // Unfortunately, this is a feature that solidity has only implemented in via-ir, so we need to do it manually diff --git a/test/module/SingleSignerValidationModule.t.sol b/test/module/SingleSignerValidationModule.t.sol index 7bddd5d8..5235ad87 100644 --- a/test/module/SingleSignerValidationModule.t.sol +++ b/test/module/SingleSignerValidationModule.t.sol @@ -83,7 +83,7 @@ contract SingleSignerValidationModuleTest is AccountTestBase { } function test_runtime_with2SameValidationInstalled() public { - uint32 newEntityId = TEST_DEFAULT_VALIDATION_ENTITY_ID + 1; + uint32 newEntityId = type(uint32).max - 1; vm.prank(address(entryPoint)); vm.expectEmit(address(singleSignerValidationModule)); diff --git a/test/script/Deploy.s.t.sol b/test/script/Deploy.s.t.sol index ae97ac0d..329d2501 100644 --- a/test/script/Deploy.s.t.sol +++ b/test/script/Deploy.s.t.sol @@ -10,6 +10,8 @@ import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; import {DeployScript} from "../../script/Deploy.s.sol"; import {AccountFactory} from "../../src/account/AccountFactory.sol"; + +import {SemiModularAccount} from "../../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {SingleSignerValidationModule} from "../../src/modules/validation/SingleSignerValidationModule.sol"; @@ -21,6 +23,7 @@ contract DeployTest is Test { address internal _owner; address internal _accountImpl; + address internal _smaImpl; address internal _singleSignerValidationModule; address internal _factory; @@ -42,6 +45,12 @@ contract DeployTest is Test { CREATE2_FACTORY ); + _smaImpl = Create2.computeAddress( + bytes32(0), + keccak256(abi.encodePacked(type(SemiModularAccount).creationCode, abi.encode(address(_entryPoint)))), + CREATE2_FACTORY + ); + _singleSignerValidationModule = Create2.computeAddress( bytes32(0), keccak256(abi.encodePacked(type(SingleSignerValidationModule).creationCode)), @@ -53,17 +62,19 @@ contract DeployTest is Test { keccak256( abi.encodePacked( type(AccountFactory).creationCode, - abi.encode(address(_entryPoint), _accountImpl, _singleSignerValidationModule, _owner) + abi.encode(address(_entryPoint), _accountImpl, _smaImpl, _singleSignerValidationModule, _owner) ) ), CREATE2_FACTORY ); vm.setEnv("ACCOUNT_IMPL", vm.toString(address(_accountImpl))); + vm.setEnv("SMA_IMPL", vm.toString(address(_smaImpl))); vm.setEnv("FACTORY", vm.toString(address(_factory))); vm.setEnv("SINGLE_SIGNER_VALIDATION_MODULE", vm.toString(_singleSignerValidationModule)); vm.setEnv("ACCOUNT_IMPL_SALT", vm.toString(uint256(0))); + vm.setEnv("SMA_IMPL_SALT", vm.toString(uint256(0))); vm.setEnv("FACTORY_SALT", vm.toString(uint256(0))); vm.setEnv("SINGLE_SIGNER_VALIDATION_MODULE_SALT", vm.toString(uint256(0))); @@ -76,6 +87,7 @@ contract DeployTest is Test { _deployScript.run(); assertTrue(_accountImpl.code.length > 0); + assertTrue(_smaImpl.code.length > 0); assertTrue(_factory.code.length > 0); assertTrue(_singleSignerValidationModule.code.length > 0); diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 8641f79b..705690a8 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -5,6 +5,7 @@ import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.so import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +import {SemiModularAccount} from "../../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; import {Call, IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; @@ -53,7 +54,13 @@ abstract contract AccountTestBase is OptimizedTest { (owner1, owner1Key) = makeAddrAndKey("owner1"); beneficiary = payable(makeAddr("beneficiary")); - singleSignerValidationModule = _deploySingleSignerValidationModule(); + address deployedSingleSignerValidation = address(_deploySingleSignerValidationModule()); + + // We etch the single signer validation to the max address, such that it coincides with the fallback + // validation module entity for semi modular account tests. + singleSignerValidationModule = SingleSignerValidationModule(address(type(uint160).max)); + vm.etch(address(singleSignerValidationModule), deployedSingleSignerValidation.code); + factory = new SingleSignerFactoryFixture(entryPoint, singleSignerValidationModule); account1 = factory.createAccount(owner1, 0); @@ -102,11 +109,7 @@ abstract contract AccountTestBase is OptimizedTest { bytes32 userOpHash = entryPoint.getUserOpHash(userOp); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = _encodeSignature( - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), - GLOBAL_VALIDATION, - abi.encodePacked(r, s, v) - ); + userOp.signature = _encodeSignature(_signerValidation, GLOBAL_VALIDATION, abi.encodePacked(r, s, v)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; @@ -153,14 +156,7 @@ abstract contract AccountTestBase is OptimizedTest { } vm.prank(owner1); - account1.executeWithAuthorization( - callData, - _encodeSignature( - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), - GLOBAL_VALIDATION, - "" - ) - ); + account1.executeWithAuthorization(callData, _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "")); } // Always expects a revert, even if the revert data is zero-length. @@ -168,19 +164,19 @@ abstract contract AccountTestBase is OptimizedTest { vm.expectRevert(expectedRevertData); vm.prank(owner1); - account1.executeWithAuthorization( - callData, - _encodeSignature( - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), - GLOBAL_VALIDATION, - "" - ) - ); + account1.executeWithAuthorization(callData, _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "")); } function _transferOwnershipToTest() internal { // Transfer ownership to test contract for easier invocation. vm.prank(owner1); + if (vm.envOr("SMA_TEST", false)) { + account1.executeWithAuthorization( + abi.encodeCall(SemiModularAccount(payable(account1)).updateFallbackSigner, (address(this))), + _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "") + ); + return; + } account1.executeWithAuthorization( abi.encodeCall( account1.execute, @@ -193,11 +189,7 @@ abstract contract AccountTestBase is OptimizedTest { ) ) ), - _encodeSignature( - ModuleEntityLib.pack(address(singleSignerValidationModule), TEST_DEFAULT_VALIDATION_ENTITY_ID), - GLOBAL_VALIDATION, - "" - ) + _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "") ); } diff --git a/test/utils/CustomValidationTestBase.sol b/test/utils/CustomValidationTestBase.sol index a7920623..57643a1d 100644 --- a/test/utils/CustomValidationTestBase.sol +++ b/test/utils/CustomValidationTestBase.sol @@ -28,14 +28,23 @@ abstract contract CustomValidationTestBase is AccountTestBase { account1 = UpgradeableModularAccount(payable(new ERC1967Proxy{salt: 0}(accountImplementation, ""))); - _beforeInstallStep(address(account1)); - - account1.initializeWithValidation( - ValidationConfigLib.pack(validationFunction, isGlobal, isSignatureValidation), - selectors, - installData, - hooks - ); + if (vm.envOr("SMA_TEST", false)) { + vm.prank(address(entryPoint)); + // The initializer doesn't work on the SMA + account1.installValidation( + ValidationConfigLib.pack(validationFunction, isGlobal, isSignatureValidation), + selectors, + installData, + hooks + ); + } else { + account1.initializeWithValidation( + ValidationConfigLib.pack(validationFunction, isGlobal, isSignatureValidation), + selectors, + installData, + hooks + ); + } vm.deal(address(account1), 100 ether); } diff --git a/test/utils/OptimizedTest.sol b/test/utils/OptimizedTest.sol index 7792badb..a2fa3c22 100644 --- a/test/utils/OptimizedTest.sol +++ b/test/utils/OptimizedTest.sol @@ -5,6 +5,7 @@ import {Test} from "forge-std/Test.sol"; import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {SemiModularAccount} from "../../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {TokenReceiverModule} from "../../src/modules/TokenReceiverModule.sol"; @@ -45,6 +46,16 @@ abstract contract OptimizedTest is Test { : new UpgradeableModularAccount(entryPoint); } + function _deploySemiModularAccount(IEntryPoint entryPoint) internal returns (UpgradeableModularAccount) { + return _isOptimizedTest() + ? UpgradeableModularAccount( + payable( + deployCode("out-optimized/SemiModularAccount.sol/SemiModularAccount.json", abi.encode(entryPoint)) + ) + ) + : UpgradeableModularAccount(new SemiModularAccount(entryPoint)); + } + function _deployTokenReceiverModule() internal returns (TokenReceiverModule) { return _isOptimizedTest() ? TokenReceiverModule(deployCode("out-optimized/TokenReceiverModule.sol/TokenReceiverModule.json")) diff --git a/test/utils/TestConstants.sol b/test/utils/TestConstants.sol index c15b2dd3..119bcd0b 100644 --- a/test/utils/TestConstants.sol +++ b/test/utils/TestConstants.sol @@ -1,4 +1,4 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.25; -uint32 constant TEST_DEFAULT_VALIDATION_ENTITY_ID = 1; +uint32 constant TEST_DEFAULT_VALIDATION_ENTITY_ID = type(uint32).max;