Skip to content

Commit

Permalink
refactor validation mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed committed Jun 26, 2024
1 parent 1e415b1 commit 364e06b
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 50 deletions.
14 changes: 11 additions & 3 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet
import {IAccountLoupe, ExecutionHook} from "../interfaces/IAccountLoupe.sol";
import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.sol";
import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol";
import {getAccountStorage, SelectorData, toFunctionReferenceArray, toExecutionHook} from "./AccountStorage.sol";
import {getAccountStorage, SelectorData, toExecutionHook, toSelector} from "./AccountStorage.sol";

abstract contract AccountLoupe is IAccountLoupe {
using EnumerableSet for EnumerableSet.Bytes32Set;
Expand All @@ -28,8 +28,16 @@ abstract contract AccountLoupe is IAccountLoupe {
}

/// @inheritdoc IAccountLoupe
function getValidations(bytes4 selector) external view override returns (FunctionReference[] memory) {
return toFunctionReferenceArray(getAccountStorage().selectorData[selector].validations);
function getSelectors(FunctionReference validationFunction) external view returns (bytes4[] memory) {
uint256 length = getAccountStorage().validationData[validationFunction].selectors.length();

bytes4[] memory selectors = new bytes4[](length);

for (uint256 i = 0; i < length; ++i) {
selectors[i] = toSelector(getAccountStorage().validationData[validationFunction].selectors.at(i));
}

return selectors;
}

/// @inheritdoc IAccountLoupe
Expand Down
12 changes: 10 additions & 2 deletions src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ struct SelectorData {
bool allowDefaultValidation;
// The execution hooks for this function selector.
EnumerableSet.Bytes32Set executionHooks;
// Which validation functions are associated with this function selector.
EnumerableSet.Bytes32Set validations;
}

struct ValidationData {
Expand All @@ -40,6 +38,8 @@ struct ValidationData {
bool isSignatureValidation;
// The pre validation hooks for this function selector.
FunctionReference[] preValidationHooks;
// The set of selectors that may be validated by this validation function.
EnumerableSet.Bytes32Set selectors;
}

struct AccountStorage {
Expand Down Expand Up @@ -93,6 +93,14 @@ function toExecutionHook(bytes32 setValue)
isPostHook = (uint256(setValue) >> 72) & 0xFF == 1;
}

function toSetValue(bytes4 selector) pure returns (bytes32) {
return bytes32(selector);
}

function toSelector(bytes32 setValue) pure returns (bytes4) {
return bytes4(setValue);
}

/// @dev Helper function to get all elements of a set into memory.
function toFunctionReferenceArray(EnumerableSet.Bytes32Set storage set)
view
Expand Down
15 changes: 5 additions & 10 deletions src/account/PluginManager2.sol
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ abstract contract PluginManager2 {

for (uint256 i = 0; i < selectors.length; ++i) {
bytes4 selector = selectors[i];
if (!_storage.selectorData[selector].validations.add(toSetValue(validationFunction))) {
if (!_storage.validationData[validationFunction].selectors.add(toSetValue(selector))) {
revert ValidationAlreadySet(selector, validationFunction);
}
}
Expand All @@ -79,7 +79,6 @@ abstract contract PluginManager2 {

function _uninstallValidation(
FunctionReference validationFunction,
bytes4[] calldata selectors,
bytes calldata uninstallData,
bytes calldata preValidationHookUninstallData
) internal {
Expand All @@ -102,14 +101,10 @@ abstract contract PluginManager2 {
}
delete _storage.validationData[validationFunction].preValidationHooks;

// Because this function also calls `onUninstall`, and removes the default flag from validation, we must
// assume these selectors passed in to be exhaustive.
// TODO: consider enforcing this from user-supplied install config.
for (uint256 i = 0; i < selectors.length; ++i) {
bytes4 selector = selectors[i];
if (!_storage.selectorData[selector].validations.remove(toSetValue(validationFunction))) {
revert ValidationNotSet(selector, validationFunction);
}
// Clear selectors
while (_storage.validationData[validationFunction].selectors.length() > 0) {
bytes32 selector = _storage.validationData[validationFunction].selectors.at(0);
_storage.validationData[validationFunction].selectors.remove(selector);
}

if (uninstallData.length > 0) {
Expand Down
8 changes: 2 additions & 6 deletions src/account/PluginManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,10 @@ abstract contract PluginManagerInternals is IPluginManager {
internal
notNullFunction(validationFunction)
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

// Fail on duplicate validation functions. Otherwise, dependency validation functions could shadow
// non-depdency validation functions. Then, if a either plugin is uninstalled, it would cause a partial
// uninstall of the other.
if (!_selectorData.validations.add(toSetValue(validationFunction))) {
if (!getAccountStorage().validationData[validationFunction].selectors.add(toSetValue(selector))) {
revert ValidationFunctionAlreadySet(selector, validationFunction);
}
}
Expand All @@ -117,11 +115,9 @@ abstract contract PluginManagerInternals is IPluginManager {
internal
notNullFunction(validationFunction)
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

// May ignore return value, as the manifest hash is validated to ensure that the validation function
// exists.
_selectorData.validations.remove(toSetValue(validationFunction));
getAccountStorage().validationData[validationFunction].selectors.remove(toSetValue(selector));
}

function _addExecHooks(
Expand Down
5 changes: 2 additions & 3 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,10 @@ contract UpgradeableModularAccount is
/// @notice May be validated by a default validation.
function uninstallValidation(
FunctionReference validationFunction,
bytes4[] calldata selectors,
bytes calldata uninstallData,
bytes calldata preValidationHookUninstallData
) external wrapNativeFunction {
_uninstallValidation(validationFunction, selectors, uninstallData, preValidationHookUninstallData);
_uninstallValidation(validationFunction, uninstallData, preValidationHookUninstallData);
}

/// @notice ERC165 introspection
Expand Down Expand Up @@ -623,7 +622,7 @@ contract UpgradeableModularAccount is
}
} else {
// Not default validation, but per-selector
if (!getAccountStorage().selectorData[selector].validations.contains(toSetValue(validationFunction))) {
if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) {
revert UserOpValidationFunctionMissing(selector);
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/helpers/KnownSelectors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ library KnownSelectors {
|| selector == IStandardExecutor.executeWithAuthorization.selector
// check against IAccountLoupe methods
|| selector == IAccountLoupe.getExecutionFunctionHandler.selector
|| selector == IAccountLoupe.getValidations.selector
|| selector == IAccountLoupe.getExecutionHooks.selector
|| selector == IAccountLoupe.getSelectors.selector || selector == IAccountLoupe.getExecutionHooks.selector
|| selector == IAccountLoupe.getPreValidationHooks.selector
|| selector == IAccountLoupe.getInstalledPlugins.selector;
}
Expand Down
8 changes: 4 additions & 4 deletions src/interfaces/IAccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ interface IAccountLoupe {
/// @return plugin The plugin address for this selector.
function getExecutionFunctionHandler(bytes4 selector) external view returns (address plugin);

/// @notice Get the validation functions for a selector.
/// @param selector The selector to get the validation functions for.
/// @return The validation functions for this selector.
function getValidations(bytes4 selector) external view returns (FunctionReference[] memory);
/// @notice Get the selectors for a validation function.
/// @param validationFunction The validation function to get the selectors for.
/// @return The allowed selectors for this validation function.
function getSelectors(FunctionReference validationFunction) external view returns (bytes4[] memory);

/// @notice Get the pre and post execution hooks for a selector.
/// @param selector The selector to get the hooks for.
Expand Down
2 changes: 0 additions & 2 deletions src/interfaces/IPluginManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ interface IPluginManager {
/// @notice Uninstall a validation function from a set of execution selectors.
/// TODO: remove or update.
/// @param validationFunction The validation function to uninstall.
/// @param selectors The selectors to uninstall the validation function for.
/// @param uninstallData Optional data to be decoded and used by the plugin to clear plugin data for the
/// account.
function uninstallValidation(
FunctionReference validationFunction,
bytes4[] calldata selectors,
bytes calldata uninstallData,
bytes calldata preValidationHookUninstallData
) external;
Expand Down
20 changes: 6 additions & 14 deletions test/account/AccountLoupe.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,15 @@ contract AccountLoupeTest is AccountTestBase {
}
}

function test_pluginLoupe_getValidationFunctions() public {
FunctionReference[] memory validations = account1.getValidations(comprehensivePlugin.foo.selector);

assertEq(validations.length, 1);
assertEq(
FunctionReference.unwrap(validations[0]),
FunctionReference.unwrap(
FunctionReferenceLib.pack(
address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION)
)
)
function test_pluginLoupe_getSelectors() public {
FunctionReference comprehensivePluginValidation = FunctionReferenceLib.pack(
address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.VALIDATION)
);

validations = account1.getValidations(account1.execute.selector);
bytes4[] memory selectors = account1.getSelectors(comprehensivePluginValidation);

assertEq(validations.length, 1);
assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(_ownerValidation));
assertEq(selectors.length, 1);
assertEq(selectors[0], comprehensivePlugin.foo.selector);
}

function test_pluginLoupe_getExecutionHooks() public {
Expand Down
11 changes: 7 additions & 4 deletions test/account/MultiValidation.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ contract MultiValidationTest is AccountTestBase {
);
validations[1] =
FunctionReferenceLib.pack(address(validator2), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER));
FunctionReference[] memory validations2 = account1.getValidations(IStandardExecutor.execute.selector);
assertEq(validations2.length, 2);
assertEq(FunctionReference.unwrap(validations2[0]), FunctionReference.unwrap(validations[0]));
assertEq(FunctionReference.unwrap(validations2[1]), FunctionReference.unwrap(validations[1]));

bytes4[] memory selectors0 = account1.getSelectors(validations[0]);
bytes4[] memory selectors1 = account1.getSelectors(validations[1]);
assertEq(selectors0.length, selectors1.length);
for (uint256 i = 0; i < selectors0.length; i++) {
assertEq(selectors0[i], selectors1[i]);
}
}

function test_runtimeValidation_specify() public {
Expand Down

0 comments on commit 364e06b

Please sign in to comment.