Skip to content

Commit

Permalink
chore: Adjust Snapshot2 API for better usage (nervosnetwork#412)
Browse files Browse the repository at this point in the history
Co-authored-by: Xuejie Xiao <[email protected]>
  • Loading branch information
mohanson and xxuejie committed Mar 14, 2024
1 parent 460509e commit 36f28af
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
23 changes: 15 additions & 8 deletions src/snapshot2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ const PAGE_SIZE: u64 = RISCV_PAGESIZE as u64;
/// we can leverage DataSource for snapshot optimizations: data that is already
/// locatable in the DataSource will not need to be included in the snapshot
/// again, all we need is an id to locate it, together with a pair of
/// offset / length to cut in to the correct slices.
/// offset / length to cut in to the correct slices. Just like CKB's syscall design,
/// an extra u64 value is included here to return the remaining full length of data
/// starting from offset, without considering `length` parameter
pub trait DataSource<I: Clone + PartialEq> {
fn load_data(&self, id: &I, offset: u64, length: u64) -> Result<Bytes, Error>;
fn load_data(&self, id: &I, offset: u64, length: u64) -> Result<(Bytes, u64), Error>;
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -68,7 +70,7 @@ impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
if address % PAGE_SIZE != 0 {
return Err(Error::MemPageUnalignedAccess);
}
let data = self.data_source().load_data(id, *offset, *length)?;
let (data, _) = self.data_source().load_data(id, *offset, *length)?;
if data.len() as u64 % PAGE_SIZE != 0 {
return Err(Error::MemPageUnalignedAccess);
}
Expand Down Expand Up @@ -107,18 +109,21 @@ impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
}

/// Similar to Memory::store_bytes, but this method also tracks memory
/// pages whose entire content comes from DataSource
/// pages whose entire content comes from DataSource. It returns 2 values:
/// the actual written bytes, and the full length of data starting from offset,
/// but ignoring `length` parameter.
pub fn store_bytes<M: SupportMachine>(
&mut self,
machine: &mut M,
addr: u64,
id: &I,
offset: u64,
length: u64,
) -> Result<(), Error> {
let data = self.data_source.load_data(id, offset, length)?;
) -> Result<(u64, u64), Error> {
let (data, full_length) = self.data_source.load_data(id, offset, length)?;
machine.memory_mut().store_bytes(addr, &data)?;
self.track_pages(machine, addr, data.len() as u64, id, offset)
self.track_pages(machine, addr, data.len() as u64, id, offset)?;
Ok((data.len() as u64, full_length))
}

/// Due to the design of ckb-vm right now, load_program function does not
Expand Down Expand Up @@ -222,7 +227,9 @@ impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
self.track_pages(machine, start, length, id, offset + action.source.start)
}

fn track_pages<M: SupportMachine>(
/// This is only made public for advanced usages, but make sure to exercise more
/// cautions when calling it!
pub fn track_pages<M: SupportMachine>(
&mut self,
machine: &mut M,
start: u64,
Expand Down
22 changes: 12 additions & 10 deletions tests/test_resume2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,17 @@ const DATA_ID: u64 = 0x2000;
struct TestSource(HashMap<u64, Bytes>);

impl DataSource<u64> for TestSource {
fn load_data(&self, id: &u64, offset: u64, length: u64) -> Result<Bytes, Error> {
fn load_data(&self, id: &u64, offset: u64, length: u64) -> Result<(Bytes, u64), Error> {
match self.0.get(id) {
Some(data) => Ok(data.slice(
offset as usize..(if length > 0 {
(offset + length) as usize
Some(data) => {
let end = if length > 0 {
offset + length
} else {
data.len()
}),
)),
data.len() as u64
};
let full_length = end - offset;
Ok((data.slice(offset as usize..end as usize), full_length))
}
None => Err(Error::Unexpected(format!(
"Id {} is missing in source!",
id
Expand Down Expand Up @@ -443,7 +445,7 @@ impl Machine {
use Machine::*;
match self {
Asm(inner, context) => {
let program = context
let (program, _) = context
.lock()
.unwrap()
.data_source()
Expand All @@ -460,7 +462,7 @@ impl Machine {
Ok(bytes)
}
Interpreter(inner, context) => {
let program = context
let (program, _) = context
.lock()
.unwrap()
.data_source()
Expand All @@ -477,7 +479,7 @@ impl Machine {
Ok(bytes)
}
InterpreterWithTrace(inner, context) => {
let program = context
let (program, _) = context
.lock()
.unwrap()
.data_source()
Expand Down

0 comments on commit 36f28af

Please sign in to comment.