Skip to content

Commit

Permalink
Introduce VoteState::deserialize_into_uninit (#2272)
Browse files Browse the repository at this point in the history
* VoteState::deserialize_into: take &mut MaybeUninit<VoteState>

Deserializing into MaybeUninit<VoteState> saves the extra cost of
initializing into a value initialized with VoteState::default()
  • Loading branch information
alessandrod authored Aug 2, 2024
1 parent c7192d3 commit e32e126
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 72 deletions.
31 changes: 29 additions & 2 deletions sdk/program/src/serialize_utils/cursor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use {
crate::{instruction::InstructionError, pubkey::Pubkey},
std::io::{Cursor, Read},
crate::{
instruction::InstructionError,
pubkey::{Pubkey, PUBKEY_BYTES},
},
std::{
io::{BufRead as _, Cursor, Read},
ptr,
},
};

pub(crate) fn read_u8<T: AsRef<[u8]>>(cursor: &mut Cursor<T>) -> Result<u8, InstructionError> {
Expand Down Expand Up @@ -50,6 +56,27 @@ pub(crate) fn read_i64<T: AsRef<[u8]>>(cursor: &mut Cursor<T>) -> Result<i64, In
Ok(i64::from_le_bytes(buf))
}

pub(crate) fn read_pubkey_into(
cursor: &mut Cursor<&[u8]>,
pubkey: *mut Pubkey,
) -> Result<(), InstructionError> {
match cursor.fill_buf() {
Ok(buf) if buf.len() >= PUBKEY_BYTES => {
// Safety: `buf` is guaranteed to be at least `PUBKEY_BYTES` bytes
// long. Pubkey a #[repr(transparent)] wrapper around a byte array,
// so this is a byte to byte copy and it's safe.
unsafe {
ptr::copy_nonoverlapping(buf.as_ptr(), pubkey as *mut u8, PUBKEY_BYTES);
}

cursor.consume(PUBKEY_BYTES);
}
_ => return Err(InstructionError::InvalidAccountData),
}

Ok(())
}

pub(crate) fn read_pubkey<T: AsRef<[u8]>>(
cursor: &mut Cursor<T>,
) -> Result<Pubkey, InstructionError> {
Expand Down
165 changes: 147 additions & 18 deletions sdk/program/src/vote/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ use {
},
bincode::{serialize_into, ErrorKind},
serde_derive::{Deserialize, Serialize},
std::{collections::VecDeque, fmt::Debug, io::Cursor},
std::{
collections::VecDeque,
fmt::Debug,
io::Cursor,
mem::{self, MaybeUninit},
},
};

mod vote_state_0_23_5;
Expand Down Expand Up @@ -479,13 +484,81 @@ impl VoteState {
}
}

/// Deserializes the input `VoteStateVersions` buffer directly into a provided `VoteState` struct
/// Deserializes the input `VoteStateVersions` buffer directly into the provided `VoteState`.
///
/// In a SBPF context, V0_23_5 is not supported, but in non-SBPF, all versions are supported for
/// compatibility with `bincode::deserialize`.
///
/// In a BPF context, V0_23_5 is not supported, but in non-BPF, all versions are supported for
/// compatibility with `bincode::deserialize`
/// On success, `vote_state` reflects the state of the input data. On failure, `vote_state` is
/// reset to `VoteState::default()`.
pub fn deserialize_into(
input: &[u8],
vote_state: &mut VoteState,
) -> Result<(), InstructionError> {
// Rebind vote_state to *mut VoteState so that the &mut binding isn't
// accessible anymore, preventing accidental use after this point.
//
// NOTE: switch to ptr::from_mut() once platform-tools moves to rustc >= 1.76
let vote_state = vote_state as *mut VoteState;

// Safety: vote_state is valid to_drop (see drop_in_place() docs). After
// dropping, the pointer is treated as uninitialized and only accessed
// through ptr::write, which is safe as per drop_in_place docs.
unsafe {
std::ptr::drop_in_place(vote_state);
}

// This is to reset vote_state to VoteState::default() if deserialize fails or panics.
struct DropGuard {
vote_state: *mut VoteState,
}

impl Drop for DropGuard {
fn drop(&mut self) {
// Safety:
//
// Deserialize failed or panicked so at this point vote_state is uninitialized. We
// must write a new _valid_ value into it or after returning (or unwinding) from
// this function the caller is left with an uninitialized `&mut VoteState`, which is
// UB (references must always be valid).
//
// This is always safe and doesn't leak memory because deserialize_into_ptr() writes
// into the fields that heap alloc only when it returns Ok().
unsafe {
self.vote_state.write(VoteState::default());
}
}
}

let guard = DropGuard { vote_state };

let res = VoteState::deserialize_into_ptr(input, vote_state);
if res.is_ok() {
mem::forget(guard);
}

res
}

/// Deserializes the input `VoteStateVersions` buffer directly into the provided
/// `MaybeUninit<VoteState>`.
///
/// In a SBPF context, V0_23_5 is not supported, but in non-SBPF, all versions are supported for
/// compatibility with `bincode::deserialize`.
///
/// On success, `vote_state` is fully initialized and can be converted to `VoteState` using
/// [MaybeUninit::assume_init]. On failure, `vote_state` may still be uninitialized and must not
/// be converted to `VoteState`.
pub fn deserialize_into_uninit(
input: &[u8],
vote_state: &mut MaybeUninit<VoteState>,
) -> Result<(), InstructionError> {
VoteState::deserialize_into_ptr(input, vote_state.as_mut_ptr())
}

fn deserialize_into_ptr(
input: &[u8],
vote_state: *mut VoteState,
) -> Result<(), InstructionError> {
let mut cursor = Cursor::new(input);

Expand All @@ -496,10 +569,18 @@ impl VoteState {
0 => {
#[cfg(not(target_os = "solana"))]
{
*vote_state = bincode::deserialize::<VoteStateVersions>(input)
.map(|versioned| versioned.convert_to_current())
.map_err(|_| InstructionError::InvalidAccountData)?;

// Safety: vote_state is valid as it comes from `&mut MaybeUninit<VoteState>` or
// `&mut VoteState`. In the first case, the value is uninitialized so we write()
// to avoid dropping invalid data; in the latter case, we `drop_in_place()`
// before writing so the value has already been dropped and we just write a new
// one in place.
unsafe {
vote_state.write(
bincode::deserialize::<VoteStateVersions>(input)
.map(|versioned| versioned.convert_to_current())
.map_err(|_| InstructionError::InvalidAccountData)?,
);
}
Ok(())
}
#[cfg(target_os = "solana")]
Expand Down Expand Up @@ -1129,10 +1210,56 @@ mod tests {
}

#[test]
fn test_vote_deserialize_into_nopanic() {
// base case
fn test_vote_deserialize_into_error() {
let target_vote_state = VoteState::new_rand_for_tests(Pubkey::new_unique(), 42);
let mut vote_state_buf =
bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap();
let len = vote_state_buf.len();
vote_state_buf.truncate(len - 1);

let mut test_vote_state = VoteState::default();
let e = VoteState::deserialize_into(&[], &mut test_vote_state).unwrap_err();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap_err();
assert_eq!(test_vote_state, VoteState::default());
}

#[test]
fn test_vote_deserialize_into_uninit() {
// base case
let target_vote_state = VoteState::default();
let vote_state_buf =
bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap();

let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(target_vote_state, test_vote_state);

// variant
// provide 4x the minimum struct size in bytes to ensure we typically touch every field
let struct_bytes_x4 = std::mem::size_of::<VoteState>() * 4;
for _ in 0..1000 {
let raw_data: Vec<u8> = (0..struct_bytes_x4).map(|_| rand::random::<u8>()).collect();
let mut unstructured = Unstructured::new(&raw_data);

let target_vote_state_versions =
VoteStateVersions::arbitrary(&mut unstructured).unwrap();
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(target_vote_state, test_vote_state);
}
}

#[test]
fn test_vote_deserialize_into_uninit_nopanic() {
// base case
let mut test_vote_state = MaybeUninit::uninit();
let e = VoteState::deserialize_into_uninit(&[], &mut test_vote_state).unwrap_err();
assert_eq!(e, InstructionError::InvalidAccountData);

// variant
Expand All @@ -1153,21 +1280,22 @@ mod tests {

// it is extremely improbable, though theoretically possible, for random bytes to be syntactically valid
// so we only check that the parser does not panic and that it succeeds or fails exactly in line with bincode
let mut test_vote_state = VoteState::default();
let test_res = VoteState::deserialize_into(&raw_data, &mut test_vote_state);
let mut test_vote_state = MaybeUninit::uninit();
let test_res = VoteState::deserialize_into_uninit(&raw_data, &mut test_vote_state);
let bincode_res = bincode::deserialize::<VoteStateVersions>(&raw_data)
.map(|versioned| versioned.convert_to_current());

if test_res.is_err() {
assert!(bincode_res.is_err());
} else {
let test_vote_state = unsafe { test_vote_state.assume_init() };
assert_eq!(test_vote_state, bincode_res.unwrap());
}
}
}

#[test]
fn test_vote_deserialize_into_ill_sized() {
fn test_vote_deserialize_into_uninit_ill_sized() {
// provide 4x the minimum struct size in bytes to ensure we typically touch every field
let struct_bytes_x4 = std::mem::size_of::<VoteState>() * 4;
for _ in 0..1000 {
Expand All @@ -1185,20 +1313,21 @@ mod tests {
expanded_buf.resize(original_buf.len() + 8, 0);

// truncated fails
let mut test_vote_state = VoteState::default();
let test_res = VoteState::deserialize_into(&truncated_buf, &mut test_vote_state);
let mut test_vote_state = MaybeUninit::uninit();
let test_res = VoteState::deserialize_into_uninit(&truncated_buf, &mut test_vote_state);
let bincode_res = bincode::deserialize::<VoteStateVersions>(&truncated_buf)
.map(|versioned| versioned.convert_to_current());

assert!(test_res.is_err());
assert!(bincode_res.is_err());

// expanded succeeds
let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&expanded_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&expanded_buf, &mut test_vote_state).unwrap();
let bincode_res = bincode::deserialize::<VoteStateVersions>(&expanded_buf)
.map(|versioned| versioned.convert_to_current());

let test_vote_state = unsafe { test_vote_state.assume_init() };
assert_eq!(test_vote_state, bincode_res.unwrap());
}
}
Expand Down
10 changes: 6 additions & 4 deletions sdk/program/src/vote/state/vote_state_0_23_5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ mod tests {
let target_vote_state_versions = VoteStateVersions::V0_23_5(Box::new(target_vote_state));
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(
target_vote_state_versions.convert_to_current(),
Expand All @@ -97,8 +98,9 @@ mod tests {
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(target_vote_state, test_vote_state);
}
Expand Down
10 changes: 6 additions & 4 deletions sdk/program/src/vote/state/vote_state_1_14_11.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ mod tests {
let target_vote_state_versions = VoteStateVersions::V1_14_11(Box::new(target_vote_state));
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(
target_vote_state_versions.convert_to_current(),
Expand All @@ -116,8 +117,9 @@ mod tests {
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(target_vote_state, test_vote_state);
}
Expand Down
Loading

0 comments on commit e32e126

Please sign in to comment.