diff --git a/contracts/plugins/assets/erc20/RewardableERC20.sol b/contracts/plugins/assets/erc20/RewardableERC20.sol index 58fd23855..6f0a606b6 100644 --- a/contracts/plugins/assets/erc20/RewardableERC20.sol +++ b/contracts/plugins/assets/erc20/RewardableERC20.sol @@ -7,6 +7,9 @@ import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import "../../../interfaces/IRewardable.sol"; +uint256 constant SHARES_BUFFER_DECIMALS = 9; // to prevent reward rounding issues + +/* /** * @title RewardableERC20 * @notice An abstract class that can be extended to create rewardable wrapper. @@ -35,7 +38,7 @@ abstract contract RewardableERC20 is IRewardable, ERC20, ReentrancyGuard { /// @dev Extending class must ensure ERC20 constructor is called constructor(IERC20 _rewardToken, uint8 _decimals) { rewardToken = _rewardToken; - one = 10**_decimals; // set via pass-in to prevent inheritance issues + one = 10**(_decimals + SHARES_BUFFER_DECIMALS); // set via pass-in to prevent inheritance issues } function claimRewards() external nonReentrant { diff --git a/test/plugins/RewardableERC20.test.ts b/test/plugins/RewardableERC20.test.ts index 55d71e569..0ab22f22c 100644 --- a/test/plugins/RewardableERC20.test.ts +++ b/test/plugins/RewardableERC20.test.ts @@ -18,6 +18,9 @@ import snapshotGasCost from '../utils/snapshotGasCost' import { formatUnits, parseUnits } from 'ethers/lib/utils' import { MAX_UINT256 } from '#/common/constants' +const SHARE_DECIMALS = 9 // decimals buffer for shares and rewards per share +const BN_SHARE_FACTOR = bn(10).pow(SHARE_DECIMALS) + type Fixture = () => Promise interface RewardableERC20Fixture { @@ -120,7 +123,7 @@ for (const wrapperName of wrapperNames) { describe(wrapperName, () => { // Decimals let shareDecimals: number - + let rewardShareDecimals: number // Assets let rewardableVault: RewardableERC20WrapperTest | RewardableERC4626VaultTest let rewardableAsset: ERC20MockRewarding @@ -152,7 +155,8 @@ for (const wrapperName of wrapperNames) { await rewardableAsset.mint(bob.address, initBalance) await rewardableAsset.connect(bob).approve(rewardableVault.address, initBalance) - shareDecimals = await rewardableVault.decimals() + shareDecimals = (await rewardableVault.decimals()) + SHARE_DECIMALS + rewardShareDecimals = rewardDecimals + SHARE_DECIMALS initShares = toShares(initBalance, assetDecimals, shareDecimals) oneShare = bn('1').mul(bn(10).pow(shareDecimals)) }) @@ -185,7 +189,9 @@ for (const wrapperName of wrapperNames) { expect(await rewardableVault.lastRewardsPerShare(alice.address)).to.equal(bn(0)) await rewardToken.mint(rewardableVault.address, parseUnits('10', rewardDecimals)) await rewardableVault.sync() - expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals)) + expect(await rewardableVault.rewardsPerShare()).to.equal( + parseUnits('1', rewardShareDecimals) + ) }) it('correctly handles reward tracking if supply is burned', async () => { @@ -196,7 +202,9 @@ for (const wrapperName of wrapperNames) { expect(await rewardableVault.lastRewardsPerShare(alice.address)).to.equal(bn(0)) await rewardToken.mint(rewardableVault.address, parseUnits('10', rewardDecimals)) await rewardableVault.sync() - expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals)) + expect(await rewardableVault.rewardsPerShare()).to.equal( + parseUnits('1', rewardShareDecimals) + ) // Setting supply to 0 await withdrawAll(rewardableVault.connect(alice)) @@ -215,7 +223,9 @@ for (const wrapperName of wrapperNames) { // Nothing updates.. as totalSupply as totalSupply is 0 await rewardableVault.sync() - expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals)) + expect(await rewardableVault.rewardsPerShare()).to.equal( + parseUnits('1', rewardShareDecimals) + ) await rewardableVault .connect(alice) .deposit(parseUnits('10', assetDecimals), alice.address) @@ -280,7 +290,9 @@ for (const wrapperName of wrapperNames) { }) it('alice shows correct balance', async () => { - expect(initShares.mul(3).div(8)).equal(await rewardableVault.balanceOf(alice.address)) + expect(initShares.mul(3).div(8).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(alice.address) + ) }) it('alice shows correct lastRewardsPerShare', async () => { @@ -288,7 +300,9 @@ for (const wrapperName of wrapperNames) { }) it('bob shows correct balance', async () => { - expect(initShares.div(8)).equal(await rewardableVault.balanceOf(bob.address)) + expect(initShares.div(8).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(bob.address) + ) }) it('bob shows correct lastRewardsPerShare', async () => { @@ -297,7 +311,9 @@ for (const wrapperName of wrapperNames) { it('rewardsPerShare is correct', async () => { // rewards / alice's deposit - expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4))) + expect(rewardsPerShare).equal( + rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR) + ) }) }) @@ -324,7 +340,9 @@ for (const wrapperName of wrapperNames) { it('alice shows correct lastRewardsPerShare', async () => { // rewards / alice's deposit - expect(initRewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4))) + expect(initRewardsPerShare).equal( + rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR) + ) expect(initRewardsPerShare).equal( await rewardableVault.lastRewardsPerShare(alice.address) ) @@ -335,6 +353,7 @@ for (const wrapperName of wrapperNames) { .mul(oneShare) .div(initShares.div(4)) .add(rewardAmount.mul(oneShare).div(initShares.div(2))) + .mul(BN_SHARE_FACTOR) expect(rewardsPerShare).equal(expectedRewardsPerShare) expect(rewardsPerShare).equal(await rewardableVault.lastRewardsPerShare(bob.address)) }) @@ -358,7 +377,9 @@ for (const wrapperName of wrapperNames) { it('rewardsPerShare is correct', async () => { // rewards / alice's deposit - expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4))) + expect(rewardsPerShare).equal( + rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR) + ) }) }) @@ -399,7 +420,9 @@ for (const wrapperName of wrapperNames) { it('rewardsPerShare is correct', async () => { // rewards / alice's deposit - expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4))) + expect(rewardsPerShare).equal( + rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR) + ) }) }) @@ -425,7 +448,9 @@ for (const wrapperName of wrapperNames) { }) it('bob shows correct balance', async () => { - expect(initShares.div(4)).equal(await rewardableVault.balanceOf(bob.address)) + expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(bob.address) + ) }) it('bob shows correct lastRewardsPerShare', async () => { @@ -434,7 +459,9 @@ for (const wrapperName of wrapperNames) { it('rewardsPerShare is correct', async () => { // rewards / alice's deposit - expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4))) + expect(rewardsPerShare).equal( + rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR) + ) }) }) @@ -454,7 +481,9 @@ for (const wrapperName of wrapperNames) { }) it('alice shows correct balance', async () => { - expect(initShares.div(4)).equal(await rewardableVault.balanceOf(alice.address)) + expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(alice.address) + ) }) it('alice has claimed rewards', async () => { @@ -466,7 +495,9 @@ for (const wrapperName of wrapperNames) { }) it('bob shows correct balance', async () => { - expect(initShares.div(8)).equal(await rewardableVault.balanceOf(bob.address)) + expect(initShares.div(8).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(bob.address) + ) }) it('bob shows correct lastRewardsPerShare', async () => { @@ -475,7 +506,9 @@ for (const wrapperName of wrapperNames) { it('rewardsPerShare is correct', async () => { // rewards / alice's deposit - expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4))) + expect(rewardsPerShare).equal( + rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR) + ) }) }) @@ -501,7 +534,9 @@ for (const wrapperName of wrapperNames) { }) it('alice shows correct balance', async () => { - expect(initShares.div(4)).equal(await rewardableVault.balanceOf(alice.address)) + expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(alice.address) + ) }) it('alice has claimed rewards', async () => { @@ -515,7 +550,9 @@ for (const wrapperName of wrapperNames) { }) it('bob shows correct balance', async () => { - expect(initShares.div(4)).equal(await rewardableVault.balanceOf(bob.address)) + expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(bob.address) + ) }) it('bob shows correct lastRewardsPerShare', async () => { @@ -532,6 +569,7 @@ for (const wrapperName of wrapperNames) { .mul(oneShare) .div(initShares.div(4)) .add(rewardAmount.mul(oneShare).div(initShares.div(2))) + .mul(BN_SHARE_FACTOR) expect(rewardsPerShare).equal(expectedRewardsPerShare) }) }) @@ -561,7 +599,9 @@ for (const wrapperName of wrapperNames) { }) it('alice shows correct balance', async () => { - expect(initShares.div(4)).equal(await rewardableVault.balanceOf(alice.address)) + expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(alice.address) + ) }) it('alice shows correct lastRewardsPerShare', async () => { @@ -573,7 +613,9 @@ for (const wrapperName of wrapperNames) { }) it('bob shows correct balance', async () => { - expect(initShares.div(4)).equal(await rewardableVault.balanceOf(bob.address)) + expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(bob.address) + ) }) it('bob shows correct lastRewardsPerShare', async () => { @@ -586,7 +628,9 @@ for (const wrapperName of wrapperNames) { it('rewardsPerShare is correct', async () => { // (rewards / alice's deposit) + (rewards / bob's deposit) - expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4)).mul(2)) + expect(rewardsPerShare).equal( + rewardAmount.mul(oneShare).div(initShares.div(4)).mul(2).mul(BN_SHARE_FACTOR) + ) }) }) @@ -597,7 +641,9 @@ for (const wrapperName of wrapperNames) { await rewardableVault.connect(alice).deposit(initBalance.div(4), alice.address) await rewardableAsset.accrueRewards(rewardAmount, rewardableVault.address) await rewardableVault.connect(bob).deposit(initBalance.div(4), bob.address) - await rewardableVault.connect(alice).transfer(bob.address, initShares.div(4)) + await rewardableVault + .connect(alice) + .transfer(bob.address, initShares.div(4).div(BN_SHARE_FACTOR)) await rewardableAsset.accrueRewards(rewardAmount, rewardableVault.address) await rewardableVault.connect(alice).deposit(initBalance.div(4), alice.address) await rewardableVault.connect(bob).claimRewards() @@ -607,7 +653,9 @@ for (const wrapperName of wrapperNames) { }) it('alice shows correct balance', async () => { - expect(initShares.div(4)).equal(await rewardableVault.balanceOf(alice.address)) + expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(alice.address) + ) }) it('alice shows correct lastRewardsPerShare', async () => { @@ -619,7 +667,9 @@ for (const wrapperName of wrapperNames) { }) it('bob shows correct balance', async () => { - expect(initShares.div(2)).equal(await rewardableVault.balanceOf(bob.address)) + expect(initShares.div(2).div(BN_SHARE_FACTOR)).equal( + await rewardableVault.balanceOf(bob.address) + ) }) it('bob shows correct lastRewardsPerShare', async () => { @@ -637,6 +687,7 @@ for (const wrapperName of wrapperNames) { .mul(oneShare) .div(initShares.div(4)) .add(rewardAmount.mul(oneShare).div(initShares.div(2))) + .mul(BN_SHARE_FACTOR) ) }) }) @@ -688,12 +739,70 @@ for (const wrapperName of wrapperNames) { for (let i = 0; i < 10; i++) { await rewardableAsset.accrueRewards(rewardAmount, rewardableVault.address) await rewardableVault.claimRewards() - - expect(await rewardableVault.rewardsPerShare()).to.equal(Math.floor(1.9 * (i + 1))) + expect(await rewardableVault.rewardsPerShare()).to.equal( + bn(`1.9e${SHARE_DECIMALS}`).mul(i + 1) + ) } }) }) + describe(`${wrapperName.replace('Test', '')} Special Case: Rounding - Regression test`, () => { + // Assets + let rewardableVault: RewardableERC20WrapperTest | RewardableERC4626VaultTest + let rewardableAsset: ERC20MockRewarding + let rewardToken: ERC20MockDecimals + // Main + let alice: Wallet + let bob: Wallet + + const initBalance = parseUnits('1000000', 18) + const rewardAmount = parseUnits('1.7', 6) + + const fixture = getFixture(18, 6) + + before('load wallets', async () => { + ;[alice, bob] = (await ethers.getSigners()) as unknown as Wallet[] + }) + + beforeEach(async () => { + // Deploy fixture + ;({ rewardableVault, rewardableAsset, rewardToken } = await loadFixture(fixture)) + + await rewardableAsset.mint(alice.address, initBalance) + await rewardableAsset.connect(alice).approve(rewardableVault.address, MAX_UINT256) + await rewardableAsset.mint(bob.address, initBalance) + await rewardableAsset.connect(bob).approve(rewardableVault.address, MAX_UINT256) + }) + + it('Avoids wrong distribution of rewards when rounding', async () => { + expect(await rewardToken.balanceOf(alice.address)).to.equal(bn(0)) + expect(await rewardToken.balanceOf(bob.address)).to.equal(bn(0)) + expect(await rewardableVault.rewardsPerShare()).to.equal(0) + + // alice deposit and accrue rewards + await rewardableVault.connect(alice).deposit(initBalance, alice.address) + await rewardableAsset.accrueRewards(rewardAmount, rewardableVault.address) + + // bob deposit + await rewardableVault.connect(bob).deposit(initBalance, bob.address) + + // accrue additional rewards (twice the amount) + await rewardableAsset.accrueRewards(rewardAmount.mul(2), rewardableVault.address) + + // claim all rewards + await rewardableVault.connect(bob).claimRewards() + await rewardableVault.connect(alice).claimRewards() + + // Alice got all first rewards plus half of the second + expect(await rewardToken.balanceOf(alice.address)).to.equal(bn(3.4e6)) + + // Bob only got half of the second rewards + expect(await rewardToken.balanceOf(bob.address)).to.equal(bn(1.7e6)) + + expect(await rewardableVault.rewardsPerShare()).to.equal(bn(`3.4e${SHARE_DECIMALS}`)) + }) + }) + const IMPLEMENTATION: Implementation = useEnv('PROTO_IMPL') == Implementation.P1.toString() ? Implementation.P1 : Implementation.P0