Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: [v0.8-develop] invert validation mapping #85

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet
import {IAccountLoupe, ExecutionHook} from "../interfaces/IAccountLoupe.sol";
import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.sol";
import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol";
import {getAccountStorage, toFunctionReferenceArray, toExecutionHook} from "./AccountStorage.sol";
import {getAccountStorage, toExecutionHook, toSelector} from "./AccountStorage.sol";

abstract contract AccountLoupe is IAccountLoupe {
using EnumerableSet for EnumerableSet.Bytes32Set;
Expand All @@ -28,8 +28,16 @@ abstract contract AccountLoupe is IAccountLoupe {
}

/// @inheritdoc IAccountLoupe
function getValidations(bytes4 selector) external view override returns (FunctionReference[] memory) {
return toFunctionReferenceArray(getAccountStorage().selectorData[selector].validations);
function getSelectors(FunctionReference validationFunction) external view returns (bytes4[] memory) {
uint256 length = getAccountStorage().validationData[validationFunction].selectors.length();

bytes4[] memory selectors = new bytes4[](length);

for (uint256 i = 0; i < length; ++i) {
selectors[i] = toSelector(getAccountStorage().validationData[validationFunction].selectors.at(i));
}

return selectors;
}

/// @inheritdoc IAccountLoupe
Expand Down
12 changes: 10 additions & 2 deletions src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ struct SelectorData {
bool allowDefaultValidation;
// The execution hooks for this function selector.
EnumerableSet.Bytes32Set executionHooks;
// Which validation functions are associated with this function selector.
EnumerableSet.Bytes32Set validations;
}

struct ValidationData {
Expand All @@ -44,6 +42,8 @@ struct ValidationData {
FunctionReference[] preValidationHooks;
// Permission hooks for this validation function.
EnumerableSet.Bytes32Set permissionHooks;
// The set of selectors that may be validated by this validation function.
EnumerableSet.Bytes32Set selectors;
}

struct AccountStorage {
Expand Down Expand Up @@ -96,6 +96,14 @@ function toExecutionHook(bytes32 setValue)
isPostHook = (uint256(setValue) >> 72) & 0xFF == 1;
}

function toSetValue(bytes4 selector) pure returns (bytes32) {
return bytes32(selector);
}

function toSelector(bytes32 setValue) pure returns (bytes4) {
return bytes4(setValue);
}

/// @dev Helper function to get all elements of a set into memory.
function toFunctionReferenceArray(EnumerableSet.Bytes32Set storage set)
view
Expand Down
15 changes: 5 additions & 10 deletions src/account/PluginManager2.sol
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ abstract contract PluginManager2 {

for (uint256 i = 0; i < selectors.length; ++i) {
bytes4 selector = selectors[i];
if (!_storage.selectorData[selector].validations.add(toSetValue(validationFunction))) {
if (!_storage.validationData[validationFunction].selectors.add(toSetValue(selector))) {
revert ValidationAlreadySet(selector, validationFunction);
}
}
Expand All @@ -102,7 +102,6 @@ abstract contract PluginManager2 {

function _uninstallValidation(
FunctionReference validationFunction,
bytes4[] calldata selectors,
bytes calldata uninstallData,
bytes calldata preValidationHookUninstallData,
bytes calldata permissionHookUninstallData
Expand Down Expand Up @@ -144,14 +143,10 @@ abstract contract PluginManager2 {
}
delete _storage.validationData[validationFunction].preValidationHooks;

// Because this function also calls `onUninstall`, and removes the default flag from validation, we must
// assume these selectors passed in to be exhaustive.
// TODO: consider enforcing this from user-supplied install config.
for (uint256 i = 0; i < selectors.length; ++i) {
bytes4 selector = selectors[i];
if (!_storage.selectorData[selector].validations.remove(toSetValue(validationFunction))) {
revert ValidationNotSet(selector, validationFunction);
}
// Clear selectors
while (_storage.validationData[validationFunction].selectors.length() > 0) {
bytes32 selector = _storage.validationData[validationFunction].selectors.at(0);
_storage.validationData[validationFunction].selectors.remove(selector);
}

if (uninstallData.length > 0) {
Expand Down
8 changes: 2 additions & 6 deletions src/account/PluginManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,10 @@ abstract contract PluginManagerInternals is IPluginManager {
internal
notNullFunction(validationFunction)
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

// Fail on duplicate validation functions. Otherwise, dependency validation functions could shadow
// non-depdency validation functions. Then, if a either plugin is uninstalled, it would cause a partial
// uninstall of the other.
if (!_selectorData.validations.add(toSetValue(validationFunction))) {
if (!getAccountStorage().validationData[validationFunction].selectors.add(toSetValue(selector))) {
revert ValidationFunctionAlreadySet(selector, validationFunction);
}
}
Expand All @@ -117,11 +115,9 @@ abstract contract PluginManagerInternals is IPluginManager {
internal
notNullFunction(validationFunction)
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

// May ignore return value, as the manifest hash is validated to ensure that the validation function
// exists.
_selectorData.validations.remove(toSetValue(validationFunction));
getAccountStorage().validationData[validationFunction].selectors.remove(toSetValue(selector));
}

function _addExecHooks(
Expand Down
9 changes: 2 additions & 7 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,12 @@ contract UpgradeableModularAccount is
/// @notice May be validated by a default validation.
function uninstallValidation(
FunctionReference validationFunction,
bytes4[] calldata selectors,
bytes calldata uninstallData,
bytes calldata preValidationHookUninstallData,
bytes calldata permissionHookUninstallData
) external wrapNativeFunction {
_uninstallValidation(
validationFunction,
selectors,
uninstallData,
preValidationHookUninstallData,
permissionHookUninstallData
validationFunction, uninstallData, preValidationHookUninstallData, permissionHookUninstallData
);
}

Expand Down Expand Up @@ -685,7 +680,7 @@ contract UpgradeableModularAccount is
}
} else {
// Not default validation, but per-selector
if (!getAccountStorage().selectorData[selector].validations.contains(toSetValue(validationFunction))) {
if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) {
revert UserOpValidationFunctionMissing(selector);
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/helpers/KnownSelectors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ library KnownSelectors {
|| selector == IStandardExecutor.executeWithAuthorization.selector
// check against IAccountLoupe methods
|| selector == IAccountLoupe.getExecutionFunctionHandler.selector
|| selector == IAccountLoupe.getValidations.selector
|| selector == IAccountLoupe.getExecutionHooks.selector
|| selector == IAccountLoupe.getSelectors.selector || selector == IAccountLoupe.getExecutionHooks.selector
|| selector == IAccountLoupe.getPreValidationHooks.selector
|| selector == IAccountLoupe.getInstalledPlugins.selector;
}
Expand Down
8 changes: 4 additions & 4 deletions src/interfaces/IAccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ interface IAccountLoupe {
/// @return plugin The plugin address for this selector.
function getExecutionFunctionHandler(bytes4 selector) external view returns (address plugin);

/// @notice Get the validation functions for a selector.
/// @param selector The selector to get the validation functions for.
/// @return The validation functions for this selector.
function getValidations(bytes4 selector) external view returns (FunctionReference[] memory);
/// @notice Get the selectors for a validation function.
/// @param validationFunction The validation function to get the selectors for.
/// @return The allowed selectors for this validation function.
function getSelectors(FunctionReference validationFunction) external view returns (bytes4[] memory);

/// @notice Get the pre and post execution hooks for a selector.
/// @param selector The selector to get the hooks for.
Expand Down
2 changes: 0 additions & 2 deletions src/interfaces/IPluginManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ interface IPluginManager {
/// @notice Uninstall a validation function from a set of execution selectors.
/// TODO: remove or update.
/// @param validationFunction The validation function to uninstall.
/// @param selectors The selectors to uninstall the validation function for.
/// @param uninstallData Optional data to be decoded and used by the plugin to clear plugin data for the
/// account.
/// @param preValidationHookUninstallData Optional data to be decoded and used by the plugin to clear account
/// data
/// @param permissionHookUninstallData Optional data to be decoded and used by the plugin to clear account data
function uninstallValidation(
FunctionReference validationFunction,
bytes4[] calldata selectors,
bytes calldata uninstallData,
bytes calldata preValidationHookUninstallData,
bytes calldata permissionHookUninstallData
Expand Down
20 changes: 6 additions & 14 deletions test/account/AccountLoupe.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,15 @@ contract AccountLoupeTest is AccountTestBase {
}
}

function test_pluginLoupe_getValidationFunctions() public {
FunctionReference[] memory validations = account1.getValidations(comprehensivePlugin.foo.selector);

assertEq(validations.length, 1);
assertEq(
FunctionReference.unwrap(validations[0]),
FunctionReference.unwrap(
FunctionReferenceLib.pack(
address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION)
)
)
function test_pluginLoupe_getSelectors() public {
FunctionReference comprehensivePluginValidation = FunctionReferenceLib.pack(
address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION)
);

validations = account1.getValidations(account1.execute.selector);
bytes4[] memory selectors = account1.getSelectors(comprehensivePluginValidation);

assertEq(validations.length, 1);
assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(_ownerValidation));
assertEq(selectors.length, 1);
assertEq(selectors[0], comprehensivePlugin.foo.selector);
}

function test_pluginLoupe_getExecutionHooks() public {
Expand Down
11 changes: 7 additions & 4 deletions test/account/MultiValidation.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ contract MultiValidationTest is AccountTestBase {
);
validations[1] =
FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER));
FunctionReference[] memory validations2 = account1.getValidations(IStandardExecutor.execute.selector);
assertEq(validations2.length, 2);
assertEq(FunctionReference.unwrap(validations2[0]), FunctionReference.unwrap(validations[0]));
assertEq(FunctionReference.unwrap(validations2[1]), FunctionReference.unwrap(validations[1]));

bytes4[] memory selectors0 = account1.getSelectors(validations[0]);
bytes4[] memory selectors1 = account1.getSelectors(validations[1]);
assertEq(selectors0.length, selectors1.length);
for (uint256 i = 0; i < selectors0.length; i++) {
assertEq(selectors0[i], selectors1[i]);
}
}

function test_runtimeValidation_specify() public {
Expand Down
Loading