From e5c5db184b3826b42f48dd96a9f3b18e4f4289c7 Mon Sep 17 00:00:00 2001 From: Matt Kurjanowicz Date: Thu, 9 Jan 2025 10:04:52 -0800 Subject: [PATCH] vmbus_async: don't panic on empty external data (#621) Don't `panic` when a GpaDirect packet contains a pointer to empty GPA ranges. This data is guest controlled and not a programmer error, so `panic` is not correct. Consumers of `read_external_ranges` already handle errors gracefully, so this is not a significant change for them. As a drive-by code improvement, also change some cases of `#[from]` to `#[source]` in error definitions, and add the requisite `.map_err` changes. --- vm/devices/net/netvsp/src/lib.rs | 90 +++++++++++------------ vm/devices/net/netvsp/src/test.rs | 7 +- vm/devices/storage/storvsp/src/lib.rs | 9 +-- vm/devices/vmbus/vmbus_async/src/queue.rs | 83 +++++++++++++++++++-- 4 files changed, 127 insertions(+), 62 deletions(-) diff --git a/vm/devices/net/netvsp/src/lib.rs b/vm/devices/net/netvsp/src/lib.rs index 249baf2166..90a2e2952e 100644 --- a/vm/devices/net/netvsp/src/lib.rs +++ b/vm/devices/net/netvsp/src/lib.rs @@ -76,6 +76,7 @@ use task_control::StopTask; use task_control::TaskControl; use thiserror::Error; use vmbus_async::queue; +use vmbus_async::queue::ExternalDataError; use vmbus_async::queue::IncomingPacket; use vmbus_async::queue::Queue; use vmbus_channel::bus::OfferParams; @@ -1311,8 +1312,9 @@ impl Nic { .clone(); driver.retarget_vp(open_request.open_data.target_vp); - let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)?; - let mut queue = Queue::new(raw)?; + let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx) + .map_err(OpenError::Ring)?; + let mut queue = Queue::new(raw).map_err(OpenError::Queue)?; let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f()); let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id); let worker = Worker { @@ -1395,7 +1397,7 @@ enum NetRestoreError { #[error("send/receive buffer invalid gpadl ID")] UnknownGpadlId(#[from] UnknownGpadlId), #[error("failed to restore channels")] - Channel(#[from] ChannelRestoreError), + Channel(#[source] ChannelRestoreError), #[error(transparent)] ReceiveBuffer(#[from] BufferError), #[error(transparent)] @@ -1436,7 +1438,10 @@ impl Nic { // state (vmbus believes the channel is open/active). There // are a number of failure paths after this point because this // call also restores vmbus device state, like the GPADL map. - let requests = control.restore(&open).await?; + let requests = control + .restore(&open) + .await + .map_err(NetRestoreError::Channel)?; match state.primary { saved_state::Primary::Version => { @@ -1586,7 +1591,10 @@ impl Nic { } } } else { - control.restore(&[false]).await?; + control + .restore(&[false]) + .await + .map_err(NetRestoreError::Channel)?; } Ok(()) } @@ -1789,11 +1797,9 @@ impl Nic { #[derive(Debug, Error)] enum WorkerError { #[error("packet error")] - Packet(#[from] PacketError), + Packet(#[source] PacketError), #[error("unexpected packet order: {0}")] - UnexpectedPacketOrder(#[from] PacketOrderError), - #[error("invalid gpadl id")] - InvalidGpadlId(#[from] UnknownGpadlId), + UnexpectedPacketOrder(#[source] PacketOrderError), #[error("unknown rndis message type: {0}")] UnknownRndisMessageType(u32), #[error("memory access error")] @@ -1845,9 +1851,9 @@ impl From for WorkerError { #[derive(Debug, Error)] enum OpenError { #[error("error establishing ring buffer")] - Ring(#[from] vmbus_channel::gpadl_ring::Error), + Ring(#[source] vmbus_channel::gpadl_ring::Error), #[error("error establishing vmbus queue")] - Queue(#[from] queue::Error), + Queue(#[source] queue::Error), } #[derive(Debug, Error)] @@ -1855,7 +1861,9 @@ enum PacketError { #[error("UnknownType {0}")] UnknownType(u32), #[error("Access")] - Access(#[from] AccessError), + Access(#[source] AccessError), + #[error("ExternalData")] + ExternalData(#[source] ExternalDataError), #[error("InvalidSendBufferIndex")] InvalidSendBufferIndex, } @@ -2011,7 +2019,9 @@ fn parse_packet<'a, T: RingMem>( Ok(Packet { data, transaction_id: packet.transaction_id(), - external_data: packet.read_external_ranges().map_err(PacketError::Access)?, + external_data: packet + .read_external_ranges() + .map_err(PacketError::ExternalData)?, send_buffer_suballocation, }) } @@ -2224,7 +2234,7 @@ impl NetChannel { control.control_messages_len += reader.len(); control.control_messages.push_back(ControlMessage { message_type, - data: reader.read_all().map_err(WorkerError::Access)?.into(), + data: reader.read_all()?.into(), }); false @@ -2251,10 +2261,7 @@ impl NetChannel { .find(|r| !r.is_empty()) .ok_or(WorkerError::RndisMessageTooSmall)?; let mut data = reader.into_inner(); - let request: rndisprot::Packet = headers - .reader(mem) - .read_plain() - .map_err(WorkerError::Access)?; + let request: rndisprot::Packet = headers.reader(mem).read_plain()?; if request.num_oob_data_elements != 0 || request.oob_data_length != 0 || request.oob_data_offset != 0 @@ -2287,8 +2294,7 @@ impl NetChannel { ) .ok_or(WorkerError::RndisMessageTooSmall)?; while !ppi.is_empty() { - let h: rndisprot::PerPacketInfo = - ppi.reader(mem).read_plain().map_err(WorkerError::Access)?; + let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?; if h.size == 0 { return Err(WorkerError::RndisMessageTooSmall); } @@ -2300,8 +2306,7 @@ impl NetChannel { .ok_or(WorkerError::RndisMessageTooSmall)?; match h.typ { rndisprot::PPI_TCP_IP_CHECKSUM => { - let n: rndisprot::TxTcpIpChecksumInfo = - d.reader(mem).read_plain().map_err(WorkerError::Access)?; + let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?; metadata.offload_tcp_checksum = (n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum(); @@ -2332,8 +2337,7 @@ impl NetChannel { } } rndisprot::PPI_LSO => { - let n: rndisprot::TcpLsoInfo = - d.reader(mem).read_plain().map_err(WorkerError::Access)?; + let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?; metadata.offload_tcp_segmentation = true; metadata.offload_tcp_checksum = true; @@ -2644,8 +2648,7 @@ impl NetChannel { return Err(WorkerError::InvalidRndisState); } - let request: rndisprot::InitializeRequest = - reader.read_plain().map_err(WorkerError::Access)?; + let request: rndisprot::InitializeRequest = reader.read_plain()?; tracing::trace!( ?request, @@ -2672,8 +2675,7 @@ impl NetChannel { af_list_offset: 0, af_list_size: 0, }, - ) - .map_err(WorkerError::Access)?; + )?; self.send_rndis_control_message(buffers, id, message_length)?; if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state { if self.guest_vf_is_available( @@ -2703,8 +2705,7 @@ impl NetChannel { } } rndisprot::MESSAGE_TYPE_QUERY_MSG => { - let request: rndisprot::QueryRequest = - reader.read_plain().map_err(WorkerError::Access)?; + let request: rndisprot::QueryRequest = reader.read_plain()?; tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG"); @@ -2734,13 +2735,11 @@ impl NetChannel { information_buffer_offset: size_of::() as u32, information_buffer_length: tx as u32, }, - ) - .map_err(WorkerError::Access)?; + )?; self.send_rndis_control_message(buffers, id, message_length)?; } rndisprot::MESSAGE_TYPE_SET_MSG => { - let request: rndisprot::SetRequest = - reader.read_plain().map_err(WorkerError::Access)?; + let request: rndisprot::SetRequest = reader.read_plain()?; tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG"); @@ -2767,8 +2766,7 @@ impl NetChannel { request_id: request.request_id, status, }, - ) - .map_err(WorkerError::Access)?; + )?; self.send_rndis_control_message(buffers, id, message_length)?; } rndisprot::MESSAGE_TYPE_RESET_MSG => { @@ -2778,8 +2776,7 @@ impl NetChannel { return Err(WorkerError::RndisMessageTypeNotImplemented) } rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => { - let request: rndisprot::KeepaliveRequest = - reader.read_plain().map_err(WorkerError::Access)?; + let request: rndisprot::KeepaliveRequest = reader.read_plain()?; tracing::trace!( ?request, @@ -2794,8 +2791,7 @@ impl NetChannel { request_id: request.request_id, status: rndisprot::STATUS_SUCCESS, }, - ) - .map_err(WorkerError::Access)?; + )?; self.send_rndis_control_message(buffers, id, message_length)?; } rndisprot::MESSAGE_TYPE_SET_EX_MSG => { @@ -4355,7 +4351,9 @@ impl NetChannel { ) -> Result>, WorkerError> { let (mut read, _) = self.queue.split(); let packet = match read.try_read() { - Ok(packet) => parse_packet(&packet, send_buffer, version)?, + Ok(packet) => { + parse_packet(&packet, send_buffer, version).map_err(WorkerError::Packet)? + } Err(queue::TryReadError::Empty) => return Ok(None), Err(queue::TryReadError::Queue(err)) => return Err(err.into()), }; @@ -4371,7 +4369,8 @@ impl NetChannel { ) -> Result, WorkerError> { let (mut read, _) = self.queue.split(); let mut packet_ref = read.read().await?; - let packet = parse_packet(&packet_ref, send_buffer, version)?; + let packet = + parse_packet(&packet_ref, send_buffer, version).map_err(WorkerError::Packet)?; if matches!(packet.data, PacketData::RndisPacket(_)) { // In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message"); @@ -5162,8 +5161,7 @@ impl NetChannel { // message. while reader.len() > 0 { let mut this_reader = reader.clone(); - let header: rndisprot::MessageHeader = - this_reader.read_plain().map_err(WorkerError::Access)?; + let header: rndisprot::MessageHeader = this_reader.read_plain()?; if self.handle_rndis_message( buffers, state, @@ -5174,9 +5172,7 @@ impl NetChannel { )? { num_packets += 1; } - reader - .skip(header.message_length as usize) - .map_err(WorkerError::Access)?; + reader.skip(header.message_length as usize)?; } Ok(num_packets) diff --git a/vm/devices/net/netvsp/src/test.rs b/vm/devices/net/netvsp/src/test.rs index 252f8413ed..c7297a2248 100644 --- a/vm/devices/net/netvsp/src/test.rs +++ b/vm/devices/net/netvsp/src/test.rs @@ -649,11 +649,10 @@ impl<'a> TestNicChannel<'a> { let external_ranges = if let Some(id) = data.transfer_buffer_id() { assert_eq!(id, 0); - data.read_transfer_ranges(recv_buf.iter()) + data.read_transfer_ranges(recv_buf.iter()).unwrap() } else { - data.read_external_ranges() - } - .unwrap(); + data.read_external_ranges().unwrap() + }; let mut direct_reader = PagedRanges::new(external_ranges.iter()).reader(&mem); diff --git a/vm/devices/storage/storvsp/src/lib.rs b/vm/devices/storage/storvsp/src/lib.rs index d3dbed8a90..239e721d0c 100644 --- a/vm/devices/storage/storvsp/src/lib.rs +++ b/vm/devices/storage/storvsp/src/lib.rs @@ -64,6 +64,7 @@ use thiserror::Error; use tracing_helpers::ErrorValueExt; use unicycle::FuturesUnordered; use vmbus_async::queue; +use vmbus_async::queue::ExternalDataError; use vmbus_async::queue::IncomingPacket; use vmbus_async::queue::OutgoingPacket; use vmbus_async::queue::Queue; @@ -294,9 +295,9 @@ enum PacketError { #[error("Invalid data transfer length")] InvalidDataTransferLength, #[error("Access error: {0}")] - Access(AccessError), + Access(#[source] AccessError), #[error("Range error")] - Range, + Range(#[source] ExternalDataError), } #[derive(Debug, Default, Clone)] @@ -391,9 +392,7 @@ fn parse_packet( let request_buf = &mut full_request.request.as_bytes_mut()[..request_size]; reader.read(request_buf).map_err(PacketError::Access)?; - let buf = packet - .read_external_ranges() - .map_err(|_| PacketError::Range)?; + let buf = packet.read_external_ranges().map_err(PacketError::Range)?; full_request.external_data = Range::new(buf, &full_request.request) .ok_or(PacketError::InvalidDataTransferLength)?; diff --git a/vm/devices/vmbus/vmbus_async/src/queue.rs b/vm/devices/vmbus/vmbus_async/src/queue.rs index bed0f0089a..154a3cab8a 100644 --- a/vm/devices/vmbus/vmbus_async/src/queue.rs +++ b/vm/devices/vmbus/vmbus_async/src/queue.rs @@ -101,6 +101,23 @@ pub enum TryWriteError { Queue(#[source] Error), } +/// An error returned by `read_external_ranges` +#[derive(Debug, Error)] +pub enum ExternalDataError { + /// The packet is corrupted in some way (e.g. it does not specify a reasonable set of GPA ranges). + #[error("invalid gpa ranges")] + GpaRange(#[source] vmbus_ring::gparange::Error), + + /// The packet specifies memory that this vmbus cannot read, for some reason. + #[error("access error")] + Access(#[source] AccessError), + + /// Caller used `read_external_ranges` when the packet contains a buffer id, + /// and the caller should have called `read_transfer_ranges` + #[error("external data should have been read by calling read_transfer_ranges")] + WrongExternalDataType, +} + /// An incoming packet batch reader. pub struct ReadBatch<'a, M: RingMem> { core: &'a Core, @@ -269,18 +286,21 @@ impl DataPacket<'_, T> { } /// Reads the GPA direct range descriptors from the packet. - pub fn read_external_ranges(&self) -> Result, AccessError> { + pub fn read_external_ranges(&self) -> Result, ExternalDataError> { if self.buffer_id.is_some() { - return Err(AccessError::OutOfRange(0, 0)); + return Err(ExternalDataError::WrongExternalDataType); } else if self.external_data.0 == 0 { - return Ok(MultiPagedRangeBuf::new(0, GpnList::new()).unwrap()); + return Ok(MultiPagedRangeBuf::empty()); } let mut reader = self.external_data.1.reader(self.ring); let len = reader.len() / 8; let mut buf = zeroed_gpn_list(len); - reader.read(buf.as_bytes_mut())?; - Ok(MultiPagedRangeBuf::new(self.external_data.0 as usize, buf).unwrap()) + reader + .read(buf.as_bytes_mut()) + .map_err(ExternalDataError::Access)?; + MultiPagedRangeBuf::new(self.external_data.0 as usize, buf) + .map_err(ExternalDataError::GpaRange) } /// Reads the transfer buffer ID from the packet, or None if this is not a transfer packet. @@ -298,7 +318,7 @@ impl DataPacket<'_, T> { I: Iterator>, { if self.external_data.0 == 0 { - return Ok(MultiPagedRangeBuf::new(0, GpnList::new()).unwrap()); + return Ok(MultiPagedRangeBuf::empty()); } let buf: MultiPagedRangeBuf = transfer_buf.collect(); @@ -803,6 +823,57 @@ mod tests { .unwrap(); } + #[async_test] + async fn test_gpa_direct_empty_external_data() { + use guestmem::ranges::PagedRange; + + let (mut host_queue, mut guest_queue) = connected_queues(16384); + + let gpa1: Vec = vec![]; + let gpas = vec![PagedRange::new(0, 0, &gpa1).unwrap()]; + + let payload: &[u8] = &[0xf; 24]; + guest_queue + .split() + .1 + .write(OutgoingPacket { + transaction_id: 0, + packet_type: OutgoingPacketType::GpaDirect(&gpas), + payload: &[payload], + }) + .await + .unwrap(); + host_queue + .split() + .0 + .read_batch() + .await + .unwrap() + .packets() + .next() + .map(|p| match p.unwrap() { + IncomingPacket::Data(data) => { + // Check the payload + let mut in_payload = [0_u8; 24]; + assert_eq!(payload.len(), data.reader().len()); + data.reader().read(&mut in_payload).unwrap(); + assert_eq!(in_payload, payload); + + // Check the external ranges + assert_eq!(data.external_range_count(), 1); + let external_data_result = data.read_external_ranges(); + assert_eq!(data.read_external_ranges().is_err(), true); + match external_data_result { + Err(ExternalDataError::GpaRange(_)) => Ok(()), + _ => Err("should be out of range"), + } + } + _ => Err("should be data"), + }) + .unwrap() + .unwrap(); + } + #[async_test] async fn test_transfer_pages() { use guestmem::ranges::PagedRange;