Skip to content

Commit

Permalink
fix: catch data in all cases (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed authored Aug 22, 2024
1 parent 78bc489 commit 0c18f92
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/account/ModuleManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pragma solidity ^0.8.25;
import {ERC165Checker} from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {collectReturnData} from "../helpers/CollectReturnData.sol";
import {MAX_PRE_VALIDATION_HOOKS} from "../helpers/Constants.sol";
import {HookConfigLib} from "../helpers/HookConfigLib.sol";
import {KnownSelectors} from "../helpers/KnownSelectors.sol";
Expand Down Expand Up @@ -170,7 +171,8 @@ abstract contract ModuleManagerInternals is IModularAccount {
// Initialize the module storage for the account.
// solhint-disable-next-line no-empty-blocks
try IModule(module).onInstall(moduleInstallData) {}
catch (bytes memory revertReason) {
catch {
bytes memory revertReason = collectReturnData();
revert ModuleInstallCallbackFailed(module, revertReason);
}

Expand Down
14 changes: 9 additions & 5 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeab
import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {collectReturnData} from "../helpers/CollectReturnData.sol";
import {DIRECT_CALL_VALIDATION_ENTITYID} from "../helpers/Constants.sol";
import {HookConfig, HookConfigLib} from "../helpers/HookConfigLib.sol";
import {ModuleEntityLib} from "../helpers/ModuleEntityLib.sol";
Expand Down Expand Up @@ -460,8 +461,8 @@ contract UpgradeableModularAccount is
bytes memory returnData
) {
preExecHookReturnData = returnData;
} catch (bytes memory revertReason) {
// TODO: same issue with EP0.6 - we can't do bytes4 error codes in modules
} catch {
bytes memory revertReason = collectReturnData();
revert PreExecHookReverted(module, entityId, revertReason);
}
}
Expand All @@ -483,7 +484,8 @@ contract UpgradeableModularAccount is
(address module, uint32 entityId) = postHookToRun.postExecHook.unpack();
// solhint-disable-next-line no-empty-blocks
try IExecutionHookModule(module).postExecutionHook(entityId, postHookToRun.preExecHookReturnData) {}
catch (bytes memory revertReason) {
catch {
bytes memory revertReason = collectReturnData();
revert PostExecHookReverted(module, entityId, revertReason);
}
}
Expand All @@ -500,8 +502,9 @@ contract UpgradeableModularAccount is
)
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
{} catch (bytes memory revertReason){
{} catch{
// forgefmt: disable-end
bytes memory revertReason = collectReturnData();
revert PreRuntimeValidationHookFailed(hookModule, hookEntityId, revertReason);
}
}
Expand Down Expand Up @@ -585,8 +588,9 @@ contract UpgradeableModularAccount is
)
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
{} catch (bytes memory revertReason){
{} catch{
// forgefmt: disable-end
bytes memory revertReason = collectReturnData();
revert RuntimeValidationFunctionReverted(module, entityId, revertReason);
}
}
Expand Down
14 changes: 14 additions & 0 deletions src/helpers/CollectReturnData.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.25;

function collectReturnData() pure returns (bytes memory returnData) {
assembly ("memory-safe") {
// Allocate a buffer of that size, advancing the memory pointer to the nearest word
returnData := mload(0x40)
mstore(returnData, returndatasize())
mstore(0x40, and(add(add(returnData, returndatasize()), 0x3f), not(0x1f)))

// Copy over the return data
returndatacopy(add(returnData, 0x20), 0, returndatasize())
}
}

0 comments on commit 0c18f92

Please sign in to comment.