From e32e1266a2a7dc9c49cde71f84d16a7d57df9e5a Mon Sep 17 00:00:00 2001 From: Alessandro Decina Date: Fri, 2 Aug 2024 08:39:18 +0700 Subject: [PATCH] Introduce VoteState::deserialize_into_uninit (#2272) * VoteState::deserialize_into: take &mut MaybeUninit Deserializing into MaybeUninit saves the extra cost of initializing into a value initialized with VoteState::default() --- sdk/program/src/serialize_utils/cursor.rs | 31 +++- sdk/program/src/vote/state/mod.rs | 165 ++++++++++++++++-- .../src/vote/state/vote_state_0_23_5.rs | 10 +- .../src/vote/state/vote_state_1_14_11.rs | 10 +- .../src/vote/state/vote_state_deserialize.rs | 128 +++++++++----- 5 files changed, 272 insertions(+), 72 deletions(-) diff --git a/sdk/program/src/serialize_utils/cursor.rs b/sdk/program/src/serialize_utils/cursor.rs index 6e78d88ef9a73a..3d4dedd092ed3a 100644 --- a/sdk/program/src/serialize_utils/cursor.rs +++ b/sdk/program/src/serialize_utils/cursor.rs @@ -1,6 +1,12 @@ use { - crate::{instruction::InstructionError, pubkey::Pubkey}, - std::io::{Cursor, Read}, + crate::{ + instruction::InstructionError, + pubkey::{Pubkey, PUBKEY_BYTES}, + }, + std::{ + io::{BufRead as _, Cursor, Read}, + ptr, + }, }; pub(crate) fn read_u8>(cursor: &mut Cursor) -> Result { @@ -50,6 +56,27 @@ pub(crate) fn read_i64>(cursor: &mut Cursor) -> Result, + pubkey: *mut Pubkey, +) -> Result<(), InstructionError> { + match cursor.fill_buf() { + Ok(buf) if buf.len() >= PUBKEY_BYTES => { + // Safety: `buf` is guaranteed to be at least `PUBKEY_BYTES` bytes + // long. Pubkey a #[repr(transparent)] wrapper around a byte array, + // so this is a byte to byte copy and it's safe. + unsafe { + ptr::copy_nonoverlapping(buf.as_ptr(), pubkey as *mut u8, PUBKEY_BYTES); + } + + cursor.consume(PUBKEY_BYTES); + } + _ => return Err(InstructionError::InvalidAccountData), + } + + Ok(()) +} + pub(crate) fn read_pubkey>( cursor: &mut Cursor, ) -> Result { diff --git a/sdk/program/src/vote/state/mod.rs b/sdk/program/src/vote/state/mod.rs index a40a7ba3476a26..abac8f5abff61f 100644 --- a/sdk/program/src/vote/state/mod.rs +++ b/sdk/program/src/vote/state/mod.rs @@ -20,7 +20,12 @@ use { }, bincode::{serialize_into, ErrorKind}, serde_derive::{Deserialize, Serialize}, - std::{collections::VecDeque, fmt::Debug, io::Cursor}, + std::{ + collections::VecDeque, + fmt::Debug, + io::Cursor, + mem::{self, MaybeUninit}, + }, }; mod vote_state_0_23_5; @@ -479,13 +484,81 @@ impl VoteState { } } - /// Deserializes the input `VoteStateVersions` buffer directly into a provided `VoteState` struct + /// Deserializes the input `VoteStateVersions` buffer directly into the provided `VoteState`. + /// + /// In a SBPF context, V0_23_5 is not supported, but in non-SBPF, all versions are supported for + /// compatibility with `bincode::deserialize`. /// - /// In a BPF context, V0_23_5 is not supported, but in non-BPF, all versions are supported for - /// compatibility with `bincode::deserialize` + /// On success, `vote_state` reflects the state of the input data. On failure, `vote_state` is + /// reset to `VoteState::default()`. pub fn deserialize_into( input: &[u8], vote_state: &mut VoteState, + ) -> Result<(), InstructionError> { + // Rebind vote_state to *mut VoteState so that the &mut binding isn't + // accessible anymore, preventing accidental use after this point. + // + // NOTE: switch to ptr::from_mut() once platform-tools moves to rustc >= 1.76 + let vote_state = vote_state as *mut VoteState; + + // Safety: vote_state is valid to_drop (see drop_in_place() docs). After + // dropping, the pointer is treated as uninitialized and only accessed + // through ptr::write, which is safe as per drop_in_place docs. + unsafe { + std::ptr::drop_in_place(vote_state); + } + + // This is to reset vote_state to VoteState::default() if deserialize fails or panics. + struct DropGuard { + vote_state: *mut VoteState, + } + + impl Drop for DropGuard { + fn drop(&mut self) { + // Safety: + // + // Deserialize failed or panicked so at this point vote_state is uninitialized. We + // must write a new _valid_ value into it or after returning (or unwinding) from + // this function the caller is left with an uninitialized `&mut VoteState`, which is + // UB (references must always be valid). + // + // This is always safe and doesn't leak memory because deserialize_into_ptr() writes + // into the fields that heap alloc only when it returns Ok(). + unsafe { + self.vote_state.write(VoteState::default()); + } + } + } + + let guard = DropGuard { vote_state }; + + let res = VoteState::deserialize_into_ptr(input, vote_state); + if res.is_ok() { + mem::forget(guard); + } + + res + } + + /// Deserializes the input `VoteStateVersions` buffer directly into the provided + /// `MaybeUninit`. + /// + /// In a SBPF context, V0_23_5 is not supported, but in non-SBPF, all versions are supported for + /// compatibility with `bincode::deserialize`. + /// + /// On success, `vote_state` is fully initialized and can be converted to `VoteState` using + /// [MaybeUninit::assume_init]. On failure, `vote_state` may still be uninitialized and must not + /// be converted to `VoteState`. + pub fn deserialize_into_uninit( + input: &[u8], + vote_state: &mut MaybeUninit, + ) -> Result<(), InstructionError> { + VoteState::deserialize_into_ptr(input, vote_state.as_mut_ptr()) + } + + fn deserialize_into_ptr( + input: &[u8], + vote_state: *mut VoteState, ) -> Result<(), InstructionError> { let mut cursor = Cursor::new(input); @@ -496,10 +569,18 @@ impl VoteState { 0 => { #[cfg(not(target_os = "solana"))] { - *vote_state = bincode::deserialize::(input) - .map(|versioned| versioned.convert_to_current()) - .map_err(|_| InstructionError::InvalidAccountData)?; - + // Safety: vote_state is valid as it comes from `&mut MaybeUninit` or + // `&mut VoteState`. In the first case, the value is uninitialized so we write() + // to avoid dropping invalid data; in the latter case, we `drop_in_place()` + // before writing so the value has already been dropped and we just write a new + // one in place. + unsafe { + vote_state.write( + bincode::deserialize::(input) + .map(|versioned| versioned.convert_to_current()) + .map_err(|_| InstructionError::InvalidAccountData)?, + ); + } Ok(()) } #[cfg(target_os = "solana")] @@ -1129,10 +1210,56 @@ mod tests { } #[test] - fn test_vote_deserialize_into_nopanic() { - // base case + fn test_vote_deserialize_into_error() { + let target_vote_state = VoteState::new_rand_for_tests(Pubkey::new_unique(), 42); + let mut vote_state_buf = + bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap(); + let len = vote_state_buf.len(); + vote_state_buf.truncate(len - 1); + let mut test_vote_state = VoteState::default(); - let e = VoteState::deserialize_into(&[], &mut test_vote_state).unwrap_err(); + VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap_err(); + assert_eq!(test_vote_state, VoteState::default()); + } + + #[test] + fn test_vote_deserialize_into_uninit() { + // base case + let target_vote_state = VoteState::default(); + let vote_state_buf = + bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap(); + + let mut test_vote_state = MaybeUninit::uninit(); + VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap(); + let test_vote_state = unsafe { test_vote_state.assume_init() }; + + assert_eq!(target_vote_state, test_vote_state); + + // variant + // provide 4x the minimum struct size in bytes to ensure we typically touch every field + let struct_bytes_x4 = std::mem::size_of::() * 4; + for _ in 0..1000 { + let raw_data: Vec = (0..struct_bytes_x4).map(|_| rand::random::()).collect(); + let mut unstructured = Unstructured::new(&raw_data); + + let target_vote_state_versions = + VoteStateVersions::arbitrary(&mut unstructured).unwrap(); + let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap(); + let target_vote_state = target_vote_state_versions.convert_to_current(); + + let mut test_vote_state = MaybeUninit::uninit(); + VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap(); + let test_vote_state = unsafe { test_vote_state.assume_init() }; + + assert_eq!(target_vote_state, test_vote_state); + } + } + + #[test] + fn test_vote_deserialize_into_uninit_nopanic() { + // base case + let mut test_vote_state = MaybeUninit::uninit(); + let e = VoteState::deserialize_into_uninit(&[], &mut test_vote_state).unwrap_err(); assert_eq!(e, InstructionError::InvalidAccountData); // variant @@ -1153,21 +1280,22 @@ mod tests { // it is extremely improbable, though theoretically possible, for random bytes to be syntactically valid // so we only check that the parser does not panic and that it succeeds or fails exactly in line with bincode - let mut test_vote_state = VoteState::default(); - let test_res = VoteState::deserialize_into(&raw_data, &mut test_vote_state); + let mut test_vote_state = MaybeUninit::uninit(); + let test_res = VoteState::deserialize_into_uninit(&raw_data, &mut test_vote_state); let bincode_res = bincode::deserialize::(&raw_data) .map(|versioned| versioned.convert_to_current()); if test_res.is_err() { assert!(bincode_res.is_err()); } else { + let test_vote_state = unsafe { test_vote_state.assume_init() }; assert_eq!(test_vote_state, bincode_res.unwrap()); } } } #[test] - fn test_vote_deserialize_into_ill_sized() { + fn test_vote_deserialize_into_uninit_ill_sized() { // provide 4x the minimum struct size in bytes to ensure we typically touch every field let struct_bytes_x4 = std::mem::size_of::() * 4; for _ in 0..1000 { @@ -1185,8 +1313,8 @@ mod tests { expanded_buf.resize(original_buf.len() + 8, 0); // truncated fails - let mut test_vote_state = VoteState::default(); - let test_res = VoteState::deserialize_into(&truncated_buf, &mut test_vote_state); + let mut test_vote_state = MaybeUninit::uninit(); + let test_res = VoteState::deserialize_into_uninit(&truncated_buf, &mut test_vote_state); let bincode_res = bincode::deserialize::(&truncated_buf) .map(|versioned| versioned.convert_to_current()); @@ -1194,11 +1322,12 @@ mod tests { assert!(bincode_res.is_err()); // expanded succeeds - let mut test_vote_state = VoteState::default(); - VoteState::deserialize_into(&expanded_buf, &mut test_vote_state).unwrap(); + let mut test_vote_state = MaybeUninit::uninit(); + VoteState::deserialize_into_uninit(&expanded_buf, &mut test_vote_state).unwrap(); let bincode_res = bincode::deserialize::(&expanded_buf) .map(|versioned| versioned.convert_to_current()); + let test_vote_state = unsafe { test_vote_state.assume_init() }; assert_eq!(test_vote_state, bincode_res.unwrap()); } } diff --git a/sdk/program/src/vote/state/vote_state_0_23_5.rs b/sdk/program/src/vote/state/vote_state_0_23_5.rs index eff88adca6dd75..efc3e89bcf0e37 100644 --- a/sdk/program/src/vote/state/vote_state_0_23_5.rs +++ b/sdk/program/src/vote/state/vote_state_0_23_5.rs @@ -75,8 +75,9 @@ mod tests { let target_vote_state_versions = VoteStateVersions::V0_23_5(Box::new(target_vote_state)); let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap(); - let mut test_vote_state = VoteState::default(); - VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap(); + let mut test_vote_state = MaybeUninit::uninit(); + VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap(); + let test_vote_state = unsafe { test_vote_state.assume_init() }; assert_eq!( target_vote_state_versions.convert_to_current(), @@ -97,8 +98,9 @@ mod tests { let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap(); let target_vote_state = target_vote_state_versions.convert_to_current(); - let mut test_vote_state = VoteState::default(); - VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap(); + let mut test_vote_state = MaybeUninit::uninit(); + VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap(); + let test_vote_state = unsafe { test_vote_state.assume_init() }; assert_eq!(target_vote_state, test_vote_state); } diff --git a/sdk/program/src/vote/state/vote_state_1_14_11.rs b/sdk/program/src/vote/state/vote_state_1_14_11.rs index 645e73dc353d3e..285272a3ab646f 100644 --- a/sdk/program/src/vote/state/vote_state_1_14_11.rs +++ b/sdk/program/src/vote/state/vote_state_1_14_11.rs @@ -94,8 +94,9 @@ mod tests { let target_vote_state_versions = VoteStateVersions::V1_14_11(Box::new(target_vote_state)); let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap(); - let mut test_vote_state = VoteState::default(); - VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap(); + let mut test_vote_state = MaybeUninit::uninit(); + VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap(); + let test_vote_state = unsafe { test_vote_state.assume_init() }; assert_eq!( target_vote_state_versions.convert_to_current(), @@ -116,8 +117,9 @@ mod tests { let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap(); let target_vote_state = target_vote_state_versions.convert_to_current(); - let mut test_vote_state = VoteState::default(); - VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap(); + let mut test_vote_state = MaybeUninit::uninit(); + VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap(); + let test_vote_state = unsafe { test_vote_state.assume_init() }; assert_eq!(target_vote_state, test_vote_state); } diff --git a/sdk/program/src/vote/state/vote_state_deserialize.rs b/sdk/program/src/vote/state/vote_state_deserialize.rs index 69fdf0636d9b57..268341513a72b8 100644 --- a/sdk/program/src/vote/state/vote_state_deserialize.rs +++ b/sdk/program/src/vote/state/vote_state_deserialize.rs @@ -1,36 +1,68 @@ use { + super::{MAX_EPOCH_CREDITS_HISTORY, MAX_LOCKOUT_HISTORY}, crate::{ + clock::Epoch, instruction::InstructionError, + pubkey::Pubkey, serialize_utils::cursor::*, - vote::state::{BlockTimestamp, LandedVote, Lockout, VoteState, MAX_ITEMS}, + vote::{ + authorized_voters::AuthorizedVoters, + state::{BlockTimestamp, LandedVote, Lockout, VoteState, MAX_ITEMS}, + }, }, - std::io::Cursor, + std::{collections::VecDeque, io::Cursor, ptr::addr_of_mut}, }; pub(super) fn deserialize_vote_state_into( cursor: &mut Cursor<&[u8]>, - vote_state: &mut VoteState, + vote_state: *mut VoteState, has_latency: bool, ) -> Result<(), InstructionError> { - vote_state.node_pubkey = read_pubkey(cursor)?; - vote_state.authorized_withdrawer = read_pubkey(cursor)?; - vote_state.commission = read_u8(cursor)?; - read_votes_into(cursor, vote_state, has_latency)?; - vote_state.root_slot = read_option_u64(cursor)?; - read_authorized_voters_into(cursor, vote_state)?; + // General safety note: we must use add_or_mut! to access the `vote_state` fields as the value + // is assumed to be _uninitialized_, so creating references to the state or any of its inner + // fields is UB. + + read_pubkey_into( + cursor, + // Safety: if vote_state is non-null, node_pubkey is guaranteed to be valid too + unsafe { addr_of_mut!((*vote_state).node_pubkey) }, + )?; + read_pubkey_into( + cursor, + // Safety: if vote_state is non-null, authorized_withdrawer is guaranteed to be valid too + unsafe { addr_of_mut!((*vote_state).authorized_withdrawer) }, + )?; + let commission = read_u8(cursor)?; + let votes = read_votes(cursor, has_latency)?; + let root_slot = read_option_u64(cursor)?; + let authorized_voters = read_authorized_voters(cursor)?; read_prior_voters_into(cursor, vote_state)?; - read_epoch_credits_into(cursor, vote_state)?; + let epoch_credits = read_epoch_credits(cursor)?; read_last_timestamp_into(cursor, vote_state)?; + // Safety: if vote_state is non-null, all the fields are guaranteed to be + // valid pointers. + // + // Heap allocated collections - votes, authorized_voters and epoch_credits - + // are guaranteed not to leak after this point as the VoteState is fully + // initialized and will be regularly dropped. + unsafe { + addr_of_mut!((*vote_state).commission).write(commission); + addr_of_mut!((*vote_state).votes).write(votes); + addr_of_mut!((*vote_state).root_slot).write(root_slot); + addr_of_mut!((*vote_state).authorized_voters).write(authorized_voters); + addr_of_mut!((*vote_state).epoch_credits).write(epoch_credits); + } + Ok(()) } -fn read_votes_into>( +fn read_votes>( cursor: &mut Cursor, - vote_state: &mut VoteState, has_latency: bool, -) -> Result<(), InstructionError> { - let vote_count = read_u64(cursor)?; +) -> Result, InstructionError> { + let vote_count = read_u64(cursor)? as usize; + let mut votes = VecDeque::with_capacity(vote_count.min(MAX_LOCKOUT_HISTORY)); for _ in 0..vote_count { let latency = if has_latency { read_u8(cursor)? } else { 0 }; @@ -39,73 +71,81 @@ fn read_votes_into>( let confirmation_count = read_u32(cursor)?; let lockout = Lockout::new_with_confirmation_count(slot, confirmation_count); - vote_state.votes.push_back(LandedVote { latency, lockout }); + votes.push_back(LandedVote { latency, lockout }); } - Ok(()) + Ok(votes) } -fn read_authorized_voters_into>( +fn read_authorized_voters>( cursor: &mut Cursor, - vote_state: &mut VoteState, -) -> Result<(), InstructionError> { +) -> Result { let authorized_voter_count = read_u64(cursor)?; + let mut authorized_voters = AuthorizedVoters::default(); for _ in 0..authorized_voter_count { let epoch = read_u64(cursor)?; let authorized_voter = read_pubkey(cursor)?; - - vote_state.authorized_voters.insert(epoch, authorized_voter); + authorized_voters.insert(epoch, authorized_voter); } - Ok(()) + Ok(authorized_voters) } fn read_prior_voters_into>( cursor: &mut Cursor, - vote_state: &mut VoteState, + vote_state: *mut VoteState, ) -> Result<(), InstructionError> { - for i in 0..MAX_ITEMS { - let prior_voter = read_pubkey(cursor)?; - let from_epoch = read_u64(cursor)?; - let until_epoch = read_u64(cursor)?; - - vote_state.prior_voters.buf[i] = (prior_voter, from_epoch, until_epoch); + // Safety: if vote_state is non-null, prior_voters is guaranteed to be valid too + unsafe { + let prior_voters = addr_of_mut!((*vote_state).prior_voters); + let prior_voters_buf = addr_of_mut!((*prior_voters).buf) as *mut (Pubkey, Epoch, Epoch); + + for i in 0..MAX_ITEMS { + let prior_voter = read_pubkey(cursor)?; + let from_epoch = read_u64(cursor)?; + let until_epoch = read_u64(cursor)?; + + prior_voters_buf + .add(i) + .write((prior_voter, from_epoch, until_epoch)); + } + + (*vote_state).prior_voters.idx = read_u64(cursor)? as usize; + (*vote_state).prior_voters.is_empty = read_bool(cursor)?; } - - vote_state.prior_voters.idx = read_u64(cursor)? as usize; - vote_state.prior_voters.is_empty = read_bool(cursor)?; - Ok(()) } -fn read_epoch_credits_into>( +fn read_epoch_credits>( cursor: &mut Cursor, - vote_state: &mut VoteState, -) -> Result<(), InstructionError> { - let epoch_credit_count = read_u64(cursor)?; +) -> Result, InstructionError> { + let epoch_credit_count = read_u64(cursor)? as usize; + let mut epoch_credits = Vec::with_capacity(epoch_credit_count.min(MAX_EPOCH_CREDITS_HISTORY)); for _ in 0..epoch_credit_count { let epoch = read_u64(cursor)?; let credits = read_u64(cursor)?; let prev_credits = read_u64(cursor)?; - - vote_state - .epoch_credits - .push((epoch, credits, prev_credits)); + epoch_credits.push((epoch, credits, prev_credits)); } - Ok(()) + Ok(epoch_credits) } fn read_last_timestamp_into>( cursor: &mut Cursor, - vote_state: &mut VoteState, + vote_state: *mut VoteState, ) -> Result<(), InstructionError> { let slot = read_u64(cursor)?; let timestamp = read_i64(cursor)?; - vote_state.last_timestamp = BlockTimestamp { slot, timestamp }; + let last_timestamp = BlockTimestamp { slot, timestamp }; + + // Safety: if vote_state is non-null, last_timestamp is guaranteed to be valid too + unsafe { + addr_of_mut!((*vote_state).last_timestamp).write(last_timestamp); + } Ok(()) }