Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK: Slot Hashes: Add queries using sol_get_sysvar #1622

Merged
merged 10 commits into from
Jun 12, 2024
Merged
29 changes: 29 additions & 0 deletions sdk/program/src/sysvar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,35 @@ macro_rules! impl_sysvar_get {
};
}

/// Handler for retrieving a slice of sysvar data from the `sol_get_sysvar`
/// syscall.
fn get_sysvar(
dst: &mut [u8],
sysvar_id: &Pubkey,
offset: u64,
length: u64,
) -> Result<(), ProgramError> {
// Check that the provided destination buffer is large enough to hold the
// requested data.
if dst.len() < length as usize {
return Err(ProgramError::InvalidArgument);
}

let sysvar_id = sysvar_id as *const _ as *const u8;
let var_addr = dst as *mut _ as *mut u8;

#[cfg(target_os = "solana")]
let result = unsafe { crate::syscalls::sol_get_sysvar(sysvar_id, var_addr, offset, length) };

#[cfg(not(target_os = "solana"))]
let result = crate::program_stubs::sol_get_sysvar(sysvar_id, var_addr, offset, length);

match result {
crate::entrypoint::SUCCESS => Ok(()),
e => Err(e.into()),
}
}

#[cfg(test)]
mod tests {
use {
Expand Down
163 changes: 161 additions & 2 deletions sdk/program/src/sysvar/slot_hashes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,17 @@
//! ```

pub use crate::slot_hashes::SlotHashes;
use crate::{account_info::AccountInfo, program_error::ProgramError, sysvar::Sysvar};
use {
crate::{
account_info::AccountInfo,
clock::Slot,
hash::Hash,
program_error::ProgramError,
slot_hashes::MAX_ENTRIES,
sysvar::{get_sysvar, Sysvar, SysvarId},
},
bytemuck::{Pod, Zeroable},
};

crate::declare_sysvar_id!("SysvarS1otHashes111111111111111111111111111", SlotHashes);

Expand All @@ -62,13 +72,109 @@ impl Sysvar for SlotHashes {
}
}

#[derive(Copy, Clone, Default, Pod, Zeroable)]
#[repr(C)]
struct PodSlotHash {
slot: Slot,
hash: Hash,
}

/// API for querying the `SlotHashes` sysvar.
pub struct SlotHashesSysvar;

impl SlotHashesSysvar {
/// Get a value from the sysvar entries by its key.
/// Returns `None` if the key is not found.
pub fn get(slot: &Slot) -> Result<Option<Hash>, ProgramError> {
get_pod_slot_hashes().map(|pod_hashes| {
pod_hashes
.binary_search_by(|PodSlotHash { slot: this, .. }| slot.cmp(this))
.map(|idx| pod_hashes[idx].hash)
.ok()
})
}

/// Get the position of an entry in the sysvar by its key.
/// Returns `None` if the key is not found.
pub fn position(slot: &Slot) -> Result<Option<usize>, ProgramError> {
get_pod_slot_hashes().map(|pod_hashes| {
pod_hashes
.binary_search_by(|PodSlotHash { slot: this, .. }| slot.cmp(this))
.ok()
})
}
}

fn get_pod_slot_hashes() -> Result<Vec<PodSlotHash>, ProgramError> {
let mut pod_hashes = vec![PodSlotHash::default(); MAX_ENTRIES];
{
let data = bytemuck::try_cast_slice_mut::<PodSlotHash, u8>(&mut pod_hashes)
.map_err(|_| ProgramError::InvalidAccountData)?;

// Ensure the created buffer is aligned to 8.
if data.as_ptr().align_offset(8) != 0 {
return Err(ProgramError::InvalidAccountData);
}

let offset = 8; // Vector length as `u64`.
let length = (SlotHashes::size_of() as u64).saturating_sub(offset);
get_sysvar(data, &SlotHashes::id(), offset, length)?;
}
Ok(pod_hashes)
}

#[cfg(test)]
mod tests {
use {
super::*,
crate::{clock::Slot, hash::Hash, slot_hashes::MAX_ENTRIES},
crate::{
clock::Slot,
entrypoint::SUCCESS,
hash::{hash, Hash},
program_stubs::{set_syscall_stubs, SyscallStubs},
slot_hashes::{SlotHash, MAX_ENTRIES},
},
};

struct MockSlotHashesSyscall {
slot_hashes: SlotHashes,
}

impl SyscallStubs for MockSlotHashesSyscall {
#[allow(clippy::arithmetic_side_effects)]
fn sol_get_sysvar(
&self,
_sysvar_id_addr: *const u8,
var_addr: *mut u8,
offset: u64,
length: u64,
) -> u64 {
// The syscall tests for `sol_get_sysvar` should ensure the following:
//
// - The provided `sysvar_id_addr` can be translated into a valid
// sysvar ID for a sysvar contained in the sysvar cache, of which
// `SlotHashes` is one.
// - Length and memory checks on `offset` and `length`.
//
// Therefore this mockup can simply just unsafely use the provided
// `offset` and `length` to copy the serialized `SlotHashes` into
// the provided `var_addr`.
let data = bincode::serialize(&self.slot_hashes).unwrap();
let slice = unsafe { std::slice::from_raw_parts_mut(var_addr, length as usize) };
slice.copy_from_slice(&data[offset as usize..(offset + length) as usize]);
SUCCESS
}
}

fn mock_get_sysvar_syscall(slot_hashes: &[SlotHash]) {
static ONCE: std::sync::Once = std::sync::Once::new();
ONCE.call_once(|| {
set_syscall_stubs(Box::new(MockSlotHashesSyscall {
slot_hashes: SlotHashes::new(slot_hashes),
}));
});
}

#[test]
fn test_size_of() {
assert_eq!(
Expand All @@ -81,4 +187,57 @@ mod tests {
.unwrap() as usize
);
}

#[test]
fn test_slot_hashes_sysvar() {
let mut slot_hashes = vec![];
for i in 0..MAX_ENTRIES {
slot_hashes.push((
i as u64,
hash(&[(i >> 24) as u8, (i >> 16) as u8, (i >> 8) as u8, i as u8]),
));
}

mock_get_sysvar_syscall(&slot_hashes);

let check_slot_hashes = SlotHashes::new(&slot_hashes);

// `get`:
assert_eq!(
SlotHashesSysvar::get(&0).unwrap().as_ref(),
check_slot_hashes.get(&0),
);
assert_eq!(
SlotHashesSysvar::get(&256).unwrap().as_ref(),
check_slot_hashes.get(&256),
);
assert_eq!(
SlotHashesSysvar::get(&511).unwrap().as_ref(),
check_slot_hashes.get(&511),
);
// `None`.
assert_eq!(
SlotHashesSysvar::get(&600).unwrap().as_ref(),
check_slot_hashes.get(&600),
);

// `position`:
assert_eq!(
SlotHashesSysvar::position(&0).unwrap(),
check_slot_hashes.position(&0),
);
assert_eq!(
SlotHashesSysvar::position(&256).unwrap(),
check_slot_hashes.position(&256),
);
assert_eq!(
SlotHashesSysvar::position(&511).unwrap(),
check_slot_hashes.position(&511),
);
// `None`.
assert_eq!(
SlotHashesSysvar::position(&600).unwrap(),
check_slot_hashes.position(&600),
);
}
}
Loading