diff --git a/src/Kernel.sol b/src/Kernel.sol index c210276a..46e6673c 100644 --- a/src/Kernel.sol +++ b/src/Kernel.sol @@ -83,7 +83,7 @@ contract Kernel is IAccount, IAccountExecute, IERC7579Account, ValidationManager if (validator.isModuleType(4)) { bytes memory ret = IHook(address(validator)).preCheck(msg.sender, msg.value, msg.data); _; - IHook(address(validator)).postCheck(ret, true, hex""); // TODO don't support try catch hook here + IHook(address(validator)).postCheck(ret); // TODO don't support try catch hook here } else { revert InvalidCaller(); } @@ -171,7 +171,7 @@ contract Kernel is IAccount, IAccountExecute, IERC7579Account, ValidationManager revert NotSupportedCallType(); } if (address(config.hook) != address(1)) { - _doPostHook(config.hook, context, success, result); + _doPostHook(config.hook, context); } } if (!success) { @@ -255,7 +255,7 @@ contract Kernel is IAccount, IAccountExecute, IERC7579Account, ValidationManager } (bool success, bytes memory ret) = ExecLib.executeDelegatecall(address(this), userOp.callData[4:]); if (address(hook) != address(1)) { - _doPostHook(hook, context, success, ret); + _doPostHook(hook, context); } else if (!success) { revert ExecutionReverted(); } @@ -277,7 +277,7 @@ contract Kernel is IAccount, IAccountExecute, IERC7579Account, ValidationManager } returnData = ExecLib.execute(execMode, executionCalldata); if (address(hook) != address(1)) { - _doPostHook(hook, context, true, abi.encode(returnData)); + _doPostHook(hook, context); } } diff --git a/src/core/HookManager.sol b/src/core/HookManager.sol index f622f515..2e5a5878 100644 --- a/src/core/HookManager.sol +++ b/src/core/HookManager.sol @@ -17,10 +17,10 @@ abstract contract HookManager { context = hook.preCheck(msg.sender, value, callData); } - function _doPostHook(IHook hook, bytes memory context, bool success, bytes memory result) internal { + function _doPostHook(IHook hook, bytes memory context) internal { // bool success, // bytes memory result - hook.postCheck(context, success, result); + hook.postCheck(context); } // @notice if hook is not initialized before, kernel will call hook.onInstall no matter what flag it shows, with hookData[1:] diff --git a/src/interfaces/IERC7579Modules.sol b/src/interfaces/IERC7579Modules.sol index 3d778be6..d07da1b7 100644 --- a/src/interfaces/IERC7579Modules.sol +++ b/src/interfaces/IERC7579Modules.sol @@ -69,30 +69,12 @@ interface IValidator is IModule { interface IExecutor is IModule {} interface IHook is IModule { - /** - * @dev Called by the smart account before execution - * @param msgSender the address that called the smart account - * @param value the value that was sent to the smart account - * @param msgData the data that was sent to the smart account - * - * MAY return arbitrary data in the `hookData` return value - */ - function preCheck(address msgSender, uint256 value, bytes calldata msgData) + function preCheck(address msgSender, uint256 msgValue, bytes calldata msgData) external payable returns (bytes memory hookData); - /** - * @dev Called by the smart account after execution - * @param hookData the data that was returned by the `preCheck` function - * @param executionSuccess whether the execution(s) was (were) successful - * @param executionReturn the return/revert data of the execution(s) - * - * MAY validate the `hookData` to validate transaction context of the `preCheck` function - */ - function postCheck(bytes calldata hookData, bool executionSuccess, bytes calldata executionReturn) - external - payable; + function postCheck(bytes calldata hookData) external payable; } interface IFallback is IModule {} diff --git a/src/mock/MockHook.sol b/src/mock/MockHook.sol index 7acaadac..912c7c76 100644 --- a/src/mock/MockHook.sol +++ b/src/mock/MockHook.sol @@ -34,7 +34,7 @@ contract MockHook is IHook { return data[msg.sender]; } - function postCheck(bytes calldata hookData, bool success, bytes memory res) external payable override { + function postCheck(bytes calldata hookData) external payable override { postHookData[msg.sender] = hookData; } } diff --git a/src/mock/MockValidator.sol b/src/mock/MockValidator.sol index 7e221212..96677dca 100644 --- a/src/mock/MockValidator.sol +++ b/src/mock/MockValidator.sol @@ -73,10 +73,7 @@ contract MockValidator is IValidator, IHook { return hex""; } - function postCheck(bytes calldata hookData, bool executionSuccess, bytes calldata executionReturn) - external - payable - { + function postCheck(bytes calldata hookData) external payable { return; } } diff --git a/src/sdk/KernelTestBase.sol b/src/sdk/KernelTestBase.sol index aedbaafd..cd2ba2a4 100644 --- a/src/sdk/KernelTestBase.sol +++ b/src/sdk/KernelTestBase.sol @@ -17,13 +17,24 @@ import "../mock/MockERC721.sol"; import "../mock/MockERC1155.sol"; import "../core/ValidationManager.sol"; import "./TestBase/erc4337Util.sol"; +import "../types/Types.sol"; +import "../types/Structs.sol"; contract MockCallee { uint256 public value; + event MockEvent(address indexed caller, address indexed here); + function setValue(uint256 _value) public { value = _value; } + + function emitEvent(bool shouldFail) public { + if (shouldFail) { + revert("Hello"); + } + emit MockEvent(msg.sender, address(this)); + } } abstract contract KernelTestBase is Test { @@ -853,6 +864,46 @@ abstract contract KernelTestBase is Test { entrypoint.handleOps(ops, payable(address(0xdeadbeef))); } + function testExecute(CallType callType, ExecType execType, bool shouldFail) external whenInitialized { + unchecked { + vm.assume(uint8(CallType.unwrap(callType)) + 1 < 3); //only call/batch/delegatecall + vm.assume(uint8(ExecType.unwrap(execType)) < 2); + } + vm.startPrank(address(entrypoint)); + ExecMode code = ExecLib.encode(callType, execType, ExecModeSelector.wrap(0x00), ExecModePayload.wrap(0x00)); + if (callType == CALLTYPE_BATCH) { + Execution[] memory execs = new Execution[](1); + execs[0] = Execution({ + target: address(callee), + value: 0, + callData: abi.encodeWithSelector(MockCallee.emitEvent.selector, shouldFail) + }); + bytes memory data = ExecLib.encodeBatch(execs); + if (execType == EXECTYPE_DEFAULT && shouldFail) { + vm.expectRevert(); + } + kernel.execute(code, data); + } else if (callType == CALLTYPE_SINGLE) { + if (execType == EXECTYPE_DEFAULT && shouldFail) { + vm.expectRevert(); + } + kernel.execute( + code, + abi.encodePacked( + address(callee), uint256(0), abi.encodeWithSelector(MockCallee.emitEvent.selector, shouldFail) + ) + ); + } else { + if (execType == EXECTYPE_DEFAULT && shouldFail) { + vm.expectRevert(); + } + kernel.execute( + code, + abi.encodePacked(address(callee), abi.encodeWithSelector(MockCallee.emitEvent.selector, shouldFail)) + ); + } + } + function testExecutorInstall(bool withHook) external whenInitialized { _installExecutor(withHook); assertEq(mockExecutor.data(address(kernel)), abi.encodePacked("executorData")); diff --git a/src/utils/ExecLib.sol b/src/utils/ExecLib.sol index 631ca0b8..08485b2c 100644 --- a/src/utils/ExecLib.sol +++ b/src/utils/ExecLib.sol @@ -49,9 +49,18 @@ library ExecLib { revert("Unsupported"); } } else if (callType == CALLTYPE_DELEGATECALL) { + returnData = new bytes[](1); address delegate = address(bytes20(executionCalldata[0:20])); bytes calldata callData = executionCalldata[20:]; - executeDelegatecall(delegate, callData); + bool success; + (success, returnData[0]) = executeDelegatecall(delegate, callData); + if (execType == EXECTYPE_TRY) { + if (!success) emit TryExecuteUnsuccessful(0, returnData[0]); + } else if (execType == EXECTYPE_DEFAULT) { + if (!success) revert("Delegatecall failed"); + } else { + revert("Unsupported"); + } } else { revert("Unsupported"); } diff --git a/src/validator/ECDSAValidator.sol b/src/validator/ECDSAValidator.sol index 00e6aca5..4d9e98c1 100644 --- a/src/validator/ECDSAValidator.sol +++ b/src/validator/ECDSAValidator.sol @@ -94,5 +94,5 @@ contract ECDSAValidator is IValidator, IHook { return hex""; } - function postCheck(bytes calldata hookData, bool success, bytes calldata res) external payable override {} + function postCheck(bytes calldata hookData) external payable override {} } diff --git a/src/validator/MultiSignatureECDSAValidator.sol b/src/validator/MultiSignatureECDSAValidator.sol index e9084de6..98121e4b 100644 --- a/src/validator/MultiSignatureECDSAValidator.sol +++ b/src/validator/MultiSignatureECDSAValidator.sol @@ -129,5 +129,5 @@ contract MultiSignatureECDSAValidator is IValidator, IHook { return hex""; } - function postCheck(bytes calldata hookData, bool success, bytes calldata res) external payable override {} + function postCheck(bytes calldata hookData) external payable override {} }