From e07c3d6816c776eeeb3179c96dbd3f827f313df6 Mon Sep 17 00:00:00 2001 From: Jay Paik Date: Wed, 20 Sep 2023 11:19:00 -0400 Subject: [PATCH] fix: address best practices --- src/CustomSlotInitializable.sol | 60 +++++++++++++++++-------- src/LightAccount.sol | 71 ++++++++++++++++++++++-------- test/CustomSlotInitializable.t.sol | 8 ++-- test/LightAccount.t.sol | 16 +++---- 4 files changed, 107 insertions(+), 48 deletions(-) diff --git a/src/CustomSlotInitializable.sol b/src/CustomSlotInitializable.sol index 8a2ab8f..f752fb1 100644 --- a/src/CustomSlotInitializable.sol +++ b/src/CustomSlotInitializable.sol @@ -3,8 +3,6 @@ pragma solidity ^0.8.21; -import {Address} from "@openzeppelin/contracts/utils/Address.sol"; - /** * @dev Identical to OpenZeppelin's `Initializable`, except that its state variables are kept at a custom storage slot * instead of at the start of storage. @@ -73,6 +71,16 @@ abstract contract CustomSlotInitializable { bool initializing; } + /** + * @dev The contract is already initialized. + */ + error InvalidInitialization(); + + /** + * @dev The contract is not initializing. + */ + error NotInitializing(); + /** * @dev Triggered when the contract has been initialized or reinitialized. */ @@ -92,13 +100,23 @@ abstract contract CustomSlotInitializable { * Emits an {Initialized} event. */ modifier initializer() { - CustomSlotInitializableStorage storage _storage = _getInitialiazableStorage(); + CustomSlotInitializableStorage storage _storage = _getInitializableStorage(); + + // Cache values to avoid duplicated sloads bool isTopLevelCall = !_storage.initializing; - require( - (isTopLevelCall && _storage.initialized < 1) - || (!Address.isContract(address(this)) && _storage.initialized == 1), - "Initializable: contract is already initialized" - ); + uint64 initialized = _storage.initialized; + + // Allowed calls: + // - initialSetup: the contract is not in the initializing state and no previous version was + // initialized + // - construction: the contract is initialized at version 1 (no reininitialization) and the + // current contract is just being deployed + bool initialSetup = initialized == 0 && isTopLevelCall; + bool construction = initialized == 1 && address(this).code.length == 0; + + if (!initialSetup && !construction) { + revert InvalidInitialization(); + } _storage.initialized = 1; if (isTopLevelCall) { _storage.initializing = true; @@ -129,10 +147,11 @@ abstract contract CustomSlotInitializable { * Emits an {Initialized} event. */ modifier reinitializer(uint64 version) { - CustomSlotInitializableStorage storage _storage = _getInitialiazableStorage(); - require( - !_storage.initializing && _storage.initialized < version, "Initializable: contract is already initialized" - ); + CustomSlotInitializableStorage storage _storage = _getInitializableStorage(); + + if (_storage.initializing || _storage.initialized >= version) { + revert InvalidInitialization(); + } _storage.initialized = version; _storage.initializing = true; _; @@ -145,7 +164,9 @@ abstract contract CustomSlotInitializable { * {initializer} and {reinitializer} modifiers, directly or indirectly. */ modifier onlyInitializing() { - require(_getInitialiazableStorage().initializing, "Initializable: contract is not initializing"); + if (!_isInitializing()) { + revert NotInitializing(); + } _; } @@ -158,8 +179,11 @@ abstract contract CustomSlotInitializable { * Emits an {Initialized} event the first time it is successfully executed. */ function _disableInitializers() internal virtual { - CustomSlotInitializableStorage storage _storage = _getInitialiazableStorage(); - require(!_storage.initializing, "Initializable: contract is initializing"); + CustomSlotInitializableStorage storage _storage = _getInitializableStorage(); + + if (_storage.initializing) { + revert InvalidInitialization(); + } if (_storage.initialized != type(uint64).max) { _storage.initialized = type(uint64).max; emit Initialized(type(uint64).max); @@ -170,17 +194,17 @@ abstract contract CustomSlotInitializable { * @dev Returns the highest version that has been initialized. See {reinitializer}. */ function _getInitializedVersion() internal view returns (uint64) { - return _getInitialiazableStorage().initialized; + return _getInitializableStorage().initialized; } /** * @dev Returns `true` if the contract is currently initializing. See {onlyInitializing}. */ function _isInitializing() internal view returns (bool) { - return _getInitialiazableStorage().initializing; + return _getInitializableStorage().initializing; } - function _getInitialiazableStorage() private view returns (CustomSlotInitializableStorage storage _storage) { + function _getInitializableStorage() private view returns (CustomSlotInitializableStorage storage _storage) { bytes32 position = _storagePosition; assembly { _storage.slot := position diff --git a/src/LightAccount.sol b/src/LightAccount.sol index 3b0324c..3e026c8 100644 --- a/src/LightAccount.sol +++ b/src/LightAccount.sol @@ -42,6 +42,8 @@ import {CustomSlotInitializable} from "./CustomSlotInitializable.sol"; * user operations through a bundler. * * 4. Event `SimpleAccountInitialized` renamed to `LightAccountInitialized`. + * + * 5. Uses custom errors. */ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, CustomSlotInitializable, IERC1271 { using ECDSA for bytes32; @@ -74,6 +76,22 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus */ event OwnershipTransferred(address indexed previousOwner, address indexed newOwner); + /** + * @dev The length of the array does not match the expected length. + */ + error ArrayLengthMismatch(); + + /** + * @dev The new owner is not a valid owner (e.g., `address(0)` or the + * account itself). + */ + error InvalidOwner(address owner); + + /** + * @dev The caller is not authorized. + */ + error NotAuthorized(address caller); + modifier onlyOwner() { _onlyOwner(); _; @@ -108,9 +126,15 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus */ function executeBatch(address[] calldata dest, bytes[] calldata func) external { _requireFromEntryPointOrOwner(); - require(dest.length == func.length, "wrong array lengths"); - for (uint256 i = 0; i < dest.length; i++) { + if (dest.length != func.length) { + revert ArrayLengthMismatch(); + } + uint256 length = dest.length; + for (uint256 i = 0; i < length;) { _call(dest[i], 0, func[i]); + unchecked { + ++i; + } } } @@ -124,12 +148,31 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus */ function executeBatch(address[] calldata dest, uint256[] calldata value, bytes[] calldata func) external { _requireFromEntryPointOrOwner(); - require(dest.length == func.length && dest.length == value.length, "wrong array lengths"); - for (uint256 i = 0; i < dest.length; i++) { + if (dest.length != func.length || dest.length != value.length) { + revert ArrayLengthMismatch(); + } + uint256 length = dest.length; + for (uint256 i = 0; i < length;) { _call(dest[i], value[i], func[i]); + unchecked { + ++i; + } } } + /** + * @notice Transfers ownership of the contract to a new account (`newOwner`). + * Can only be called by the current owner or from the entry point via a + * user operation signed by the current owner. + * @param newOwner The new owner + */ + function transferOwnership(address newOwner) external virtual onlyOwner { + if (newOwner == address(0) || newOwner == address(this)) { + revert InvalidOwner(newOwner); + } + _transferOwnership(newOwner); + } + /** * @notice Called once as part of initialization, either during initial deployment or when first upgrading to * this contract. @@ -158,18 +201,6 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus entryPoint().withdrawTo(withdrawAddress, amount); } - /** - * @notice Transfers ownership of the contract to a new account (`newOwner`). - * Can only be called by the current owner or from the entry point via a - * user operation signed by the current owner. - * @param newOwner The new owner - */ - function transferOwnership(address newOwner) public virtual onlyOwner { - require(newOwner != address(0), "account: new owner is the zero address"); - require(newOwner != address(this), "account: new owner is self"); - _transferOwnership(newOwner); - } - /// @inheritdoc BaseAccount function entryPoint() public view virtual override returns (IEntryPoint) { return _entryPoint; @@ -252,12 +283,16 @@ contract LightAccount is BaseAccount, TokenCallbackHandler, UUPSUpgradeable, Cus function _onlyOwner() internal view { //directly from EOA owner, or through the account itself (which gets redirected through execute()) - require(msg.sender == owner() || msg.sender == address(this), "only owner"); + if (msg.sender != address(this) && msg.sender != owner()) { + revert NotAuthorized(msg.sender); + } } // Require the function call went through EntryPoint or owner function _requireFromEntryPointOrOwner() internal view { - require(msg.sender == address(entryPoint()) || msg.sender == owner(), "account: not Owner or EntryPoint"); + if (msg.sender != address(entryPoint()) && msg.sender != owner()) { + revert NotAuthorized(msg.sender); + } } function _call(address target, uint256 value, bytes memory data) internal { diff --git a/test/CustomSlotInitializable.t.sol b/test/CustomSlotInitializable.t.sol index eb405ec..afb2a31 100644 --- a/test/CustomSlotInitializable.t.sol +++ b/test/CustomSlotInitializable.t.sol @@ -40,26 +40,26 @@ contract CustomSlotInitializableTest is Test { } function testCannotReinitialize() public { - vm.expectRevert(bytes("Initializable: contract is already initialized")); + vm.expectRevert(CustomSlotInitializable.InvalidInitialization.selector); v1Proxy.upgradeToAndCall(v1Impl, abi.encodeCall(V1.initialize, ())); } function testCannotUpgradeBackwards() public { v1Proxy.upgradeToAndCall(v2Impl, abi.encodeCall(V2.initialize, ())); V2 v2Proxy = V2(address(v1Proxy)); - vm.expectRevert(bytes("Initializable: contract is already initialized")); + vm.expectRevert(CustomSlotInitializable.InvalidInitialization.selector); v2Proxy.upgradeToAndCall(v1Impl, abi.encodeCall(V1.initialize, ())); } function testDisableInitializers() public { v1Proxy.disableInitializers(); - vm.expectRevert(bytes("Initializable: contract is already initialized")); + vm.expectRevert(CustomSlotInitializable.InvalidInitialization.selector); v1Proxy.upgradeToAndCall(v2Impl, abi.encodeCall(V2.initialize, ())); } function testCannotCallDisableInitializersInInitializer() public { DisablesInitializersWhileInitializing account = new DisablesInitializersWhileInitializing(); - vm.expectRevert("Initializable: contract is initializing"); + vm.expectRevert(CustomSlotInitializable.InvalidInitialization.selector); account.initialize(); } diff --git a/test/LightAccount.t.sol b/test/LightAccount.t.sol index 4662ae5..f0e0a2c 100644 --- a/test/LightAccount.t.sol +++ b/test/LightAccount.t.sol @@ -82,7 +82,7 @@ contract LightAccountTest is Test { } function testExecuteCannotBeCalledByRandos() public { - vm.expectRevert(bytes("account: not Owner or EntryPoint")); + vm.expectRevert(abi.encodeWithSelector(LightAccount.NotAuthorized.selector, (address(this)))); account.execute(address(lightSwitch), 0, abi.encodeCall(LightSwitch.turnOn, ())); } @@ -110,7 +110,7 @@ contract LightAccountTest is Test { dest[1] = address(lightSwitch); bytes[] memory func = new bytes[](1); func[0] = abi.encodeCall(LightSwitch.turnOn, ()); - vm.expectRevert(bytes("wrong array lengths")); + vm.expectRevert(LightAccount.ArrayLengthMismatch.selector); account.executeBatch(dest, func); } @@ -136,7 +136,7 @@ contract LightAccountTest is Test { value[1] = uint256(1 ether); bytes[] memory func = new bytes[](1); func[0] = abi.encodeCall(LightSwitch.turnOn, ()); - vm.expectRevert(bytes("wrong array lengths")); + vm.expectRevert(LightAccount.ArrayLengthMismatch.selector); account.executeBatch(dest, value, func); } @@ -163,7 +163,7 @@ contract LightAccountTest is Test { function testWithdrawDepositToCannotBeCalledByRandos() public { account.addDeposit{value: 10}(); - vm.expectRevert(bytes("only owner")); + vm.expectRevert(abi.encodeWithSelector(LightAccount.NotAuthorized.selector, (address(this)))); account.withdrawDepositTo(BENEFICIARY, 5); } @@ -189,19 +189,19 @@ contract LightAccountTest is Test { } function testRandosCannotTransferOwnership() public { - vm.expectRevert(bytes("only owner")); + vm.expectRevert(abi.encodeWithSelector(LightAccount.NotAuthorized.selector, (address(this)))); account.transferOwnership(address(0x100)); } function testCannotTransferOwnershipToZero() public { vm.prank(eoaAddress); - vm.expectRevert(bytes("account: new owner is the zero address")); + vm.expectRevert(abi.encodeWithSelector(LightAccount.InvalidOwner.selector, (address(0)))); account.transferOwnership(address(0)); } function testCannotTransferOwnershipToLightContractItself() public { vm.prank(eoaAddress); - vm.expectRevert(bytes("account: new owner is self")); + vm.expectRevert(abi.encodeWithSelector(LightAccount.InvalidOwner.selector, (address(account)))); account.transferOwnership(address(account)); } @@ -244,7 +244,7 @@ contract LightAccountTest is Test { // Try to upgrade to a normal SimpleAccount with a different entry point. IEntryPoint newEntryPoint = IEntryPoint(address(0x2000)); SimpleAccount newImplementation = new SimpleAccount(newEntryPoint); - vm.expectRevert(bytes("only owner")); + vm.expectRevert(abi.encodeWithSelector(LightAccount.NotAuthorized.selector, (address(this)))); account.upgradeToAndCall(address(newImplementation), abi.encodeCall(SimpleAccount.initialize, (address(this)))); }