Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Unnecessary Non-Terminal Key checks #248

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions contracts/OperatorRewardsCollector.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { ISDUtilityPool, UserData, OperatorLiquidation } from "./interfaces/ISDU
import { ISDCollateral } from "./interfaces/SDCollateral/ISDCollateral.sol";
import { IWETH } from "./interfaces/IWETH.sol";
import { IStaderOracle } from "../contracts/interfaces/IStaderOracle.sol";
import { IPoolUtils } from "../contracts/interfaces/IPoolUtils.sol";

contract OperatorRewardsCollector is IOperatorRewardsCollector, AccessControlUpgradeable {
IStaderConfig public staderConfig;
Expand Down Expand Up @@ -52,10 +53,15 @@ contract OperatorRewardsCollector is IOperatorRewardsCollector, AccessControlUpg
* @dev This function first checks for any unpaid liquidations for the operator and repays them if necessary. Then, it transfers any remaining balance to the operator's reward address.
*/
function claim() external {
claimLiquidation(msg.sender);
uint256 amount = balances[msg.sender] > withdrawableInEth(msg.sender)
? withdrawableInEth(msg.sender)
: balances[msg.sender];
uint256 amount;
if (_isPermissionlessCaller(msg.sender)) {
claimLiquidation(msg.sender);
amount = balances[msg.sender] > withdrawableInEth(msg.sender)
? withdrawableInEth(msg.sender)
: balances[msg.sender];
} else {
amount = balances[msg.sender];
}
_claim(msg.sender, amount);
}

Expand All @@ -66,7 +72,12 @@ contract OperatorRewardsCollector is IOperatorRewardsCollector, AccessControlUpg
* @param _amount amount of ETH to claim
*/
function claimWithAmount(uint256 _amount) external {
claimLiquidation(msg.sender);
if (_isPermissionlessCaller(msg.sender)) {
claimLiquidation(msg.sender);
uint256 maxWithdrawableInEth = withdrawableInEth(msg.sender);
if (_amount > maxWithdrawableInEth) revert InsufficientBalance();
}
if (_amount > balances[msg.sender]) revert InsufficientBalance();
_claim(msg.sender, _amount);
}

Expand Down Expand Up @@ -113,6 +124,14 @@ contract OperatorRewardsCollector is IOperatorRewardsCollector, AccessControlUpg
return balances[operator];
}

function _isPermissionlessCaller(address caller) internal returns (bool) {
IPoolUtils poolUtils = IPoolUtils(staderConfig.getPoolUtils());
uint8 poolId = poolUtils.getOperatorPoolId(caller);
address permissionlessNodeRegistry = staderConfig.getPermissionlessNodeRegistry();

return INodeRegistry(permissionlessNodeRegistry).POOL_ID() == poolId;
}

/**
* @notice Completes any pending liquidation for an operator if exists.
* @dev Internal function to handle liquidation completion.
Expand Down Expand Up @@ -181,9 +200,6 @@ contract OperatorRewardsCollector is IOperatorRewardsCollector, AccessControlUpg
* @param amount The amount to be claimed.
*/
function _claim(address operator, uint256 amount) internal {
uint256 maxWithdrawableInEth = withdrawableInEth(operator);
if (amount > maxWithdrawableInEth || amount > balances[operator]) revert InsufficientBalance();

balances[operator] -= amount;

// If there's an amount to send, transfer it to the operator's rewards address
Expand Down
27 changes: 0 additions & 27 deletions contracts/PermissionedNodeRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import { Operator, Validator, INodeRegistry } from "./interfaces/INodeRegistry.sol";
import { IPermissionedPool } from "./interfaces/IPermissionedPool.sol";
import { IValidatorWithdrawalVault } from "./interfaces/IValidatorWithdrawalVault.sol";
import { ISDCollateral } from "./interfaces/SDCollateral/ISDCollateral.sol";

Check warning on line 18 in contracts/PermissionedNodeRegistry.sol

View workflow job for this annotation

GitHub Actions / Run linters

imported name ISDCollateral is not used
import { IPermissionedNodeRegistry } from "./interfaces/IPermissionedNodeRegistry.sol";

contract PermissionedNodeRegistry is

Check warning on line 21 in contracts/PermissionedNodeRegistry.sol

View workflow job for this annotation

GitHub Actions / Run linters

Contract has 19 states declarations but allowed no more than 15
INodeRegistry,
IPermissionedNodeRegistry,
AccessControlUpgradeable,
Expand Down Expand Up @@ -432,17 +432,6 @@
emit UpdatedOperatorName(msg.sender, _operatorName);
}

/**
* @notice update the maximum non terminal key limit per operator
* @dev only `MANAGER` role can call
* @param _maxNonTerminalKeyPerOperator updated maximum non terminal key per operator limit
*/
function updateMaxNonTerminalKeyPerOperator(uint64 _maxNonTerminalKeyPerOperator) external override {
UtilLib.onlyManagerRole(msg.sender, staderConfig);
maxNonTerminalKeyPerOperator = _maxNonTerminalKeyPerOperator;
emit UpdatedMaxNonTerminalKeyPerOperator(maxNonTerminalKeyPerOperator);
}

/**
* @notice update number of validator keys that can be added in a single tx by the operator
* @dev only `OPERATOR` role can call
Expand Down Expand Up @@ -624,7 +613,7 @@
}
}
// If the result array isn't full, resize it to remove the unused elements
assembly {

Check warning on line 616 in contracts/PermissionedNodeRegistry.sol

View workflow job for this annotation

GitHub Actions / Run linters

Avoid to use inline assembly. It is acceptable only in rare cases
mstore(validators, validatorCount)
}

Expand Down Expand Up @@ -720,22 +709,6 @@
revert InvalidKeyCount();
}
totalKeys = getOperatorTotalKeys(_operatorId);
uint256 totalNonTerminalKeys = getOperatorTotalNonTerminalKeys(msg.sender, 0, totalKeys);
if ((totalNonTerminalKeys + keyCount) > maxNonTerminalKeyPerOperator) {
revert MaxKeyLimitReached();
}

//checks if operator has enough SD collateral for adding `keyCount` keys
//SD threshold for permissioned NOs is 0 for phase1
if (
!ISDCollateral(staderConfig.getSDCollateral()).hasEnoughSDCollateral(
msg.sender,
POOL_ID,
totalNonTerminalKeys + keyCount
)
) {
revert NotEnoughSDCollateral();
}
}

// operator in active state
Expand Down
15 changes: 15 additions & 0 deletions contracts/PermissionlessNodeRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import { IPermissionlessNodeRegistry } from "./interfaces/IPermissionlessNodeRegistry.sol";
import { IOperatorRewardsCollector } from "./interfaces/IOperatorRewardsCollector.sol";

contract PermissionlessNodeRegistry is

Check warning on line 24 in contracts/PermissionlessNodeRegistry.sol

View workflow job for this annotation

GitHub Actions / Run linters

Contract has 19 states declarations but allowed no more than 15
INodeRegistry,
IPermissionlessNodeRegistry,
AccessControlUpgradeable,
Expand Down Expand Up @@ -59,6 +59,7 @@
//mapping of operator address with nodeELReward vault address
mapping(uint256 => address) public override nodeELRewardVaultByOperatorId;
mapping(uint256 => address) public proposedRewardAddressByOperatorId;
uint256 public maxKeysPerOperator;

/// @custom:oz-upgrades-unsafe-allow constructor
constructor() {
Expand Down Expand Up @@ -150,6 +151,10 @@
) public payable override nonReentrant whenNotPaused {
uint256 operatorId = onlyActiveOperator(msg.sender);
uint256 keyCount = _pubkey.length;
uint256 totalKeys = getOperatorTotalKeys(operatorId);
if (totalKeys + keyCount > maxKeysPerOperator) {
revert MaxKeyLimitExceed();
}
if (keyCount != _preDepositSignature.length || keyCount != _depositSignature.length) {
revert MisMatchingInputKeysSize();
}
Expand Down Expand Up @@ -444,6 +449,16 @@
emit TransferredCollateralToPool(_amount);
}

/**
* @notice update the max validator per operator value
* @dev only `MANAGER` role can update
*/
function updateMaxKeysPerOperator(uint256 _maxKeysPerOperator) external {
UtilLib.onlyManagerRole(msg.sender, staderConfig);
maxKeysPerOperator = _maxKeysPerOperator;
emit UpdateMaxKeysPerOperator(_maxKeysPerOperator);
}

/**
* @param _nodeOperator @notice operator total non terminal keys within a specified validator list
* @param _startIndex start index in validator queue to start with
Expand Down Expand Up @@ -555,7 +570,7 @@
}
}
// If the result array isn't full, resize it to remove the unused elements
assembly {

Check warning on line 573 in contracts/PermissionlessNodeRegistry.sol

View workflow job for this annotation

GitHub Actions / Run linters

Avoid to use inline assembly. It is acceptable only in rare cases
mstore(validators, validatorCount)
}

Expand Down
2 changes: 0 additions & 2 deletions contracts/interfaces/IPermissionedNodeRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ interface IPermissionedNodeRegistry {

function markValidatorStatusAsPreDeposit(bytes calldata _pubkey) external;

function updateMaxNonTerminalKeyPerOperator(uint64 _maxNonTerminalKeyPerOperator) external;

function updateInputKeyCountLimit(uint16 _inputKeyCountLimit) external;

function proposeRewardAddress(address _operatorAddress, address _newRewardAddress) external;
Expand Down
2 changes: 2 additions & 0 deletions contracts/interfaces/IPermissionlessNodeRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ interface IPermissionlessNodeRegistry {
error InSufficientBalance();
error CooldownNotComplete();
error NoChangeInState();
error MaxKeyLimitExceed();

// Events
event OnboardedOperator(
Expand All @@ -21,6 +22,7 @@ interface IPermissionlessNodeRegistry {
event UpdatedSocializingPoolState(uint256 operatorId, bool optedForSocializingPool, uint256 block);
event TransferredCollateralToPool(uint256 amount);
event ValidatorAddedViaReferral(uint256 amount, string referralId);
event UpdateMaxKeysPerOperator(uint256 maxKeysPerOperator);

//Getters

Expand Down
56 changes: 56 additions & 0 deletions test/fork/reward-claim-optimisation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { ethers, network } from "hardhat";
import "dotenv/config";
import { impersonateAccount, setBalance } from "@nomicfoundation/hardhat-network-helpers";

const PROXY_OWNER = "0x1112D5C55670Cb5144BF36114C20a122908068B9"
const PROXY_ADMIN = "0x67B12264Ca3e0037Fc7E22F2457b42643a04C86e";
const OPERATOR_REWARDS_COLLECTOR_ADDRESS = "0x84ffDC9De310144D889540A49052F6d1AdB2C335";
const OPERATOR = "0xb851788Fa34B0d9215F54531061D4e2e06A74AEE"

async function setForkBlock(blockNumber: number) {
await network.provider.request({
method: "hardhat_reset",
params: [
{
forking: {
jsonRpcUrl: process.env.PROVIDER_URL_MAINNET,
blockNumber: blockNumber,
},
},
],
});
}

async function configureNewContract (contractName: String, contractAddress: String) {
await setBalance(PROXY_OWNER, ethers.parseEther("1"))
await impersonateAccount(PROXY_OWNER)

const impersonatedProxyOwner = await ethers.getSigner(PROXY_OWNER);

const contractFactory = await ethers.getContractFactory(contractName);
const contractImpl = await contractFactory.deploy();
console.log(`${contractName} Implementation deployed to:`, await contractImpl.getAddress());

const proxyAdminContract = await ethers.getContractAt("ProxyAdmin", PROXY_ADMIN);
await proxyAdminContract.connect(impersonatedProxyOwner).upgrade(contractAddress, await contractImpl.getAddress());

const contract = await ethers.getContractAt(contractName, contractAddress)

return contract;
}

describe("Gas Coverage", function () {
it("should consume less gas after upgrade", async () => {
await setForkBlock(21270988);
await setBalance(OPERATOR, ethers.parseEther("100"))
await impersonateAccount(OPERATOR);
const impersonatedOperator = await ethers.getSigner(OPERATOR);

const newOperatorRewardsCollector = await configureNewContract("OperatorRewardsCollector", OPERATOR_REWARDS_COLLECTOR_ADDRESS);

// Firing a txn with updated contracts
const claimTxn = await newOperatorRewardsCollector.connect(impersonatedOperator).claim();
const claimTxnReceipt = await claimTxn.wait();
console.log("Claim Txn Gas Estimate:", claimTxnReceipt.gasUsed);
});
});
10 changes: 10 additions & 0 deletions test/foundry_tests/NodeELRewardVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { PoolUtilsMock } from "../mocks/PoolUtilsMock.sol";
import { StakePoolManagerMock } from "../mocks/StakePoolManagerMock.sol";
import { StaderOracleMock } from "../mocks/StaderOracleMock.sol";
import { SDUtilityPoolMock } from "../mocks/SDUtilityPoolMock.sol";
import { PermissionlessNodeRegistryMock } from "../mocks/PermissionlessNodeRegistryMock.sol";

contract NodeELRewardVaultTest is Test {
address private constant OPERATOR_ADDRESSS = address(500);
Expand Down Expand Up @@ -59,6 +60,7 @@ contract NodeELRewardVaultTest is Test {
address operator = OPERATOR_ADDRESSS;
mockStaderOracle(staderOracleMock);
mockSdUtilityPool(sdUtilityPoolMock, operator);
mockPermissionlessNodeRegistry(vm.addr(105));

OperatorRewardsCollector operatorRCImpl = new OperatorRewardsCollector();
TransparentUpgradeableProxy operatorRCProxy = new TransparentUpgradeableProxy(
Expand Down Expand Up @@ -95,6 +97,7 @@ contract NodeELRewardVaultTest is Test {
staderConfig.updateSDCollateral(address(sdCollateral));
staderConfig.updateSDUtilityPool(sdUtilityPoolMock);
staderConfig.updateStaderOracle(staderOracleMock);
staderConfig.updatePermissionlessNodeRegistry(vm.addr(105));
staderConfig.grantRole(staderConfig.MANAGER(), staderManager);
vaultFactory.grantRole(vaultFactory.NODE_REGISTRY_CONTRACT(), address(poolUtils.nodeRegistry()));
vm.stopPrank();
Expand Down Expand Up @@ -265,4 +268,11 @@ contract NodeELRewardVaultTest is Test {
bytes memory mockCode = address(implementation).code;
vm.etch(staderOracleMock, mockCode);
}

function mockPermissionlessNodeRegistry(address _permissionlessNodeRegistry) private {
emit log_named_address("permissionlessNodeRegistry", _permissionlessNodeRegistry);
PermissionlessNodeRegistryMock nodeRegistryMock = new PermissionlessNodeRegistryMock();
bytes memory mockCode = address(nodeRegistryMock).code;
vm.etch(_permissionlessNodeRegistry, mockCode);
}
}
108 changes: 108 additions & 0 deletions test/foundry_tests/OperatorRewardsCollector.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ contract OperatorRewardsCollectorTest is Test {
event SDWithdrawn(address indexed operator, uint256 sdAmount);
event SDRepaid(address operator, uint256 repayAmount);

error InsufficientBalance();

address staderAdmin;
address staderManager;
address staderTreasury;
Expand Down Expand Up @@ -173,6 +175,112 @@ contract OperatorRewardsCollectorTest is Test {
vm.stopPrank();
}

function test_Claim_PermissionedPoolOperator(uint256 amount) public {
vm.assume(amount < 100000 ether);

operatorRewardsCollector.depositFor{ value: amount }(staderManager);
assertEq(operatorRewardsCollector.balances(staderManager), amount);
vm.mockCall(
address(staderOracle),
abi.encodeWithSelector(IStaderOracle.getSDPriceInETH.selector),
abi.encode(1e14)
);

vm.mockCall(
address(poolUtils),
abi.encodeWithSelector(IPoolUtils.getOperatorPoolId.selector, staderManager),
abi.encode(uint8(2)) // Assuming `POOL_ID()` for PermissionedNodeRegistry is `2`
);

vm.mockCall(
address(permissionlessNodeRegistryMock),
abi.encodeWithSelector(INodeRegistry.POOL_ID.selector),
abi.encode(uint8(1))
);

vm.startPrank(staderManager);
operatorRewardsCollector.claim();

assertEq(operatorRewardsCollector.balances(staderManager), 0);
vm.stopPrank();
}

function test_ClaimWithAmount_PermissionlessNodeRegistry() public {
uint256 amount = 10 ether; // Deposit amount

// Deposit funds into the contract
operatorRewardsCollector.depositFor{ value: amount }(staderManager);
assertEq(operatorRewardsCollector.balances(staderManager), amount);

uint256 withdrawableAmount = operatorRewardsCollector.withdrawableInEth(staderManager); // Withdrawable amount

// Mock the operator's pool ID to match the PermissionlessPool
vm.mockCall(
address(poolUtils),
abi.encodeWithSelector(IPoolUtils.getOperatorPoolId.selector, staderManager),
abi.encode(uint8(1)) // PermissionlessPool
);

// Mock Permissionless Node Registry POOL_ID
vm.mockCall(
address(permissionlessNodeRegistryMock),
abi.encodeWithSelector(INodeRegistry.POOL_ID.selector),
abi.encode(uint8(1)) // PermissionlessNodeRegistry
);

vm.startPrank(staderManager);

// Case 1: Claiming more than `withdrawableInEth`
vm.expectRevert(InsufficientBalance.selector);
operatorRewardsCollector.claimWithAmount(withdrawableAmount + 1 ether);

// Case 2: Claiming more than `balances[msg.sender]`
vm.expectRevert(InsufficientBalance.selector);
operatorRewardsCollector.claimWithAmount(amount + 1 ether);

// Case 3: Valid claim within limits
uint256 validClaim = 3 ether;
operatorRewardsCollector.claimWithAmount(validClaim);
assertEq(operatorRewardsCollector.balances(staderManager), amount - validClaim);

vm.stopPrank();
}

function test_ClaimWithAmount_PermissionedNodeRegistry(uint256 amount, uint256 claimAmount) public {
// Assume reasonable values for deposits and claims
vm.assume(amount > 0 && amount < 100000 ether);

operatorRewardsCollector.depositFor{ value: amount }(staderManager);
assertEq(operatorRewardsCollector.balances(staderManager), amount);

// Mock the operator's pool ID to match the PermissionedPool
vm.mockCall(
address(poolUtils),
abi.encodeWithSelector(IPoolUtils.getOperatorPoolId.selector, staderManager),
abi.encode(uint8(2))
);

vm.mockCall(
address(permissionlessNodeRegistryMock),
abi.encodeWithSelector(INodeRegistry.POOL_ID.selector),
abi.encode(uint8(1))
);

vm.startPrank(staderManager);

if (claimAmount > amount) {
// Case: Attempting to claim more than balance, expect revert
vm.expectRevert(InsufficientBalance.selector);
operatorRewardsCollector.claimWithAmount(claimAmount);
} else {
// Case: Claiming valid amount
operatorRewardsCollector.claimWithAmount(claimAmount);
assertEq(operatorRewardsCollector.balances(staderManager), amount - claimAmount);
}

vm.stopPrank();
}

function test_claimLiquidationZeroAmount(uint256 amount) public {
vm.assume(amount < 100000 ether);

Expand Down
Loading
Loading