Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add default hook, now only support msg.sender == entrypoint as default #110

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/Kernel.sol
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ contract Kernel is IAccount, IAccountExecute, IERC7579Account, ValidationManager

function _domainNameAndVersion() internal pure override returns (string memory name, string memory version) {
name = "Kernel";
version = "0.3.0-beta";
version = "0.3.1-beta";
}

receive() external payable {
Expand Down Expand Up @@ -151,8 +151,16 @@ contract Kernel is IAccount, IAccountExecute, IERC7579Account, ValidationManager
} else {
// action installed
bytes memory context;
if (address(config.hook) != address(1)) {
if (
address(config.hook) != address(1) && address(config.hook) != 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF
) {
context = _doPreHook(config.hook, msg.value, msg.data);
} else if (address(config.hook) == 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) {
// for selector manager, address(0) for the hook will default to type(address).max,
// and this will only allow entrypoints to interact
if (msg.sender != address(entrypoint)) {
revert InvalidCaller();
}
}
// execute action
if (config.callType == CALLTYPE_SINGLE) {
Expand Down Expand Up @@ -312,14 +320,18 @@ contract Kernel is IAccount, IAccountExecute, IERC7579Account, ValidationManager
ValidationConfig({nonce: vs.currentNonce, hook: IHook(address(bytes20(initData[0:20])))});
bytes calldata validatorData;
bytes calldata hookData;
bytes calldata selectorData;
assembly {
validatorData.offset := add(add(initData.offset, 52), calldataload(add(initData.offset, 20)))
validatorData.length := calldataload(sub(validatorData.offset, 32))
hookData.offset := add(add(initData.offset, 52), calldataload(add(initData.offset, 52)))
hookData.length := calldataload(sub(hookData.offset, 32))
selectorData.offset := add(add(initData.offset, 52), calldataload(add(initData.offset, 84)))
selectorData.length := calldataload(sub(selectorData.offset, 32))
}
_installValidation(vId, config, validatorData, hookData);
//_installHook(config.hook, hookData); hook install is handled inside installvalidation
// NOTE: we don't allow configure on selector data on v3.1, but using bytes instead of bytes4 for selector data to make sure we are future proof
_setSelector(vId, bytes4(selectorData[0:4]), true);
} else if (moduleType == MODULE_TYPE_EXECUTOR) {
bytes calldata executorData;
bytes calldata hookData;
Expand Down Expand Up @@ -470,7 +482,7 @@ contract Kernel is IAccount, IAccountExecute, IERC7579Account, ValidationManager
}

function accountId() external pure override returns (string memory accountImplementationId) {
return "kernel.advanced.v0.3.0-beta";
return "kernel.advanced.v0.3.1";
}

function supportsExecutionMode(ExecMode mode) external pure override returns (bool) {
Expand Down
2 changes: 1 addition & 1 deletion src/core/SelectorManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ abstract contract SelectorManager {

function _installSelector(bytes4 selector, address target, IHook hook, bytes calldata selectorData) internal {
if (address(hook) == address(0)) {
hook = IHook(address(1));
hook = IHook(address(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF));
}
SelectorConfig storage ss = _selectorConfig(selector);
// we are going to install only through call/delegatecall
Expand Down
110 changes: 83 additions & 27 deletions src/sdk/KernelTestBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ abstract contract KernelTestBase is Test {
);

bytes32 digest = keccak256(
abi.encodePacked("\x19\x01", _buildDomainSeparator("Kernel", "0.3.0-beta", address(kernel)), hash)
abi.encodePacked("\x19\x01", _buildDomainSeparator("Kernel", "0.3.1-beta", address(kernel)), hash)
);

return digest;
Expand Down Expand Up @@ -638,7 +638,13 @@ abstract contract KernelTestBase is Test {
return abi.encode(permissions);
}

function _installAction(bool withHook) internal {
enum HookInfo {
NoHook,
DefaultHook,
WithHook
}

function _installAction(HookInfo withHook) internal {
vm.deal(address(kernel), 1e18);
MockAction mockAction = new MockAction();
PackedUserOperation[] memory ops = new PackedUserOperation[](1);
Expand All @@ -652,8 +658,10 @@ abstract contract KernelTestBase is Test {
address(mockAction),
abi.encodePacked(
MockAction.doSomething.selector,
withHook ? address(mockHook) : address(0),
withHook
withHook == HookInfo.WithHook
? address(mockHook)
: withHook == HookInfo.NoHook ? address(1) : address(0),
withHook == HookInfo.WithHook
? abi.encode(hex"ff", abi.encodePacked(bytes1(0xff), "hookData"))
: abi.encode(hex"ff", hex"")
)
Expand All @@ -664,14 +672,33 @@ abstract contract KernelTestBase is Test {
entrypoint.handleOps(ops, payable(address(0xdeadbeef)));
}

function testActionInstall(bool withHook) external whenInitialized {
function testActionInstall(uint8 hookUint) external whenInitialized {
vm.assume(uint8(hookUint) < 3);
HookInfo withHook = HookInfo(hookUint);
_installAction(withHook);
SelectorManager.SelectorConfig memory config = kernel.selectorConfig(MockAction.doSomething.selector);
assertEq(address(config.hook), withHook ? address(mockHook) : address(1));
vm.expectEmit(address(kernel));
emit MockAction.MockActionEvent(address(kernel));
MockAction(address(kernel)).doSomething();
if (withHook) {
assertEq(
address(config.hook),
withHook == HookInfo.WithHook
? address(mockHook)
: withHook == HookInfo.NoHook ? address(1) : address(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)
);
if (withHook != HookInfo.DefaultHook) {
vm.expectEmit(address(kernel));
emit MockAction.MockActionEvent(address(kernel));
MockAction(address(kernel)).doSomething();
} else {
vm.expectRevert();
MockAction(address(kernel)).doSomething();
PackedUserOperation memory op = _prepareUserOp(
VALIDATION_TYPE_ROOT, false, false, abi.encodeWithSelector(MockAction.doSomething.selector), true, true
);
PackedUserOperation[] memory ops = new PackedUserOperation[](1);
ops[0] = op;
entrypoint.handleOps(ops, payable(address(0xdeadbeef)));
}

if (withHook == HookInfo.WithHook) {
assertEq(mockHook.data(address(kernel)), abi.encodePacked("hookData"));
assertEq(
mockHook.preHookData(address(kernel)), abi.encodePacked(address(this), MockAction.doSomething.selector)
Expand All @@ -680,7 +707,9 @@ abstract contract KernelTestBase is Test {
}
}

function testActionUninstall(bool withHook) external whenInitialized {
function testActionUninstall(uint8 hookUint) external whenInitialized {
vm.assume(uint8(hookUint) < 3);
HookInfo withHook = HookInfo(hookUint);
_installAction(withHook);
PackedUserOperation[] memory ops = new PackedUserOperation[](1);
ops[0] = _prepareUserOp(
Expand All @@ -703,7 +732,7 @@ abstract contract KernelTestBase is Test {
assertEq(address(config.target), address(0));
}

function _installFallback(bool withHook) internal {
function _installFallback(HookInfo withHook) internal {
vm.deal(address(kernel), 1e18);
PackedUserOperation[] memory ops = new PackedUserOperation[](1);
ops[0] = _prepareUserOp(
Expand All @@ -716,8 +745,10 @@ abstract contract KernelTestBase is Test {
address(mockFallback),
abi.encodePacked(
MockFallback.fallbackFunction.selector,
withHook ? address(mockHook) : address(0),
withHook
withHook == HookInfo.WithHook
? address(mockHook)
: withHook == HookInfo.NoHook ? address(1) : address(0),
withHook == HookInfo.WithHook
? abi.encode(abi.encodePacked(hex"00", "fallbackData"), abi.encodePacked(bytes1(0xff), "hookData"))
: abi.encode(abi.encodePacked(hex"00", "fallbackData"), abi.encodePacked(""))
)
Expand All @@ -728,20 +759,43 @@ abstract contract KernelTestBase is Test {
entrypoint.handleOps(ops, payable(address(0xdeadbeef)));
}

function testFallbackInstall(bool withHook) external whenInitialized {
function testFallbackInstall(uint8 hookUint) external whenInitialized {
vm.assume(uint8(hookUint) < 3);
HookInfo withHook = HookInfo(hookUint);
_installFallback(withHook);
assertEq(mockFallback.data(address(kernel)), abi.encodePacked("fallbackData"));

SelectorManager.SelectorConfig memory config = kernel.selectorConfig(MockFallback.fallbackFunction.selector);
assertEq(address(config.hook), withHook ? address(mockHook) : address(1));
assertEq(
address(config.hook),
withHook == HookInfo.WithHook
? address(mockHook)
: withHook == HookInfo.NoHook ? address(1) : address(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF)
);
assertEq(address(config.target), address(mockFallback));

(bool success, bytes memory result) =
address(kernel).call(abi.encodeWithSelector(MockFallback.fallbackFunction.selector, uint256(10)));
assertTrue(success);
(uint256 res) = abi.decode(result, (uint256));
assertEq(res, 100);
if (withHook) {
if (withHook != HookInfo.DefaultHook) {
(bool success, bytes memory result) =
address(kernel).call(abi.encodeWithSelector(MockFallback.fallbackFunction.selector, uint256(10)));
assertTrue(success);
(uint256 res) = abi.decode(result, (uint256));
assertEq(res, 100);
} else {
(bool success, bytes memory result) =
address(kernel).call(abi.encodeWithSelector(MockFallback.fallbackFunction.selector, uint256(10)));
assertFalse(success);
PackedUserOperation memory op = _prepareUserOp(
VALIDATION_TYPE_ROOT,
false,
false,
abi.encodeWithSelector(MockFallback.fallbackFunction.selector, uint256(10)),
true,
true
);
PackedUserOperation[] memory ops = new PackedUserOperation[](1);
ops[0] = op;
entrypoint.handleOps(ops, payable(address(0xdeadbeef)));
}
if (withHook == HookInfo.WithHook) {
assertEq(mockHook.data(address(kernel)), abi.encodePacked("hookData"));
assertEq(
mockHook.preHookData(address(kernel)),
Expand All @@ -751,7 +805,9 @@ abstract contract KernelTestBase is Test {
}
}

function testFallbackUninstall(bool withHook) external whenInitialized {
function testFallbackUninstall(uint8 hookUint) external whenInitialized {
vm.assume(uint8(hookUint) < 3);
HookInfo withHook = HookInfo(hookUint);
_installFallback(withHook);
PackedUserOperation[] memory ops = new PackedUserOperation[](1);
ops[0] = _prepareUserOp(
Expand Down Expand Up @@ -840,7 +896,7 @@ abstract contract KernelTestBase is Test {
function testSignatureRoot(bytes32 hash) external whenInitialized {
bytes32 wrappedHash = keccak256(abi.encode(keccak256("Kernel(bytes32 hash)"), hash));
bytes32 digest = keccak256(
abi.encodePacked("\x19\x01", _buildDomainSeparator("Kernel", "0.3.0-beta", address(kernel)), wrappedHash)
abi.encodePacked("\x19\x01", _buildDomainSeparator("Kernel", "0.3.1-beta", address(kernel)), wrappedHash)
);
bytes memory sig = _rootSignDigest(digest, true);
sig = abi.encodePacked(hex"00", sig);
Expand Down Expand Up @@ -868,7 +924,7 @@ abstract contract KernelTestBase is Test {

bytes32 wrappedHash = keccak256(abi.encode(keccak256("Kernel(bytes32 hash)"), hash));
bytes32 digest = keccak256(
abi.encodePacked("\x19\x01", _buildDomainSeparator("Kernel", "0.3.0-beta", address(kernel)), wrappedHash)
abi.encodePacked("\x19\x01", _buildDomainSeparator("Kernel", "0.3.1-beta", address(kernel)), wrappedHash)
);
bytes memory sig = _validatorSignDigest(digest, true);
sig = abi.encodePacked(hex"01", address(enabledValidator), sig);
Expand All @@ -895,7 +951,7 @@ abstract contract KernelTestBase is Test {
);
bytes32 wrappedHash = keccak256(abi.encode(keccak256("Kernel(bytes32 hash)"), hash));
bytes32 digest = keccak256(
abi.encodePacked("\x19\x01", _buildDomainSeparator("Kernel", "0.3.0-beta", address(kernel)), wrappedHash)
abi.encodePacked("\x19\x01", _buildDomainSeparator("Kernel", "0.3.1-beta", address(kernel)), wrappedHash)
);
bytes memory sig = _permissionSignDigest(digest, true);
sig = abi.encodePacked(hex"02", PermissionId.unwrap(enabledPermission), hex"ff", sig);
Expand Down