diff --git a/core/src/fees.rs b/core/src/fees.rs index 30424d41..3f820e17 100644 --- a/core/src/fees.rs +++ b/core/src/fees.rs @@ -87,24 +87,20 @@ impl Fees { self.current_fee(current_epoch).wallet } - fn get_updatable_fee_mut(&mut self, current_epoch: u64) -> Result<&mut Fee, TipRouterError> { - let next_epoch = current_epoch - .checked_add(1) - .ok_or(TipRouterError::ArithmeticOverflow)?; - + fn get_updatable_fee_mut(&mut self, current_epoch: u64) -> &mut Fee { // If either fee is scheduled for next epoch, return that one - if self.fee_1.activation_epoch == next_epoch { - return Ok(&mut self.fee_1); + if self.fee_1.activation_epoch > current_epoch { + return &mut self.fee_1; } - if self.fee_2.activation_epoch == next_epoch { - return Ok(&mut self.fee_2); + if self.fee_2.activation_epoch > current_epoch { + return &mut self.fee_2; } // Otherwise return the one with lower activation epoch if self.fee_1.activation_epoch <= self.fee_2.activation_epoch { - Ok(&mut self.fee_1) + &mut self.fee_1 } else { - Ok(&mut self.fee_2) + &mut self.fee_2 } } @@ -117,7 +113,7 @@ impl Fees { current_epoch: u64, ) -> Result<(), TipRouterError> { let current_fees = *self.current_fee(current_epoch); - let new_fees = self.get_updatable_fee_mut(current_epoch)?; + let new_fees = self.get_updatable_fee_mut(current_epoch); *new_fees = current_fees; if let Some(new_dao_fee_bps) = new_dao_fee_bps { @@ -176,4 +172,95 @@ impl Fee { } } -// TODO Some tests for fees +#[cfg(test)] +mod tests { + use solana_program::pubkey::Pubkey; + + use super::*; + + #[test] + fn test_update_fees() { + let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + let new_wallet = Pubkey::new_unique(); + + fees.set_new_fees(Some(400), None, None, Some(new_wallet), 10) + .unwrap(); + assert_eq!(fees.fee_1.dao_share_bps, 400); + assert_eq!(fees.fee_1.wallet, new_wallet); + assert_eq!(fees.fee_1.activation_epoch, 11); + } + + #[test] + fn test_update_fees_no_changes() { + let original = Fee::new(Pubkey::new_unique(), 100, 200, 300, 5); + let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + fees.fee_1 = original.clone(); + + fees.set_new_fees(None, None, None, None, 10).unwrap(); + assert_eq!(fees.fee_1.dao_share_bps, original.dao_share_bps); + assert_eq!(fees.fee_1.ncn_share_bps, original.ncn_share_bps); + assert_eq!( + fees.fee_1.block_engine_fee_bps, + original.block_engine_fee_bps + ); + assert_eq!(fees.fee_1.wallet, original.wallet); + assert_eq!(fees.fee_1.activation_epoch, 11); + } + + #[test] + fn test_update_fees_errors() { + let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + assert!(matches!( + fees.set_new_fees(Some(10001), None, None, None, 10), + Err(TipRouterError::FeeCapExceeded) + )); + + let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + assert!(matches!( + fees.set_new_fees(None, None, None, None, u64::MAX), + Err(TipRouterError::ArithmeticOverflow) + )); + } + + #[test] + fn test_current_fee() { + let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + assert_eq!(fees.current_fee(5).activation_epoch, 5); + + fees.fee_1.activation_epoch = 10; + + assert_eq!(fees.current_fee(5).activation_epoch, 5); + assert_eq!(fees.current_fee(10).activation_epoch, 10); + + fees.fee_2.activation_epoch = 15; + + assert_eq!(fees.current_fee(12).activation_epoch, 10); + assert_eq!(fees.current_fee(15).activation_epoch, 15); + } + + #[test] + fn test_get_updatable_fee_mut() { + let mut fees = Fees::new(Pubkey::new_unique(), 100, 200, 300, 5); + + let fee = fees.get_updatable_fee_mut(10); + fee.dao_share_bps = 400; + fee.activation_epoch = 11; + + assert_eq!(fees.fee_1.dao_share_bps, 400); + assert_eq!(fees.fee_1.activation_epoch, 11); + + fees.fee_2.activation_epoch = 13; + + let fee = fees.get_updatable_fee_mut(12); + fee.dao_share_bps = 500; + fee.activation_epoch = 13; + + assert_eq!(fees.fee_2.dao_share_bps, 500); + assert_eq!(fees.fee_2.activation_epoch, 13); + + assert_eq!(fees.get_updatable_fee_mut(u64::MAX).activation_epoch, 11); + } +} diff --git a/core/src/instruction.rs b/core/src/instruction.rs index 3411ff17..3ffa68dc 100644 --- a/core/src/instruction.rs +++ b/core/src/instruction.rs @@ -16,7 +16,7 @@ pub enum WeightTableInstruction { /// Initialize the global configuration for this NCN #[account(0, writable, name = "config")] #[account(1, name = "ncn")] - #[account(2, writable, name = "ncn_admin")] + #[account(2, signer, name = "ncn_admin")] #[account(3, name = "fee_wallet")] #[account(4, name = "tie_breaker_admin")] #[account(5, name = "restaking_program_id")] diff --git a/core/src/ncn_config.rs b/core/src/ncn_config.rs index c77cc75d..d68a0fa9 100644 --- a/core/src/ncn_config.rs +++ b/core/src/ncn_config.rs @@ -56,8 +56,8 @@ impl NcnConfig { } } - pub fn seeds() -> Vec> { - vec![b"config".to_vec()] + pub fn seeds(ncn: &Pubkey) -> Vec> { + vec![b"config".to_vec(), ncn.to_bytes().to_vec()] } pub fn find_program_address(program_id: &Pubkey, ncn: &Pubkey) -> (Pubkey, u8, Vec>) {