diff --git a/crates/virtio-queue/src/mock.rs b/crates/virtio-queue/src/mock.rs index 4e130142..baf5f91a 100644 --- a/crates/virtio-queue/src/mock.rs +++ b/crates/virtio-queue/src/mock.rs @@ -4,14 +4,18 @@ //! Utilities used by unit tests and benchmarks for mocking the driver side //! of the virtio protocol. +#![allow(missing_docs)] + +use std::cmp::min; use std::marker::PhantomData; use std::mem::size_of; +use vm_memory::guest_memory::Error as GuestMemError; use vm_memory::{ Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemory, GuestUsize, }; -use crate::defs::{VIRTQ_DESC_F_INDIRECT, VIRTQ_DESC_F_NEXT}; +use crate::defs::{VIRTQ_DESC_F_INDIRECT, VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; use crate::{Descriptor, Queue}; /// Wrapper struct used for accesing a particular address of a GuestMemory area. @@ -179,6 +183,77 @@ impl<'a, M: GuestMemory> DescriptorTable<'a, M> { (self.len as usize * size_of::()) as u64 } + /// Takes a vector of MockDescriptors, converts them to real descriptors by adding + /// all the missing fields and stores them in the descriptor table. + /// + /// # Example: + /// ``` + /// use vm_memory::{GuestAddress, GuestMemoryMmap}; + /// use virtio_queue::mock::{DescriptorTable, MockDescriptor, MockDescriptorChain}; + /// + /// // This creates the descriptor chain: [2, 4, 10, 11]. + /// let v = vec![MockDescriptor::new().with_index(2).with_addr(0x1000).with_len(0x1000), + /// MockDescriptor::new().with_index(4).with_len(0x100), + /// MockDescriptor::new().with_index(10).with_addr(0x3000).with_len(0x1000).indirect(), + /// MockDescriptor::new().with_len(0x100).writeable()]; + /// + /// let direct_chain = MockDescriptorChain::new(v); + /// + /// let m = &GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(); + /// let dt = DescriptorTable::new(m, GuestAddress(0x1000), 11); + /// dt.add_chain(direct_chain); + /// ``` + pub fn add_chain(&self, mdc: MockDescriptorChain) -> u16 { + let mut prev = MockDescriptor::new(); + for (i, md) in mdc.chain.iter().enumerate() { + let addr = md.addr().unwrap_or_else(|| match i { + 0 => 0, + _ => prev.addr().unwrap() + prev.len().unwrap() as u64, + }); + + let len = md.len().unwrap_or(0x1000); + + let mut flags: u16; + if i == (mdc.len() - 1) as usize { + flags = 0; + } else { + flags = VIRTQ_DESC_F_NEXT; + } + + if md.is_indirect() { + flags |= VIRTQ_DESC_F_INDIRECT; + } + + if md.is_writeable() { + flags |= VIRTQ_DESC_F_WRITE; + } + + let index = md.index().unwrap_or_else(|| match i { + 0 => 0, + _ => prev.index().unwrap(), + }); + + let next = if i == (mdc.len() - 1) as usize { + 0 + } else { + mdc.chain[i + 1].index().unwrap_or(index + 1) + }; + + let desc = Descriptor::new(addr, len, flags, next); + self.store(index, desc); + + prev = MockDescriptor { + addr: Some(addr), + index: Some(index), + len: Some(len), + ..*md + }; + } + + // Returns the index of the first descriptor + mdc.chain[0].index().unwrap_or(0) + } + /// Create a chain of descriptors pub fn build_chain(&mut self, len: u16) -> u16 { let indices = self @@ -377,3 +452,148 @@ impl<'a, M: GuestMemory> MockSplitQueue<'a, M> { q } } + +pub struct MockDescriptorChain { + pub chain: Vec, +} + +impl MockDescriptorChain { + pub fn new(chain: Vec) -> Self { + Self { chain } + } + + pub fn with_len(len: u16) -> Self { + let mut chain = Vec::with_capacity(len as usize); + for i in 0..len { + let md = MockDescriptor::new() + .with_index(i) + .with_addr((0x1000 * i) as u64) + .with_len(0x1000); + + chain.push(md); + } + + MockDescriptorChain::new(chain) + } + + pub fn len(&self) -> u16 { + self.chain.len() as u16 + } + + pub fn write_slice(&self, buffer: &[u8], mem: &M) -> Result<(), GuestMemError> { + let mut prev = MockDescriptor::new(); + let (mut start, mut end) = (0, 0); + let mut to_write = buffer.len(); + + for (index, md) in self.chain.iter().enumerate() { + if to_write == 0 { + return Ok(()); + } + + let addr = md.addr().unwrap_or_else(|| match index { + 0 => 0, + _ => prev.addr().unwrap() + prev.len().unwrap() as u64, + }); + let len = md.len().unwrap_or(0x1000); + + start = match index { + 0 => addr as usize, + _ => end, + }; + end = min(start + len as usize, start + to_write); + + mem.write(&buffer[start..end], GuestAddress(addr))?; + to_write -= end - start; + + prev = MockDescriptor { + addr: Some(addr), + len: Some(len), + ..*md + }; + } + + Ok(()) + } +} + +#[derive(Clone, Copy)] +pub struct MockDescriptor { + index: Option, + addr: Option, + len: Option, + writeable: bool, + indirect: bool, +} + +impl MockDescriptor { + pub fn new() -> Self { + Self { + index: None, + addr: None, + len: None, + writeable: false, + indirect: false, + } + } + + pub fn with_index(self, index: u16) -> Self { + Self { + index: Some(index), + ..self + } + } + + pub fn with_addr(self, addr: u64) -> Self { + Self { + addr: Some(addr), + ..self + } + } + + pub fn with_len(self, len: u32) -> Self { + Self { + len: Some(len), + ..self + } + } + + pub fn writeable(self) -> Self { + Self { + writeable: true, + ..self + } + } + + pub fn indirect(self) -> Self { + Self { + indirect: true, + ..self + } + } + + pub fn index(&self) -> Option { + self.index + } + + pub fn addr(&self) -> Option { + self.addr + } + + pub fn len(&self) -> Option { + self.len + } + + pub fn is_writeable(&self) -> bool { + self.writeable + } + + pub fn is_indirect(&self) -> bool { + self.indirect + } +} + +impl Default for MockDescriptor { + fn default() -> Self { + Self::new() + } +}