From c42cd11c8ceff1943faa5f2aee44ed9792c924fc Mon Sep 17 00:00:00 2001 From: Zer0dot Date: Mon, 12 Aug 2024 13:34:34 -0400 Subject: [PATCH] [DRAFT] refactor/SMA: Inheritable Account Refactor (#133) --- .env.example | 4 +- .github/workflows/test.yml | 10 +- src/account/AccountFactory.sol | 2 +- src/account/AccountStorage.sol | 5 - src/account/SemiModularAccount.sol | 162 ++++++++++++++ src/account/UpgradeableModularAccount.sol | 208 ++++++------------ .../SingleSignerValidationModule.sol | 9 +- test/account/AccountReturnData.t.sol | 2 - test/account/FallbackValidationTest.t.sol | 77 ------- test/account/ImmutableAppend.t.sol | 15 +- test/account/MultiValidation.t.sol | 2 +- test/account/PerHookData.t.sol | 6 +- test/account/ReplaceModule.t.sol | 1 - test/account/UpgradeableModularAccount.t.sol | 37 +++- test/mocks/SingleSignerFactoryFixture.sol | 12 +- test/module/AllowlistModule.t.sol | 6 +- .../module/SingleSignerValidationModule.t.sol | 2 +- test/utils/AccountTestBase.sol | 34 ++- test/utils/CustomValidationTestBase.sol | 25 ++- test/utils/OptimizedTest.sol | 11 + test/utils/TestConstants.sol | 2 +- 21 files changed, 360 insertions(+), 272 deletions(-) create mode 100644 src/account/SemiModularAccount.sol delete mode 100644 test/account/FallbackValidationTest.t.sol diff --git a/.env.example b/.env.example index 612c3436..a7bd543e 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,3 @@ - # Factory owner capable only of managing stake OWNER= # EP 0.7 address @@ -22,3 +21,6 @@ UNSTAKE_DELAY= # Allowlist Module ALLOWLIST_MODULE= ALLOWLIST_MODULE_SALT= + +# Whether to test the semi-modular or full modular account +SMA_TEST=false diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dfa5cf41..3187d269 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,7 +66,10 @@ jobs: run: FOUNDRY_PROFILE=optimized-build forge build - name: Run tests - run: FOUNDRY_PROFILE=optimized-test-deep forge test -vvv + run: FOUNDRY_PROFILE=optimized-test-deep SMA_TEST=false forge test -vvv + + - name: Run SMA tests + run: FOUNDRY_PROFILE=optimized-test-deep SMA_TEST=true forge test -vvv test-default: name: Run Forge Tests (default) @@ -88,4 +91,7 @@ jobs: run: forge build - name: Run tests - run: forge test -vvv + run: SMA_TEST=false forge test -vvv + + - name: Run SMA tests + run: SMA_TEST=true forge test -vvv diff --git a/src/account/AccountFactory.sol b/src/account/AccountFactory.sol index 8dbea6bc..285771f0 100644 --- a/src/account/AccountFactory.sol +++ b/src/account/AccountFactory.sol @@ -114,7 +114,7 @@ contract AccountFactory is Ownable { return keccak256(abi.encodePacked(owner, salt, entityId)); } - function _getAddressFallbackSigner(bytes memory immutables, bytes32 salt) public view returns (address) { + function _getAddressFallbackSigner(bytes memory immutables, bytes32 salt) internal view returns (address) { return LibClone.predictDeterministicAddressERC1967(address(ACCOUNT_IMPL), immutables, salt, address(this)); } diff --git a/src/account/AccountStorage.sol b/src/account/AccountStorage.sol index f5f4e49d..f488e325 100644 --- a/src/account/AccountStorage.sol +++ b/src/account/AccountStorage.sol @@ -41,11 +41,6 @@ struct AccountStorage { // AccountStorageInitializable variables uint8 initialized; bool initializing; - // Address for fallback single signer validation - address fallbackSigner; - // Whether or not the fallback signer is enabled, we can't use a zero fallbackSigner for this since it defaults - // to reading the bytecode-appended signer. - bool fallbackSignerDisabled; // Execution functions and their associated functions mapping(bytes4 selector => ExecutionData) executionData; mapping(ModuleEntity validationFunction => ValidationData) validationData; diff --git a/src/account/SemiModularAccount.sol b/src/account/SemiModularAccount.sol new file mode 100644 index 00000000..2c4e0b6d --- /dev/null +++ b/src/account/SemiModularAccount.sol @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.25; + +import {UpgradeableModularAccount} from "./UpgradeableModularAccount.sol"; +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; + +import {ModuleEntityLib} from "../helpers/ModuleEntityLib.sol"; + +import {ModuleEntity, ValidationConfig} from "../interfaces/IModuleManager.sol"; + +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; +import {LibClone} from "solady/utils/LibClone.sol"; + +contract SemiModularAccount is UpgradeableModularAccount { + using MessageHashUtils for bytes32; + using ModuleEntityLib for ModuleEntity; + + struct SemiModularAccountStorage { + address fallbackSigner; + bool fallbackSignerDisabled; + } + + // keccak256("ERC6900.SemiModularAccount.Storage") + uint256 internal constant _SEMI_MODULAR_ACCOUNT_STORAGE_SLOT = + 0x5b9dc9aa943f8fa2653ceceda5e3798f0686455280432166ba472eca0bc17a32; + + ModuleEntity internal constant _FALLBACK_VALIDATION = ModuleEntity.wrap(bytes24(type(uint192).max)); + + event FallbackSignerSet(address indexed previousFallbackSigner, address indexed newFallbackSigner); + event FallbackSignerEnabledSet(bool prevEnabled, bool newEnabled); + + error FallbackSignerMismatch(); + error FallbackSignerDisabled(); + error InitializerDisabled(); + + constructor(IEntryPoint anEntryPoint) UpgradeableModularAccount(anEntryPoint) {} + + /// Override reverts on initialization, effectively disabling the initializer. + function initializeWithValidation(ValidationConfig, bytes4[] calldata, bytes calldata, bytes[] calldata) + external + override + initializer + { + revert InitializerDisabled(); + } + + /// @notice Updates the fallback signer address in storage. + /// @dev This function causes the fallback signer getter to ignore the bytecode signer if it is nonzero. It can + /// also be used to revert back to the bytecode signer by setting to zero. + /// @param fallbackSigner The new signer to set. + function updateFallbackSigner(address fallbackSigner) external wrapNativeFunction { + SemiModularAccountStorage storage _storage = _getSemiModularAccountStorage(); + emit FallbackSignerSet(_storage.fallbackSigner, fallbackSigner); + + _storage.fallbackSigner = fallbackSigner; + } + + /// @notice Sets whether the fallback signer validation should be enabled or disabled. + function setFallbackSignerEnabled(bool enabled) external wrapNativeFunction { + SemiModularAccountStorage storage _storage = _getSemiModularAccountStorage(); + emit FallbackSignerEnabledSet(!_storage.fallbackSignerDisabled, enabled); + + _storage.fallbackSignerDisabled = !enabled; + } + + function isFallbackSignerEnabled() external view returns (bool) { + return !_getSemiModularAccountStorage().fallbackSignerDisabled; + } + + function getFallbackSigner() external view returns (address) { + return _getFallbackSigner(); + } + + function _execUserOpValidation( + ModuleEntity userOpValidationFunction, + PackedUserOperation memory userOp, + bytes32 userOpHash + ) internal override returns (uint256) { + if (userOpValidationFunction.eq(_FALLBACK_VALIDATION)) { + address fallbackSigner = _getFallbackSigner(); + + if ( + SignatureChecker.isValidSignatureNow( + fallbackSigner, userOpHash.toEthSignedMessageHash(), userOp.signature + ) + ) { + return _SIG_VALIDATION_PASSED; + } + return _SIG_VALIDATION_FAILED; + } + + return super._execUserOpValidation(userOpValidationFunction, userOp, userOpHash); + } + + function _execRuntimeValidation( + ModuleEntity runtimeValidationFunction, + bytes calldata callData, + bytes calldata authorization + ) internal override { + if (runtimeValidationFunction.eq(_FALLBACK_VALIDATION)) { + address fallbackSigner = _getFallbackSigner(); + + if (msg.sender != fallbackSigner) { + revert FallbackSignerMismatch(); + } + return; + } + super._execRuntimeValidation(runtimeValidationFunction, callData, authorization); + } + + function _exec1271Validation(ModuleEntity sigValidation, bytes32 hash, bytes calldata signature) + internal + view + override + returns (bytes4) + { + if (sigValidation.eq(_FALLBACK_VALIDATION)) { + address fallbackSigner = _getFallbackSigner(); + + if (SignatureChecker.isValidSignatureNow(fallbackSigner, hash, signature)) { + return _1271_MAGIC_VALUE; + } + return _1271_INVALID; + } + return super._exec1271Validation(sigValidation, hash, signature); + } + + function _globalValidationAllowed(bytes4 selector) internal view override returns (bool) { + return selector == this.updateFallbackSigner.selector || super._globalValidationAllowed(selector); + } + + function _isValidationGlobal(ModuleEntity validationFunction) internal view override returns (bool) { + return validationFunction.eq(_FALLBACK_VALIDATION) || super._isValidationGlobal(validationFunction); + } + + function _getFallbackSigner() internal view returns (address) { + SemiModularAccountStorage storage _storage = _getSemiModularAccountStorage(); + + if (_storage.fallbackSignerDisabled) { + revert FallbackSignerDisabled(); + } + + address storageFallbackSigner = _storage.fallbackSigner; + if (storageFallbackSigner != address(0)) { + return storageFallbackSigner; + } + + bytes memory appendedData = LibClone.argsOnERC1967(address(this), 0, 20); + + return address(uint160(bytes20(appendedData))); + } + + function _getSemiModularAccountStorage() internal pure returns (SemiModularAccountStorage storage) { + SemiModularAccountStorage storage _storage; + assembly ("memory-safe") { + _storage.slot := _SEMI_MODULAR_ACCOUNT_STORAGE_SLOT + } + return _storage; + } +} diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 3d795b70..083a4b32 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -36,9 +36,6 @@ import {ModuleManagerInternals} from "./ModuleManagerInternals.sol"; import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; -import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; - -import {LibClone} from "solady/utils/LibClone.sol"; contract UpgradeableModularAccount is AccountExecutor, @@ -78,10 +75,7 @@ contract UpgradeableModularAccount is uint256 internal constant _SIG_VALIDATION_PASSED = 0; uint256 internal constant _SIG_VALIDATION_FAILED = 1; - ModuleEntity internal constant _FALLBACK_VALIDATION = ModuleEntity.wrap(bytes24(type(uint192).max)); - event ModularAccountInitialized(IEntryPoint indexed entryPoint); - event FallbackSignerSet(address indexed previousFallbackSigner, address indexed newFallbackSigner); error NonCanonicalEncoding(); error NotEntryPoint(); @@ -97,8 +91,6 @@ contract UpgradeableModularAccount is error ValidationFunctionMissing(bytes4 selector); error ValidationSignatureSegmentMissing(); error SignatureSegmentOutOfOrder(); - error FallbackSignerMismatch(); - error FallbackSignerDisabled(); // Wraps execution of a native function with runtime validation and hooks // Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installExecution, uninstallExecution @@ -266,23 +258,10 @@ contract UpgradeableModularAccount is bytes4[] calldata selectors, bytes calldata installData, bytes[] calldata hooks - ) external initializer { + ) external virtual initializer { _installValidation(validationConfig, selectors, installData, hooks); } - function updateFallbackSigner(address fallbackSigner) external wrapNativeFunction { - AccountStorage storage _storage = getAccountStorage(); - - emit FallbackSignerSet(_storage.fallbackSigner, fallbackSigner); - _storage.fallbackSigner = fallbackSigner; - } - - function setFallbackSignerEnabled(bool enabled) external wrapNativeFunction { - AccountStorage storage _storage = getAccountStorage(); - _storage.fallbackSignerDisabled = !enabled; - // TODO: event - } - /// @inheritdoc IModuleManager /// @notice May be validated by a global validation. function installValidation( @@ -332,27 +311,9 @@ contract UpgradeableModularAccount is } function isValidSignature(bytes32 hash, bytes calldata signature) public view override returns (bytes4) { - AccountStorage storage _storage = getAccountStorage(); - ModuleEntity sigValidation = ModuleEntity.wrap(bytes24(signature)); - if (sigValidation.eq(_FALLBACK_VALIDATION)) { - // do sig validation - return _fallbackSignatureValidation(hash, signature[24:]); - } - - (address module, uint32 entityId) = sigValidation.unpack(); - if (!_storage.validationData[sigValidation].isSignatureValidation) { - revert SignatureValidationInvalid(module, entityId); - } - - if ( - IValidationModule(module).validateSignature(address(this), entityId, msg.sender, hash, signature[24:]) - == _1271_MAGIC_VALUE - ) { - return _1271_MAGIC_VALUE; - } - return _1271_INVALID; + return _exec1271Validation(sigValidation, hash, signature[24:]); } /// @notice Gets the entry point for this account @@ -366,7 +327,6 @@ contract UpgradeableModularAccount is // Parent function validateUserOp enforces that this call can only be made by the EntryPoint function _validateSignature(PackedUserOperation calldata userOp, bytes32 userOpHash) internal - virtual override returns (uint256 validationData) { @@ -451,15 +411,7 @@ contract UpgradeableModularAccount is userOp.signature = signatureSegment.getBody(); - uint256 currentValidationRes; - if (userOpValidationFunction.eq(_FALLBACK_VALIDATION)) { - // fallback userop validation - currentValidationRes = _fallbackUserOpValidation(userOp, userOpHash); - } else { - (address module, uint32 entityId) = userOpValidationFunction.unpack(); - - currentValidationRes = IValidation(module).validateUserOp(entityId, userOp, userOpHash); - } + uint256 currentValidationRes = _execUserOpValidation(userOpValidationFunction, userOp, userOpHash); if (preUserOpValidationHooks.length != 0) { // If we have other validation data we need to coalesce with @@ -512,22 +464,7 @@ contract UpgradeableModularAccount is revert ValidationSignatureSegmentMissing(); } - if (runtimeValidationFunction.eq(_FALLBACK_VALIDATION)) { - _fallbackRuntimeValidation(); - return; - } - - (address module, uint32 entityId) = runtimeValidationFunction.unpack(); - - try IValidationModule(module).validateRuntime( - address(this), entityId, msg.sender, msg.value, callData, authSegment.getBody() - ) - // forgefmt: disable-start - // solhint-disable-next-line no-empty-blocks - {} catch (bytes memory revertReason){ - // forgefmt: disable-end - revert RuntimeValidationFunctionReverted(module, entityId, revertReason); - } + _execRuntimeValidation(runtimeValidationFunction, callData, authSegment.getBody()); } function _doPreHooks(EnumerableSet.Bytes32Set storage executionHooks, bytes memory data) @@ -677,6 +614,73 @@ contract UpgradeableModularAccount is return (postPermissionHooks, postExecutionHooks); } + function _execUserOpValidation( + ModuleEntity userOpValidationFunction, + PackedUserOperation memory userOp, + bytes32 userOpHash + ) internal virtual returns (uint256) { + (address module, uint32 entityId) = userOpValidationFunction.unpack(); + + return IValidation(module).validateUserOp(entityId, userOp, userOpHash); + } + + function _execRuntimeValidation( + ModuleEntity runtimeValidationFunction, + bytes calldata callData, + bytes calldata authorization + ) internal virtual { + (address module, uint32 entityId) = runtimeValidationFunction.unpack(); + + try IValidation(module).validateRuntime( + address(this), entityId, msg.sender, msg.value, callData, authorization + ) + // forgefmt: disable-start + // solhint-disable-next-line no-empty-blocks + {} catch (bytes memory revertReason){ + // forgefmt: disable-end + revert RuntimeValidationFunctionReverted(module, entityId, revertReason); + } + } + + function _exec1271Validation(ModuleEntity sigValidation, bytes32 hash, bytes calldata signature) + internal + view + virtual + returns (bytes4) + { + AccountStorage storage _storage = getAccountStorage(); + + (address module, uint32 entityId) = sigValidation.unpack(); + if (!_storage.validationData[sigValidation].isSignatureValidation) { + revert SignatureValidationInvalid(module, entityId); + } + + if ( + IValidation(module).validateSignature(address(this), entityId, msg.sender, hash, signature) + == _1271_MAGIC_VALUE + ) { + return _1271_MAGIC_VALUE; + } + return _1271_INVALID; + } + + function _globalValidationAllowed(bytes4 selector) internal view virtual returns (bool) { + if ( + selector == this.execute.selector || selector == this.executeBatch.selector + || selector == this.installExecution.selector || selector == this.uninstallExecution.selector + || selector == this.installValidation.selector || selector == this.uninstallValidation.selector + || selector == this.upgradeToAndCall.selector + ) { + return true; + } + + return getAccountStorage().executionData[selector].allowGlobalValidation; + } + + function _isValidationGlobal(ModuleEntity validationFunction) internal view virtual returns (bool) { + return getAccountStorage().validationData[validationFunction].isGlobal; + } + function _checkIfValidationAppliesCallData( bytes calldata callData, ModuleEntity validationFunction, @@ -730,34 +734,13 @@ contract UpgradeableModularAccount is } } - function _globalValidationAllowed(bytes4 selector) internal view returns (bool) { - if ( - selector == this.execute.selector || selector == this.executeBatch.selector - || selector == this.installExecution.selector || selector == this.uninstallExecution.selector - || selector == this.installValidation.selector || selector == this.uninstallValidation.selector - || selector == this.upgradeToAndCall.selector || selector == this.updateFallbackSigner.selector - ) { - return true; - } - - return getAccountStorage().executionData[selector].allowGlobalValidation; - } - function _checkIfValidationAppliesSelector(bytes4 selector, ModuleEntity validationFunction, bool isGlobal) internal view { - AccountStorage storage _storage = getAccountStorage(); - // Check that the provided validation function is applicable to the selector if (isGlobal) { - if ( - _globalValidationAllowed(selector) - && ( - _storage.validationData[validationFunction].isGlobal - || validationFunction.eq(_FALLBACK_VALIDATION) - ) - ) { + if (_globalValidationAllowed(selector) && _isValidationGlobal(validationFunction)) { return; } revert ValidationFunctionMissing(selector); @@ -768,51 +751,4 @@ contract UpgradeableModularAccount is } } } - - function _fallbackRuntimeValidation() internal view { - if (msg.sender != _getFallbackSigner()) { - revert FallbackSignerMismatch(); - } - } - - function _fallbackUserOpValidation(PackedUserOperation memory userOp, bytes32 userOpHash) - internal - view - returns (uint256) - { - // Validate the user op signature against the owner. - (address sigSigner,,) = (userOpHash.toEthSignedMessageHash()).tryRecover(userOp.signature); - if (sigSigner == address(0) || sigSigner != _getFallbackSigner()) { - return _SIG_VALIDATION_FAILED; - } - return _SIG_VALIDATION_PASSED; - } - - function _fallbackSignatureValidation(bytes32 digest, bytes calldata signature) - internal - view - returns (bytes4) - { - if (SignatureChecker.isValidSignatureNow(_getFallbackSigner(), digest, signature)) { - return _1271_MAGIC_VALUE; - } - return _1271_INVALID; - } - - function _getFallbackSigner() internal view returns (address) { - AccountStorage storage _storage = getAccountStorage(); - - if (_storage.fallbackSignerDisabled) { - revert FallbackSignerDisabled(); - } - - address storageFallbackSigner = _storage.fallbackSigner; - if (storageFallbackSigner != address(0)) { - return storageFallbackSigner; - } - - bytes memory appendedData = LibClone.argsOnERC1967(address(this), 0, 20); - - return address(uint160(bytes20(appendedData))); - } } diff --git a/src/modules/validation/SingleSignerValidationModule.sol b/src/modules/validation/SingleSignerValidationModule.sol index 7ad1c768..8281f1b8 100644 --- a/src/modules/validation/SingleSignerValidationModule.sol +++ b/src/modules/validation/SingleSignerValidationModule.sol @@ -1,14 +1,13 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity ^0.8.25; -import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; -import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; -import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; - import {IModule, ModuleMetadata} from "../../interfaces/IModule.sol"; import {IValidationModule} from "../../interfaces/IValidationModule.sol"; import {BaseModule} from "../BaseModule.sol"; -import {ISingleSignerValidationModule} from "./ISingleSignerValidationModule.sol"; +import {ISingleSignerValidation} from "./ISingleSignerValidation.sol"; +import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; +import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; /// @title ECSDA Validation /// @author ERC-6900 Authors diff --git a/test/account/AccountReturnData.t.sol b/test/account/AccountReturnData.t.sol index 4e5647cd..3542f3db 100644 --- a/test/account/AccountReturnData.t.sol +++ b/test/account/AccountReturnData.t.sol @@ -2,7 +2,6 @@ pragma solidity ^0.8.19; import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../src/helpers/Constants.sol"; -import {ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; import {ValidationConfigLib} from "../../src/helpers/ValidationConfigLib.sol"; import {Call} from "../../src/interfaces/IStandardExecutor.sol"; import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; @@ -13,7 +12,6 @@ import { ResultCreatorModule } from "../mocks/modules/ReturnDataModuleMocks.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; -import {TEST_DEFAULT_VALIDATION_ENTITY_ID} from "../utils/TestConstants.sol"; // Tests all the different ways that return data can be read from modules through an account contract AccountReturnDataTest is AccountTestBase { diff --git a/test/account/FallbackValidationTest.t.sol b/test/account/FallbackValidationTest.t.sol deleted file mode 100644 index f4efc4b2..00000000 --- a/test/account/FallbackValidationTest.t.sol +++ /dev/null @@ -1,77 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.25; - -import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; -import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; - -import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; - -import {ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; -import {ModuleEntity} from "../../src/interfaces/IModuleManager.sol"; - -import {AccountTestBase} from "../utils/AccountTestBase.sol"; - -contract FallbackValidationTest is AccountTestBase { - using MessageHashUtils for bytes32; - - address public ethRecipient; - - // A separate account and owner that isn't deployed yet, used to test initcode - address public owner2; - uint256 public owner2Key; - UpgradeableModularAccount public account2; - - ModuleEntity constant FALLBACK_VALIDATION = ModuleEntity.wrap(bytes24(type(uint192).max)); - - function setUp() public { - (owner2, owner2Key) = makeAddrAndKey("owner2"); - - // Compute counterfactual address - account2 = UpgradeableModularAccount(payable(factory.getAddressFallbackSigner(owner2, 0))); - vm.deal(address(account2), 100 ether); - - ethRecipient = makeAddr("ethRecipient"); - vm.deal(ethRecipient, 1 wei); - } - - function test_fallbackValidation_userOp_simple() public { - PackedUserOperation memory userOp = PackedUserOperation({ - sender: address(account2), - nonce: 0, - initCode: abi.encodePacked( - address(factory), abi.encodeCall(factory.createAccountWithFallbackValidation, (owner2, 0)) - ), - callData: abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), - accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), - preVerificationGas: 0, - gasFees: _encodeGas(1, 1), - paymasterAndData: "", - signature: "" - }); - - // Generate signature - bytes32 userOpHash = entryPoint.getUserOpHash(userOp); - (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); - userOp.signature = _encodeSignature(FALLBACK_VALIDATION, GLOBAL_VALIDATION, abi.encodePacked(r, s, v)); - - PackedUserOperation[] memory userOps = new PackedUserOperation[](1); - userOps[0] = userOp; - - entryPoint.handleOps(userOps, beneficiary); - - assertEq(ethRecipient.balance, 2 wei); - } - - function test_fallbackValidation_runtime_simple() public { - // Deploy the account first - factory.createAccountWithFallbackValidation(owner2, 0); - - vm.prank(owner2); - account2.executeWithAuthorization( - abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), - _encodeSignature(FALLBACK_VALIDATION, GLOBAL_VALIDATION, "") - ); - - assertEq(ethRecipient.balance, 2 wei); - } -} diff --git a/test/account/ImmutableAppend.t.sol b/test/account/ImmutableAppend.t.sol index a63f0408..dd75f4a5 100644 --- a/test/account/ImmutableAppend.t.sol +++ b/test/account/ImmutableAppend.t.sol @@ -1,20 +1,10 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity ^0.8.19; -import {IEntryPoint, UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; -import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; -import {ValidationConfig, ValidationConfigLib} from "../../src/helpers/ValidationConfigLib.sol"; -import {Call, IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; -import {DirectCallModule} from "../mocks/modules/DirectCallModule.sol"; - import {AccountTestBase} from "../utils/AccountTestBase.sol"; - -import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import {LibClone} from "solady/utils/LibClone.sol"; contract ImmutableAppendTest is AccountTestBase { - using ValidationConfigLib for ValidationConfig; - /* -------------------------------------------------------------------------- */ /* Negatives */ /* -------------------------------------------------------------------------- */ @@ -24,6 +14,11 @@ contract ImmutableAppendTest is AccountTestBase { /* -------------------------------------------------------------------------- */ function test_success_getData() public { + if (!vm.envBool("SMA_TEST")) { + // this test isn't relevant at all for non-SMA, and is temporary. + return; + } + bytes memory expectedArgs = abi.encodePacked(owner1); assertEq(keccak256(LibClone.argsOnERC1967(address(account1))), keccak256(expectedArgs)); diff --git a/test/account/MultiValidation.t.sol b/test/account/MultiValidation.t.sol index f24a2094..78de0657 100644 --- a/test/account/MultiValidation.t.sol +++ b/test/account/MultiValidation.t.sol @@ -64,7 +64,7 @@ contract MultiValidationTest is AccountTestBase { abi.encodeWithSelector( UpgradeableModularAccount.RuntimeValidationFunctionReverted.selector, address(validator2), - 1, + TEST_DEFAULT_VALIDATION_ENTITY_ID, abi.encodeWithSignature("NotAuthorized()") ) ); diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol index 65691a01..f9bfb823 100644 --- a/test/account/PerHookData.t.sol +++ b/test/account/PerHookData.t.sol @@ -344,13 +344,15 @@ contract PerHookDataTest is CustomValidationTestBase { ), abi.encode(_counter) ); - + // patched to also work during SMA tests by differentiating the validation + _signerValidation = + ModuleEntityLib.pack(address(singleSignerValidation), TEST_DEFAULT_VALIDATION_ENTITY_ID - 1); return ( _signerValidation, true, true, new bytes4[](0), - abi.encode(TEST_DEFAULT_VALIDATION_ENTITY_ID, owner1), + abi.encode(TEST_DEFAULT_VALIDATION_ENTITY_ID - 1, owner1), hooks ); } diff --git a/test/account/ReplaceModule.t.sol b/test/account/ReplaceModule.t.sol index c70c20ac..9de8247b 100644 --- a/test/account/ReplaceModule.t.sol +++ b/test/account/ReplaceModule.t.sol @@ -21,7 +21,6 @@ import {IValidationHookModule} from "../../src/interfaces/IValidationHookModule. import {SingleSignerValidationModule} from "../../src/modules/validation/SingleSignerValidationModule.sol"; import {MockModule} from "../mocks/MockModule.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; -import {TEST_DEFAULT_VALIDATION_ENTITY_ID} from "../utils/TestConstants.sol"; interface TestModule { function testFunction() external; diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index 66e30bb7..29663e13 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -10,6 +10,8 @@ import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; import {ModuleManagerInternals} from "../../src/account/ModuleManagerInternals.sol"; + +import {SemiModularAccount} from "../../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {ExecutionDataView, IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; @@ -52,7 +54,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { (owner2, owner2Key) = makeAddrAndKey("owner2"); // Compute counterfactual address - account2 = UpgradeableModularAccount(payable(factory.getAddressFallbackSigner(owner2, 0))); + account2 = UpgradeableModularAccount(payable(factory.getAddress(owner2, 0))); vm.deal(address(account2), 100 ether); ethRecipient = makeAddr("ethRecipient"); @@ -92,13 +94,22 @@ contract UpgradeableModularAccountTest is AccountTestBase { } function test_basicUserOp_withInitCode() public { + bytes memory callData = vm.envBool("SMA_TEST") + ? abi.encodeCall(SemiModularAccount(payable(account1)).updateFallbackSigner, (owner2)) + : abi.encodeCall( + UpgradeableModularAccount.execute, + ( + address(singleSignerValidation), + 0, + abi.encodeCall(SingleSignerValidation.transferSigner, (TEST_DEFAULT_VALIDATION_ENTITY_ID, owner2)) + ) + ); + PackedUserOperation memory userOp = PackedUserOperation({ sender: address(account2), nonce: 0, - initCode: abi.encodePacked( - address(factory), abi.encodeCall(factory.createAccountWithFallbackValidation, (owner2, 0)) - ), - callData: abi.encodeCall(UpgradeableModularAccount.updateFallbackSigner, (owner2)), + initCode: abi.encodePacked(address(factory), abi.encodeCall(factory.createAccount, (owner2, 0))), + callData: callData, accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), preVerificationGas: 0, gasFees: _encodeGas(1, 2), @@ -123,9 +134,7 @@ contract UpgradeableModularAccountTest is AccountTestBase { PackedUserOperation memory userOp = PackedUserOperation({ sender: address(account2), nonce: 0, - initCode: abi.encodePacked( - address(factory), abi.encodeCall(factory.createAccountWithFallbackValidation, (owner2, 0)) - ), + initCode: abi.encodePacked(address(factory), abi.encodeCall(factory.createAccount, (owner2, 0))), callData: abi.encodeCall(UpgradeableModularAccount.execute, (recipient, 1 wei, "")), accountGasLimits: _encodeGas(VERIFICATION_GAS_LIMIT, CALL_GAS_LIMIT), preVerificationGas: 0, @@ -352,9 +361,15 @@ contract UpgradeableModularAccountTest is AccountTestBase { // TODO: Consider if this test belongs here or in the tests specific to the SingleSignerValidation function test_transferOwnership() public { - // Note: replaced "owner1" with address(0), this doesn't actually affect the account, but allows the test - // to pass by ensuring the signer can be set in the validation - assertEq(singleSignerValidation.signers(TEST_DEFAULT_VALIDATION_ENTITY_ID, address(account1)), address(0)); + if (vm.envBool("SMA_TEST")) { + // Note: replaced "owner1" with address(0), this doesn't actually affect the account, but allows the + // test to pass by ensuring the signer can be set in the validation. + assertEq( + singleSignerValidation.signers(TEST_DEFAULT_VALIDATION_ENTITY_ID, address(account1)), address(0) + ); + } else { + assertEq(singleSignerValidation.signers(TEST_DEFAULT_VALIDATION_ENTITY_ID, address(account1)), owner1); + } vm.prank(address(entryPoint)); account1.execute( diff --git a/test/mocks/SingleSignerFactoryFixture.sol b/test/mocks/SingleSignerFactoryFixture.sol index c0dfbe64..1214f9d1 100644 --- a/test/mocks/SingleSignerFactoryFixture.sol +++ b/test/mocks/SingleSignerFactoryFixture.sol @@ -27,7 +27,10 @@ contract SingleSignerFactoryFixture is OptimizedTest { constructor(IEntryPoint _entryPoint, SingleSignerValidationModule _singleSignerValidationModule) { entryPoint = _entryPoint; - accountImplementation = _deployUpgradeableModularAccount(_entryPoint); + + accountImplementation = vm.envBool("SMA_TEST") + ? _deploySemiModularAccount(_entryPoint) + : _deployUpgradeableModularAccount(_entryPoint); _PROXY_BYTECODE_HASH = keccak256( abi.encodePacked(type(ERC1967Proxy).creationCode, abi.encode(address(accountImplementation), "")) ); @@ -45,6 +48,10 @@ contract SingleSignerFactoryFixture is OptimizedTest { * account creation */ function createAccount(address owner, uint256 salt) public returns (UpgradeableModularAccount) { + if (vm.envBool("SMA_TEST")) { + return createAccountWithFallbackValidation(owner, salt); + } + address addr = Create2.computeAddress(getSalt(owner, salt), _PROXY_BYTECODE_HASH); // short circuit if exists @@ -92,6 +99,9 @@ contract SingleSignerFactoryFixture is OptimizedTest { * calculate the counterfactual address of this account as it would be returned by createAccount() */ function getAddress(address owner, uint256 salt) public view returns (address) { + if (vm.envBool("SMA_TEST")) { + return getAddressFallbackSigner(owner, salt); + } return Create2.computeAddress(getSalt(owner, salt), _PROXY_BYTECODE_HASH); } diff --git a/test/module/AllowlistModule.t.sol b/test/module/AllowlistModule.t.sol index c17a9bf9..1edd422a 100644 --- a/test/module/AllowlistModule.t.sol +++ b/test/module/AllowlistModule.t.sol @@ -341,13 +341,15 @@ contract AllowlistModuleTest is CustomValidationTestBase { HookConfigLib.packValidationHook(address(allowlistModule), HOOK_ENTITY_ID), abi.encode(HOOK_ENTITY_ID, allowlistInit) ); - + // patched to also work during SMA tests by differentiating the validation + _signerValidation = + ModuleEntityLib.pack(address(singleSignerValidation), TEST_DEFAULT_VALIDATION_ENTITY_ID - 1); return ( _signerValidation, true, true, new bytes4[](0), - abi.encode(TEST_DEFAULT_VALIDATION_ENTITY_ID, owner1), + abi.encode(TEST_DEFAULT_VALIDATION_ENTITY_ID - 1, owner1), hooks ); } diff --git a/test/module/SingleSignerValidationModule.t.sol b/test/module/SingleSignerValidationModule.t.sol index 7bddd5d8..24b97c2d 100644 --- a/test/module/SingleSignerValidationModule.t.sol +++ b/test/module/SingleSignerValidationModule.t.sol @@ -83,7 +83,7 @@ contract SingleSignerValidationModuleTest is AccountTestBase { } function test_runtime_with2SameValidationInstalled() public { - uint32 newEntityId = TEST_DEFAULT_VALIDATION_ENTITY_ID + 1; + uint32 newEntityId = TEST_DEFAULT_VALIDATION_ENTITY_ID - 1; vm.prank(address(entryPoint)); vm.expectEmit(address(singleSignerValidationModule)); diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 34824635..297a0914 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -5,6 +5,7 @@ import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.so import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +import {SemiModularAccount} from "../../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol"; import {Call, IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; @@ -53,13 +54,20 @@ abstract contract AccountTestBase is OptimizedTest { (owner1, owner1Key) = makeAddrAndKey("owner1"); beneficiary = payable(makeAddr("beneficiary")); - singleSignerValidationModule = _deploySingleSignerValidationModule(); - factory = new SingleSignerFactoryFixture(entryPoint, singleSignerValidationModule); + address deployedSingleSignerValidation = address(_deploySingleSignerValidation()); - account1 = factory.createAccountWithFallbackValidation(owner1, 0); + // We etch the single signer validation to the max address, such that it coincides with the fallback + // validation module entity for semi modular account tests. + singleSignerValidation = SingleSignerValidation(address(type(uint160).max)); + vm.etch(address(singleSignerValidation), deployedSingleSignerValidation.code); + + factory = new SingleSignerFactoryFixture(entryPoint, singleSignerValidation); + + account1 = factory.createAccount(owner1, 0); vm.deal(address(account1), 100 ether); - _signerValidation = ModuleEntity.wrap(bytes24(type(uint192).max)); + _signerValidation = + ModuleEntityLib.pack(address(singleSignerValidation), TEST_DEFAULT_VALIDATION_ENTITY_ID); } function _runExecUserOp(address target, bytes memory callData) internal { @@ -162,8 +170,24 @@ abstract contract AccountTestBase is OptimizedTest { function _transferOwnershipToTest() internal { // Transfer ownership to test contract for easier invocation. vm.prank(owner1); + if (vm.envBool("SMA_TEST")) { + account1.executeWithAuthorization( + abi.encodeCall(SemiModularAccount(payable(account1)).updateFallbackSigner, (address(this))), + _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "") + ); + return; + } account1.executeWithAuthorization( - abi.encodeCall(account1.updateFallbackSigner, (address(this))), + abi.encodeCall( + account1.execute, + ( + address(singleSignerValidation), + 0, + abi.encodeCall( + SingleSignerValidation.transferSigner, (TEST_DEFAULT_VALIDATION_ENTITY_ID, address(this)) + ) + ) + ), _encodeSignature(_signerValidation, GLOBAL_VALIDATION, "") ); } diff --git a/test/utils/CustomValidationTestBase.sol b/test/utils/CustomValidationTestBase.sol index a7920623..ccac8da7 100644 --- a/test/utils/CustomValidationTestBase.sol +++ b/test/utils/CustomValidationTestBase.sol @@ -28,14 +28,23 @@ abstract contract CustomValidationTestBase is AccountTestBase { account1 = UpgradeableModularAccount(payable(new ERC1967Proxy{salt: 0}(accountImplementation, ""))); - _beforeInstallStep(address(account1)); - - account1.initializeWithValidation( - ValidationConfigLib.pack(validationFunction, isGlobal, isSignatureValidation), - selectors, - installData, - hooks - ); + if (vm.envBool("SMA_TEST")) { + vm.prank(address(entryPoint)); + // The initializer doesn't work on the SMA + account1.installValidation( + ValidationConfigLib.pack(validationFunction, isGlobal, isSignatureValidation), + selectors, + installData, + hooks + ); + } else { + account1.initializeWithValidation( + ValidationConfigLib.pack(validationFunction, isGlobal, isSignatureValidation), + selectors, + installData, + hooks + ); + } vm.deal(address(account1), 100 ether); } diff --git a/test/utils/OptimizedTest.sol b/test/utils/OptimizedTest.sol index 7792badb..a2fa3c22 100644 --- a/test/utils/OptimizedTest.sol +++ b/test/utils/OptimizedTest.sol @@ -5,6 +5,7 @@ import {Test} from "forge-std/Test.sol"; import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; +import {SemiModularAccount} from "../../src/account/SemiModularAccount.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {TokenReceiverModule} from "../../src/modules/TokenReceiverModule.sol"; @@ -45,6 +46,16 @@ abstract contract OptimizedTest is Test { : new UpgradeableModularAccount(entryPoint); } + function _deploySemiModularAccount(IEntryPoint entryPoint) internal returns (UpgradeableModularAccount) { + return _isOptimizedTest() + ? UpgradeableModularAccount( + payable( + deployCode("out-optimized/SemiModularAccount.sol/SemiModularAccount.json", abi.encode(entryPoint)) + ) + ) + : UpgradeableModularAccount(new SemiModularAccount(entryPoint)); + } + function _deployTokenReceiverModule() internal returns (TokenReceiverModule) { return _isOptimizedTest() ? TokenReceiverModule(deployCode("out-optimized/TokenReceiverModule.sol/TokenReceiverModule.json")) diff --git a/test/utils/TestConstants.sol b/test/utils/TestConstants.sol index c15b2dd3..119bcd0b 100644 --- a/test/utils/TestConstants.sol +++ b/test/utils/TestConstants.sol @@ -1,4 +1,4 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.25; -uint32 constant TEST_DEFAULT_VALIDATION_ENTITY_ID = 1; +uint32 constant TEST_DEFAULT_VALIDATION_ENTITY_ID = type(uint32).max;