From b510ec9425336acc62626e30090ba35fb124d57f Mon Sep 17 00:00:00 2001 From: John Starks Date: Thu, 5 Dec 2024 21:05:43 +0000 Subject: [PATCH 1/4] mesh_channel: hide RPC response receiver from callers In the future, RPC channels will typically (always?) use blocking senders to provide backpressure and better support early cancellation. To prepare for this, wrap the RPC response oneshot channel in a separate type which in the future can contain both a blocking send request and the response receiver. Also, hide the internal fields of `Rpc` to give us flexibility there in the future, too. --- openhcl/profiler_worker/src/lib.rs | 4 +- openhcl/underhill_core/src/diag.rs | 4 +- openhcl/underhill_core/src/dispatch/mod.rs | 13 +- .../src/dispatch/vtl2_settings_worker.rs | 3 +- openhcl/underhill_core/src/get_tracing.rs | 8 +- openhcl/underhill_core/src/lib.rs | 20 +-- openhcl/underhill_core/src/worker.rs | 2 +- openvmm/hvlite_core/src/worker/dispatch.rs | 20 +-- openvmm/hvlite_helpers/src/underhill.rs | 7 +- .../src/mapping_manager/va_mapper.rs | 3 +- openvmm/membacking/src/region_manager.rs | 4 +- openvmm/openvmm_entry/src/lib.rs | 10 +- openvmm/openvmm_entry/src/ttrpc/mod.rs | 9 +- petri/pipette_client/src/lib.rs | 3 +- petri/pipette_client/src/process.rs | 4 +- petri/pipette_client/src/send.rs | 25 ++- petri/src/vm/runtime.rs | 8 +- petri/src/worker.rs | 9 +- support/mesh/mesh_channel/src/error.rs | 33 ---- support/mesh/mesh_channel/src/rpc.rs | 158 +++++++++++++++--- support/mesh/mesh_worker/src/worker.rs | 19 +-- vm/devices/get/guest_crash_device/src/lib.rs | 9 +- .../get/guest_emulation_device/src/lib.rs | 31 ++-- .../src/process_loop.rs | 52 +++--- vm/devices/hyperv_ic/src/shutdown.rs | 9 +- .../net/net_consomme/consomme/src/lib.rs | 3 +- vm/devices/net/net_packet_capture/src/lib.rs | 4 +- vm/devices/net/netvsp/src/test.rs | 14 +- .../disk_nvme/nvme_driver/src/queue_pair.rs | 21 +-- .../storage/nvme/src/workers/coordinator.rs | 5 +- vm/devices/vmbus/vmbus_channel/src/channel.rs | 20 +-- vm/devices/vmbus/vmbus_channel/src/offer.rs | 19 ++- vm/devices/vmbus/vmbus_client/src/lib.rs | 73 ++++---- vm/devices/vmbus/vmbus_relay/src/hvsock.rs | 16 ++ vm/devices/vmbus/vmbus_relay/src/lib.rs | 4 +- vm/devices/vmbus/vmbus_server/src/lib.rs | 37 ++-- .../vmbus_server/src/proxyintegration.rs | 14 +- vm/vmgs/vmgs_broker/src/client.rs | 28 +++- vmm_core/src/partition_unit/vp_set.rs | 3 +- vmm_core/state_unit/src/lib.rs | 10 +- workers/debug_worker/src/gdb/mod.rs | 10 +- workers/debug_worker/src/gdb/targets/base.rs | 43 +++-- workers/debug_worker/src/lib.rs | 4 +- workers/vnc_worker/src/lib.rs | 6 +- 44 files changed, 454 insertions(+), 347 deletions(-) diff --git a/openhcl/profiler_worker/src/lib.rs b/openhcl/profiler_worker/src/lib.rs index 995fde7b8c..6d121f662a 100644 --- a/openhcl/profiler_worker/src/lib.rs +++ b/openhcl/profiler_worker/src/lib.rs @@ -97,8 +97,8 @@ impl Worker for ProfilerWorker { WorkerRpc::Stop => { break; } - WorkerRpc::Restart(response) => { - response.send(Err(RemoteError::new(anyhow::anyhow!("not supported")))); + WorkerRpc::Restart(rpc) => { + rpc.complete(Err(RemoteError::new(anyhow::anyhow!("not supported")))); } WorkerRpc::Inspect(_deferred) => {} }, diff --git a/openhcl/underhill_core/src/diag.rs b/openhcl/underhill_core/src/diag.rs index 0b5e7aabc1..2a3b72ae80 100644 --- a/openhcl/underhill_core/src/diag.rs +++ b/openhcl/underhill_core/src/diag.rs @@ -65,8 +65,8 @@ impl Worker for DiagWorker { }; match msg { WorkerRpc::Stop => break Ok(()), - WorkerRpc::Restart(response) => { - response.send(Err(RemoteError::new(anyhow::anyhow!("not supported")))); + WorkerRpc::Restart(rpc) => { + rpc.complete(Err(RemoteError::new(anyhow::anyhow!("not supported")))); } WorkerRpc::Inspect(_) => {} } diff --git a/openhcl/underhill_core/src/dispatch/mod.rs b/openhcl/underhill_core/src/dispatch/mod.rs index 6e503aeca2..e1e2e85c63 100644 --- a/openhcl/underhill_core/src/dispatch/mod.rs +++ b/openhcl/underhill_core/src/dispatch/mod.rs @@ -33,7 +33,6 @@ use hyperv_ic_resources::shutdown::ShutdownType; use igvm_defs::MemoryMapEntryType; use inspect::Inspect; use mesh::error::RemoteError; -use mesh::error::RemoteResult; use mesh::rpc::FailableRpc; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; @@ -183,7 +182,7 @@ pub(crate) struct LoadedVm { } pub struct LoadedVmState { - pub restart_response: mesh::OneshotSender>, + pub restart_rpc: FailableRpc<(), T>, pub servicing_state: ServicingState, pub vm_rpc: mesh::Receiver, pub control_send: mesh::Sender, @@ -262,16 +261,16 @@ impl LoadedVm { Event::WorkerRpcGone => break None, Event::WorkerRpc(message) => match message { WorkerRpc::Stop => break None, - WorkerRpc::Restart(response) => { + WorkerRpc::Restart(rpc) => { let state = async { let running = self.stop().await; match self.save(None, false).await { - Ok(servicing_state) => Some((response, servicing_state)), + Ok(servicing_state) => Some((rpc, servicing_state)), Err(err) => { if running { self.start(None).await; } - response.send(Err(RemoteError::new(err))); + rpc.complete(Err(RemoteError::new(err))); None } } @@ -279,9 +278,9 @@ impl LoadedVm { .instrument(tracing::info_span!("restart")) .await; - if let Some((response, servicing_state)) = state { + if let Some((rpc, servicing_state)) = state { break Some(LoadedVmState { - restart_response: response, + restart_rpc: rpc, servicing_state, vm_rpc, control_send: self.control_send.lock().take().unwrap(), diff --git a/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs b/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs index 9654656c11..0774e8403d 100644 --- a/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs +++ b/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs @@ -18,6 +18,7 @@ use ide_resources::GuestMedia; use ide_resources::IdeControllerConfig; use ide_resources::IdeDeviceConfig; use ide_resources::IdePath; +use mesh::rpc::RpcError; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use mesh::CancelContext; @@ -61,7 +62,7 @@ use vm_resource::ResourceResolver; #[derive(Error, Debug)] enum Error<'a> { #[error("RPC error")] - Rpc(#[source] mesh::RecvError), + Rpc(#[source] RpcError), #[error("cannot add/remove storage controllers at runtime")] StorageCannotAddRemoveControllerAtRuntime, #[error("Striping devices don't support runtime change")] diff --git a/openhcl/underhill_core/src/get_tracing.rs b/openhcl/underhill_core/src/get_tracing.rs index 0ad3b0d1f7..c9c8c46136 100644 --- a/openhcl/underhill_core/src/get_tracing.rs +++ b/openhcl/underhill_core/src/get_tracing.rs @@ -152,13 +152,13 @@ impl GetTracingBackend { ) .merge(); - let flush_response = loop { + let flush_rpc = loop { let trace_type = streams.next().await.unwrap(); match trace_type { Event::Trace(data) => { write.send(&data).await.ok(); } - Event::Flush(Rpc((), response)) => break Some(response), + Event::Flush(rpc) => break Some(rpc), Event::Done => break None, } }; @@ -174,8 +174,8 @@ impl GetTracingBackend { // Wait for the host to read everything. write.wait_empty().await.ok(); - if let Some(resp) = flush_response { - resp.send(()); + if let Some(rpc) = flush_rpc { + rpc.complete(()); } else { break; } diff --git a/openhcl/underhill_core/src/lib.rs b/openhcl/underhill_core/src/lib.rs index f9356451d7..667213715b 100644 --- a/openhcl/underhill_core/src/lib.rs +++ b/openhcl/underhill_core/src/lib.rs @@ -470,7 +470,7 @@ async fn run_control( Control(ControlRequest), } - let mut restart_response = None; + let mut restart_rpc = None; loop { let event = { let mut stream = ( @@ -551,7 +551,7 @@ async fn run_control( }; let r = async { - if restart_response.is_some() { + if restart_rpc.is_some() { anyhow::bail!("previous restart still in progress"); } @@ -568,7 +568,7 @@ async fn run_control( rpc.complete(r.map_err(RemoteError::new)); } else { state = ControlState::Restarting; - restart_response = Some(rpc.1); + restart_rpc = Some(rpc); } } diag_server::DiagRequest::Pause(rpc) => { @@ -647,7 +647,7 @@ async fn run_control( } #[cfg(feature = "profiler")] diag_server::DiagRequest::Profile(rpc) => { - let Rpc(rpc_params, rpc_sender) = rpc; + let (rpc_params, rpc_sender) = rpc.split(); // Create profiler host if there is none created before if profiler_host.is_none() { match launch_mesh_host(mesh, "profiler", Some(tracing.tracer())) @@ -658,7 +658,7 @@ async fn run_control( profiler_host = Some(host); } Err(e) => { - rpc_sender.send(Err(RemoteError::new(e))); + rpc_sender.complete(Err(RemoteError::new(e))); continue; } } @@ -680,7 +680,7 @@ async fn run_control( profiler_worker = worker; } Err(e) => { - rpc_sender.send(Err(RemoteError::new(e))); + rpc_sender.complete(Err(RemoteError::new(e))); continue; } } @@ -695,7 +695,7 @@ async fn run_control( .and_then(|result| result.context("profiler worker failed")) .map_err(RemoteError::new); - rpc_sender.send(result); + rpc_sender.complete(result); }) .detach(); } @@ -703,9 +703,9 @@ async fn run_control( } Event::Worker(event) => match event { WorkerEvent::Started => { - if let Some(response) = restart_response.take() { + if let Some(response) = restart_rpc.take() { tracing::info!("restart complete"); - response.send(Ok(())); + response.complete(Ok(())); } else { tracing::info!("vm worker started"); } @@ -719,7 +719,7 @@ async fn run_control( } WorkerEvent::RestartFailed(err) => { tracing::error!(error = &err as &dyn std::error::Error, "restart failed"); - restart_response.take().unwrap().send(Err(err)); + restart_rpc.take().unwrap().complete(Err(err)); state = ControlState::Started; } }, diff --git a/openhcl/underhill_core/src/worker.rs b/openhcl/underhill_core/src/worker.rs index 1ec5e794e2..abede74c40 100644 --- a/openhcl/underhill_core/src/worker.rs +++ b/openhcl/underhill_core/src/worker.rs @@ -430,7 +430,7 @@ impl Worker for UnderhillVmWorker { }; tracing::info!("sending worker restart state"); - state.restart_response.send(Ok(RestartState { + state.restart_rpc.complete(Ok(RestartState { params, servicing_state: state.servicing_state, })) diff --git a/openvmm/hvlite_core/src/worker/dispatch.rs b/openvmm/hvlite_core/src/worker/dispatch.rs index 1f7c87bd55..035c9d858f 100644 --- a/openvmm/hvlite_core/src/worker/dispatch.rs +++ b/openvmm/hvlite_core/src/worker/dispatch.rs @@ -59,7 +59,6 @@ use memory_range::MemoryRange; use mesh::error::RemoteError; use mesh::payload::message::ProtobufMessage; use mesh::payload::Protobuf; -use mesh::rpc::Rpc; use mesh::MeshPayload; use mesh_worker::Worker; use mesh_worker::WorkerId; @@ -2477,7 +2476,7 @@ impl LoadedVm { pub async fn run( mut self, driver: &impl Spawn, - mut rpc: mesh::Receiver, + mut rpc_recv: mesh::Receiver, mut worker_rpc: mesh::Receiver>, ) { enum Event { @@ -2508,7 +2507,7 @@ impl LoadedVm { loop { let event: Event = { - let a = rpc.recv().map(Event::VmRpc); + let a = rpc_recv.recv().map(Event::VmRpc); let b = worker_rpc.recv().map(Event::WorkerRpc); (a, b).race().await }; @@ -2517,7 +2516,7 @@ impl LoadedVm { Event::WorkerRpc(Err(_)) => break, Event::WorkerRpc(Ok(message)) => match message { WorkerRpc::Stop => break, - WorkerRpc::Restart(response) => { + WorkerRpc::Restart(rpc) => { let mut stopped = false; // First run the non-destructive operations. let r = async { @@ -2532,8 +2531,8 @@ impl LoadedVm { .await; match r { Ok((shared_memory, saved_state)) => { - response.send(Ok(self - .serialize(rpc, shared_memory, saved_state) + rpc.complete(Ok(self + .serialize(rpc_recv, shared_memory, saved_state) .await)); return; @@ -2542,7 +2541,7 @@ impl LoadedVm { if stopped { self.state_units.start().await; } - response.send(Err(RemoteError::new(err))); + rpc.complete(Err(RemoteError::new(err))); } } } @@ -2618,16 +2617,17 @@ impl LoadedVm { }) .await } - VmRpc::ConnectHvsock(Rpc((mut ctx, service_id, vtl), response)) => { + VmRpc::ConnectHvsock(rpc) => { + let ((mut ctx, service_id, vtl), response) = rpc.split(); if let Some(relay) = self.hvsock_relay(vtl) { let fut = relay.connect(&mut ctx, service_id); driver .spawn("vmrpc-hvsock-connect", async move { - response.send(fut.await.map_err(RemoteError::new)) + response.complete(fut.await.map_err(RemoteError::new)) }) .detach(); } else { - response.send(Err(RemoteError::new(anyhow::anyhow!( + response.complete(Err(RemoteError::new(anyhow::anyhow!( "hvsock is not available" )))); } diff --git a/openvmm/hvlite_helpers/src/underhill.rs b/openvmm/hvlite_helpers/src/underhill.rs index 033f6bc568..8121c2a124 100644 --- a/openvmm/hvlite_helpers/src/underhill.rs +++ b/openvmm/hvlite_helpers/src/underhill.rs @@ -6,7 +6,6 @@ use anyhow::Context; use get_resources::ged::GuestEmulationRequest; use hvlite_defs::rpc::VmRpc; -use mesh::error::RemoteResultExt; use mesh::rpc::RpcSend; /// Replace the running version of Underhill. @@ -28,9 +27,8 @@ pub async fn service_underhill( // blocked while waiting for the guest. tracing::debug!("waiting for guest to send saved state"); let r = send - .call(GuestEmulationRequest::SaveGuestVtl2State, ()) + .call_failable(GuestEmulationRequest::SaveGuestVtl2State, ()) .await - .flatten() .context("failed to save VTL2 state"); if r.is_err() { @@ -51,9 +49,8 @@ pub async fn service_underhill( // // TODO: event driven, cancellable. tracing::debug!("waiting for VTL0 to start"); - send.call(GuestEmulationRequest::WaitForVtl0Start, ()) + send.call_failable(GuestEmulationRequest::WaitForVtl0Start, ()) .await - .flatten() .context("vtl0 start failed")?; Ok(()) diff --git a/openvmm/membacking/src/mapping_manager/va_mapper.rs b/openvmm/membacking/src/mapping_manager/va_mapper.rs index d70c211d29..27e4ef918e 100644 --- a/openvmm/membacking/src/mapping_manager/va_mapper.rs +++ b/openvmm/membacking/src/mapping_manager/va_mapper.rs @@ -32,6 +32,7 @@ use futures::executor::block_on; use guestmem::GuestMemoryAccess; use guestmem::PageFaultAction; use memory_range::MemoryRange; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use parking_lot::Mutex; use sparse_mmap::SparseMapping; @@ -156,7 +157,7 @@ impl MapperTask { #[derive(Debug, Error)] pub enum VaMapperError { #[error("failed to communicate with the memory manager")] - MemoryManagerGone(#[source] mesh::RecvError), + MemoryManagerGone(#[source] RpcError), #[error("failed to reserve address space")] Reserve(#[source] std::io::Error), } diff --git a/openvmm/membacking/src/region_manager.rs b/openvmm/membacking/src/region_manager.rs index 4be04c53b0..aca030db43 100644 --- a/openvmm/membacking/src/region_manager.rs +++ b/openvmm/membacking/src/region_manager.rs @@ -628,9 +628,7 @@ impl RegionHandle { impl Drop for RegionHandle { fn drop(&mut self) { if let Some(id) = self.id { - let (send, _recv) = mesh::oneshot(); - self.req_send - .send(RegionRequest::RemoveRegion(Rpc(id, send))); + let _recv = self.req_send.call(RegionRequest::RemoveRegion, id); // Don't wait for the response. } } diff --git a/openvmm/openvmm_entry/src/lib.rs b/openvmm/openvmm_entry/src/lib.rs index be3da2dee6..01326627f6 100644 --- a/openvmm/openvmm_entry/src/lib.rs +++ b/openvmm/openvmm_entry/src/lib.rs @@ -72,10 +72,10 @@ use inspect::InspectMut; use inspect::InspectionBuilder; use io::Read; use mesh::error::RemoteError; +use mesh::rpc::RpcError; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use mesh::CancelContext; -use mesh::RecvError; use mesh_worker::launch_local_worker; use mesh_worker::WorkerEvent; use mesh_worker::WorkerHandle; @@ -2092,7 +2092,7 @@ async fn run_control(driver: &DefaultDriver, mesh: &VmmMesh, opt: Options) -> an }) .unwrap(); - let mut state_change_task = None::>>; + let mut state_change_task = None::>>; let mut pulse_save_restore_interval: Option = None; let mut pending_shutdown = None; @@ -2113,8 +2113,8 @@ async fn run_control(driver: &DefaultDriver, mesh: &VmmMesh, opt: Options) -> an PulseSaveRestore, Worker(WorkerEvent), VncWorker(WorkerEvent), - StateChange(Result), - ShutdownResult(Result), + StateChange(Result), + ShutdownResult(Result), } let mut console_command_recv = console_command_recv @@ -2360,7 +2360,7 @@ async fn run_control(driver: &DefaultDriver, mesh: &VmmMesh, opt: Options) -> an fn state_change( driver: impl Spawn, vm_rpc: &mesh::Sender, - state_change_task: &mut Option>>, + state_change_task: &mut Option>>, f: impl FnOnce(Rpc<(), U>) -> VmRpc, g: impl FnOnce(U) -> StateChange + 'static + Send, ) { diff --git a/openvmm/openvmm_entry/src/ttrpc/mod.rs b/openvmm/openvmm_entry/src/ttrpc/mod.rs index 3a699ac490..579138f771 100644 --- a/openvmm/openvmm_entry/src/ttrpc/mod.rs +++ b/openvmm/openvmm_entry/src/ttrpc/mod.rs @@ -33,7 +33,6 @@ use inspect_proto::InspectResponse2; use inspect_proto::InspectService; use inspect_proto::UpdateResponse2; use mesh::error::RemoteError; -use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use mesh::CancelReason; use mesh::MeshPayload; @@ -197,7 +196,7 @@ impl VmService { }, request = recv.recv().fuse() => { match request { - Ok(WorkerRpc::Restart(response)) => response.send(Err(RemoteError::new(anyhow::anyhow!("not supported")))), + Ok(WorkerRpc::Restart(rpc)) => rpc.complete(Err(RemoteError::new(anyhow::anyhow!("not supported")))), Ok(WorkerRpc::Inspect(_)) => (), Ok(WorkerRpc::Stop) => { tracing::info!("ttrpc worker stopping"); @@ -605,14 +604,12 @@ impl VmService { } fn pause_vm(&mut self, vm: &Vm) -> impl Future> { - let (send, recv) = mesh::oneshot(); - vm.worker_rpc.send(VmRpc::Pause(Rpc((), send))); + let recv = vm.worker_rpc.call(VmRpc::Pause, ()); async move { recv.await.map(drop).context("pause failed") } } fn resume_vm(&mut self, vm: &Vm) -> impl Future> { - let (send, recv) = mesh::oneshot(); - vm.worker_rpc.send(VmRpc::Resume(Rpc((), send))); + let recv = vm.worker_rpc.call(VmRpc::Resume, ()); async move { recv.await.map(drop).context("resume failed") } } diff --git a/petri/pipette_client/src/lib.rs b/petri/pipette_client/src/lib.rs index 8bbc4f1c94..73fe2e6f50 100644 --- a/petri/pipette_client/src/lib.rs +++ b/petri/pipette_client/src/lib.rs @@ -22,6 +22,7 @@ use futures::AsyncWriteExt; use futures::StreamExt; use futures::TryFutureExt; use futures_concurrency::future::TryJoin; +use mesh::rpc::RpcError; use mesh_remote::PointToPointMesh; use pal_async::task::Spawn; use pal_async::task::Task; @@ -81,7 +82,7 @@ impl PipetteClient { } /// Pings the agent to check if it's alive. - pub async fn ping(&self) -> Result<(), mesh::RecvError> { + pub async fn ping(&self) -> Result<(), RpcError> { self.send.call(PipetteRequest::Ping, ()).await } diff --git a/petri/pipette_client/src/process.rs b/petri/pipette_client/src/process.rs index 2c0c1536f8..83344c9ae5 100644 --- a/petri/pipette_client/src/process.rs +++ b/petri/pipette_client/src/process.rs @@ -9,7 +9,6 @@ use futures::executor::block_on; use futures::io::AllowStdIo; use futures::AsyncReadExt; use futures_concurrency::future::Join; -use mesh::error::RemoteResultExt; use mesh::pipe::ReadPipe; use mesh::pipe::WritePipe; use pipette_protocol::EnvPair; @@ -167,9 +166,8 @@ impl<'a> Command<'a> { let response = self .client .send - .call(PipetteRequest::Execute, request) + .call_failable(PipetteRequest::Execute, request) .await - .flatten() .with_context(|| format!("failed to execute {}", self.program))?; Ok(Child { diff --git a/petri/pipette_client/src/send.rs b/petri/pipette_client/src/send.rs index e824b2f1ca..fbae46b5d8 100644 --- a/petri/pipette_client/src/send.rs +++ b/petri/pipette_client/src/send.rs @@ -4,6 +4,7 @@ //! A thin wrapper around a `mesh::Sender` that provides //! useful error handling semantics. +use mesh::rpc::RpcError; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use mesh::CancelContext; @@ -20,14 +21,30 @@ impl PipetteSender { /// A wrapper around [`mesh::Sender::call`] that will sleep for 5 seconds on failure, /// allowing any additional work occurring on the system to hopefully complete. /// See also [`petri::PetriVm::wait_for_halt_or`] - pub(crate) async fn call(&self, f: F, input: I) -> Result + pub(crate) async fn call(&self, f: F, input: I) -> Result where F: FnOnce(Rpc) -> PipetteRequest, R: 'static + Send, { - let (result_send, result_recv) = mesh::oneshot(); - self.0.send_rpc(f(Rpc(input, result_send))); - let result = result_recv.await; + let result = self.0.call(f, input).await; + if result.is_err() { + tracing::warn!("Pipette request channel failed, sleeping for 5 seconds to let outstanding work finish"); + let mut c = CancelContext::new().with_timeout(Duration::from_secs(5)); + let _ = c.cancelled().await; + } + result + } + + /// A wrapper around [`mesh::Sender::call_failable`] that will sleep for 5 seconds on failure, + /// allowing any additional work occurring on the system to hopefully complete. + /// See also [`petri::PetriVm::wait_for_halt_or`] + pub(crate) async fn call_failable(&self, f: F, input: I) -> Result> + where + F: FnOnce(Rpc>) -> PipetteRequest, + T: 'static + Send, + E: 'static + Send, + { + let result = self.0.call_failable(f, input).await; if result.is_err() { tracing::warn!("Pipette request channel failed, sleeping for 5 seconds to let outstanding work finish"); let mut c = CancelContext::new().with_timeout(Duration::from_secs(5)); diff --git a/petri/src/vm/runtime.rs b/petri/src/vm/runtime.rs index f686eee338..38a98462b5 100644 --- a/petri/src/vm/runtime.rs +++ b/petri/src/vm/runtime.rs @@ -12,6 +12,7 @@ use futures::FutureExt; use futures_concurrency::future::Race; use hvlite_defs::rpc::PulseSaveRestoreError; use hyperv_ic_resources::shutdown::ShutdownRpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::CancelContext; use mesh::Receiver; @@ -423,14 +424,15 @@ impl PetriVmInner { async fn verify_save_restore(&self) -> anyhow::Result<()> { for i in 0..2 { - let result = self.worker.pulse_save_restore().await?; + let result = self.worker.pulse_save_restore().await; match result { Ok(()) => {} - Err(PulseSaveRestoreError::ResetNotSupported) => { + Err(RpcError::Channel(err)) => return Err(err.into()), + Err(RpcError::Call(PulseSaveRestoreError::ResetNotSupported)) => { tracing::warn!("Reset not supported, could not test save + restore."); break; } - Err(PulseSaveRestoreError::Other(err)) => { + Err(RpcError::Call(PulseSaveRestoreError::Other(err))) => { return Err(anyhow::Error::from(err)) .context(format!("Save + restore {i} failed.")); } diff --git a/petri/src/worker.rs b/petri/src/worker.rs index 71d62e5d2a..77aafe24f2 100644 --- a/petri/src/worker.rs +++ b/petri/src/worker.rs @@ -6,6 +6,7 @@ use hvlite_defs::rpc::PulseSaveRestoreError; use hvlite_defs::rpc::VmRpc; use hvlite_defs::worker::VmWorkerParameters; use hvlite_defs::worker::VM_WORKER; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh_worker::WorkerHandle; use mesh_worker::WorkerHost; @@ -42,7 +43,7 @@ impl Worker { )) } - pub(crate) async fn resume(&self) -> Result { + pub(crate) async fn resume(&self) -> Result { self.rpc.call(VmRpc::Resume, ()).await } @@ -51,10 +52,8 @@ impl Worker { Ok(()) } - pub(crate) async fn pulse_save_restore( - &self, - ) -> Result, mesh::RecvError> { - self.rpc.call(VmRpc::PulseSaveRestore, ()).await + pub(crate) async fn pulse_save_restore(&self) -> Result<(), RpcError> { + self.rpc.call_failable(VmRpc::PulseSaveRestore, ()).await } pub(crate) async fn restart_openhcl( diff --git a/support/mesh/mesh_channel/src/error.rs b/support/mesh/mesh_channel/src/error.rs index bf32364180..3d9df098c4 100644 --- a/support/mesh/mesh_channel/src/error.rs +++ b/support/mesh/mesh_channel/src/error.rs @@ -3,12 +3,10 @@ //! Remotable errors. -use crate::RecvError; use mesh_protobuf::EncodeAs; use mesh_protobuf::Protobuf; use std::fmt; use std::fmt::Display; -use thiserror::Error; /// An error that can be remoted across a mesh channel. /// @@ -115,34 +113,3 @@ impl std::error::Error for DecodedError { /// Alias for a [`Result`] with a [`RemoteError`] error. pub type RemoteResult = Result; - -/// An error from an RPC call, via -/// [`RpcSend::call_failable`](super::rpc::RpcSend::call_failable). -#[derive(Debug, Error)] -pub enum RpcError { - #[error(transparent)] - Call(E), - #[error(transparent)] - Channel(RecvError), -} - -/// Extension trait to [`Result`] for folding `Result, RecvError>` -/// to `RpcError`. -pub trait RemoteResultExt { - type Flattened; - - /// Flattens the result into `RpcError`. - fn flatten(self) -> Self::Flattened; -} - -impl RemoteResultExt for Result, RecvError> { - type Flattened = Result>; - - fn flatten(self) -> Self::Flattened { - match self { - Ok(Ok(t)) => Ok(t), - Ok(Err(e)) => Err(RpcError::Call(e)), - Err(e) => Err(RpcError::Channel(e)), - } - } -} diff --git a/support/mesh/mesh_channel/src/rpc.rs b/support/mesh/mesh_channel/src/rpc.rs index 589309f722..11f397ea92 100644 --- a/support/mesh/mesh_channel/src/rpc.rs +++ b/support/mesh/mesh_channel/src/rpc.rs @@ -5,17 +5,18 @@ use super::error::RemoteResult; use crate::error::RemoteError; -use crate::error::RemoteResultExt; -use crate::error::RpcError; use crate::oneshot; use crate::OneshotReceiver; use crate::OneshotSender; +use crate::RecvError; use mesh_node::message::MeshField; use mesh_protobuf::Protobuf; +use std::convert::Infallible; use std::future::Future; use std::pin::Pin; use std::task::ready; use std::task::Poll; +use thiserror::Error; /// An RPC message for a request with input of type `I` and output of type `R`. /// The receiver of the message should process the request and return results @@ -25,12 +26,31 @@ use std::task::Poll; bound = "I: 'static + MeshField + Send, R: 'static + MeshField + Send", resource = "mesh_node::resource::Resource" )] -pub struct Rpc(pub I, pub OneshotSender); +pub struct Rpc(I, OneshotSender); /// An RPC message with a failable result. pub type FailableRpc = Rpc>; impl Rpc { + /// Returns a new RPC message with `input` and no one listening for the + /// result. + pub fn detached(input: I) -> Self { + let (result_send, _) = oneshot(); + Rpc(input, result_send) + } + + /// Returns the input to the RPC. + pub fn input(&self) -> &I { + &self.0 + } + + /// Splits the RPC into its input and an input-less RPC. This is useful when + /// the input is needed in one place but the RPC will be completed in + /// another. + pub fn split(self) -> (I, Rpc<(), R>) { + (self.0, Rpc((), self.1)) + } + /// Handles an RPC request by calling `f` and sending the result to the /// initiator. pub fn handle_sync(self, f: F) @@ -101,12 +121,12 @@ impl Rpc> { } /// A trait implemented by objects that can send RPC requests. -pub trait RpcSend { +pub trait RpcSend: Sized { /// The message type for this sender. type Message; /// Send an RPC request. - fn send_rpc(&self, message: Self::Message); + fn send_rpc(self, message: Self::Message); /// Issues a request and returns a channel to receive the result. /// @@ -127,14 +147,14 @@ pub trait RpcSend { /// assert_eq!(send.call(Request::Add, (3, 4)).await.unwrap(), 7); /// } /// ``` - fn call(&self, f: F, input: I) -> OneshotReceiver + fn call(self, f: F, input: I) -> PendingRpc where F: FnOnce(Rpc) -> Self::Message, R: 'static + Send, { let (result_send, result_recv) = oneshot(); self.send_rpc(f(Rpc(input, result_send))); - result_recv + PendingRpc(result_recv) } /// Issues a request and returns an object to receive the result. @@ -142,55 +162,145 @@ pub trait RpcSend { /// This is like [`RpcSend::call`], but for RPCs that return a [`Result`]. /// The returned object combines the channel error and the call's error into /// a single [`RpcError`] type, which makes it easier to handle errors. - fn call_failable(&self, f: F, input: I) -> RpcResultReceiver> + fn call_failable(self, f: F, input: I) -> PendingFailableRpc where F: FnOnce(Rpc>) -> Self::Message, T: 'static + Send, E: 'static + Send, { - RpcResultReceiver(self.call(f, input)) + PendingFailableRpc(self.call(f, input)) + } +} + +/// A trait implemented by objects that can try to send RPC requests but may +/// fail. +pub trait TryRpcSend: Sized { + /// The message type for this sender. + type Message; + /// The error type returned when sending an RPC request fails. + type Error; + + /// Tries to send an RPC request. + fn try_send_rpc(self, message: Self::Message) -> Result<(), Self::Error>; + + /// Issues a request and returns a channel to receive the result. + /// + /// `f` maps an [`Rpc`] object to the message type and is often an enum + /// variant name. + /// + /// `input` is the input to the call. + /// + /// # Example + /// + /// ```rust + /// # use mesh_channel::rpc::{Rpc, RpcSend}; + /// # use mesh_channel::Sender; + /// enum Request { + /// Add(Rpc<(u32, u32), u32>), + /// } + /// async fn add(send: &Sender) { + /// assert_eq!(send.call(Request::Add, (3, 4)).await.unwrap(), 7); + /// } + /// ``` + fn try_call(self, f: F, input: I) -> Result, Self::Error> + where + F: FnOnce(Rpc) -> Self::Message, + R: 'static + Send, + { + let (result_send, result_recv) = oneshot(); + self.try_send_rpc(f(Rpc(input, result_send)))?; + Ok(PendingRpc(result_recv)) + } + + /// Issues a request and returns an object to receive the result. + /// + /// This is like [`TryRpcSend::try_call`], but for RPCs that return a + /// [`Result`]. The returned object combines the channel error and the + /// call's error into a single [`RpcError`] type, which makes it easier to + /// handle errors. + fn try_call_failable( + self, + f: F, + input: I, + ) -> Result, Self::Error> + where + F: FnOnce(Rpc>) -> Self::Message, + T: 'static + Send, + E: 'static + Send, + { + Ok(PendingFailableRpc(self.try_call(f, input)?)) + } +} + +/// An error from an RPC call, via +/// [`RpcSend::call_failable`] or [`RpcSend::call`]. +#[derive(Debug, Error)] +pub enum RpcError { + #[error(transparent)] + Call(E), + #[error(transparent)] + Channel(RecvError), +} + +/// The result future of an [`RpcSend::call`] call. +#[must_use] +#[derive(Debug)] +pub struct PendingRpc(OneshotReceiver); + +impl Future for PendingRpc { + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + Poll::Ready(ready!(Pin::new(&mut self.get_mut().0).poll(cx)).map_err(RpcError::Channel)) } } /// The result future of an [`RpcSend::call_failable`] call. #[must_use] -pub struct RpcResultReceiver(OneshotReceiver); +#[derive(Debug)] +pub struct PendingFailableRpc(PendingRpc>); -impl Future for RpcResultReceiver> { +impl Future for PendingFailableRpc { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - Poll::Ready(ready!(Pin::new(&mut self.get_mut().0).poll(cx)).flatten()) + let r = ready!(Pin::new(&mut self.get_mut().0).poll(cx)); + match r { + Ok(Ok(t)) => Ok(t), + Ok(Err(e)) => Err(RpcError::Call(e)), + Err(RpcError::Channel(e)) => Err(RpcError::Channel(e)), + } + .into() } } -#[cfg(feature = "newchan")] -impl RpcSend for mesh_channel_core::Sender { +impl RpcSend for OneshotSender { type Message = T; - fn send_rpc(&self, message: T) { + fn send_rpc(self, message: T) { self.send(message); } } -#[cfg(not(feature = "newchan_spsc"))] -impl RpcSend for crate::Sender { +#[cfg(feature = "newchan")] +impl RpcSend for &mesh_channel_core::Sender { type Message = T; - fn send_rpc(&self, message: T) { + fn send_rpc(self, message: T) { self.send(message); } } -#[cfg(not(feature = "newchan_mpsc"))] -impl RpcSend for crate::MpscSender { +#[cfg(not(feature = "newchan_spsc"))] +impl RpcSend for &crate::Sender { type Message = T; fn send_rpc(&self, message: T) { self.send(message); } } -impl RpcSend for &T { - type Message = T::Message; - fn send_rpc(&self, message: T::Message) { - (*self).send_rpc(message); +#[cfg(not(feature = "newchan_mpsc"))] +impl RpcSend for &crate::MpscSender { + type Message = T; + fn send_rpc(self, message: T) { + self.send(message); } } diff --git a/support/mesh/mesh_worker/src/worker.rs b/support/mesh/mesh_worker/src/worker.rs index 2840094048..ca3c43e6c8 100644 --- a/support/mesh/mesh_worker/src/worker.rs +++ b/support/mesh/mesh_worker/src/worker.rs @@ -11,8 +11,8 @@ use futures::StreamExt; use futures_concurrency::stream::Merge; use inspect::Inspect; use mesh::error::RemoteError; -use mesh::error::RemoteResult; -use mesh::error::RemoteResultExt; +use mesh::rpc::FailableRpc; +use mesh::rpc::RpcSend; use mesh::MeshPayload; use std::fmt; use std::marker::PhantomData; @@ -85,7 +85,7 @@ pub enum WorkerRpc { Stop, /// Tear down and send the state necessary to restart on the provided /// channel. - Restart(mesh::OneshotSender>), + Restart(FailableRpc<(), T>), /// Inspect the worker. Inspect(inspect::Deferred), } @@ -446,9 +446,8 @@ impl WorkerLaunchRequest { } } LaunchType::Restart { send, events } => { - let (state_send, state_recv) = mesh::oneshot(); - send.send(WorkerRpc::Restart(state_send)); - let state = match block_on(state_recv).flatten() { + let state_recv = send.call_failable(WorkerRpc::Restart, ()); + let state = match block_on(state_recv) { Ok(state) => state, Err(err) => { self.events @@ -763,8 +762,8 @@ mod tests { while let Ok(req) = recv.recv().await { match req { WorkerRpc::Stop => break, - WorkerRpc::Restart(state_send) => { - state_send.send(Ok(TestWorkerState { value: self.value })); + WorkerRpc::Restart(rpc) => { + rpc.complete(Ok(TestWorkerState { value: self.value })); break; } WorkerRpc::Inspect(_deferred) => (), @@ -795,8 +794,8 @@ mod tests { while let Ok(req) = recv.recv().await { match req { WorkerRpc::Stop => break, - WorkerRpc::Restart(state_send) => { - state_send.send(Ok(())); + WorkerRpc::Restart(rpc) => { + rpc.complete(Ok(())); break; } WorkerRpc::Inspect(_deferred) => (), diff --git a/vm/devices/get/guest_crash_device/src/lib.rs b/vm/devices/get/guest_crash_device/src/lib.rs index a9574976be..4709617604 100644 --- a/vm/devices/get/guest_crash_device/src/lib.rs +++ b/vm/devices/get/guest_crash_device/src/lib.rs @@ -16,9 +16,8 @@ use get_protocol::crash::CRASHDUMP_GUID; use guid::Guid; use inspect::Inspect; use inspect::InspectMut; -use mesh::error::RemoteResult; -use mesh::error::RemoteResultExt; use mesh::rpc::FailableRpc; +use mesh::rpc::PendingFailableRpc; use mesh::rpc::RpcSend; use std::fs::File; use std::io::Seek; @@ -93,7 +92,7 @@ enum ProtocolState { enum DumpState { OpeningFile { - recv: mesh::OneshotReceiver>, + recv: PendingFailableRpc, }, Writing { file: File, @@ -227,7 +226,7 @@ impl GuestCrashDevice { } crash::MessageType::REQUEST_NIX_DUMP_START_V1 => { let (send, recv) = mesh::oneshot(); - let recv = self.request_dump.call(|x| x, recv); + let recv = self.request_dump.call_failable(|x| x, recv); channel.state = ProtocolState::DumpRequested { activity_id: header.activity_id, done: send, @@ -245,7 +244,7 @@ impl GuestCrashDevice { let DumpState::OpeningFile { recv } = state else { unreachable!() }; - let status = match recv.await.flatten() { + let status = match recv.await { Ok(file) => { *state = DumpState::Writing { file, diff --git a/vm/devices/get/guest_emulation_device/src/lib.rs b/vm/devices/get/guest_emulation_device/src/lib.rs index 827bbe6c4d..8333b8b514 100644 --- a/vm/devices/get/guest_emulation_device/src/lib.rs +++ b/vm/devices/get/guest_emulation_device/src/lib.rs @@ -189,7 +189,7 @@ pub struct GuestEmulationDevice { #[inspect(skip)] guest_request_recv: mesh::Receiver, #[inspect(skip)] - waiting_for_vtl0_start: Vec>>, + waiting_for_vtl0_start: Vec>>, vmgs: Option, @@ -311,7 +311,7 @@ pub struct GedChannel { #[inspect(with = "Option::is_some")] vtl0_start_report: Option>, #[inspect(with = "Option::is_some")] - modify: Option>>, + modify: Option>>, // TODO: allow unused temporarily as a follow up change will use it to // implement AK cert renewal. #[inspect(skip)] @@ -320,7 +320,7 @@ pub struct GedChannel { } struct InProgressSave { - response: mesh::OneshotSender>, + rpc: Rpc<(), Result<(), SaveRestoreError>>, buffer: Vec, } @@ -491,22 +491,23 @@ impl GedChannel { ) -> Result<(), Error> { match guest_request { GuestEmulationRequest::WaitForConnect(rpc) => rpc.handle_sync(|()| ()), - GuestEmulationRequest::WaitForVtl0Start(Rpc((), response)) => { + GuestEmulationRequest::WaitForVtl0Start(rpc) => { if let Some(result) = self.vtl0_start_report.clone() { - response.send(result); + rpc.complete(result); } else { - state.waiting_for_vtl0_start.push(response); + state.waiting_for_vtl0_start.push(rpc); } } - GuestEmulationRequest::ModifyVtl2Settings(Rpc(data, response)) => { + GuestEmulationRequest::ModifyVtl2Settings(rpc) => { + let (data, response) = rpc.split(); if self.modify.is_some() { - response.send(Err(ModifyVtl2SettingsError::OperationInProgress)); + response.complete(Err(ModifyVtl2SettingsError::OperationInProgress)); return Ok(()); } // TODO: support larger payloads. if data.len() > MAX_PAYLOAD_SIZE { - response.send(Err(ModifyVtl2SettingsError::LargeSettingsNotSupported)); + response.complete(Err(ModifyVtl2SettingsError::LargeSettingsNotSupported)); return Ok(()); } @@ -524,7 +525,7 @@ impl GedChannel { self.modify = Some(response); } - GuestEmulationRequest::SaveGuestVtl2State(Rpc((), response)) => { + GuestEmulationRequest::SaveGuestVtl2State(rpc) => { let r = (|| { if self.save.is_some() { return Err(SaveRestoreError::OperationInProgress); @@ -551,11 +552,11 @@ impl GedChannel { match r { Ok(()) => { self.save = Some(InProgressSave { - response, + rpc, buffer: Vec::new(), }) } - Err(err) => response.send(Err(err)), + Err(err) => rpc.complete(Err(err)), } } }; @@ -903,7 +904,7 @@ impl GedChannel { if r.is_ok() { state.save_restore_buf = Some(save.buffer); } - save.response.send(r); + save.rpc.complete(r); } Ok(()) } @@ -1121,7 +1122,7 @@ impl GedChannel { _ => return Err(Error::InvalidFieldValue), }; for response in state.waiting_for_vtl0_start.drain(..) { - response.send(result.clone()); + response.complete(result.clone()); } self.vtl0_start_report = Some(result); Ok(()) @@ -1177,7 +1178,7 @@ impl GedChannel { } _ => return Err(Error::InvalidFieldValue), }; - modify.send(r); + modify.complete(r); Ok(()) } diff --git a/vm/devices/get/guest_emulation_transport/src/process_loop.rs b/vm/devices/get/guest_emulation_transport/src/process_loop.rs index 22b69c19ed..72b8c3b2e2 100644 --- a/vm/devices/get/guest_emulation_transport/src/process_loop.rs +++ b/vm/devices/get/guest_emulation_transport/src/process_loop.rs @@ -19,6 +19,9 @@ use guid::Guid; use inspect::Inspect; use inspect::InspectMut; use inspect_counters::Counter; +use mesh::rpc::Rpc; +use mesh::rpc::RpcError; +use mesh::rpc::TryRpcSend; use mesh::RecvError; use parking_lot::Mutex; use std::cmp::min; @@ -398,6 +401,15 @@ impl BufferedSender { } } +impl TryRpcSend for &mut BufferedSender { + type Message = T; + type Error = BufferedSenderFull; + + fn try_send_rpc(self, message: Self::Message) -> Result<(), Self::Error> { + self.send(message) + } +} + /// A variant of `Option>` for late-bound guest notification /// consumers that buffers a fixed-number of messages during the window between /// GET init and worker startup. @@ -508,7 +520,7 @@ struct GuestNotificationListeners { // pair of notifications that don't really act like "notifications" for the // foreseeable future... enum GuestNotificationResponse { - ModifyVtl2Settings(Result>, RecvError>), + ModifyVtl2Settings(Result<(), RpcError>>), } #[derive(Default, Inspect)] @@ -553,7 +565,7 @@ struct PipeChannels { enum WriteRequest { Message(Vec), - Flush(mesh::OneshotSender<()>), + Flush(Rpc<(), ()>), } impl HostRequestPipeAccess { @@ -780,7 +792,7 @@ impl ProcessLoop { } match self.write_recv.recv().await.unwrap() { WriteRequest::Message(message) => outgoing = message, - WriteRequest::Flush(send) => send.send(()), + WriteRequest::Flush(send) => send.complete(()), } } } @@ -953,7 +965,7 @@ impl ProcessLoop { /// for its response. fn push_basic_host_request_handler( &mut self, - req: mesh::rpc::Rpc, + req: Rpc, f: impl 'static + Send + FnOnce(I) -> Req, ) where Req: AsBytes + 'static + Send + Sync, @@ -973,10 +985,10 @@ impl ProcessLoop { match message { // GET infrastructure - not part of the GET protocol itself. // No direct interaction with the host. - Msg::FlushWrites(mesh::rpc::Rpc((), response)) => { + Msg::FlushWrites(rpc) => { self.pipe_channels .message_send - .send(WriteRequest::Flush(response)); + .send(WriteRequest::Flush(rpc)); } Msg::Inspect(req) => { req.inspect(self); @@ -1132,9 +1144,10 @@ impl ProcessLoop { Msg::SendServicingState(req) => self.push_host_request_handler(move |access| { req.handle_must_succeed(|data| request_send_servicing_state(access, data)) }), - Msg::CompleteStartVtl0(mesh::rpc::Rpc(input, res)) => { + Msg::CompleteStartVtl0(rpc) => { + let (input, res) = rpc.split(); self.complete_start_vtl0(input)?; - res.send(()); + res.complete(()); } // Host Notifications (don't require a response) @@ -1368,18 +1381,14 @@ impl ProcessLoop { vtl2_settings_buf: Vec, kind: get_protocol::GuestNotifications, ) -> Result<(), FatalError> { - let (result_send, result_recv) = mesh::oneshot(); - - let req = ModifyVtl2SettingsRequest(mesh::rpc::Rpc(vtl2_settings_buf, result_send)); - let res = result_recv + let res = self + .guest_notification_listeners + .vtl2_settings + .try_call_failable(ModifyVtl2SettingsRequest, vtl2_settings_buf) + .map_err(|_| FatalError::TooManyGuestNotifications(kind))? .map(GuestNotificationResponse::ModifyVtl2Settings) .boxed(); - self.guest_notification_listeners - .vtl2_settings - .send(req) - .map_err(|_| FatalError::TooManyGuestNotifications(kind))?; - self.guest_notification_responses.push(res); Ok(()) } @@ -1439,13 +1448,14 @@ impl ProcessLoop { fn complete_modify_vtl2_settings( &mut self, - result: Result>, RecvError>, + result: Result<(), RpcError>>, ) -> Result<(), FatalError> { - let errors = result.unwrap_or_else(|err| { - Err(vec![Vtl2SettingsErrorInfo::new( + let errors = result.map_err(|err| match err { + RpcError::Call(err) => err, + RpcError::Channel(err) => vec![Vtl2SettingsErrorInfo::new( underhill_config::Vtl2SettingsErrorCode::InternalFailure, err.to_string(), - )]) + )], }); let (status, errors_json) = match errors { diff --git a/vm/devices/hyperv_ic/src/shutdown.rs b/vm/devices/hyperv_ic/src/shutdown.rs index 80e4002b3c..43b91f2bfe 100644 --- a/vm/devices/hyperv_ic/src/shutdown.rs +++ b/vm/devices/hyperv_ic/src/shutdown.rs @@ -53,7 +53,7 @@ pub struct ShutdownChannel { pipe: MessagePipe, state: ChannelState, #[inspect(with = "Option::is_some")] - pending_shutdown: Option>, + pending_shutdown: Option>, } #[derive(Inspect)] @@ -153,8 +153,9 @@ impl ShutdownChannel { } ChannelState::Ready { ref mut state, .. } => match state { ReadyState::Ready => { - self.pending_shutdown = Some(rpc.1); - *state = ReadyState::SendShutdown(rpc.0); + let (input, rpc) = rpc.split(); + self.pending_shutdown = Some(rpc); + *state = ReadyState::SendShutdown(input); } ReadyState::SendShutdown { .. } | ReadyState::WaitShutdown => { rpc.complete(ShutdownResult::AlreadyInProgress) @@ -278,7 +279,7 @@ impl ShutdownChannel { ShutdownResult::Failed(status) }; if let Some(send) = self.pending_shutdown.take() { - send.send(result); + send.complete(result); } *state = ReadyState::Ready; } diff --git a/vm/devices/net/net_consomme/consomme/src/lib.rs b/vm/devices/net/net_consomme/consomme/src/lib.rs index 1a67433a23..d0d01eaa76 100644 --- a/vm/devices/net/net_consomme/consomme/src/lib.rs +++ b/vm/devices/net/net_consomme/consomme/src/lib.rs @@ -26,6 +26,7 @@ mod udp; mod windows; use inspect::InspectMut; +use mesh::rpc::RpcError; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use pal_async::driver::Driver; @@ -51,7 +52,7 @@ use thiserror::Error; pub enum ConsommeMessageError { /// Communication error with running instance. #[error("communication error")] - Mesh(mesh::RecvError), + Mesh(RpcError), /// Error executing request on current network instance. #[error("network err")] Network(DropReason), diff --git a/vm/devices/net/net_packet_capture/src/lib.rs b/vm/devices/net/net_packet_capture/src/lib.rs index b7dbb0ce1a..ef94c9bb89 100644 --- a/vm/devices/net/net_packet_capture/src/lib.rs +++ b/vm/devices/net/net_packet_capture/src/lib.rs @@ -310,7 +310,7 @@ impl Endpoint for PacketCaptureEndpoint { Message::PacketCaptureEndpointCommand( PacketCaptureEndpointCommand::PacketCapture(rpc), ) => { - let options = rpc.0; + let (options, response) = rpc.split(); let result = async { let id = &self.id; let start = match options.operation { @@ -341,7 +341,7 @@ impl Endpoint for PacketCaptureEndpoint { Err(e) => (Err(e), false), Ok(value) => (Ok(()), value), }; - rpc.1.send(result.map_err(RemoteError::new)); + response.complete(result.map_err(RemoteError::new)); if restart_required { break EndpointAction::RestartRequired; } diff --git a/vm/devices/net/netvsp/src/test.rs b/vm/devices/net/netvsp/src/test.rs index 802b19c784..252f8413ed 100644 --- a/vm/devices/net/netvsp/src/test.rs +++ b/vm/devices/net/netvsp/src/test.rs @@ -22,8 +22,8 @@ use hvdef::hypercall::HvGuestOsId; use hvdef::hypercall::HvGuestOsMicrosoft; use hvdef::hypercall::HvGuestOsMicrosoftIds; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; -use mesh::RecvError; use net_backend::null::NullEndpoint; use net_backend::DisconnectableEndpoint; use net_backend::Endpoint; @@ -393,12 +393,8 @@ impl TestNicDevice { req: impl FnOnce(Rpc) -> ChannelRequest, input: I, f: impl 'static + Send + FnOnce(R) -> ChannelResponse, - ) -> Result { - let (response, recv) = mesh::oneshot(); - self.offer_input - .request_send - .send((req)(Rpc(input, response))); - recv.await.map(f) + ) -> Result { + self.offer_input.request_send.call(req, input).await.map(f) } async fn connect_vmbus_channel(&mut self) -> TestNicChannel<'_> { @@ -1090,9 +1086,7 @@ impl TestVirtualFunctionState { pub async fn set_ready(&self, is_ready: bool) { let ready_callback = self.oneshot_ready_callback.lock().take(); if let Some(ready_callback) = ready_callback { - let (result_send, result_recv) = mesh::oneshot(); - ready_callback.send(Rpc(is_ready, result_send)); - result_recv.await.unwrap(); + ready_callback.call(|x| x, is_ready).await.unwrap(); } *self.is_ready.0.lock() = Some(is_ready); self.is_ready.1.notify(usize::MAX); diff --git a/vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs b/vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs index 07233e9875..44d8f355ae 100644 --- a/vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs +++ b/vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs @@ -22,6 +22,7 @@ use guestmem::GuestMemoryError; use inspect::Inspect; use inspect_counters::Counter; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::Cancel; use mesh::CancelContext; @@ -92,11 +93,7 @@ impl PendingCommands { } /// Inserts a command into the pending list, updating it with a new CID. - fn insert( - &mut self, - command: &mut spec::Command, - respond: mesh::OneshotSender, - ) { + fn insert(&mut self, command: &mut spec::Command, respond: Rpc<(), spec::Completion>) { let entry = self.commands.vacant_entry(); assert!(entry.key() < Self::MAX_CIDS); assert_eq!(self.next_cid_high_bits % Self::CID_SEQ_OFFSET, Wrapping(0)); @@ -109,7 +106,7 @@ impl PendingCommands { }); } - fn remove(&mut self, cid: u16) -> mesh::OneshotSender { + fn remove(&mut self, cid: u16) -> Rpc<(), spec::Completion> { let command = self .commands .try_remove((cid & Self::CID_KEY_MASK) as usize) @@ -152,7 +149,6 @@ impl PendingCommands { commands: commands .iter() .map(|state| { - let (send, mut _recv) = mesh::oneshot::(); // To correctly restore Slab we need both the command index, // inherited from command's CID, and the command itself. ( @@ -160,7 +156,7 @@ impl PendingCommands { (state.command.cdw0.cid() & Self::CID_KEY_MASK) as usize, PendingCommand { command: state.command, - respond: send, + respond: Rpc::detached(()), }, ) }) @@ -336,7 +332,7 @@ impl QueuePair { #[allow(missing_docs)] pub enum RequestError { #[error("queue pair is gone")] - Gone(#[source] mesh::RecvError), + Gone(#[source] RpcError), #[error("nvme error")] Nvme(#[source] NvmeError), #[error("memory error")] @@ -579,7 +575,7 @@ struct PendingCommand { // Keep the command around for diagnostics. command: spec::Command, #[inspect(skip)] - respond: mesh::OneshotSender, + respond: Rpc<(), spec::Completion>, } enum Req { @@ -659,7 +655,8 @@ impl QueueHandler { match event { Event::Request(req) => match req { - Req::Command(Rpc(mut command, respond)) => { + Req::Command(rpc) => { + let (mut command, respond) = rpc.split(); self.commands.insert(&mut command, respond); self.sq.write(command).unwrap(); self.stats.issued.increment(); @@ -679,7 +676,7 @@ impl QueueHandler { self.drain_after_restore = false; } self.sq.update_head(completion.sqhd); - respond.send(completion); + respond.complete(completion); self.stats.completed.increment(); } } diff --git a/vm/devices/storage/nvme/src/workers/coordinator.rs b/vm/devices/storage/nvme/src/workers/coordinator.rs index fb1c1b402e..2ddb1ea153 100644 --- a/vm/devices/storage/nvme/src/workers/coordinator.rs +++ b/vm/devices/storage/nvme/src/workers/coordinator.rs @@ -17,6 +17,7 @@ use guestmem::GuestMemory; use guid::Guid; use inspect::Inspect; use inspect::InspectMut; +use mesh::rpc::PendingRpc; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use pal_async::task::Spawn; @@ -39,9 +40,9 @@ pub struct NvmeWorkers { #[derive(Debug)] enum EnableState { Disabled, - Enabling(mesh::OneshotReceiver<()>), + Enabling(PendingRpc<()>), Enabled, - Resetting(mesh::OneshotReceiver<()>), + Resetting(PendingRpc<()>), } impl InspectMut for NvmeWorkers { diff --git a/vm/devices/vmbus/vmbus_channel/src/channel.rs b/vm/devices/vmbus/vmbus_channel/src/channel.rs index 1bbb99cb40..a3f67f2895 100644 --- a/vm/devices/vmbus/vmbus_channel/src/channel.rs +++ b/vm/devices/vmbus/vmbus_channel/src/channel.rs @@ -427,9 +427,6 @@ pub enum ChannelRestoreError { /// Failed to enable subchannels. #[error("failed to enable subchannels")] EnablingSubchannels(#[source] anyhow::Error), - /// Failed to send restore request. - #[error("failed to send restore request")] - SendingRequest(#[source] RecvError), /// Failed to restore vmbus channel. #[error("failed to restore vmbus channel")] RestoreError(#[source] anyhow::Error), @@ -570,8 +567,8 @@ impl Device { self.handle_gpadl(gpadl.id, gpadl.count, gpadl.buf, channel_idx); true }), - ChannelRequest::TeardownGpadl(Rpc(id, response_send)) => { - self.handle_teardown_gpadl(id, response_send, channel_idx); + ChannelRequest::TeardownGpadl(rpc) => { + self.handle_teardown_gpadl(rpc, channel_idx); } ChannelRequest::Modify(rpc) => { rpc.handle(|req| async { @@ -644,16 +641,12 @@ impl Device { } } - fn handle_teardown_gpadl( - &mut self, - id: GpadlId, - response_send: mesh::OneshotSender<()>, - channel_idx: usize, - ) { + fn handle_teardown_gpadl(&mut self, rpc: Rpc, channel_idx: usize) { + let id = *rpc.input(); if let Some(f) = self.gpadl_map.remove( id, Box::new(move || { - response_send.send(()); + rpc.complete(()); }), ) { f() @@ -787,9 +780,8 @@ impl Device { let mut results = Vec::with_capacity(states.len()); for (channel_idx, open) in states.iter().copied().enumerate() { let result = self.server_requests[channel_idx] - .call(ChannelServerRequest::Restore, open) + .call_failable(ChannelServerRequest::Restore, open) .await - .map_err(ChannelRestoreError::SendingRequest)? .map_err(|err| ChannelRestoreError::RestoreError(err.into()))?; assert!(open == result.open_request.is_some()); diff --git a/vm/devices/vmbus/vmbus_channel/src/offer.rs b/vm/devices/vmbus/vmbus_channel/src/offer.rs index df8340b38c..7241ccfb0c 100644 --- a/vm/devices/vmbus/vmbus_channel/src/offer.rs +++ b/vm/devices/vmbus/vmbus_channel/src/offer.rs @@ -104,7 +104,8 @@ impl Offer { let mut open_done = None; while let Ok(request) = request_recv.recv().await { match request { - ChannelRequest::Open(Rpc(open_request, response_send)) => { + ChannelRequest::Open(rpc) => { + let (open_request, response_send) = rpc.split(); let done = Arc::new(AtomicBool::new(false)); send.send(OpenMessage { open_request, @@ -113,7 +114,8 @@ impl Offer { }); open_done = Some(done); } - ChannelRequest::Close(Rpc((), _response_send)) => { + ChannelRequest::Close(rpc) => { + let _response_send = rpc; // TODO: figure out if we should really just drop this here. open_done .take() .expect("channel must be open") @@ -137,11 +139,12 @@ impl Offer { } }) } - ChannelRequest::TeardownGpadl(Rpc(id, response_send)) => { + ChannelRequest::TeardownGpadl(rpc) => { + let (id, response_send) = rpc.split(); if let Some(f) = gpadls.remove( id, Box::new(move || { - response_send.send(()); + response_send.complete(()); }), ) { f(); @@ -217,18 +220,18 @@ struct OpenMessage { response: OpenResponse, } -struct OpenResponse(Option>); +struct OpenResponse(Option>); impl OpenResponse { fn respond(mut self, open: bool) { - self.0.take().unwrap().send(open) + self.0.take().unwrap().complete(open) } } impl Drop for OpenResponse { fn drop(&mut self) { - if let Some(send) = self.0.take() { - send.send(false); + if let Some(rpc) = self.0.take() { + rpc.complete(false); } } } diff --git a/vm/devices/vmbus/vmbus_client/src/lib.rs b/vm/devices/vmbus/vmbus_client/src/lib.rs index c9d2e7d364..f511180523 100644 --- a/vm/devices/vmbus/vmbus_client/src/lib.rs +++ b/vm/devices/vmbus/vmbus_client/src/lib.rs @@ -383,7 +383,7 @@ enum ChannelState { /// The channel has been offered to the client. Offered, /// The channel has requested the server to be opened. - Opening(mesh::OneshotSender), + Opening(Rpc<(), bool>), /// The channel has been successfully opened. Opened, } @@ -402,7 +402,7 @@ struct Channel { offer: protocol::OfferChannel, response_send: mesh::Sender, state: ChannelState, - modify_response_send: Option>, + modify_response_send: Option>, } impl std::fmt::Debug for Channel { @@ -491,7 +491,7 @@ impl ClientTask { FeatureFlags::new() }; - let request = &rpc.0; + let request = rpc.input(); tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting"); let target_info = protocol::TargetInfo::new(SINT, VTL, feature_flags); @@ -552,7 +552,7 @@ impl ClientTask { return; } - let message = protocol::ModifyConnection::from(request.0); + let message = protocol::ModifyConnection::from(*request.input()); self.modify_request = Some(request); self.inner.send(&message); } @@ -781,7 +781,7 @@ impl ClientTask { unreachable!("validated above"); }; - sender.send(gpadl_created) + sender.complete(gpadl_created) } fn handle_open_result(&mut self, result: protocol::OpenResult) { @@ -811,7 +811,7 @@ impl ClientTask { return; }; - rpc.send(channel_opened); + rpc.complete(channel_opened); } fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) { @@ -893,7 +893,7 @@ impl ClientTask { return; }; - sender.send(response.status); + sender.complete(response.status); } fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) { @@ -981,7 +981,7 @@ impl ClientTask { } tracing::info!(channel_id = channel_id.0, "opening channel on host"); - let request = &rpc.0; + let request = rpc.input(); let open_data = &request.open_data; let open_channel = protocol::OpenChannel { @@ -1014,15 +1014,16 @@ impl ClientTask { self.inner.send(&open_channel); } - self.inner.channels.get_mut(&channel_id).unwrap().state = ChannelState::Opening(rpc.1); + self.inner.channels.get_mut(&channel_id).unwrap().state = + ChannelState::Opening(rpc.split().1); } fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: Rpc) { - let request = &rpc.0; + let (request, rpc) = rpc.split(); if self .inner .gpadls - .insert((channel_id, request.id), GpadlState::Offered(rpc.1)) + .insert((channel_id, request.id), GpadlState::Offered(rpc)) .is_some() { panic!( @@ -1129,12 +1130,12 @@ impl ClientTask { panic!("duplicate channel modify request {channel_id:?}"); } - channel.modify_response_send = Some(rpc.1); - let request = &rpc.0; + let (request, response) = rpc.split(); + channel.modify_response_send = Some(response); let payload = match request { ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel { channel_id, - target_vp: *target_vp, + target_vp, }, }; @@ -1311,7 +1312,7 @@ impl Inspect for ClientTask { #[derive(Debug)] enum GpadlState { /// GpadlHeader has been sent to the host. - Offered(mesh::OneshotSender), + Offered(Rpc<(), bool>), /// Host has responded with GpadlCreated. Created, /// GpadlTeardown message has been sent to the host. @@ -1795,8 +1796,8 @@ mod tests { let (server, mut client, _) = test_init(); let channel = server.get_channel(&mut client).await; - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Open(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Open, OpenRequest { open_data: OpenData { target_vp: 0, @@ -1808,8 +1809,7 @@ mod tests { }, flags: OpenChannelFlags::new(), }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -1846,8 +1846,8 @@ mod tests { let (server, mut client, _) = test_init(); let channel = server.get_channel(&mut client).await; - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Open(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Open, OpenRequest { open_data: OpenData { target_vp: 0, @@ -1859,8 +1859,7 @@ mod tests { }, flags: OpenChannelFlags::new(), }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -1899,11 +1898,10 @@ mod tests { // N.B. A real server requires the channel to be open before sending this, but the test // server doesn't care. - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Modify(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Modify, ModifyRequest::TargetVp { target_vp: 1 }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -2016,15 +2014,14 @@ mod tests { async fn test_gpadl_success() { let (server, mut client, _) = test_init(); let mut channel = server.get_channel(&mut client).await; - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Gpadl(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Gpadl, GpadlRequest { id: GpadlId(1), count: 1, buf: vec![5], }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -2079,15 +2076,14 @@ mod tests { async fn test_gpadl_fail() { let (server, mut client, _) = test_init(); let channel = server.get_channel(&mut client).await; - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Gpadl(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Gpadl, GpadlRequest { id: GpadlId(1), count: 1, buf: vec![7], }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -2121,15 +2117,14 @@ mod tests { let mut channel = server.get_channel(&mut client).await; let channel_id = ChannelId(0); let gpadl_id = GpadlId(1); - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Gpadl(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Gpadl, GpadlRequest { id: gpadl_id, count: 1, buf: vec![3], }, - send, - ))); + ); assert_eq!( server.next().unwrap(), diff --git a/vm/devices/vmbus/vmbus_relay/src/hvsock.rs b/vm/devices/vmbus/vmbus_relay/src/hvsock.rs index dc6b550c3b..e65fc85026 100644 --- a/vm/devices/vmbus/vmbus_relay/src/hvsock.rs +++ b/vm/devices/vmbus/vmbus_relay/src/hvsock.rs @@ -70,6 +70,22 @@ mod tests { use vmbus_server::Guid; use zerocopy::FromZeroes; + struct FakeRpcSender(Option); + + impl RpcSend for &mut FakeRpcSender { + type Message = T; + + fn send_rpc(self, message: Self::Message) { + self.0 = Some(message); + } + } + + fn dummy_rpc(t: T) -> Rpc { + let mut sender = FakeRpcSender(None); + let _ = sender.call(|x| x, t); + sender.0.unwrap() + } + #[test] fn test_check_result() { let mut tracker = HvsockRequestTracker::new(); diff --git a/vm/devices/vmbus/vmbus_relay/src/lib.rs b/vm/devices/vmbus/vmbus_relay/src/lib.rs index c8b26f79fd..82ac48152f 100644 --- a/vm/devices/vmbus/vmbus_relay/src/lib.rs +++ b/vm/devices/vmbus/vmbus_relay/src/lib.rs @@ -352,7 +352,7 @@ struct RelayChannel { /// State used to relay host-to-guest interrupts. interrupt_relay: Option, /// RPCs for gpadls that are waiting for a torndown message. - gpadls_tearing_down: HashMap>, + gpadls_tearing_down: HashMap>, } struct RelayChannelTask { @@ -434,7 +434,7 @@ impl RelayChannelTask { } fn handle_gpadl_teardown(&mut self, rpc: Rpc) { - let gpadl_id = rpc.0; + let (gpadl_id, rpc) = rpc.split(); tracing::trace!(gpadl_id = gpadl_id.0, "Tearing down GPADL"); let _ = &self diff --git a/vm/devices/vmbus/vmbus_server/src/lib.rs b/vm/devices/vmbus/vmbus_server/src/lib.rs index 4e9a7ce718..74828c4963 100644 --- a/vm/devices/vmbus/vmbus_server/src/lib.rs +++ b/vm/devices/vmbus/vmbus_server/src/lib.rs @@ -33,13 +33,11 @@ use futures::StreamExt; use guestmem::GuestMemory; use hvdef::Vtl; use inspect::Inspect; -use mesh::error::RemoteError; -use mesh::error::RemoteResult; -use mesh::error::RemoteResultExt; use mesh::payload::Protobuf; +use mesh::rpc::FailableRpc; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; -use mesh::RecvError; use pal_async::task::Spawn; use pal_async::task::Task; use pal_event::Event; @@ -207,7 +205,7 @@ pub struct OfferInfo { #[derive(mesh::MeshPayload)] pub enum OfferRequest { - Offer(OfferInfo, mesh::OneshotSender>), + Offer(FailableRpc), } impl Inspect for VmbusServer { @@ -623,13 +621,13 @@ struct ServerTaskInner { hvsock_send: mesh::Sender, channels: HashMap, channel_responses: FuturesUnordered< - Pin)>>>, + Pin)>>>, >, external_server_send: Option>, relay_send: mesh::Sender, channel_bitmap: Option>, shared_event_port: Option>, - reset_done: Option>, + reset_done: Option>, enable_mnf: bool, } @@ -722,7 +720,7 @@ impl ServerTask { &mut self, offer_id: OfferId, seq: u64, - response: Result, + response: Result, ) { // Validate the sequence to ensure the response is not for a revoked channel. let channel = self @@ -847,9 +845,9 @@ impl ServerTask { fn handle_request(&mut self, request: VmbusRequest) { tracing::debug!(?request, "handle_request"); match request { - VmbusRequest::Reset(Rpc((), done)) => { + VmbusRequest::Reset(rpc) => { assert!(self.inner.reset_done.is_none()); - self.inner.reset_done = Some(done); + self.inner.reset_done = Some(rpc); self.server.with_notifier(&mut self.inner).reset(); // TODO: clear pending messages and other requests. } @@ -1000,8 +998,8 @@ impl ServerTask { } r = self.offer_recv.select_next_some() => { match r { - OfferRequest::Offer(request, response) => { - response.send(self.handle_offer(request).map_err(RemoteError::new)) + OfferRequest::Offer(rpc) => { + rpc.handle_failable_sync(|request| { self.handle_offer(request) }) }, } } @@ -1156,10 +1154,9 @@ impl channels::Notifier for ServerTaskInner { req: impl FnOnce(Rpc) -> ChannelRequest, input: I, f: impl 'static + Send + FnOnce(R) -> ChannelResponse, - ) -> Pin)>>> + ) -> Pin)>>> { - let (response, recv) = mesh::oneshot(); - channel.send.send((req)(Rpc(input, response))); + let recv = channel.send.call(req, input); let seq = channel.seq; Box::pin(async move { let r = recv.await.map(f); @@ -1337,7 +1334,7 @@ impl channels::Notifier for ServerTaskInner { } let done = self.reset_done.take().expect("must have requested reset"); - done.send(()); + done.complete(()); } } @@ -1503,9 +1500,9 @@ impl VmbusServerControl { /// This is used by the relay to forward the host's parameters. pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result { let flags = offer_info.params.flags; - let (send, recv) = mesh::oneshot(); - self.send.send(OfferRequest::Offer(offer_info, send)); - recv.await.flatten()?; + self.send + .call_failable(OfferRequest::Offer, offer_info) + .await?; Ok(OfferResources::new( self.mem.clone(), if flags.confidential_ring_buffer() || flags.confidential_external_memory() { @@ -1713,7 +1710,7 @@ mod tests { panic!("Wrong request"); }; - f(&rpc.0); + f(rpc.input()); rpc.complete(true); } diff --git a/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs b/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs index 3d0643cc59..e1c1997a46 100644 --- a/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs +++ b/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs @@ -17,8 +17,7 @@ use anyhow::Context; use futures::stream::SelectAll; use futures::StreamExt; use guestmem::GuestMemory; -use mesh::error::RemoteResultExt; -use mesh::rpc::Rpc; +use mesh::rpc::RpcSend; use mesh::Cancel; use mesh::CancelContext; use pal_async::driver::SpawnDriver; @@ -238,18 +237,17 @@ impl ProxyTask { }; let (request_send, request_recv) = mesh::channel(); let (server_request_send, server_request_recv) = mesh::channel(); - let (send, recv) = mesh::oneshot(); - self.server.send.send(OfferRequest::Offer( + let recv = self.server.send.call_failable( + OfferRequest::Offer, OfferInfo { params: offer.into(), event: Interrupt::from_event(incoming_event), request_send, server_request_recv, }, - send, - )); + ); - let (request_recv, server_request_send) = match recv.await.flatten() { + let (request_recv, server_request_send) = match recv.await { Ok(()) => (Some(request_recv), Some(server_request_send)), Err(err) => { // Currently there is no way to propagate this failure. @@ -356,7 +354,7 @@ impl ProxyTask { } // Modifying the target VP is handle by the server, there is nothing the proxy // driver needs to do. - ChannelRequest::Modify(Rpc(_, response)) => response.send(0), + ChannelRequest::Modify(rpc) => rpc.complete(0), } } None => { diff --git a/vm/vmgs/vmgs_broker/src/client.rs b/vm/vmgs/vmgs_broker/src/client.rs index ea713d08a7..c0ca80a21b 100644 --- a/vm/vmgs/vmgs_broker/src/client.rs +++ b/vm/vmgs/vmgs_broker/src/client.rs @@ -6,6 +6,7 @@ use crate::broker::VmgsBrokerRpc; use inspect::Inspect; +use mesh_channel::rpc::RpcError; use mesh_channel::rpc::RpcSend; use thiserror::Error; use tracing::instrument; @@ -18,12 +19,21 @@ use vmgs_format::FileId; pub enum VmgsClientError { /// VMGS broker is offline #[error("broker is offline")] - BrokerOffline(#[from] mesh_channel::RecvError), + BrokerOffline(#[from] RpcError), /// VMGS error #[error("vmgs error")] Vmgs(#[from] vmgs::Error), } +impl From> for VmgsClientError { + fn from(value: RpcError) -> Self { + match value { + RpcError::Call(e) => VmgsClientError::Vmgs(e), + RpcError::Channel(e) => VmgsClientError::BrokerOffline(RpcError::Channel(e)), + } + } +} + /// Client to interact with a backend-agnostic VMGS instance. #[derive(Clone)] pub struct VmgsClient { @@ -42,8 +52,8 @@ impl VmgsClient { pub async fn get_file_info(&self, file_id: FileId) -> Result { let res = self .control - .call(VmgsBrokerRpc::GetFileInfo, file_id) - .await??; + .call_failable(VmgsBrokerRpc::GetFileInfo, file_id) + .await?; Ok(res) } @@ -53,8 +63,8 @@ impl VmgsClient { pub async fn read_file(&self, file_id: FileId) -> Result, VmgsClientError> { let res = self .control - .call(VmgsBrokerRpc::ReadFile, file_id) - .await??; + .call_failable(VmgsBrokerRpc::ReadFile, file_id) + .await?; Ok(res) } @@ -66,8 +76,8 @@ impl VmgsClient { #[instrument(skip_all, fields(file_id))] pub async fn write_file(&self, file_id: FileId, buf: Vec) -> Result<(), VmgsClientError> { self.control - .call(VmgsBrokerRpc::WriteFile, (file_id, buf)) - .await??; + .call_failable(VmgsBrokerRpc::WriteFile, (file_id, buf)) + .await?; Ok(()) } @@ -83,8 +93,8 @@ impl VmgsClient { buf: Vec, ) -> Result<(), VmgsClientError> { self.control - .call(VmgsBrokerRpc::WriteFileEncrypted, (file_id, buf)) - .await??; + .call_failable(VmgsBrokerRpc::WriteFileEncrypted, (file_id, buf)) + .await?; Ok(()) } diff --git a/vmm_core/src/partition_unit/vp_set.rs b/vmm_core/src/partition_unit/vp_set.rs index db171a0dee..c56215df5f 100644 --- a/vmm_core/src/partition_unit/vp_set.rs +++ b/vmm_core/src/partition_unit/vp_set.rs @@ -19,6 +19,7 @@ use futures_concurrency::stream::Merge; use guestmem::GuestMemory; use hvdef::Vtl; use inspect::Inspect; +use mesh::rpc::RpcError; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use pal_async::local::block_with_io; @@ -890,7 +891,7 @@ pub struct RegisterSetError(&'static str, #[source] anyhow::Error); #[derive(Debug, Error)] #[error("the vp runner was dropped")] -struct RunnerGoneError(#[source] mesh::RecvError); +struct RunnerGoneError(#[source] RpcError); #[cfg(feature = "gdb")] impl VpSet { diff --git a/vmm_core/state_unit/src/lib.rs b/vmm_core/state_unit/src/lib.rs index db0ad3fad8..b1a3561b34 100644 --- a/vmm_core/state_unit/src/lib.rs +++ b/vmm_core/state_unit/src/lib.rs @@ -36,10 +36,11 @@ use futures::StreamExt; use futures_concurrency::stream::Merge; use inspect::Inspect; use inspect::InspectMut; -use mesh::oneshot; use mesh::payload::Protobuf; use mesh::rpc::FailableRpc; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; +use mesh::rpc::RpcSend; use mesh::MeshPayload; use mesh::Receiver; use mesh::Sender; @@ -257,7 +258,7 @@ pub struct NameInUse(Arc); struct UnitRecvError { name: Arc, #[source] - source: mesh::RecvError, + source: RpcError, } #[derive(Debug, Clone)] @@ -893,7 +894,6 @@ fn state_change( request: impl FnOnce(Rpc) -> StateRequest, input: Option, ) -> impl Future, UnitRecvError>> { - let (response_send, response_recv) = oneshot(); let send = unit.send.clone(); async move { @@ -901,8 +901,8 @@ fn state_change( let span = tracing::info_span!("device_state_change", device = name.as_ref()); async move { let start = Instant::now(); - send.send((request)(Rpc(input, response_send))); - let r = response_recv + let r = send + .call(request, input) .await .map_err(|err| UnitRecvError { name, source: err }); tracing::debug!(duration = ?Instant::now() - start, "device state change complete"); diff --git a/workers/debug_worker/src/gdb/mod.rs b/workers/debug_worker/src/gdb/mod.rs index 51801d8e05..99b54ad7b7 100644 --- a/workers/debug_worker/src/gdb/mod.rs +++ b/workers/debug_worker/src/gdb/mod.rs @@ -4,7 +4,6 @@ use anyhow::Context; use futures::executor::block_on; use gdbstub::common::Tid; -use mesh::error::RemoteResultExt; use mesh::rpc::RpcSend; use std::num::NonZeroUsize; use vmm_core_defs::debug_rpc::DebugRequest; @@ -66,11 +65,10 @@ impl VmProxy { #[allow(dead_code)] // TODO: add monitor command to inspect physical memory? fn read_guest_physical_memory(&mut self, gpa: u64, data: &mut [u8]) -> anyhow::Result<()> { - let buf = block_on(self.req_chan.call( + let buf = block_on(self.req_chan.call_failable( DebugRequest::ReadMemory, (GuestAddress::Gpa(gpa), data.len()), )) - .flatten() .context("failed to read memory")?; data.copy_from_slice( buf.get(..data.len()) @@ -86,11 +84,10 @@ impl VmProxy { gva: u64, data: &mut [u8], ) -> anyhow::Result<()> { - let buf = block_on(self.req_chan.call( + let buf = block_on(self.req_chan.call_failable( DebugRequest::ReadMemory, (GuestAddress::Gva { vp: vp_index, gva }, data.len()), )) - .flatten() .context("failed to read memory")?; data.copy_from_slice( buf.get(..data.len()) @@ -106,11 +103,10 @@ impl VmProxy { gva: u64, data: &[u8], ) -> anyhow::Result<()> { - block_on(self.req_chan.call( + block_on(self.req_chan.call_failable( DebugRequest::WriteMemory, (GuestAddress::Gva { vp: vp_index, gva }, data.to_vec()), )) - .flatten() .context("failed to write memory")?; Ok(()) } diff --git a/workers/debug_worker/src/gdb/targets/base.rs b/workers/debug_worker/src/gdb/targets/base.rs index c88862754c..d0ec2d1263 100644 --- a/workers/debug_worker/src/gdb/targets/base.rs +++ b/workers/debug_worker/src/gdb/targets/base.rs @@ -14,7 +14,6 @@ use gdbstub::target::ext::base::multithread::MultiThreadSingleStep; use gdbstub::target::ext::base::multithread::MultiThreadSingleStepOps; use gdbstub::target::ext::base::single_register_access::SingleRegisterAccess; use gdbstub::target::TargetResult; -use mesh::error::RemoteResultExt; use mesh::rpc::RpcSend; use vmm_core_defs::debug_rpc::DebugRequest; use vmm_core_defs::debug_rpc::DebugState; @@ -23,9 +22,12 @@ impl MultiThreadBase for VmTarget<'_, T> { fn read_registers(&mut self, regs: &mut T::Registers, tid: Tid) -> TargetResult<(), Self> { let vp_index = self.0.tid_to_vp(tid).fatal()?; - let state = block_on(self.0.req_chan.call(DebugRequest::GetVpState, vp_index)) - .flatten() - .nonfatal()?; + let state = block_on( + self.0 + .req_chan + .call_failable(DebugRequest::GetVpState, vp_index), + ) + .nonfatal()?; T::registers(&state, regs)?; Ok(()) @@ -34,18 +36,20 @@ impl MultiThreadBase for VmTarget<'_, T> { fn write_registers(&mut self, regs: &T::Registers, tid: Tid) -> TargetResult<(), Self> { let vp_index = self.0.tid_to_vp(tid).fatal()?; - let mut state = block_on(self.0.req_chan.call(DebugRequest::GetVpState, vp_index)) - .flatten() - .nonfatal()?; + let mut state = block_on( + self.0 + .req_chan + .call_failable(DebugRequest::GetVpState, vp_index), + ) + .nonfatal()?; T::update_registers(&mut state, regs)?; block_on( self.0 .req_chan - .call(DebugRequest::SetVpState, (vp_index, state)), + .call_failable(DebugRequest::SetVpState, (vp_index, state)), ) - .flatten() .nonfatal()?; Ok(()) @@ -109,9 +113,12 @@ impl SingleRegisterAccess for VmTarget<'_, T> { ) -> TargetResult { let vp_index = self.0.tid_to_vp(tid).fatal()?; - let state = block_on(self.0.req_chan.call(DebugRequest::GetVpState, vp_index)) - .flatten() - .nonfatal()?; + let state = block_on( + self.0 + .req_chan + .call_failable(DebugRequest::GetVpState, vp_index), + ) + .nonfatal()?; Ok(T::register(&state, reg_id, buf)?) } @@ -119,18 +126,20 @@ impl SingleRegisterAccess for VmTarget<'_, T> { fn write_register(&mut self, tid: Tid, reg_id: T::RegId, val: &[u8]) -> TargetResult<(), Self> { let vp_index = self.0.tid_to_vp(tid).fatal()?; - let mut state = block_on(self.0.req_chan.call(DebugRequest::GetVpState, vp_index)) - .flatten() - .nonfatal()?; + let mut state = block_on( + self.0 + .req_chan + .call_failable(DebugRequest::GetVpState, vp_index), + ) + .nonfatal()?; T::update_register(&mut state, reg_id, val)?; block_on( self.0 .req_chan - .call(DebugRequest::SetVpState, (vp_index, state)), + .call_failable(DebugRequest::SetVpState, (vp_index, state)), ) - .flatten() .nonfatal()?; Ok(()) diff --git a/workers/debug_worker/src/lib.rs b/workers/debug_worker/src/lib.rs index 6166e200b9..b5671ccdf7 100644 --- a/workers/debug_worker/src/lib.rs +++ b/workers/debug_worker/src/lib.rs @@ -121,7 +121,7 @@ where Ok(message) => match message { WorkerRpc::Stop => return Ok(()), WorkerRpc::Inspect(deferred) => deferred.inspect(&mut server), - WorkerRpc::Restart(response) => { + WorkerRpc::Restart(rpc) => { let vm_proxy = match server.state { State::Listening { vm_proxy } => vm_proxy, State::Connected { task, abort, .. } => { @@ -148,7 +148,7 @@ where }, } }; - response.send(Ok(state)); + rpc.complete(Ok(state)); return Ok(()); } }, diff --git a/workers/vnc_worker/src/lib.rs b/workers/vnc_worker/src/lib.rs index f636f34876..afe15107a6 100644 --- a/workers/vnc_worker/src/lib.rs +++ b/workers/vnc_worker/src/lib.rs @@ -116,7 +116,7 @@ impl VncWorker { state: self.state, }; - let response = loop { + let rpc = loop { let r = futures::select! { // merge semantics r = rpc_recv.recv().fuse() => r, r = server.process(&driver).fuse() => break r.map(|_| None)?, @@ -130,7 +130,7 @@ impl VncWorker { Err(_) => break None, } }; - if let Some(response) = response { + if let Some(rpc) = rpc { let (view, input) = match server.state { State::Listening { view, input } => (view, input), State::Connected { task, abort, .. } => { @@ -144,7 +144,7 @@ impl VncWorker { framebuffer: view.0.access(), input_send: input.send, }; - response.send(Ok(state)); + rpc.complete(Ok(state)); } Ok(()) }) From 872dc4910a410b668f06c79586aab73c7210d4d5 Mon Sep 17 00:00:00 2001 From: John Starks Date: Mon, 9 Dec 2024 22:25:56 +0000 Subject: [PATCH 2/4] remove --- vm/devices/vmbus/vmbus_relay/src/hvsock.rs | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/vm/devices/vmbus/vmbus_relay/src/hvsock.rs b/vm/devices/vmbus/vmbus_relay/src/hvsock.rs index e65fc85026..dc6b550c3b 100644 --- a/vm/devices/vmbus/vmbus_relay/src/hvsock.rs +++ b/vm/devices/vmbus/vmbus_relay/src/hvsock.rs @@ -70,22 +70,6 @@ mod tests { use vmbus_server::Guid; use zerocopy::FromZeroes; - struct FakeRpcSender(Option); - - impl RpcSend for &mut FakeRpcSender { - type Message = T; - - fn send_rpc(self, message: Self::Message) { - self.0 = Some(message); - } - } - - fn dummy_rpc(t: T) -> Rpc { - let mut sender = FakeRpcSender(None); - let _ = sender.call(|x| x, t); - sender.0.unwrap() - } - #[test] fn test_check_result() { let mut tracker = HvsockRequestTracker::new(); From d168da1c6fa9c7fcbcc8185aca783ac2fee15f60 Mon Sep 17 00:00:00 2001 From: John Starks Date: Fri, 3 Jan 2025 06:14:44 +0000 Subject: [PATCH 3/4] fix --- vm/devices/get/guest_emulation_device/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vm/devices/get/guest_emulation_device/src/lib.rs b/vm/devices/get/guest_emulation_device/src/lib.rs index 8333b8b514..6a2a8d5306 100644 --- a/vm/devices/get/guest_emulation_device/src/lib.rs +++ b/vm/devices/get/guest_emulation_device/src/lib.rs @@ -311,7 +311,7 @@ pub struct GedChannel { #[inspect(with = "Option::is_some")] vtl0_start_report: Option>, #[inspect(with = "Option::is_some")] - modify: Option>>, + modify: Option>>, // TODO: allow unused temporarily as a follow up change will use it to // implement AK cert renewal. #[inspect(skip)] From 9ea4eec7a2ee2b49d3aa2a8119fa69fea626f88e Mon Sep 17 00:00:00 2001 From: John Starks Date: Wed, 8 Jan 2025 17:02:00 +0000 Subject: [PATCH 4/4] fmt --- openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs | 2 +- openvmm/openvmm_entry/src/lib.rs | 2 +- petri/pipette_client/src/send.rs | 2 +- vm/devices/net/net_consomme/consomme/src/lib.rs | 2 +- vmm_core/src/partition_unit/vp_set.rs | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs b/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs index 0774e8403d..d3a01390be 100644 --- a/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs +++ b/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs @@ -18,8 +18,8 @@ use ide_resources::GuestMedia; use ide_resources::IdeControllerConfig; use ide_resources::IdeDeviceConfig; use ide_resources::IdePath; -use mesh::rpc::RpcError; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::CancelContext; use nvme_resources::NamespaceDefinition; diff --git a/openvmm/openvmm_entry/src/lib.rs b/openvmm/openvmm_entry/src/lib.rs index 01326627f6..d2e2f333bf 100644 --- a/openvmm/openvmm_entry/src/lib.rs +++ b/openvmm/openvmm_entry/src/lib.rs @@ -72,8 +72,8 @@ use inspect::InspectMut; use inspect::InspectionBuilder; use io::Read; use mesh::error::RemoteError; -use mesh::rpc::RpcError; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::CancelContext; use mesh_worker::launch_local_worker; diff --git a/petri/pipette_client/src/send.rs b/petri/pipette_client/src/send.rs index fbae46b5d8..53e0e10842 100644 --- a/petri/pipette_client/src/send.rs +++ b/petri/pipette_client/src/send.rs @@ -4,8 +4,8 @@ //! A thin wrapper around a `mesh::Sender` that provides //! useful error handling semantics. -use mesh::rpc::RpcError; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::CancelContext; use pipette_protocol::PipetteRequest; diff --git a/vm/devices/net/net_consomme/consomme/src/lib.rs b/vm/devices/net/net_consomme/consomme/src/lib.rs index d0d01eaa76..29e7f410db 100644 --- a/vm/devices/net/net_consomme/consomme/src/lib.rs +++ b/vm/devices/net/net_consomme/consomme/src/lib.rs @@ -26,8 +26,8 @@ mod udp; mod windows; use inspect::InspectMut; -use mesh::rpc::RpcError; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use pal_async::driver::Driver; use smoltcp::phy::Checksum; diff --git a/vmm_core/src/partition_unit/vp_set.rs b/vmm_core/src/partition_unit/vp_set.rs index c56215df5f..92529683b7 100644 --- a/vmm_core/src/partition_unit/vp_set.rs +++ b/vmm_core/src/partition_unit/vp_set.rs @@ -19,8 +19,8 @@ use futures_concurrency::stream::Merge; use guestmem::GuestMemory; use hvdef::Vtl; use inspect::Inspect; -use mesh::rpc::RpcError; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use pal_async::local::block_with_io; use parking_lot::Mutex;