diff --git a/core/src/epoch_snapshot.rs b/core/src/epoch_snapshot.rs index 6427801..34bc67c 100644 --- a/core/src/epoch_snapshot.rs +++ b/core/src/epoch_snapshot.rs @@ -1,6 +1,6 @@ use bytemuck::{Pod, Zeroable}; use jito_bytemuck::{ - types::{PodBool, PodU128, PodU16, PodU64}, + types::{PodBool, PodU16, PodU64}, AccountDeserialize, Discriminator, }; use jito_vault_core::vault_operator_delegation::VaultOperatorDelegation; @@ -9,7 +9,8 @@ use solana_program::{account_info::AccountInfo, msg, program_error::ProgramError use spl_math::precise_number::PreciseNumber; use crate::{ - discriminators::Discriminators, error::TipRouterError, fees::Fees, weight_table::WeightTable, + discriminators::Discriminators, error::TipRouterError, fees::Fees, stake_weight::StakeWeight, + weight_table::WeightTable, }; // PDA'd ["epoch_snapshot", NCN, NCN_EPOCH_SLOT] @@ -37,9 +38,7 @@ pub struct EpochSnapshot { operators_registered: PodU64, valid_operator_vault_delegations: PodU64, - /// Counted as each delegate gets added - ///TODO What happens if `finalized() && total_votes() == 0`? - stake_weight: PodU128, + stake_weight: StakeWeight, /// Reserved space reserved: [u8; 128], @@ -70,7 +69,7 @@ impl EpochSnapshot { vault_count: PodU64::from(vault_count), operators_registered: PodU64::from(0), valid_operator_vault_delegations: PodU64::from(0), - stake_weight: PodU128::from(0), + stake_weight: StakeWeight::default(), reserved: [0; 128], } } @@ -147,8 +146,8 @@ impl EpochSnapshot { self.valid_operator_vault_delegations.into() } - pub fn stake_weight(&self) -> u128 { - self.stake_weight.into() + pub const fn stake_weight(&self) -> &StakeWeight { + &self.stake_weight } pub fn finalized(&self) -> bool { @@ -159,7 +158,7 @@ impl EpochSnapshot { &mut self, current_slot: u64, vault_operator_delegations: u64, - stake_weight: u128, + stake_weight: &StakeWeight, ) -> Result<(), TipRouterError> { if self.finalized() { return Err(TipRouterError::OperatorFinalized); @@ -177,11 +176,7 @@ impl EpochSnapshot { .ok_or(TipRouterError::ArithmeticOverflow)?, ); - self.stake_weight = PodU128::from( - self.stake_weight() - .checked_add(stake_weight) - .ok_or(TipRouterError::ArithmeticOverflow)?, - ); + self.stake_weight.increment(stake_weight)?; if self.finalized() { self.slot_finalized = PodU64::from(current_slot); @@ -213,7 +208,7 @@ pub struct OperatorSnapshot { vault_operator_delegations_registered: PodU64, valid_operator_vault_delegations: PodU64, - stake_weight: PodU128, + stake_weight: StakeWeight, reserved: [u8; 256], //TODO change to 64 @@ -224,8 +219,8 @@ pub struct OperatorSnapshot { #[repr(C)] pub struct VaultOperatorStakeWeight { vault: Pubkey, - stake_weight: PodU128, vault_index: PodU64, + stake_weight: StakeWeight, reserved: [u8; 32], } @@ -234,18 +229,18 @@ impl Default for VaultOperatorStakeWeight { Self { vault: Pubkey::default(), vault_index: PodU64::from(u64::MAX), - stake_weight: PodU128::from(0), + stake_weight: StakeWeight::default(), reserved: [0; 32], } } } impl VaultOperatorStakeWeight { - pub fn new(vault: Pubkey, stake_weight: u128, vault_index: u64) -> Self { + pub fn new(vault: Pubkey, vault_index: u64, stake_weight: &StakeWeight) -> Self { Self { vault, vault_index: PodU64::from(vault_index), - stake_weight: PodU128::from(stake_weight), + stake_weight: *stake_weight, reserved: [0; 32], } } @@ -258,8 +253,8 @@ impl VaultOperatorStakeWeight { self.vault_index.into() } - pub fn stake_weight(&self) -> u128 { - self.stake_weight.into() + pub fn stake_weight(&self) -> &StakeWeight { + &self.stake_weight } } @@ -301,7 +296,7 @@ impl OperatorSnapshot { vault_operator_delegation_count: PodU64::from(vault_operator_delegation_count), vault_operator_delegations_registered: PodU64::from(0), valid_operator_vault_delegations: PodU64::from(0), - stake_weight: PodU128::from(0), + stake_weight: StakeWeight::default(), reserved: [0; 256], vault_operator_stake_weight: [VaultOperatorStakeWeight::default(); 32], }) @@ -430,8 +425,8 @@ impl OperatorSnapshot { self.valid_operator_vault_delegations.into() } - pub fn stake_weight(&self) -> u128 { - self.stake_weight.into() + pub fn stake_weight(&self) -> &StakeWeight { + &self.stake_weight } pub fn finalized(&self) -> bool { @@ -448,7 +443,7 @@ impl OperatorSnapshot { &mut self, vault: Pubkey, vault_index: u64, - stake_weight: u128, + stake_weight: &StakeWeight, ) -> Result<(), TipRouterError> { if self.vault_operator_delegations_registered() > Self::MAX_VAULT_OPERATOR_STAKE_WEIGHT as u64 @@ -461,7 +456,7 @@ impl OperatorSnapshot { } self.vault_operator_stake_weight[self.vault_operator_delegations_registered() as usize] = - VaultOperatorStakeWeight::new(vault, stake_weight, vault_index); + VaultOperatorStakeWeight::new(vault, vault_index, stake_weight); Ok(()) } @@ -471,7 +466,7 @@ impl OperatorSnapshot { current_slot: u64, vault: Pubkey, vault_index: u64, - stake_weight: u128, + stake_weight: &StakeWeight, ) -> Result<(), TipRouterError> { if self.finalized() { return Err(TipRouterError::VaultOperatorDelegationFinalized); @@ -485,7 +480,7 @@ impl OperatorSnapshot { .ok_or(TipRouterError::ArithmeticOverflow)?, ); - if stake_weight > 0 { + if stake_weight.stake_weight() > 0 { self.valid_operator_vault_delegations = PodU64::from( self.valid_operator_vault_delegations() .checked_add(1) @@ -493,11 +488,7 @@ impl OperatorSnapshot { ); } - self.stake_weight = PodU128::from( - self.stake_weight() - .checked_add(stake_weight) - .ok_or(TipRouterError::ArithmeticOverflow)?, - ); + self.stake_weight.increment(stake_weight)?; if self.finalized() { self.slot_finalized = PodU64::from(current_slot); diff --git a/core/src/error.rs b/core/src/error.rs index 4c5f7da..489d56d 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -74,6 +74,8 @@ pub enum TipRouterError { ConsensusAlreadyReached, #[error("Consensus not reached")] ConsensusNotReached, + #[error("Not a valid NCN fee group")] + InvalidNcnFeeGroup, } impl DecodeError for TipRouterError { diff --git a/core/src/fees.rs b/core/src/fees.rs index 2ec6982..d026bb5 100644 --- a/core/src/fees.rs +++ b/core/src/fees.rs @@ -1,42 +1,46 @@ use bytemuck::{Pod, Zeroable}; use jito_bytemuck::types::PodU64; use shank::ShankType; -use solana_program::{msg, pubkey::Pubkey}; +use solana_program::pubkey::Pubkey; use spl_math::precise_number::PreciseNumber; -use crate::{constants::MAX_FEE_BPS, error::TipRouterError}; +use crate::{constants::MAX_FEE_BPS, error::TipRouterError, ncn_fee_group::NcnFeeGroup}; -/// Fee account. Allows for fee updates to take place in a future epoch without requiring an update. +/// Fee Config. Allows for fee updates to take place in a future epoch without requiring an update. /// This is important so all operators calculate the same Merkle root regardless of when fee changes take place. #[derive(Debug, Clone, Copy, Zeroable, ShankType, Pod)] #[repr(C)] -pub struct Fees { - fee_1: Fee, - fee_2: Fee, +pub struct FeeConfig { + dao_fee_wallet: Pubkey, + + fee_1: Fees, + fee_2: Fees, } -impl Fees { +impl FeeConfig { pub fn new( - wallet: Pubkey, - dao_fee_share_bps: u64, - ncn_fee_share_bps: u64, + dao_fee_wallet: Pubkey, block_engine_fee_bps: u64, + dao_fee_bps: u64, + default_ncn_fee_bps: u64, current_epoch: u64, - ) -> Self { - let fee = Fee::new( - wallet, - dao_fee_share_bps, - ncn_fee_share_bps, + ) -> Result { + let fee = Fees::new( block_engine_fee_bps, + dao_fee_bps, + default_ncn_fee_bps, current_epoch, - ); - Self { + )?; + + Ok(Self { + dao_fee_wallet, fee_1: fee, fee_2: fee, - } + }) } - fn current_fee(&self, current_epoch: u64) -> &Fee { + // ------------- Getters ------------- + fn current_fees(&self, current_epoch: u64) -> &Fees { // If either fee is not yet active, return the other one if self.fee_1.activation_epoch() > current_epoch { return &self.fee_2; @@ -53,108 +57,182 @@ impl Fees { } } - pub fn check_fees_okay(&self, current_epoch: u64) -> Result<(), TipRouterError> { - let _ = self.precise_block_engine_fee(current_epoch)?; - let _ = self.precise_dao_fee(current_epoch)?; - let _ = self.precise_ncn_fee(current_epoch)?; + pub fn total_fees_bps(&self, current_epoch: u64) -> Result { + let mut total_fees_bps = self.dao_fee_bps(current_epoch); - Ok(()) + for group in NcnFeeGroup::all_groups().iter() { + let ncn_fee_bps = self.ncn_fee_bps(*group, current_epoch)?; + + total_fees_bps = total_fees_bps + .checked_add(ncn_fee_bps) + .ok_or(TipRouterError::ArithmeticOverflow)?; + } + + Ok(total_fees_bps) + } + + pub fn precise_total_fee_bps( + &self, + current_epoch: u64, + ) -> Result { + let mut precise_total_fees_bps = self.precise_dao_fee_bps(current_epoch)?; + + for group in NcnFeeGroup::all_groups().iter() { + let precise_ncn_fee_bps = self.precise_ncn_fee_bps(*group, current_epoch)?; + + precise_total_fees_bps = precise_total_fees_bps + .checked_add(&precise_ncn_fee_bps) + .ok_or(TipRouterError::ArithmeticOverflow)?; + } + + Ok(precise_total_fees_bps) } - pub fn block_engine_fee(&self, current_epoch: u64) -> u64 { - self.current_fee(current_epoch).block_engine_fee_bps() + pub fn block_engine_fee_bps(&self, current_epoch: u64) -> u64 { + let current_fees = self.current_fees(current_epoch); + current_fees.block_engine_fee_bps() } - pub fn precise_block_engine_fee( + pub fn precise_block_engine_fee_bps( &self, current_epoch: u64, ) -> Result { - let fee = self.current_fee(current_epoch); + let current_fees = self.current_fees(current_epoch); - PreciseNumber::new(fee.block_engine_fee_bps() as u128) - .ok_or(TipRouterError::NewPreciseNumberError) + current_fees.precise_block_engine_fee_bps() } - /// Calculate fee as a portion of remaining BPS after block engine fee - /// new_fee = dao_fee_bps / ((10000 - block_engine_fee_bps) / 10000) - /// = dao_fee_bps * 10000 / (10000 - block_engine_fee_bps) - pub fn dao_fee(&self, current_epoch: u64) -> Result { - let fee = self.current_fee(current_epoch); - let remaining_bps = MAX_FEE_BPS - .checked_sub(fee.block_engine_fee_bps()) - .ok_or(TipRouterError::ArithmeticOverflow)?; - fee.dao_share_bps() - .checked_mul(MAX_FEE_BPS) - .and_then(|x| x.checked_div(remaining_bps)) - .ok_or(TipRouterError::DenominatorIsZero) + pub fn dao_fee_bps(&self, current_epoch: u64) -> u64 { + let current_fees = self.current_fees(current_epoch); + current_fees.dao_fee_bps() } - pub fn precise_dao_fee(&self, current_epoch: u64) -> Result { - let fee = self.current_fee(current_epoch); + pub fn precise_dao_fee_bps(&self, current_epoch: u64) -> Result { + let current_fees = self.current_fees(current_epoch); + current_fees.precise_dao_fee_bps() + } - let remaining_bps = MAX_FEE_BPS - .checked_sub(fee.block_engine_fee_bps()) - .ok_or(TipRouterError::ArithmeticOverflow)?; + pub fn adjusted_dao_fee_bps(&self, current_epoch: u64) -> Result { + let current_fees = self.current_fees(current_epoch); + current_fees.adjusted_dao_fee_bps() + } - let precise_remaining_bps = PreciseNumber::new(remaining_bps as u128) - .ok_or(TipRouterError::NewPreciseNumberError)?; + pub fn adjusted_precise_dao_fee_bps( + &self, + current_epoch: u64, + ) -> Result { + let current_fees = self.current_fees(current_epoch); + current_fees.adjusted_precise_dao_fee_bps() + } - let dao_fee = fee - .ncn_share_bps() - .checked_mul(MAX_FEE_BPS) - .ok_or(TipRouterError::ArithmeticOverflow)?; + pub fn ncn_fee_bps( + &self, + ncn_fee_group: NcnFeeGroup, + current_epoch: u64, + ) -> Result { + let current_fees = self.current_fees(current_epoch); + current_fees.ncn_fee_bps(ncn_fee_group) + } - let precise_dao_fee = - PreciseNumber::new(dao_fee as u128).ok_or(TipRouterError::NewPreciseNumberError)?; + pub fn precise_ncn_fee_bps( + &self, + ncn_fee_group: NcnFeeGroup, + current_epoch: u64, + ) -> Result { + let current_fees = self.current_fees(current_epoch); + current_fees.precise_ncn_fee_bps(ncn_fee_group) + } - precise_dao_fee - .checked_div(&precise_remaining_bps) - .ok_or(TipRouterError::DenominatorIsZero) + pub fn adjusted_ncn_fee_bps( + &self, + ncn_fee_group: NcnFeeGroup, + current_epoch: u64, + ) -> Result { + let current_fees = self.current_fees(current_epoch); + current_fees.adjusted_ncn_fee_bps(ncn_fee_group) } - /// Calculate fee as a portion of remaining BPS after block engine fee - /// new_fee = ncn_fee_bps / ((10000 - block_engine_fee_bps) / 10000) - /// = ncn_fee_bps * 10000 / (10000 - block_engine_fee_bps) - pub fn ncn_fee(&self, current_epoch: u64) -> Result { - let fee = self.current_fee(current_epoch); + pub fn adjusted_precise_ncn_fee_bps( + &self, + ncn_fee_group: NcnFeeGroup, + current_epoch: u64, + ) -> Result { + let current_fees = self.current_fees(current_epoch); + current_fees.adjusted_precise_ncn_fee_bps(ncn_fee_group) + } - let remaining_bps = MAX_FEE_BPS - .checked_sub(fee.block_engine_fee_bps()) - .ok_or(TipRouterError::ArithmeticOverflow)?; - fee.ncn_share_bps() - .checked_mul(MAX_FEE_BPS) - .and_then(|x| x.checked_div(remaining_bps)) - .ok_or(TipRouterError::DenominatorIsZero) + pub const fn fee_wallet(&self) -> Pubkey { + self.dao_fee_wallet } - pub fn precise_ncn_fee(&self, current_epoch: u64) -> Result { - let fee = self.current_fee(current_epoch); + // ------------- Setters ------------- + /// Updates the Fee Config + /// Any option set to None will be ignored + /// `new_wallet`` and `new_block_engine_fee_bps` will take effect immediately + /// `new_ncn_fee_bps` will set the fee group specified in `new_ncn_fee_group` + /// if no `new_ncn_fee_group` is specified, the default ncn group will be set + pub fn update_fee_config( + &mut self, + new_wallet: Option, + new_block_engine_fee_bps: Option, + new_dao_fee_bps: Option, + new_ncn_fee_bps: Option, + new_ncn_fee_group: Option, + current_epoch: u64, + ) -> Result<(), TipRouterError> { + // Set Wallet + if let Some(new_wallet) = new_wallet { + self.dao_fee_wallet = new_wallet; + } - let remaining_bps = MAX_FEE_BPS - .checked_sub(fee.block_engine_fee_bps()) - .ok_or(TipRouterError::ArithmeticOverflow)?; + // Set new block engine fee + if let Some(new_block_engine_fee_bps) = new_block_engine_fee_bps { + self.fee_1 + .set_block_engine_fee_bps(new_block_engine_fee_bps); + self.fee_2 + .set_block_engine_fee_bps(new_block_engine_fee_bps); + } - let precise_remaining_bps = PreciseNumber::new(remaining_bps as u128) - .ok_or(TipRouterError::NewPreciseNumberError)?; + // Change Fees + { + let current_fees = *self.current_fees(current_epoch); + let new_fees = self.get_updatable_fee_mut(current_epoch); + *new_fees = current_fees; + + if let Some(new_dao_fee_bps) = new_dao_fee_bps { + if new_dao_fee_bps > MAX_FEE_BPS { + return Err(TipRouterError::FeeCapExceeded); + } + new_fees.set_dao_fee_bps(new_dao_fee_bps); + } - let ncn_fee = fee - .ncn_share_bps() - .checked_mul(MAX_FEE_BPS) - .ok_or(TipRouterError::ArithmeticOverflow)?; + // If no fee group is set, use the default + if let Some(new_ncn_fee_bps) = new_ncn_fee_bps { + if new_ncn_fee_bps > MAX_FEE_BPS { + return Err(TipRouterError::FeeCapExceeded); + } + + if let Some(new_ncn_fee_group) = new_ncn_fee_group { + new_fees.set_ncn_fee_bps(new_ncn_fee_group, new_ncn_fee_bps)?; + } else { + new_fees.set_ncn_fee_bps(NcnFeeGroup::default(), new_ncn_fee_bps)?; + } + } - let precise_ncn_fee = - PreciseNumber::new(ncn_fee as u128).ok_or(TipRouterError::NewPreciseNumberError)?; + let next_epoch = current_epoch + .checked_add(1) + .ok_or(TipRouterError::ArithmeticOverflow)?; - precise_ncn_fee - .checked_div(&precise_remaining_bps) - .ok_or(TipRouterError::DenominatorIsZero) - } + new_fees.set_activation_epoch(next_epoch); - pub fn fee_wallet(&self, current_epoch: u64) -> Pubkey { - self.current_fee(current_epoch).wallet + self.check_fees_okay(next_epoch)?; + } + + Ok(()) } - fn get_updatable_fee_mut(&mut self, current_epoch: u64) -> &mut Fee { + // ----------------- Helpers ----------------- + fn get_updatable_fee_mut(&mut self, current_epoch: u64) -> &mut Fees { // If either fee is scheduled for next epoch, return that one if self.fee_1.activation_epoch() > current_epoch { return &mut self.fee_1; @@ -171,51 +249,15 @@ impl Fees { } } - pub fn set_new_fees( - &mut self, - new_dao_fee_bps: Option, - new_ncn_fee_bps: Option, - new_block_engine_fee_bps: Option, - new_wallet: Option, - current_epoch: u64, - ) -> Result<(), TipRouterError> { - let current_fees = *self.current_fee(current_epoch); - let new_fees = self.get_updatable_fee_mut(current_epoch); - *new_fees = current_fees; - - if let Some(new_dao_fee_bps) = new_dao_fee_bps { - if new_dao_fee_bps > MAX_FEE_BPS { - return Err(TipRouterError::FeeCapExceeded); - } - new_fees.set_dao_share_bps(new_dao_fee_bps); - } - if let Some(new_ncn_fee_bps) = new_ncn_fee_bps { - if new_ncn_fee_bps > MAX_FEE_BPS { - return Err(TipRouterError::FeeCapExceeded); - } - new_fees.set_ncn_share_bps(new_ncn_fee_bps); - } - if let Some(new_block_engine_fee_bps) = new_block_engine_fee_bps { - // Block engine fee must be less than MAX_FEE_BPS, - // otherwise we'll divide by zero when calculating - // the other fees - if new_block_engine_fee_bps >= MAX_FEE_BPS { - msg!("Block engine fee cannot equal or exceed MAX_FEE_BPS"); - return Err(TipRouterError::FeeCapExceeded); - } - new_fees.set_block_engine_fee_bps(new_block_engine_fee_bps); - } - if let Some(new_wallet) = new_wallet { - new_fees.wallet = new_wallet; - } - - let next_epoch = current_epoch - .checked_add(1) - .ok_or(TipRouterError::ArithmeticOverflow)?; + pub fn check_fees_okay(&self, current_epoch: u64) -> Result<(), TipRouterError> { + let _ = self.precise_block_engine_fee_bps(current_epoch)?; + let _ = self.adjusted_precise_dao_fee_bps(current_epoch)?; - new_fees.set_activation_epoch(next_epoch); + let all_fee_groups = NcnFeeGroup::all_groups(); - self.check_fees_okay(next_epoch)?; + for group in all_fee_groups.iter() { + let _ = self.adjusted_precise_ncn_fee_bps(*group, current_epoch)?; + } Ok(()) } @@ -223,186 +265,395 @@ impl Fees { #[derive(Debug, Clone, Copy, Zeroable, ShankType, Pod)] #[repr(C)] -pub struct Fee { - wallet: Pubkey, - dao_share_bps: PodU64, - ncn_share_bps: PodU64, - block_engine_fee_bps: PodU64, +pub struct Fees { activation_epoch: PodU64, + + block_engine_fee_bps: PodU64, + dao_fee_bps: PodU64, + + ncn_fee_groups_bps: [NcnFee; NcnFeeGroup::FEE_GROUP_COUNT], + + // Reserves + reserved: [u8; 64], } -impl Fee { +impl Fees { pub fn new( - wallet: Pubkey, - dao_share_bps: u64, - ncn_share_bps: u64, block_engine_fee_bps: u64, + dao_fee_bps: u64, + default_ncn_fee_bps: u64, epoch: u64, - ) -> Self { - Self { - wallet, - dao_share_bps: PodU64::from(dao_share_bps), - ncn_share_bps: PodU64::from(ncn_share_bps), - block_engine_fee_bps: PodU64::from(block_engine_fee_bps), + ) -> Result { + let mut fees = Self { activation_epoch: PodU64::from(epoch), - } - } + block_engine_fee_bps: PodU64::from(block_engine_fee_bps), + dao_fee_bps: PodU64::from(dao_fee_bps), + ncn_fee_groups_bps: [NcnFee::default(); NcnFeeGroup::FEE_GROUP_COUNT], + reserved: [0; 64], + }; + + fees.set_ncn_fee_bps(NcnFeeGroup::default(), default_ncn_fee_bps)?; - pub fn dao_share_bps(&self) -> u64 { - self.dao_share_bps.into() + Ok(fees) } - pub fn ncn_share_bps(&self) -> u64 { - self.ncn_share_bps.into() + // ------ Getters ----------------- + pub fn activation_epoch(&self) -> u64 { + self.activation_epoch.into() } pub fn block_engine_fee_bps(&self) -> u64 { self.block_engine_fee_bps.into() } - pub fn activation_epoch(&self) -> u64 { - self.activation_epoch.into() + pub fn precise_block_engine_fee_bps(&self) -> Result { + PreciseNumber::new(self.block_engine_fee_bps().into()) + .ok_or(TipRouterError::NewPreciseNumberError) } - fn set_dao_share_bps(&mut self, value: u64) { - self.dao_share_bps = PodU64::from(value); + pub fn dao_fee_bps(&self) -> u64 { + self.dao_fee_bps.into() } - fn set_ncn_share_bps(&mut self, value: u64) { - self.ncn_share_bps = PodU64::from(value); + pub fn precise_dao_fee_bps(&self) -> Result { + PreciseNumber::new(self.dao_fee_bps().into()).ok_or(TipRouterError::NewPreciseNumberError) } - fn set_block_engine_fee_bps(&mut self, value: u64) { - self.block_engine_fee_bps = PodU64::from(value); + pub fn adjusted_dao_fee_bps(&self) -> Result { + self.adjusted_fee_bps(self.dao_fee_bps()) } - fn set_activation_epoch(&mut self, value: u64) { - self.activation_epoch = PodU64::from(value); + pub fn adjusted_precise_dao_fee_bps(&self) -> Result { + self.adjusted_precise_fee_bps(self.dao_fee_bps()) } -} -#[cfg(test)] -mod tests { - use solana_program::pubkey::Pubkey; + pub fn ncn_fee_bps(&self, ncn_fee_group: NcnFeeGroup) -> Result { + let group_index = ncn_fee_group.group_index()?; - use super::*; + Ok(self.ncn_fee_groups_bps[group_index].fee()) + } - #[test] - fn test_update_fees() { - let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); - let new_wallet = Pubkey::new_unique(); + pub fn precise_ncn_fee_bps( + &self, + ncn_fee_group: NcnFeeGroup, + ) -> Result { + let fee = self.ncn_fee_bps(ncn_fee_group)?; - fees.set_new_fees(Some(400), None, None, Some(new_wallet), 10) - .unwrap(); - assert_eq!(fees.fee_1.dao_share_bps(), 400); - assert_eq!(fees.fee_1.wallet, new_wallet); - assert_eq!(fees.fee_1.activation_epoch(), 11); + PreciseNumber::new(fee.into()).ok_or(TipRouterError::NewPreciseNumberError) } - #[test] - fn test_update_all_fees() { - let mut fees = Fees::new(Pubkey::new_unique(), 0, 0, 0, 5); + pub fn adjusted_ncn_fee_bps(&self, ncn_fee_group: NcnFeeGroup) -> Result { + let fee = self.ncn_fee_bps(ncn_fee_group)?; - fees.set_new_fees(Some(100), Some(200), Some(300), None, 10) - .unwrap(); - assert_eq!(fees.fee_1.dao_share_bps(), 100); - assert_eq!(fees.fee_1.ncn_share_bps(), 200); - assert_eq!(fees.fee_1.block_engine_fee_bps(), 300); - assert_eq!(fees.fee_1.activation_epoch(), 11); + self.adjusted_fee_bps(fee) } - #[test] - fn test_update_fees_no_changes() { - let original = Fee::new(Pubkey::new_unique(), 100, 200, 300, 5); - let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); - fees.fee_1 = original; - - fees.set_new_fees(None, None, None, None, 10).unwrap(); - assert_eq!(fees.fee_1.dao_share_bps(), original.dao_share_bps()); - assert_eq!(fees.fee_1.ncn_share_bps(), original.ncn_share_bps()); - assert_eq!( - fees.fee_1.block_engine_fee_bps(), - original.block_engine_fee_bps() - ); - assert_eq!(fees.fee_1.wallet, original.wallet); - assert_eq!(fees.fee_1.activation_epoch(), 11); + pub fn adjusted_precise_ncn_fee_bps( + &self, + ncn_fee_group: NcnFeeGroup, + ) -> Result { + let fee = self.ncn_fee_bps(ncn_fee_group)?; + + self.adjusted_precise_fee_bps(fee) } - #[test] - fn test_update_fees_errors() { - let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + // ------ Setters ----------------- + fn set_activation_epoch(&mut self, value: u64) { + self.activation_epoch = PodU64::from(value); + } - assert_eq!( - fees.set_new_fees(Some(10001), None, None, None, 10), - Err(TipRouterError::FeeCapExceeded) - ); + fn set_block_engine_fee_bps(&mut self, value: u64) { + self.block_engine_fee_bps = PodU64::from(value); + } - let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + fn set_dao_fee_bps(&mut self, value: u64) { + self.dao_fee_bps = PodU64::from(value); + } - assert_eq!( - fees.set_new_fees(None, None, None, None, u64::MAX), - Err(TipRouterError::ArithmeticOverflow) - ); + pub fn set_ncn_fee_bps( + &mut self, + ncn_fee_group: NcnFeeGroup, + value: u64, + ) -> Result<(), TipRouterError> { + let group_index = ncn_fee_group.group_index()?; - let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + self.ncn_fee_groups_bps[group_index] = NcnFee::new(value); - assert_eq!( - fees.set_new_fees(None, None, Some(MAX_FEE_BPS), None, 10), - Err(TipRouterError::FeeCapExceeded) - ); + Ok(()) } - #[test] - fn test_check_fees_okay() { - let fees = Fees::new(Pubkey::new_unique(), 0, 0, 0, 5); + // ------ Helpers ----------------- + fn adjusted_fee_bps(&self, fee: u64) -> Result { + let remaining_bps = MAX_FEE_BPS + .checked_sub(self.block_engine_fee_bps()) + .ok_or(TipRouterError::ArithmeticOverflow)?; + fee.checked_mul(MAX_FEE_BPS) + .and_then(|x| x.checked_div(remaining_bps)) + .ok_or(TipRouterError::DenominatorIsZero) + } - fees.check_fees_okay(5).unwrap(); + fn adjusted_precise_fee_bps(&self, fee: u64) -> Result { + let remaining_bps = MAX_FEE_BPS + .checked_sub(self.block_engine_fee_bps()) + .ok_or(TipRouterError::ArithmeticOverflow)?; - let fees = Fees::new(Pubkey::new_unique(), 0, 0, MAX_FEE_BPS, 5); + let precise_remaining_bps = PreciseNumber::new(remaining_bps as u128) + .ok_or(TipRouterError::NewPreciseNumberError)?; - assert_eq!( - fees.check_fees_okay(5), - Err(TipRouterError::DenominatorIsZero) - ); - } + let adjusted_fee = fee + .checked_mul(MAX_FEE_BPS) + .ok_or(TipRouterError::ArithmeticOverflow)?; - #[test] - fn test_current_fee() { - let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + let precise_adjusted_fee = PreciseNumber::new(adjusted_fee as u128) + .ok_or(TipRouterError::NewPreciseNumberError)?; - assert_eq!(fees.current_fee(5).activation_epoch(), 5); + precise_adjusted_fee + .checked_div(&precise_remaining_bps) + .ok_or(TipRouterError::DenominatorIsZero) + } +} - fees.fee_1.set_activation_epoch(10); +#[derive(Debug, Clone, Copy, Zeroable, ShankType, Pod)] +#[repr(C)] +pub struct NcnFee { + fee: PodU64, + reserved: [u8; 64], +} - assert_eq!(fees.current_fee(5).activation_epoch(), 5); - assert_eq!(fees.current_fee(10).activation_epoch(), 10); +impl Default for NcnFee { + fn default() -> Self { + Self { + fee: PodU64::from(0), + reserved: [0; 64], + } + } +} - fees.fee_2.set_activation_epoch(15); +impl NcnFee { + pub fn new(fee: u64) -> Self { + Self { + fee: PodU64::from(fee), + reserved: [0; 64], + } + } - assert_eq!(fees.current_fee(12).activation_epoch(), 10); - assert_eq!(fees.current_fee(15).activation_epoch(), 15); + pub fn fee(&self) -> u64 { + self.fee.into() } +} + +#[cfg(test)] +mod tests { + use solana_program::pubkey::Pubkey; + + use super::*; #[test] - fn test_get_updatable_fee_mut() { - let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + fn test_update_fees() { + const BLOCK_ENGINE_FEE: u64 = 100; + const DAO_FEE: u64 = 200; + const DEFAULT_NCN_FEE: u64 = 300; + const STARTING_EPOCH: u64 = 10; + + let dao_fee_wallet = Pubkey::new_unique(); + + let mut fee_config = FeeConfig::new( + dao_fee_wallet, + BLOCK_ENGINE_FEE, + DAO_FEE, + DEFAULT_NCN_FEE, + STARTING_EPOCH, + ) + .unwrap(); + + assert_eq!(fee_config.fee_wallet(), dao_fee_wallet); + + assert_eq!(fee_config.fee_1.activation_epoch(), STARTING_EPOCH); + assert_eq!(fee_config.fee_1.block_engine_fee_bps(), BLOCK_ENGINE_FEE); + assert_eq!(fee_config.fee_1.dao_fee_bps(), DAO_FEE); + assert_eq!( + fee_config + .fee_1 + .ncn_fee_bps(NcnFeeGroup::default()) + .unwrap(), + 0 + ); - let fee = fees.get_updatable_fee_mut(10); - fee.set_dao_share_bps(400); - fee.set_activation_epoch(11); + assert_eq!(fee_config.fee_2.activation_epoch(), STARTING_EPOCH); + assert_eq!(fee_config.fee_2.block_engine_fee_bps(), BLOCK_ENGINE_FEE); + assert_eq!(fee_config.fee_2.dao_fee_bps(), DAO_FEE); + assert_eq!( + fee_config + .fee_2 + .ncn_fee_bps(NcnFeeGroup::default()) + .unwrap(), + 0 + ); - assert_eq!(fees.fee_1.dao_share_bps(), 400); - assert_eq!(fees.fee_1.activation_epoch(), 11); + let new_fees = Fees::new(500, 600, 700, 10).unwrap(); + let new_wallet = Pubkey::new_unique(); - fees.fee_2.set_activation_epoch(13); + fee_config + .update_fee_config( + Some(new_wallet), + Some(new_fees.block_engine_fee_bps()), + Some(new_fees.dao_fee_bps()), + Some(new_fees.ncn_fee_bps(NcnFeeGroup::default()).unwrap()), + None, + STARTING_EPOCH, + ) + .unwrap(); - let fee = fees.get_updatable_fee_mut(12); - fee.set_dao_share_bps(500); - fee.set_activation_epoch(13); + assert_eq!(fee_config.fee_wallet(), new_wallet); - assert_eq!(fees.fee_2.dao_share_bps(), 500); - assert_eq!(fees.fee_2.activation_epoch(), 13); + assert_eq!(fee_config.fee_1.activation_epoch(), STARTING_EPOCH + 1); + assert_eq!(fee_config.fee_1.block_engine_fee_bps(), 500); + assert_eq!(fee_config.fee_1.dao_fee_bps(), 600); + assert_eq!( + fee_config + .fee_1 + .ncn_fee_bps(NcnFeeGroup::default()) + .unwrap(), + 700 + ); - assert_eq!(fees.get_updatable_fee_mut(u64::MAX).activation_epoch(), 11); + assert_eq!(fee_config.fee_2.activation_epoch(), STARTING_EPOCH); + assert_eq!(fee_config.fee_2.block_engine_fee_bps(), 500); // This will change regardless + assert_eq!(fee_config.fee_2.dao_fee_bps(), DAO_FEE); + assert_eq!( + fee_config + .fee_2 + .ncn_fee_bps(NcnFeeGroup::default()) + .unwrap(), + DEFAULT_NCN_FEE + ); } + + // #[test] + // fn test_update_fees() { + // let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + // let new_wallet = Pubkey::new_unique(); + + // fees.set_new_fees(Some(400), None, None, Some(new_wallet), 10) + // .unwrap(); + // assert_eq!(fees.fee_1.dao_share_bps(), 400); + // assert_eq!(fees.wallet, new_wallet); + // assert_eq!(fees.fee_1.activation_epoch(), 11); + // } + + // #[test] + // fn test_update_all_fees() { + // let mut fees = Fees::new(Pubkey::new_unique(), 0, 0, 0, 5); + + // fees.set_new_fees(Some(100), Some(200), Some(300), None, 10) + // .unwrap(); + // assert_eq!(fees.fee_1.dao_share_bps(), 100); + // assert_eq!(fees.fee_1.ncn_share_bps(), 200); + // assert_eq!(fees.block_engine_fee_bps(), 300); + // assert_eq!(fees.fee_1.activation_epoch(), 11); + // } + + // #[test] + // fn test_update_fees_no_changes() { + // const DAO_SHARE_FEE_BPS: u64 = 100; + // const NCN_SHARE_FEE_BPS: u64 = 100; + // const BLOCK_ENGINE_FEE: u64 = 100; + // const STARTING_EPOCH: u64 = 10; + + // let wallet = Pubkey::new_unique(); + + // let mut fees = Fees::new( + // wallet, + // DAO_SHARE_FEE_BPS, + // NCN_SHARE_FEE_BPS, + // BLOCK_ENGINE_FEE, + // STARTING_EPOCH, + // ); + + // fees.set_new_fees(None, None, None, None, STARTING_EPOCH) + // .unwrap(); + // assert_eq!(fees.fee_1.dao_share_bps(), DAO_SHARE_FEE_BPS); + // assert_eq!(fees.fee_1.ncn_share_bps(), NCN_SHARE_FEE_BPS); + // assert_eq!(fees.block_engine_fee_bps(), BLOCK_ENGINE_FEE); + // assert_eq!(fees.wallet, wallet); + // assert_eq!(fees.fee_1.activation_epoch(), STARTING_EPOCH + 1); + // } + + // #[test] + // fn test_update_fees_errors() { + // let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + // assert_eq!( + // fees.set_new_fees(Some(10001), None, None, None, 10), + // Err(TipRouterError::FeeCapExceeded) + // ); + + // let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + // assert_eq!( + // fees.set_new_fees(None, None, None, None, u64::MAX), + // Err(TipRouterError::ArithmeticOverflow) + // ); + + // let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + // assert_eq!( + // fees.set_new_fees(None, None, Some(MAX_FEE_BPS), None, 10), + // Err(TipRouterError::FeeCapExceeded) + // ); + // } + + // #[test] + // fn test_check_fees_okay() { + // let fees = Fees::new(Pubkey::new_unique(), 0, 0, 0, 5); + + // fees.check_fees_okay(5).unwrap(); + + // let fees = Fees::new(Pubkey::new_unique(), 0, 0, MAX_FEE_BPS, 5); + + // assert_eq!( + // fees.check_fees_okay(5), + // Err(TipRouterError::DenominatorIsZero) + // ); + // } + + // #[test] + // fn test_current_fee() { + // let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + // assert_eq!(fees.current_fee(5).activation_epoch(), 5); + + // fees.fee_1.set_activation_epoch(10); + + // assert_eq!(fees.current_fee(5).activation_epoch(), 5); + // assert_eq!(fees.current_fee(10).activation_epoch(), 10); + + // fees.fee_2.set_activation_epoch(15); + + // assert_eq!(fees.current_fee(12).activation_epoch(), 10); + // assert_eq!(fees.current_fee(15).activation_epoch(), 15); + // } + + // #[test] + // fn test_get_updatable_fee_mut() { + // let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + // let fee = fees.get_updatable_fee_mut(10); + // fee.set_dao_share_bps(400); + // fee.set_activation_epoch(11); + + // assert_eq!(fees.fee_1.dao_share_bps(), 400); + // assert_eq!(fees.fee_1.activation_epoch(), 11); + + // fees.fee_2.set_activation_epoch(13); + + // let fee = fees.get_updatable_fee_mut(12); + // fee.set_dao_share_bps(500); + // fee.set_activation_epoch(13); + + // assert_eq!(fees.fee_2.dao_share_bps(), 500); + // assert_eq!(fees.fee_2.activation_epoch(), 13); + + // assert_eq!(fees.get_updatable_fee_mut(u64::MAX).activation_epoch(), 11); + // } } diff --git a/core/src/lib.rs b/core/src/lib.rs index ac78271..a47124c 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -7,6 +7,8 @@ pub mod fees; pub mod instruction; pub mod loaders; pub mod ncn_config; +pub mod ncn_fee_group; +pub mod stake_weight; pub mod tracked_mints; pub mod weight_entry; pub mod weight_table; diff --git a/core/src/ncn_config.rs b/core/src/ncn_config.rs index 2ca4808..22a0da2 100644 --- a/core/src/ncn_config.rs +++ b/core/src/ncn_config.rs @@ -3,7 +3,7 @@ use jito_bytemuck::{AccountDeserialize, Discriminator}; use shank::{ShankAccount, ShankType}; use solana_program::{account_info::AccountInfo, msg, program_error::ProgramError, pubkey::Pubkey}; -use crate::{discriminators::Discriminators, fees::Fees}; +use crate::{discriminators::Discriminators, fees::FeeConfig}; #[derive(Debug, Clone, Copy, Zeroable, ShankType, Pod, AccountDeserialize, ShankAccount)] #[repr(C)] @@ -15,7 +15,7 @@ pub struct NcnConfig { pub fee_admin: Pubkey, - pub fees: Fees, + pub fee_config: FeeConfig, /// Bump seed for the PDA pub bump: u8, @@ -32,13 +32,13 @@ impl NcnConfig { ncn: Pubkey, tie_breaker_admin: Pubkey, fee_admin: Pubkey, - fees: Fees, + fee_config: &FeeConfig, ) -> Self { Self { ncn, tie_breaker_admin, fee_admin, - fees, + fee_config: *fee_config, bump: 0, reserved: [0; 127], } diff --git a/core/src/ncn_fee_group.rs b/core/src/ncn_fee_group.rs new file mode 100644 index 0000000..983f7d1 --- /dev/null +++ b/core/src/ncn_fee_group.rs @@ -0,0 +1,135 @@ +use bytemuck::{Pod, Zeroable}; +use shank::ShankType; + +use crate::error::TipRouterError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum NcnFeeGroupType { + Default = 0x0, + JTO = 0x1, + Reserved2 = 0x2, + Reserved3 = 0x3, + Reserved4 = 0x4, + Reserved5 = 0x5, + Reserved6 = 0x6, + Reserved7 = 0x7, + Reserved8 = 0x8, + Reserved9 = 0x9, + ReservedA = 0xA, + ReservedB = 0xB, + ReservedC = 0xC, + ReservedD = 0xD, + ReservedE = 0xE, + ReservedF = 0xF, +} + +#[derive(Debug, Clone, Copy, Zeroable, ShankType, Pod)] +#[repr(C)] +pub struct NcnFeeGroup { + pub group: u8, +} + +impl Default for NcnFeeGroup { + fn default() -> Self { + Self { + group: NcnFeeGroupType::Default as u8, + } + } +} + +impl NcnFeeGroup { + pub const FEE_GROUP_COUNT: usize = 16; + + pub const fn new(group: NcnFeeGroupType) -> Self { + // So compiler will yell at us if we miss a group + match group { + NcnFeeGroupType::Default => Self { group: group as u8 }, + NcnFeeGroupType::JTO => Self { group: group as u8 }, + NcnFeeGroupType::Reserved2 => Self { group: group as u8 }, + NcnFeeGroupType::Reserved3 => Self { group: group as u8 }, + NcnFeeGroupType::Reserved4 => Self { group: group as u8 }, + NcnFeeGroupType::Reserved5 => Self { group: group as u8 }, + NcnFeeGroupType::Reserved6 => Self { group: group as u8 }, + NcnFeeGroupType::Reserved7 => Self { group: group as u8 }, + NcnFeeGroupType::Reserved8 => Self { group: group as u8 }, + NcnFeeGroupType::Reserved9 => Self { group: group as u8 }, + NcnFeeGroupType::ReservedA => Self { group: group as u8 }, + NcnFeeGroupType::ReservedB => Self { group: group as u8 }, + NcnFeeGroupType::ReservedC => Self { group: group as u8 }, + NcnFeeGroupType::ReservedD => Self { group: group as u8 }, + NcnFeeGroupType::ReservedE => Self { group: group as u8 }, + NcnFeeGroupType::ReservedF => Self { group: group as u8 }, + } + } + + pub const fn from_u8(group: u8) -> Result { + match group { + 0x0 => Ok(Self::new(NcnFeeGroupType::Default)), + 0x1 => Ok(Self::new(NcnFeeGroupType::JTO)), + 0x2 => Ok(Self::new(NcnFeeGroupType::Reserved2)), + 0x3 => Ok(Self::new(NcnFeeGroupType::Reserved3)), + 0x4 => Ok(Self::new(NcnFeeGroupType::Reserved4)), + 0x5 => Ok(Self::new(NcnFeeGroupType::Reserved5)), + 0x6 => Ok(Self::new(NcnFeeGroupType::Reserved6)), + 0x7 => Ok(Self::new(NcnFeeGroupType::Reserved7)), + 0x8 => Ok(Self::new(NcnFeeGroupType::Reserved8)), + 0x9 => Ok(Self::new(NcnFeeGroupType::Reserved9)), + 0xA => Ok(Self::new(NcnFeeGroupType::ReservedA)), + 0xB => Ok(Self::new(NcnFeeGroupType::ReservedB)), + 0xC => Ok(Self::new(NcnFeeGroupType::ReservedC)), + 0xD => Ok(Self::new(NcnFeeGroupType::ReservedD)), + 0xE => Ok(Self::new(NcnFeeGroupType::ReservedE)), + 0xF => Ok(Self::new(NcnFeeGroupType::ReservedF)), + _ => Err(TipRouterError::InvalidNcnFeeGroup), + } + } + + pub const fn group_type(&self) -> Result { + match self.group { + 0x0 => Ok(NcnFeeGroupType::Default), + 0x1 => Ok(NcnFeeGroupType::JTO), + 0x2 => Ok(NcnFeeGroupType::Reserved2), + 0x3 => Ok(NcnFeeGroupType::Reserved3), + 0x4 => Ok(NcnFeeGroupType::Reserved4), + 0x5 => Ok(NcnFeeGroupType::Reserved5), + 0x6 => Ok(NcnFeeGroupType::Reserved6), + 0x7 => Ok(NcnFeeGroupType::Reserved7), + 0x8 => Ok(NcnFeeGroupType::Reserved8), + 0x9 => Ok(NcnFeeGroupType::Reserved9), + 0xA => Ok(NcnFeeGroupType::ReservedA), + 0xB => Ok(NcnFeeGroupType::ReservedB), + 0xC => Ok(NcnFeeGroupType::ReservedC), + 0xD => Ok(NcnFeeGroupType::ReservedD), + 0xE => Ok(NcnFeeGroupType::ReservedE), + 0xF => Ok(NcnFeeGroupType::ReservedF), + _ => Err(TipRouterError::InvalidNcnFeeGroup), + } + } + + pub fn group_index(&self) -> Result { + let group = self.group_type()?; + Ok(group as usize) + } + + pub fn all_groups() -> Vec { + vec![ + Self::new(NcnFeeGroupType::Default), + Self::new(NcnFeeGroupType::JTO), + Self::new(NcnFeeGroupType::Reserved2), + Self::new(NcnFeeGroupType::Reserved3), + Self::new(NcnFeeGroupType::Reserved4), + Self::new(NcnFeeGroupType::Reserved5), + Self::new(NcnFeeGroupType::Reserved6), + Self::new(NcnFeeGroupType::Reserved7), + Self::new(NcnFeeGroupType::Reserved8), + Self::new(NcnFeeGroupType::Reserved9), + Self::new(NcnFeeGroupType::ReservedA), + Self::new(NcnFeeGroupType::ReservedB), + Self::new(NcnFeeGroupType::ReservedC), + Self::new(NcnFeeGroupType::ReservedD), + Self::new(NcnFeeGroupType::ReservedE), + Self::new(NcnFeeGroupType::ReservedF), + ] + } +} diff --git a/core/src/stake_weight.rs b/core/src/stake_weight.rs new file mode 100644 index 0000000..e9d98ec --- /dev/null +++ b/core/src/stake_weight.rs @@ -0,0 +1,101 @@ +use bytemuck::{Pod, Zeroable}; +use jito_bytemuck::types::{PodU128, PodU64}; +use shank::ShankType; + +use crate::{error::TipRouterError, ncn_fee_group::NcnFeeGroup}; + +#[derive(Debug, Clone, Copy, Zeroable, ShankType, Pod)] +#[repr(C)] +pub struct StakeWeight { + stake_weight: PodU128, + reward_stake_weights: [RewardStakeWeight; NcnFeeGroup::FEE_GROUP_COUNT], + // Reserves + reserved: [u8; 64], +} + +impl Default for StakeWeight { + fn default() -> Self { + Self { + stake_weight: PodU128::from(0), + reward_stake_weights: [RewardStakeWeight::default(); NcnFeeGroup::FEE_GROUP_COUNT], + reserved: [0; 64], + } + } +} + +impl StakeWeight { + pub fn stake_weight(&self) -> u128 { + self.stake_weight.into() + } + + pub fn reward_stake_weight(&self, ncn_fee_group: NcnFeeGroup) -> Result { + let group_index = ncn_fee_group.group_index()?; + + Ok(self.reward_stake_weights[group_index].reward_stake_weight()) + } + + pub fn increment(&mut self, stake_weight: &StakeWeight) -> Result<(), TipRouterError> { + self.increment_stake_weight(stake_weight.stake_weight())?; + + for group in NcnFeeGroup::all_groups().iter() { + self.increment_reward_stake_weight(*group, stake_weight.reward_stake_weight(*group)?)?; + } + + Ok(()) + } + + pub fn increment_stake_weight(&mut self, stake_weight: u128) -> Result<(), TipRouterError> { + self.stake_weight = PodU128::from( + self.stake_weight() + .checked_add(stake_weight) + .ok_or(TipRouterError::ArithmeticOverflow)?, + ); + + Ok(()) + } + + pub fn increment_reward_stake_weight( + &mut self, + ncn_fee_group: NcnFeeGroup, + stake_weight: u64, + ) -> Result<(), TipRouterError> { + let group_index = ncn_fee_group.group_index()?; + + self.reward_stake_weights[group_index].reward_stake_weight = PodU64::from( + self.reward_stake_weight(ncn_fee_group)? + .checked_add(stake_weight) + .ok_or(TipRouterError::ArithmeticOverflow)?, + ); + + Ok(()) + } +} + +#[derive(Debug, Clone, Copy, Zeroable, ShankType, Pod)] +#[repr(C)] +pub struct RewardStakeWeight { + reward_stake_weight: PodU64, + reserved: [u8; 64], +} + +impl Default for RewardStakeWeight { + fn default() -> Self { + Self { + reward_stake_weight: PodU64::from(0), + reserved: [0; 64], + } + } +} + +impl RewardStakeWeight { + pub fn new(reward_stake_weight: u64) -> Self { + Self { + reward_stake_weight: PodU64::from(reward_stake_weight), + reserved: [0; 64], + } + } + + pub fn reward_stake_weight(&self) -> u64 { + self.reward_stake_weight.into() + } +} diff --git a/integration_tests/tests/tip_router/set_config_fees.rs b/integration_tests/tests/tip_router/set_config_fees.rs index b99312d..03923ec 100644 --- a/integration_tests/tests/tip_router/set_config_fees.rs +++ b/integration_tests/tests/tip_router/set_config_fees.rs @@ -102,11 +102,11 @@ mod tests { .get_ncn_config(ncn_root.ncn_pubkey) .await?; let clock = fixture.clock().await; - assert_eq!(config.fees.dao_fee(clock.epoch as u64).unwrap(), 100); - assert_eq!(config.fees.ncn_fee(clock.epoch as u64).unwrap(), 200); - assert_eq!(config.fees.block_engine_fee(clock.epoch as u64), 0); + assert_eq!(config.fee_config.dao_fee(clock.epoch as u64).unwrap(), 100); + assert_eq!(config.fee_config.ncn_fee(clock.epoch as u64).unwrap(), 200); + assert_eq!(config.fee_config.block_engine_fee(clock.epoch as u64), 0); assert_eq!( - config.fees.fee_wallet(clock.epoch as u64), + config.fee_config.fee_wallet(clock.epoch as u64), new_fee_wallet.pubkey() ); diff --git a/program/src/initialize_epoch_snapshot.rs b/program/src/initialize_epoch_snapshot.rs index 1814bb4..df0b939 100644 --- a/program/src/initialize_epoch_snapshot.rs +++ b/program/src/initialize_epoch_snapshot.rs @@ -86,7 +86,7 @@ pub fn process_initialize_epoch_snapshot( let ncn_fees: fees::Fees = { let ncn_config_data = ncn_config.data.borrow(); let ncn_config_account = NcnConfig::try_from_slice_unchecked(&ncn_config_data)?; - ncn_config_account.fees + ncn_config_account.fee_config }; let operator_count: u64 = { diff --git a/program/src/initialize_ncn_config.rs b/program/src/initialize_ncn_config.rs index 83b450e..cf6b6ca 100644 --- a/program/src/initialize_ncn_config.rs +++ b/program/src/initialize_ncn_config.rs @@ -5,7 +5,7 @@ use jito_jsm_core::{ }; use jito_restaking_core::{config::Config, ncn::Ncn}; use jito_tip_router_core::{ - constants::MAX_FEE_BPS, error::TipRouterError, fees::Fees, ncn_config::NcnConfig, + constants::MAX_FEE_BPS, error::TipRouterError, fees::FeeConfig, ncn_config::NcnConfig, }; use solana_program::{ account_info::AccountInfo, clock::Clock, entrypoint::ProgramResult, @@ -84,21 +84,24 @@ pub fn process_initialize_ncn_config( let mut config_data = ncn_config.try_borrow_mut_data()?; config_data[0] = NcnConfig::DISCRIMINATOR; let config = NcnConfig::try_from_slice_unchecked_mut(&mut config_data)?; + + let fee_config = FeeConfig::new( + *fee_wallet.key, + block_engine_fee_bps, + dao_fee_bps, + ncn_fee_bps, + epoch, + )?; + *config = NcnConfig::new( *ncn_account.key, *tie_breaker_admin.key, *ncn_admin.key, - Fees::new( - *fee_wallet.key, - dao_fee_bps, - ncn_fee_bps, - block_engine_fee_bps, - epoch, - ), + &fee_config, ); config.bump = config_bump; - config.fees.check_fees_okay(epoch)?; + config.fee_config.check_fees_okay(epoch)?; Ok(()) } diff --git a/program/src/set_config_fees.rs b/program/src/set_config_fees.rs index 7df9a95..0cff379 100644 --- a/program/src/set_config_fees.rs +++ b/program/src/set_config_fees.rs @@ -10,10 +10,11 @@ use solana_program::{ pub fn process_set_config_fees( program_id: &Pubkey, accounts: &[AccountInfo], + new_fee_wallet: Option, + new_block_engine_fee_bps: Option, new_dao_fee_bps: Option, new_ncn_fee_bps: Option, - new_block_engine_fee_bps: Option, - new_fee_wallet: Option, + new_ncn_fee_group: Option, ) -> ProgramResult { let [restaking_config, config, ncn_account, fee_admin, restaking_program] = accounts else { return Err(ProgramError::NotEnoughAccountKeys); @@ -53,11 +54,18 @@ pub fn process_set_config_fees( return Err(TipRouterError::IncorrectFeeAdmin.into()); } - config.fees.set_new_fees( + let new_ncn_fee_group = let Some(new_ncn_fee_group) = new_ncn_fee_group { + new_ncn_fee_group + } else { + config.fee_config.ncn_fee_group() + }; + + config.fee_config.update_fee_config( new_dao_fee_bps, new_ncn_fee_bps, new_block_engine_fee_bps, new_fee_wallet, + new_ncn_fee_group, epoch, )?;