Skip to content

Commit

Permalink
add buffer for decimals in shares calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
julianmrodri committed Oct 26, 2023
1 parent a7f76a7 commit ca53604
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 27 deletions.
5 changes: 4 additions & 1 deletion contracts/plugins/assets/erc20/RewardableERC20.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import "@openzeppelin/contracts/token/ERC20/ERC20.sol";
import "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import "../../../interfaces/IRewardable.sol";

uint256 constant SHARES_BUFFER_DECIMALS = 9; // to prevent reward rounding issues

/*
/**
* @title RewardableERC20
* @notice An abstract class that can be extended to create rewardable wrapper.
Expand Down Expand Up @@ -35,7 +38,7 @@ abstract contract RewardableERC20 is IRewardable, ERC20, ReentrancyGuard {
/// @dev Extending class must ensure ERC20 constructor is called
constructor(IERC20 _rewardToken, uint8 _decimals) {
rewardToken = _rewardToken;
one = 10**_decimals; // set via pass-in to prevent inheritance issues
one = 10**(_decimals + SHARES_BUFFER_DECIMALS); // set via pass-in to prevent inheritance issues
}

function claimRewards() external nonReentrant {
Expand Down
161 changes: 135 additions & 26 deletions test/plugins/RewardableERC20.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import snapshotGasCost from '../utils/snapshotGasCost'
import { formatUnits, parseUnits } from 'ethers/lib/utils'
import { MAX_UINT256 } from '#/common/constants'

const SHARE_DECIMALS = 9 // decimals buffer for shares and rewards per share
const BN_SHARE_FACTOR = bn(10).pow(SHARE_DECIMALS)

type Fixture<T> = () => Promise<T>

interface RewardableERC20Fixture {
Expand Down Expand Up @@ -120,7 +123,7 @@ for (const wrapperName of wrapperNames) {
describe(wrapperName, () => {
// Decimals
let shareDecimals: number

let rewardShareDecimals: number
// Assets
let rewardableVault: RewardableERC20WrapperTest | RewardableERC4626VaultTest
let rewardableAsset: ERC20MockRewarding
Expand Down Expand Up @@ -152,7 +155,8 @@ for (const wrapperName of wrapperNames) {
await rewardableAsset.mint(bob.address, initBalance)
await rewardableAsset.connect(bob).approve(rewardableVault.address, initBalance)

shareDecimals = await rewardableVault.decimals()
shareDecimals = (await rewardableVault.decimals()) + SHARE_DECIMALS
rewardShareDecimals = rewardDecimals + SHARE_DECIMALS
initShares = toShares(initBalance, assetDecimals, shareDecimals)
oneShare = bn('1').mul(bn(10).pow(shareDecimals))
})
Expand Down Expand Up @@ -185,7 +189,9 @@ for (const wrapperName of wrapperNames) {
expect(await rewardableVault.lastRewardsPerShare(alice.address)).to.equal(bn(0))
await rewardToken.mint(rewardableVault.address, parseUnits('10', rewardDecimals))
await rewardableVault.sync()
expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals))
expect(await rewardableVault.rewardsPerShare()).to.equal(
parseUnits('1', rewardShareDecimals)
)
})

it('correctly handles reward tracking if supply is burned', async () => {
Expand All @@ -196,7 +202,9 @@ for (const wrapperName of wrapperNames) {
expect(await rewardableVault.lastRewardsPerShare(alice.address)).to.equal(bn(0))
await rewardToken.mint(rewardableVault.address, parseUnits('10', rewardDecimals))
await rewardableVault.sync()
expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals))
expect(await rewardableVault.rewardsPerShare()).to.equal(
parseUnits('1', rewardShareDecimals)
)

// Setting supply to 0
await withdrawAll(rewardableVault.connect(alice))
Expand All @@ -215,7 +223,9 @@ for (const wrapperName of wrapperNames) {

// Nothing updates.. as totalSupply as totalSupply is 0
await rewardableVault.sync()
expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals))
expect(await rewardableVault.rewardsPerShare()).to.equal(
parseUnits('1', rewardShareDecimals)
)
await rewardableVault
.connect(alice)
.deposit(parseUnits('10', assetDecimals), alice.address)
Expand Down Expand Up @@ -280,15 +290,19 @@ for (const wrapperName of wrapperNames) {
})

it('alice shows correct balance', async () => {
expect(initShares.mul(3).div(8)).equal(await rewardableVault.balanceOf(alice.address))
expect(initShares.mul(3).div(8).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(alice.address)
)
})

it('alice shows correct lastRewardsPerShare', async () => {
expect(rewardsPerShare).equal(await rewardableVault.lastRewardsPerShare(alice.address))
})

it('bob shows correct balance', async () => {
expect(initShares.div(8)).equal(await rewardableVault.balanceOf(bob.address))
expect(initShares.div(8).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(bob.address)
)
})

it('bob shows correct lastRewardsPerShare', async () => {
Expand All @@ -297,7 +311,9 @@ for (const wrapperName of wrapperNames) {

it('rewardsPerShare is correct', async () => {
// rewards / alice's deposit
expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4)))
expect(rewardsPerShare).equal(
rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR)
)
})
})

Expand All @@ -324,7 +340,9 @@ for (const wrapperName of wrapperNames) {

it('alice shows correct lastRewardsPerShare', async () => {
// rewards / alice's deposit
expect(initRewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4)))
expect(initRewardsPerShare).equal(
rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR)
)
expect(initRewardsPerShare).equal(
await rewardableVault.lastRewardsPerShare(alice.address)
)
Expand All @@ -335,6 +353,7 @@ for (const wrapperName of wrapperNames) {
.mul(oneShare)
.div(initShares.div(4))
.add(rewardAmount.mul(oneShare).div(initShares.div(2)))
.mul(BN_SHARE_FACTOR)
expect(rewardsPerShare).equal(expectedRewardsPerShare)
expect(rewardsPerShare).equal(await rewardableVault.lastRewardsPerShare(bob.address))
})
Expand All @@ -358,7 +377,9 @@ for (const wrapperName of wrapperNames) {

it('rewardsPerShare is correct', async () => {
// rewards / alice's deposit
expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4)))
expect(rewardsPerShare).equal(
rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR)
)
})
})

Expand Down Expand Up @@ -399,7 +420,9 @@ for (const wrapperName of wrapperNames) {

it('rewardsPerShare is correct', async () => {
// rewards / alice's deposit
expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4)))
expect(rewardsPerShare).equal(
rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR)
)
})
})

Expand All @@ -425,7 +448,9 @@ for (const wrapperName of wrapperNames) {
})

it('bob shows correct balance', async () => {
expect(initShares.div(4)).equal(await rewardableVault.balanceOf(bob.address))
expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(bob.address)
)
})

it('bob shows correct lastRewardsPerShare', async () => {
Expand All @@ -434,7 +459,9 @@ for (const wrapperName of wrapperNames) {

it('rewardsPerShare is correct', async () => {
// rewards / alice's deposit
expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4)))
expect(rewardsPerShare).equal(
rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR)
)
})
})

Expand All @@ -454,7 +481,9 @@ for (const wrapperName of wrapperNames) {
})

it('alice shows correct balance', async () => {
expect(initShares.div(4)).equal(await rewardableVault.balanceOf(alice.address))
expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(alice.address)
)
})

it('alice has claimed rewards', async () => {
Expand All @@ -466,7 +495,9 @@ for (const wrapperName of wrapperNames) {
})

it('bob shows correct balance', async () => {
expect(initShares.div(8)).equal(await rewardableVault.balanceOf(bob.address))
expect(initShares.div(8).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(bob.address)
)
})

it('bob shows correct lastRewardsPerShare', async () => {
Expand All @@ -475,7 +506,9 @@ for (const wrapperName of wrapperNames) {

it('rewardsPerShare is correct', async () => {
// rewards / alice's deposit
expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4)))
expect(rewardsPerShare).equal(
rewardAmount.mul(oneShare).div(initShares.div(4)).mul(BN_SHARE_FACTOR)
)
})
})

Expand All @@ -501,7 +534,9 @@ for (const wrapperName of wrapperNames) {
})

it('alice shows correct balance', async () => {
expect(initShares.div(4)).equal(await rewardableVault.balanceOf(alice.address))
expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(alice.address)
)
})

it('alice has claimed rewards', async () => {
Expand All @@ -515,7 +550,9 @@ for (const wrapperName of wrapperNames) {
})

it('bob shows correct balance', async () => {
expect(initShares.div(4)).equal(await rewardableVault.balanceOf(bob.address))
expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(bob.address)
)
})

it('bob shows correct lastRewardsPerShare', async () => {
Expand All @@ -532,6 +569,7 @@ for (const wrapperName of wrapperNames) {
.mul(oneShare)
.div(initShares.div(4))
.add(rewardAmount.mul(oneShare).div(initShares.div(2)))
.mul(BN_SHARE_FACTOR)
expect(rewardsPerShare).equal(expectedRewardsPerShare)
})
})
Expand Down Expand Up @@ -561,7 +599,9 @@ for (const wrapperName of wrapperNames) {
})

it('alice shows correct balance', async () => {
expect(initShares.div(4)).equal(await rewardableVault.balanceOf(alice.address))
expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(alice.address)
)
})

it('alice shows correct lastRewardsPerShare', async () => {
Expand All @@ -573,7 +613,9 @@ for (const wrapperName of wrapperNames) {
})

it('bob shows correct balance', async () => {
expect(initShares.div(4)).equal(await rewardableVault.balanceOf(bob.address))
expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(bob.address)
)
})

it('bob shows correct lastRewardsPerShare', async () => {
Expand All @@ -586,7 +628,9 @@ for (const wrapperName of wrapperNames) {

it('rewardsPerShare is correct', async () => {
// (rewards / alice's deposit) + (rewards / bob's deposit)
expect(rewardsPerShare).equal(rewardAmount.mul(oneShare).div(initShares.div(4)).mul(2))
expect(rewardsPerShare).equal(
rewardAmount.mul(oneShare).div(initShares.div(4)).mul(2).mul(BN_SHARE_FACTOR)
)
})
})

Expand All @@ -597,7 +641,9 @@ for (const wrapperName of wrapperNames) {
await rewardableVault.connect(alice).deposit(initBalance.div(4), alice.address)
await rewardableAsset.accrueRewards(rewardAmount, rewardableVault.address)
await rewardableVault.connect(bob).deposit(initBalance.div(4), bob.address)
await rewardableVault.connect(alice).transfer(bob.address, initShares.div(4))
await rewardableVault
.connect(alice)
.transfer(bob.address, initShares.div(4).div(BN_SHARE_FACTOR))
await rewardableAsset.accrueRewards(rewardAmount, rewardableVault.address)
await rewardableVault.connect(alice).deposit(initBalance.div(4), alice.address)
await rewardableVault.connect(bob).claimRewards()
Expand All @@ -607,7 +653,9 @@ for (const wrapperName of wrapperNames) {
})

it('alice shows correct balance', async () => {
expect(initShares.div(4)).equal(await rewardableVault.balanceOf(alice.address))
expect(initShares.div(4).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(alice.address)
)
})

it('alice shows correct lastRewardsPerShare', async () => {
Expand All @@ -619,7 +667,9 @@ for (const wrapperName of wrapperNames) {
})

it('bob shows correct balance', async () => {
expect(initShares.div(2)).equal(await rewardableVault.balanceOf(bob.address))
expect(initShares.div(2).div(BN_SHARE_FACTOR)).equal(
await rewardableVault.balanceOf(bob.address)
)
})

it('bob shows correct lastRewardsPerShare', async () => {
Expand All @@ -637,6 +687,7 @@ for (const wrapperName of wrapperNames) {
.mul(oneShare)
.div(initShares.div(4))
.add(rewardAmount.mul(oneShare).div(initShares.div(2)))
.mul(BN_SHARE_FACTOR)
)
})
})
Expand Down Expand Up @@ -688,12 +739,70 @@ for (const wrapperName of wrapperNames) {
for (let i = 0; i < 10; i++) {
await rewardableAsset.accrueRewards(rewardAmount, rewardableVault.address)
await rewardableVault.claimRewards()

expect(await rewardableVault.rewardsPerShare()).to.equal(Math.floor(1.9 * (i + 1)))
expect(await rewardableVault.rewardsPerShare()).to.equal(
bn(`1.9e${SHARE_DECIMALS}`).mul(i + 1)
)
}
})
})

describe(`${wrapperName.replace('Test', '')} Special Case: Rounding - Regression test`, () => {
// Assets
let rewardableVault: RewardableERC20WrapperTest | RewardableERC4626VaultTest
let rewardableAsset: ERC20MockRewarding
let rewardToken: ERC20MockDecimals
// Main
let alice: Wallet
let bob: Wallet

const initBalance = parseUnits('1000000', 18)
const rewardAmount = parseUnits('1.7', 6)

const fixture = getFixture(18, 6)

before('load wallets', async () => {
;[alice, bob] = (await ethers.getSigners()) as unknown as Wallet[]
})

beforeEach(async () => {
// Deploy fixture
;({ rewardableVault, rewardableAsset, rewardToken } = await loadFixture(fixture))

await rewardableAsset.mint(alice.address, initBalance)
await rewardableAsset.connect(alice).approve(rewardableVault.address, MAX_UINT256)
await rewardableAsset.mint(bob.address, initBalance)
await rewardableAsset.connect(bob).approve(rewardableVault.address, MAX_UINT256)
})

it('Avoids wrong distribution of rewards when rounding', async () => {
expect(await rewardToken.balanceOf(alice.address)).to.equal(bn(0))
expect(await rewardToken.balanceOf(bob.address)).to.equal(bn(0))
expect(await rewardableVault.rewardsPerShare()).to.equal(0)

// alice deposit and accrue rewards
await rewardableVault.connect(alice).deposit(initBalance, alice.address)
await rewardableAsset.accrueRewards(rewardAmount, rewardableVault.address)

// bob deposit
await rewardableVault.connect(bob).deposit(initBalance, bob.address)

// accrue additional rewards (twice the amount)
await rewardableAsset.accrueRewards(rewardAmount.mul(2), rewardableVault.address)

// claim all rewards
await rewardableVault.connect(bob).claimRewards()
await rewardableVault.connect(alice).claimRewards()

// Alice got all first rewards plus half of the second
expect(await rewardToken.balanceOf(alice.address)).to.equal(bn(3.4e6))

// Bob only got half of the second rewards
expect(await rewardToken.balanceOf(bob.address)).to.equal(bn(1.7e6))

expect(await rewardableVault.rewardsPerShare()).to.equal(bn(`3.4e${SHARE_DECIMALS}`))
})
})

const IMPLEMENTATION: Implementation =
useEnv('PROTO_IMPL') == Implementation.P1.toString() ? Implementation.P1 : Implementation.P0

Expand Down

0 comments on commit ca53604

Please sign in to comment.