Skip to content

Commit

Permalink
vmbus_async: don't panic on empty external data (#621)
Browse files Browse the repository at this point in the history
Don't `panic` when a GpaDirect packet contains a pointer to empty GPA
ranges. This data is guest controlled and not a programmer error, so
`panic` is not correct. Consumers of `read_external_ranges` already
handle errors gracefully, so this is not a significant change for them.

As a drive-by code improvement, also change some cases of `#[from]` to
`#[source]` in error definitions, and add the requisite `.map_err`
changes.
  • Loading branch information
mattkur authored Jan 9, 2025
1 parent 353d20a commit e5c5db1
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 62 deletions.
90 changes: 43 additions & 47 deletions vm/devices/net/netvsp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ use task_control::StopTask;
use task_control::TaskControl;
use thiserror::Error;
use vmbus_async::queue;
use vmbus_async::queue::ExternalDataError;
use vmbus_async::queue::IncomingPacket;
use vmbus_async::queue::Queue;
use vmbus_channel::bus::OfferParams;
Expand Down Expand Up @@ -1311,8 +1312,9 @@ impl Nic {
.clone();
driver.retarget_vp(open_request.open_data.target_vp);

let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)?;
let mut queue = Queue::new(raw)?;
let raw = gpadl_channel(&driver, &self.resources, open_request, channel_idx)
.map_err(OpenError::Ring)?;
let mut queue = Queue::new(raw).map_err(OpenError::Queue)?;
let guest_os_id = self.adapter.get_guest_os_id.as_ref().map(|f| f());
let can_use_ring_size_opt = can_use_ring_opt(&mut queue, guest_os_id);
let worker = Worker {
Expand Down Expand Up @@ -1395,7 +1397,7 @@ enum NetRestoreError {
#[error("send/receive buffer invalid gpadl ID")]
UnknownGpadlId(#[from] UnknownGpadlId),
#[error("failed to restore channels")]
Channel(#[from] ChannelRestoreError),
Channel(#[source] ChannelRestoreError),
#[error(transparent)]
ReceiveBuffer(#[from] BufferError),
#[error(transparent)]
Expand Down Expand Up @@ -1436,7 +1438,10 @@ impl Nic {
// state (vmbus believes the channel is open/active). There
// are a number of failure paths after this point because this
// call also restores vmbus device state, like the GPADL map.
let requests = control.restore(&open).await?;
let requests = control
.restore(&open)
.await
.map_err(NetRestoreError::Channel)?;

match state.primary {
saved_state::Primary::Version => {
Expand Down Expand Up @@ -1586,7 +1591,10 @@ impl Nic {
}
}
} else {
control.restore(&[false]).await?;
control
.restore(&[false])
.await
.map_err(NetRestoreError::Channel)?;
}
Ok(())
}
Expand Down Expand Up @@ -1789,11 +1797,9 @@ impl Nic {
#[derive(Debug, Error)]
enum WorkerError {
#[error("packet error")]
Packet(#[from] PacketError),
Packet(#[source] PacketError),
#[error("unexpected packet order: {0}")]
UnexpectedPacketOrder(#[from] PacketOrderError),
#[error("invalid gpadl id")]
InvalidGpadlId(#[from] UnknownGpadlId),
UnexpectedPacketOrder(#[source] PacketOrderError),
#[error("unknown rndis message type: {0}")]
UnknownRndisMessageType(u32),
#[error("memory access error")]
Expand Down Expand Up @@ -1845,17 +1851,19 @@ impl From<task_control::Cancelled> for WorkerError {
#[derive(Debug, Error)]
enum OpenError {
#[error("error establishing ring buffer")]
Ring(#[from] vmbus_channel::gpadl_ring::Error),
Ring(#[source] vmbus_channel::gpadl_ring::Error),
#[error("error establishing vmbus queue")]
Queue(#[from] queue::Error),
Queue(#[source] queue::Error),
}

#[derive(Debug, Error)]
enum PacketError {
#[error("UnknownType {0}")]
UnknownType(u32),
#[error("Access")]
Access(#[from] AccessError),
Access(#[source] AccessError),
#[error("ExternalData")]
ExternalData(#[source] ExternalDataError),
#[error("InvalidSendBufferIndex")]
InvalidSendBufferIndex,
}
Expand Down Expand Up @@ -2011,7 +2019,9 @@ fn parse_packet<'a, T: RingMem>(
Ok(Packet {
data,
transaction_id: packet.transaction_id(),
external_data: packet.read_external_ranges().map_err(PacketError::Access)?,
external_data: packet
.read_external_ranges()
.map_err(PacketError::ExternalData)?,
send_buffer_suballocation,
})
}
Expand Down Expand Up @@ -2224,7 +2234,7 @@ impl<T: RingMem> NetChannel<T> {
control.control_messages_len += reader.len();
control.control_messages.push_back(ControlMessage {
message_type,
data: reader.read_all().map_err(WorkerError::Access)?.into(),
data: reader.read_all()?.into(),
});

false
Expand All @@ -2251,10 +2261,7 @@ impl<T: RingMem> NetChannel<T> {
.find(|r| !r.is_empty())
.ok_or(WorkerError::RndisMessageTooSmall)?;
let mut data = reader.into_inner();
let request: rndisprot::Packet = headers
.reader(mem)
.read_plain()
.map_err(WorkerError::Access)?;
let request: rndisprot::Packet = headers.reader(mem).read_plain()?;
if request.num_oob_data_elements != 0
|| request.oob_data_length != 0
|| request.oob_data_offset != 0
Expand Down Expand Up @@ -2287,8 +2294,7 @@ impl<T: RingMem> NetChannel<T> {
)
.ok_or(WorkerError::RndisMessageTooSmall)?;
while !ppi.is_empty() {
let h: rndisprot::PerPacketInfo =
ppi.reader(mem).read_plain().map_err(WorkerError::Access)?;
let h: rndisprot::PerPacketInfo = ppi.reader(mem).read_plain()?;
if h.size == 0 {
return Err(WorkerError::RndisMessageTooSmall);
}
Expand All @@ -2300,8 +2306,7 @@ impl<T: RingMem> NetChannel<T> {
.ok_or(WorkerError::RndisMessageTooSmall)?;
match h.typ {
rndisprot::PPI_TCP_IP_CHECKSUM => {
let n: rndisprot::TxTcpIpChecksumInfo =
d.reader(mem).read_plain().map_err(WorkerError::Access)?;
let n: rndisprot::TxTcpIpChecksumInfo = d.reader(mem).read_plain()?;

metadata.offload_tcp_checksum =
(n.is_ipv4() || n.is_ipv6()) && n.tcp_checksum();
Expand Down Expand Up @@ -2332,8 +2337,7 @@ impl<T: RingMem> NetChannel<T> {
}
}
rndisprot::PPI_LSO => {
let n: rndisprot::TcpLsoInfo =
d.reader(mem).read_plain().map_err(WorkerError::Access)?;
let n: rndisprot::TcpLsoInfo = d.reader(mem).read_plain()?;

metadata.offload_tcp_segmentation = true;
metadata.offload_tcp_checksum = true;
Expand Down Expand Up @@ -2644,8 +2648,7 @@ impl<T: RingMem> NetChannel<T> {
return Err(WorkerError::InvalidRndisState);
}

let request: rndisprot::InitializeRequest =
reader.read_plain().map_err(WorkerError::Access)?;
let request: rndisprot::InitializeRequest = reader.read_plain()?;

tracing::trace!(
?request,
Expand All @@ -2672,8 +2675,7 @@ impl<T: RingMem> NetChannel<T> {
af_list_offset: 0,
af_list_size: 0,
},
)
.map_err(WorkerError::Access)?;
)?;
self.send_rndis_control_message(buffers, id, message_length)?;
if let PrimaryChannelGuestVfState::Available { vfid } = primary.guest_vf_state {
if self.guest_vf_is_available(
Expand Down Expand Up @@ -2703,8 +2705,7 @@ impl<T: RingMem> NetChannel<T> {
}
}
rndisprot::MESSAGE_TYPE_QUERY_MSG => {
let request: rndisprot::QueryRequest =
reader.read_plain().map_err(WorkerError::Access)?;
let request: rndisprot::QueryRequest = reader.read_plain()?;

tracing::trace!(?request, "handling control message MESSAGE_TYPE_QUERY_MSG");

Expand Down Expand Up @@ -2734,13 +2735,11 @@ impl<T: RingMem> NetChannel<T> {
information_buffer_offset: size_of::<rndisprot::QueryComplete>() as u32,
information_buffer_length: tx as u32,
},
)
.map_err(WorkerError::Access)?;
)?;
self.send_rndis_control_message(buffers, id, message_length)?;
}
rndisprot::MESSAGE_TYPE_SET_MSG => {
let request: rndisprot::SetRequest =
reader.read_plain().map_err(WorkerError::Access)?;
let request: rndisprot::SetRequest = reader.read_plain()?;

tracing::trace!(?request, "handling control message MESSAGE_TYPE_SET_MSG");

Expand All @@ -2767,8 +2766,7 @@ impl<T: RingMem> NetChannel<T> {
request_id: request.request_id,
status,
},
)
.map_err(WorkerError::Access)?;
)?;
self.send_rndis_control_message(buffers, id, message_length)?;
}
rndisprot::MESSAGE_TYPE_RESET_MSG => {
Expand All @@ -2778,8 +2776,7 @@ impl<T: RingMem> NetChannel<T> {
return Err(WorkerError::RndisMessageTypeNotImplemented)
}
rndisprot::MESSAGE_TYPE_KEEPALIVE_MSG => {
let request: rndisprot::KeepaliveRequest =
reader.read_plain().map_err(WorkerError::Access)?;
let request: rndisprot::KeepaliveRequest = reader.read_plain()?;

tracing::trace!(
?request,
Expand All @@ -2794,8 +2791,7 @@ impl<T: RingMem> NetChannel<T> {
request_id: request.request_id,
status: rndisprot::STATUS_SUCCESS,
},
)
.map_err(WorkerError::Access)?;
)?;
self.send_rndis_control_message(buffers, id, message_length)?;
}
rndisprot::MESSAGE_TYPE_SET_EX_MSG => {
Expand Down Expand Up @@ -4355,7 +4351,9 @@ impl<T: 'static + RingMem> NetChannel<T> {
) -> Result<Option<Packet<'a>>, WorkerError> {
let (mut read, _) = self.queue.split();
let packet = match read.try_read() {
Ok(packet) => parse_packet(&packet, send_buffer, version)?,
Ok(packet) => {
parse_packet(&packet, send_buffer, version).map_err(WorkerError::Packet)?
}
Err(queue::TryReadError::Empty) => return Ok(None),
Err(queue::TryReadError::Queue(err)) => return Err(err.into()),
};
Expand All @@ -4371,7 +4369,8 @@ impl<T: 'static + RingMem> NetChannel<T> {
) -> Result<Packet<'a>, WorkerError> {
let (mut read, _) = self.queue.split();
let mut packet_ref = read.read().await?;
let packet = parse_packet(&packet_ref, send_buffer, version)?;
let packet =
parse_packet(&packet_ref, send_buffer, version).map_err(WorkerError::Packet)?;
if matches!(packet.data, PacketData::RndisPacket(_)) {
// In WorkerState::Init if an rndis packet is received, assume it is MESSAGE_TYPE_INITIALIZE_MSG
tracing::trace!(target: "netvsp/vmbus", "detected rndis initialization message");
Expand Down Expand Up @@ -5162,8 +5161,7 @@ impl<T: 'static + RingMem> NetChannel<T> {
// message.
while reader.len() > 0 {
let mut this_reader = reader.clone();
let header: rndisprot::MessageHeader =
this_reader.read_plain().map_err(WorkerError::Access)?;
let header: rndisprot::MessageHeader = this_reader.read_plain()?;
if self.handle_rndis_message(
buffers,
state,
Expand All @@ -5174,9 +5172,7 @@ impl<T: 'static + RingMem> NetChannel<T> {
)? {
num_packets += 1;
}
reader
.skip(header.message_length as usize)
.map_err(WorkerError::Access)?;
reader.skip(header.message_length as usize)?;
}

Ok(num_packets)
Expand Down
7 changes: 3 additions & 4 deletions vm/devices/net/netvsp/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,10 @@ impl<'a> TestNicChannel<'a> {
let external_ranges = if let Some(id) = data.transfer_buffer_id() {
assert_eq!(id, 0);

data.read_transfer_ranges(recv_buf.iter())
data.read_transfer_ranges(recv_buf.iter()).unwrap()
} else {
data.read_external_ranges()
}
.unwrap();
data.read_external_ranges().unwrap()
};
let mut direct_reader =
PagedRanges::new(external_ranges.iter()).reader(&mem);

Expand Down
9 changes: 4 additions & 5 deletions vm/devices/storage/storvsp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ use thiserror::Error;
use tracing_helpers::ErrorValueExt;
use unicycle::FuturesUnordered;
use vmbus_async::queue;
use vmbus_async::queue::ExternalDataError;
use vmbus_async::queue::IncomingPacket;
use vmbus_async::queue::OutgoingPacket;
use vmbus_async::queue::Queue;
Expand Down Expand Up @@ -294,9 +295,9 @@ enum PacketError {
#[error("Invalid data transfer length")]
InvalidDataTransferLength,
#[error("Access error: {0}")]
Access(AccessError),
Access(#[source] AccessError),
#[error("Range error")]
Range,
Range(#[source] ExternalDataError),
}

#[derive(Debug, Default, Clone)]
Expand Down Expand Up @@ -391,9 +392,7 @@ fn parse_packet<T: RingMem>(
let request_buf = &mut full_request.request.as_bytes_mut()[..request_size];
reader.read(request_buf).map_err(PacketError::Access)?;

let buf = packet
.read_external_ranges()
.map_err(|_| PacketError::Range)?;
let buf = packet.read_external_ranges().map_err(PacketError::Range)?;

full_request.external_data = Range::new(buf, &full_request.request)
.ok_or(PacketError::InvalidDataTransferLength)?;
Expand Down
Loading

0 comments on commit e5c5db1

Please sign in to comment.