diff --git a/openhcl/profiler_worker/src/lib.rs b/openhcl/profiler_worker/src/lib.rs index 995fde7b8..6d121f662 100644 --- a/openhcl/profiler_worker/src/lib.rs +++ b/openhcl/profiler_worker/src/lib.rs @@ -97,8 +97,8 @@ impl Worker for ProfilerWorker { WorkerRpc::Stop => { break; } - WorkerRpc::Restart(response) => { - response.send(Err(RemoteError::new(anyhow::anyhow!("not supported")))); + WorkerRpc::Restart(rpc) => { + rpc.complete(Err(RemoteError::new(anyhow::anyhow!("not supported")))); } WorkerRpc::Inspect(_deferred) => {} }, diff --git a/openhcl/underhill_core/src/diag.rs b/openhcl/underhill_core/src/diag.rs index 0b5e7aabc..2a3b72ae8 100644 --- a/openhcl/underhill_core/src/diag.rs +++ b/openhcl/underhill_core/src/diag.rs @@ -65,8 +65,8 @@ impl Worker for DiagWorker { }; match msg { WorkerRpc::Stop => break Ok(()), - WorkerRpc::Restart(response) => { - response.send(Err(RemoteError::new(anyhow::anyhow!("not supported")))); + WorkerRpc::Restart(rpc) => { + rpc.complete(Err(RemoteError::new(anyhow::anyhow!("not supported")))); } WorkerRpc::Inspect(_) => {} } diff --git a/openhcl/underhill_core/src/dispatch/mod.rs b/openhcl/underhill_core/src/dispatch/mod.rs index 6e503aeca..e1e2e85c6 100644 --- a/openhcl/underhill_core/src/dispatch/mod.rs +++ b/openhcl/underhill_core/src/dispatch/mod.rs @@ -33,7 +33,6 @@ use hyperv_ic_resources::shutdown::ShutdownType; use igvm_defs::MemoryMapEntryType; use inspect::Inspect; use mesh::error::RemoteError; -use mesh::error::RemoteResult; use mesh::rpc::FailableRpc; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; @@ -183,7 +182,7 @@ pub(crate) struct LoadedVm { } pub struct LoadedVmState { - pub restart_response: mesh::OneshotSender>, + pub restart_rpc: FailableRpc<(), T>, pub servicing_state: ServicingState, pub vm_rpc: mesh::Receiver, pub control_send: mesh::Sender, @@ -262,16 +261,16 @@ impl LoadedVm { Event::WorkerRpcGone => break None, Event::WorkerRpc(message) => match message { WorkerRpc::Stop => break None, - WorkerRpc::Restart(response) => { + WorkerRpc::Restart(rpc) => { let state = async { let running = self.stop().await; match self.save(None, false).await { - Ok(servicing_state) => Some((response, servicing_state)), + Ok(servicing_state) => Some((rpc, servicing_state)), Err(err) => { if running { self.start(None).await; } - response.send(Err(RemoteError::new(err))); + rpc.complete(Err(RemoteError::new(err))); None } } @@ -279,9 +278,9 @@ impl LoadedVm { .instrument(tracing::info_span!("restart")) .await; - if let Some((response, servicing_state)) = state { + if let Some((rpc, servicing_state)) = state { break Some(LoadedVmState { - restart_response: response, + restart_rpc: rpc, servicing_state, vm_rpc, control_send: self.control_send.lock().take().unwrap(), diff --git a/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs b/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs index 9654656c1..d3a01390b 100644 --- a/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs +++ b/openhcl/underhill_core/src/dispatch/vtl2_settings_worker.rs @@ -19,6 +19,7 @@ use ide_resources::IdeControllerConfig; use ide_resources::IdeDeviceConfig; use ide_resources::IdePath; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::CancelContext; use nvme_resources::NamespaceDefinition; @@ -61,7 +62,7 @@ use vm_resource::ResourceResolver; #[derive(Error, Debug)] enum Error<'a> { #[error("RPC error")] - Rpc(#[source] mesh::RecvError), + Rpc(#[source] RpcError), #[error("cannot add/remove storage controllers at runtime")] StorageCannotAddRemoveControllerAtRuntime, #[error("Striping devices don't support runtime change")] diff --git a/openhcl/underhill_core/src/get_tracing.rs b/openhcl/underhill_core/src/get_tracing.rs index 0ad3b0d1f..c9c8c4613 100644 --- a/openhcl/underhill_core/src/get_tracing.rs +++ b/openhcl/underhill_core/src/get_tracing.rs @@ -152,13 +152,13 @@ impl GetTracingBackend { ) .merge(); - let flush_response = loop { + let flush_rpc = loop { let trace_type = streams.next().await.unwrap(); match trace_type { Event::Trace(data) => { write.send(&data).await.ok(); } - Event::Flush(Rpc((), response)) => break Some(response), + Event::Flush(rpc) => break Some(rpc), Event::Done => break None, } }; @@ -174,8 +174,8 @@ impl GetTracingBackend { // Wait for the host to read everything. write.wait_empty().await.ok(); - if let Some(resp) = flush_response { - resp.send(()); + if let Some(rpc) = flush_rpc { + rpc.complete(()); } else { break; } diff --git a/openhcl/underhill_core/src/lib.rs b/openhcl/underhill_core/src/lib.rs index f9356451d..667213715 100644 --- a/openhcl/underhill_core/src/lib.rs +++ b/openhcl/underhill_core/src/lib.rs @@ -470,7 +470,7 @@ async fn run_control( Control(ControlRequest), } - let mut restart_response = None; + let mut restart_rpc = None; loop { let event = { let mut stream = ( @@ -551,7 +551,7 @@ async fn run_control( }; let r = async { - if restart_response.is_some() { + if restart_rpc.is_some() { anyhow::bail!("previous restart still in progress"); } @@ -568,7 +568,7 @@ async fn run_control( rpc.complete(r.map_err(RemoteError::new)); } else { state = ControlState::Restarting; - restart_response = Some(rpc.1); + restart_rpc = Some(rpc); } } diag_server::DiagRequest::Pause(rpc) => { @@ -647,7 +647,7 @@ async fn run_control( } #[cfg(feature = "profiler")] diag_server::DiagRequest::Profile(rpc) => { - let Rpc(rpc_params, rpc_sender) = rpc; + let (rpc_params, rpc_sender) = rpc.split(); // Create profiler host if there is none created before if profiler_host.is_none() { match launch_mesh_host(mesh, "profiler", Some(tracing.tracer())) @@ -658,7 +658,7 @@ async fn run_control( profiler_host = Some(host); } Err(e) => { - rpc_sender.send(Err(RemoteError::new(e))); + rpc_sender.complete(Err(RemoteError::new(e))); continue; } } @@ -680,7 +680,7 @@ async fn run_control( profiler_worker = worker; } Err(e) => { - rpc_sender.send(Err(RemoteError::new(e))); + rpc_sender.complete(Err(RemoteError::new(e))); continue; } } @@ -695,7 +695,7 @@ async fn run_control( .and_then(|result| result.context("profiler worker failed")) .map_err(RemoteError::new); - rpc_sender.send(result); + rpc_sender.complete(result); }) .detach(); } @@ -703,9 +703,9 @@ async fn run_control( } Event::Worker(event) => match event { WorkerEvent::Started => { - if let Some(response) = restart_response.take() { + if let Some(response) = restart_rpc.take() { tracing::info!("restart complete"); - response.send(Ok(())); + response.complete(Ok(())); } else { tracing::info!("vm worker started"); } @@ -719,7 +719,7 @@ async fn run_control( } WorkerEvent::RestartFailed(err) => { tracing::error!(error = &err as &dyn std::error::Error, "restart failed"); - restart_response.take().unwrap().send(Err(err)); + restart_rpc.take().unwrap().complete(Err(err)); state = ControlState::Started; } }, diff --git a/openhcl/underhill_core/src/worker.rs b/openhcl/underhill_core/src/worker.rs index 1ec5e794e..abede74c4 100644 --- a/openhcl/underhill_core/src/worker.rs +++ b/openhcl/underhill_core/src/worker.rs @@ -430,7 +430,7 @@ impl Worker for UnderhillVmWorker { }; tracing::info!("sending worker restart state"); - state.restart_response.send(Ok(RestartState { + state.restart_rpc.complete(Ok(RestartState { params, servicing_state: state.servicing_state, })) diff --git a/openvmm/hvlite_core/src/worker/dispatch.rs b/openvmm/hvlite_core/src/worker/dispatch.rs index 1f7c87bd5..035c9d858 100644 --- a/openvmm/hvlite_core/src/worker/dispatch.rs +++ b/openvmm/hvlite_core/src/worker/dispatch.rs @@ -59,7 +59,6 @@ use memory_range::MemoryRange; use mesh::error::RemoteError; use mesh::payload::message::ProtobufMessage; use mesh::payload::Protobuf; -use mesh::rpc::Rpc; use mesh::MeshPayload; use mesh_worker::Worker; use mesh_worker::WorkerId; @@ -2477,7 +2476,7 @@ impl LoadedVm { pub async fn run( mut self, driver: &impl Spawn, - mut rpc: mesh::Receiver, + mut rpc_recv: mesh::Receiver, mut worker_rpc: mesh::Receiver>, ) { enum Event { @@ -2508,7 +2507,7 @@ impl LoadedVm { loop { let event: Event = { - let a = rpc.recv().map(Event::VmRpc); + let a = rpc_recv.recv().map(Event::VmRpc); let b = worker_rpc.recv().map(Event::WorkerRpc); (a, b).race().await }; @@ -2517,7 +2516,7 @@ impl LoadedVm { Event::WorkerRpc(Err(_)) => break, Event::WorkerRpc(Ok(message)) => match message { WorkerRpc::Stop => break, - WorkerRpc::Restart(response) => { + WorkerRpc::Restart(rpc) => { let mut stopped = false; // First run the non-destructive operations. let r = async { @@ -2532,8 +2531,8 @@ impl LoadedVm { .await; match r { Ok((shared_memory, saved_state)) => { - response.send(Ok(self - .serialize(rpc, shared_memory, saved_state) + rpc.complete(Ok(self + .serialize(rpc_recv, shared_memory, saved_state) .await)); return; @@ -2542,7 +2541,7 @@ impl LoadedVm { if stopped { self.state_units.start().await; } - response.send(Err(RemoteError::new(err))); + rpc.complete(Err(RemoteError::new(err))); } } } @@ -2618,16 +2617,17 @@ impl LoadedVm { }) .await } - VmRpc::ConnectHvsock(Rpc((mut ctx, service_id, vtl), response)) => { + VmRpc::ConnectHvsock(rpc) => { + let ((mut ctx, service_id, vtl), response) = rpc.split(); if let Some(relay) = self.hvsock_relay(vtl) { let fut = relay.connect(&mut ctx, service_id); driver .spawn("vmrpc-hvsock-connect", async move { - response.send(fut.await.map_err(RemoteError::new)) + response.complete(fut.await.map_err(RemoteError::new)) }) .detach(); } else { - response.send(Err(RemoteError::new(anyhow::anyhow!( + response.complete(Err(RemoteError::new(anyhow::anyhow!( "hvsock is not available" )))); } diff --git a/openvmm/hvlite_helpers/src/underhill.rs b/openvmm/hvlite_helpers/src/underhill.rs index 033f6bc56..8121c2a12 100644 --- a/openvmm/hvlite_helpers/src/underhill.rs +++ b/openvmm/hvlite_helpers/src/underhill.rs @@ -6,7 +6,6 @@ use anyhow::Context; use get_resources::ged::GuestEmulationRequest; use hvlite_defs::rpc::VmRpc; -use mesh::error::RemoteResultExt; use mesh::rpc::RpcSend; /// Replace the running version of Underhill. @@ -28,9 +27,8 @@ pub async fn service_underhill( // blocked while waiting for the guest. tracing::debug!("waiting for guest to send saved state"); let r = send - .call(GuestEmulationRequest::SaveGuestVtl2State, ()) + .call_failable(GuestEmulationRequest::SaveGuestVtl2State, ()) .await - .flatten() .context("failed to save VTL2 state"); if r.is_err() { @@ -51,9 +49,8 @@ pub async fn service_underhill( // // TODO: event driven, cancellable. tracing::debug!("waiting for VTL0 to start"); - send.call(GuestEmulationRequest::WaitForVtl0Start, ()) + send.call_failable(GuestEmulationRequest::WaitForVtl0Start, ()) .await - .flatten() .context("vtl0 start failed")?; Ok(()) diff --git a/openvmm/membacking/src/mapping_manager/va_mapper.rs b/openvmm/membacking/src/mapping_manager/va_mapper.rs index d70c211d2..27e4ef918 100644 --- a/openvmm/membacking/src/mapping_manager/va_mapper.rs +++ b/openvmm/membacking/src/mapping_manager/va_mapper.rs @@ -32,6 +32,7 @@ use futures::executor::block_on; use guestmem::GuestMemoryAccess; use guestmem::PageFaultAction; use memory_range::MemoryRange; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use parking_lot::Mutex; use sparse_mmap::SparseMapping; @@ -156,7 +157,7 @@ impl MapperTask { #[derive(Debug, Error)] pub enum VaMapperError { #[error("failed to communicate with the memory manager")] - MemoryManagerGone(#[source] mesh::RecvError), + MemoryManagerGone(#[source] RpcError), #[error("failed to reserve address space")] Reserve(#[source] std::io::Error), } diff --git a/openvmm/membacking/src/region_manager.rs b/openvmm/membacking/src/region_manager.rs index 4be04c53b..aca030db4 100644 --- a/openvmm/membacking/src/region_manager.rs +++ b/openvmm/membacking/src/region_manager.rs @@ -628,9 +628,7 @@ impl RegionHandle { impl Drop for RegionHandle { fn drop(&mut self) { if let Some(id) = self.id { - let (send, _recv) = mesh::oneshot(); - self.req_send - .send(RegionRequest::RemoveRegion(Rpc(id, send))); + let _recv = self.req_send.call(RegionRequest::RemoveRegion, id); // Don't wait for the response. } } diff --git a/openvmm/openvmm_entry/src/lib.rs b/openvmm/openvmm_entry/src/lib.rs index be3da2dee..d2e2f333b 100644 --- a/openvmm/openvmm_entry/src/lib.rs +++ b/openvmm/openvmm_entry/src/lib.rs @@ -73,9 +73,9 @@ use inspect::InspectionBuilder; use io::Read; use mesh::error::RemoteError; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::CancelContext; -use mesh::RecvError; use mesh_worker::launch_local_worker; use mesh_worker::WorkerEvent; use mesh_worker::WorkerHandle; @@ -2092,7 +2092,7 @@ async fn run_control(driver: &DefaultDriver, mesh: &VmmMesh, opt: Options) -> an }) .unwrap(); - let mut state_change_task = None::>>; + let mut state_change_task = None::>>; let mut pulse_save_restore_interval: Option = None; let mut pending_shutdown = None; @@ -2113,8 +2113,8 @@ async fn run_control(driver: &DefaultDriver, mesh: &VmmMesh, opt: Options) -> an PulseSaveRestore, Worker(WorkerEvent), VncWorker(WorkerEvent), - StateChange(Result), - ShutdownResult(Result), + StateChange(Result), + ShutdownResult(Result), } let mut console_command_recv = console_command_recv @@ -2360,7 +2360,7 @@ async fn run_control(driver: &DefaultDriver, mesh: &VmmMesh, opt: Options) -> an fn state_change( driver: impl Spawn, vm_rpc: &mesh::Sender, - state_change_task: &mut Option>>, + state_change_task: &mut Option>>, f: impl FnOnce(Rpc<(), U>) -> VmRpc, g: impl FnOnce(U) -> StateChange + 'static + Send, ) { diff --git a/openvmm/openvmm_entry/src/ttrpc/mod.rs b/openvmm/openvmm_entry/src/ttrpc/mod.rs index 3a699ac49..579138f77 100644 --- a/openvmm/openvmm_entry/src/ttrpc/mod.rs +++ b/openvmm/openvmm_entry/src/ttrpc/mod.rs @@ -33,7 +33,6 @@ use inspect_proto::InspectResponse2; use inspect_proto::InspectService; use inspect_proto::UpdateResponse2; use mesh::error::RemoteError; -use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use mesh::CancelReason; use mesh::MeshPayload; @@ -197,7 +196,7 @@ impl VmService { }, request = recv.recv().fuse() => { match request { - Ok(WorkerRpc::Restart(response)) => response.send(Err(RemoteError::new(anyhow::anyhow!("not supported")))), + Ok(WorkerRpc::Restart(rpc)) => rpc.complete(Err(RemoteError::new(anyhow::anyhow!("not supported")))), Ok(WorkerRpc::Inspect(_)) => (), Ok(WorkerRpc::Stop) => { tracing::info!("ttrpc worker stopping"); @@ -605,14 +604,12 @@ impl VmService { } fn pause_vm(&mut self, vm: &Vm) -> impl Future> { - let (send, recv) = mesh::oneshot(); - vm.worker_rpc.send(VmRpc::Pause(Rpc((), send))); + let recv = vm.worker_rpc.call(VmRpc::Pause, ()); async move { recv.await.map(drop).context("pause failed") } } fn resume_vm(&mut self, vm: &Vm) -> impl Future> { - let (send, recv) = mesh::oneshot(); - vm.worker_rpc.send(VmRpc::Resume(Rpc((), send))); + let recv = vm.worker_rpc.call(VmRpc::Resume, ()); async move { recv.await.map(drop).context("resume failed") } } diff --git a/petri/pipette_client/src/lib.rs b/petri/pipette_client/src/lib.rs index 8bbc4f1c9..73fe2e6f5 100644 --- a/petri/pipette_client/src/lib.rs +++ b/petri/pipette_client/src/lib.rs @@ -22,6 +22,7 @@ use futures::AsyncWriteExt; use futures::StreamExt; use futures::TryFutureExt; use futures_concurrency::future::TryJoin; +use mesh::rpc::RpcError; use mesh_remote::PointToPointMesh; use pal_async::task::Spawn; use pal_async::task::Task; @@ -81,7 +82,7 @@ impl PipetteClient { } /// Pings the agent to check if it's alive. - pub async fn ping(&self) -> Result<(), mesh::RecvError> { + pub async fn ping(&self) -> Result<(), RpcError> { self.send.call(PipetteRequest::Ping, ()).await } diff --git a/petri/pipette_client/src/process.rs b/petri/pipette_client/src/process.rs index 2c0c1536f..83344c9ae 100644 --- a/petri/pipette_client/src/process.rs +++ b/petri/pipette_client/src/process.rs @@ -9,7 +9,6 @@ use futures::executor::block_on; use futures::io::AllowStdIo; use futures::AsyncReadExt; use futures_concurrency::future::Join; -use mesh::error::RemoteResultExt; use mesh::pipe::ReadPipe; use mesh::pipe::WritePipe; use pipette_protocol::EnvPair; @@ -167,9 +166,8 @@ impl<'a> Command<'a> { let response = self .client .send - .call(PipetteRequest::Execute, request) + .call_failable(PipetteRequest::Execute, request) .await - .flatten() .with_context(|| format!("failed to execute {}", self.program))?; Ok(Child { diff --git a/petri/pipette_client/src/send.rs b/petri/pipette_client/src/send.rs index e824b2f1c..53e0e1084 100644 --- a/petri/pipette_client/src/send.rs +++ b/petri/pipette_client/src/send.rs @@ -5,6 +5,7 @@ //! useful error handling semantics. use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::CancelContext; use pipette_protocol::PipetteRequest; @@ -20,14 +21,30 @@ impl PipetteSender { /// A wrapper around [`mesh::Sender::call`] that will sleep for 5 seconds on failure, /// allowing any additional work occurring on the system to hopefully complete. /// See also [`petri::PetriVm::wait_for_halt_or`] - pub(crate) async fn call(&self, f: F, input: I) -> Result + pub(crate) async fn call(&self, f: F, input: I) -> Result where F: FnOnce(Rpc) -> PipetteRequest, R: 'static + Send, { - let (result_send, result_recv) = mesh::oneshot(); - self.0.send_rpc(f(Rpc(input, result_send))); - let result = result_recv.await; + let result = self.0.call(f, input).await; + if result.is_err() { + tracing::warn!("Pipette request channel failed, sleeping for 5 seconds to let outstanding work finish"); + let mut c = CancelContext::new().with_timeout(Duration::from_secs(5)); + let _ = c.cancelled().await; + } + result + } + + /// A wrapper around [`mesh::Sender::call_failable`] that will sleep for 5 seconds on failure, + /// allowing any additional work occurring on the system to hopefully complete. + /// See also [`petri::PetriVm::wait_for_halt_or`] + pub(crate) async fn call_failable(&self, f: F, input: I) -> Result> + where + F: FnOnce(Rpc>) -> PipetteRequest, + T: 'static + Send, + E: 'static + Send, + { + let result = self.0.call_failable(f, input).await; if result.is_err() { tracing::warn!("Pipette request channel failed, sleeping for 5 seconds to let outstanding work finish"); let mut c = CancelContext::new().with_timeout(Duration::from_secs(5)); diff --git a/petri/src/vm/runtime.rs b/petri/src/vm/runtime.rs index f686eee33..38a98462b 100644 --- a/petri/src/vm/runtime.rs +++ b/petri/src/vm/runtime.rs @@ -12,6 +12,7 @@ use futures::FutureExt; use futures_concurrency::future::Race; use hvlite_defs::rpc::PulseSaveRestoreError; use hyperv_ic_resources::shutdown::ShutdownRpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::CancelContext; use mesh::Receiver; @@ -423,14 +424,15 @@ impl PetriVmInner { async fn verify_save_restore(&self) -> anyhow::Result<()> { for i in 0..2 { - let result = self.worker.pulse_save_restore().await?; + let result = self.worker.pulse_save_restore().await; match result { Ok(()) => {} - Err(PulseSaveRestoreError::ResetNotSupported) => { + Err(RpcError::Channel(err)) => return Err(err.into()), + Err(RpcError::Call(PulseSaveRestoreError::ResetNotSupported)) => { tracing::warn!("Reset not supported, could not test save + restore."); break; } - Err(PulseSaveRestoreError::Other(err)) => { + Err(RpcError::Call(PulseSaveRestoreError::Other(err))) => { return Err(anyhow::Error::from(err)) .context(format!("Save + restore {i} failed.")); } diff --git a/petri/src/worker.rs b/petri/src/worker.rs index 71d62e5d2..77aafe24f 100644 --- a/petri/src/worker.rs +++ b/petri/src/worker.rs @@ -6,6 +6,7 @@ use hvlite_defs::rpc::PulseSaveRestoreError; use hvlite_defs::rpc::VmRpc; use hvlite_defs::worker::VmWorkerParameters; use hvlite_defs::worker::VM_WORKER; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh_worker::WorkerHandle; use mesh_worker::WorkerHost; @@ -42,7 +43,7 @@ impl Worker { )) } - pub(crate) async fn resume(&self) -> Result { + pub(crate) async fn resume(&self) -> Result { self.rpc.call(VmRpc::Resume, ()).await } @@ -51,10 +52,8 @@ impl Worker { Ok(()) } - pub(crate) async fn pulse_save_restore( - &self, - ) -> Result, mesh::RecvError> { - self.rpc.call(VmRpc::PulseSaveRestore, ()).await + pub(crate) async fn pulse_save_restore(&self) -> Result<(), RpcError> { + self.rpc.call_failable(VmRpc::PulseSaveRestore, ()).await } pub(crate) async fn restart_openhcl( diff --git a/support/mesh/mesh_channel/src/error.rs b/support/mesh/mesh_channel/src/error.rs index bf3236418..3d9df098c 100644 --- a/support/mesh/mesh_channel/src/error.rs +++ b/support/mesh/mesh_channel/src/error.rs @@ -3,12 +3,10 @@ //! Remotable errors. -use crate::RecvError; use mesh_protobuf::EncodeAs; use mesh_protobuf::Protobuf; use std::fmt; use std::fmt::Display; -use thiserror::Error; /// An error that can be remoted across a mesh channel. /// @@ -115,34 +113,3 @@ impl std::error::Error for DecodedError { /// Alias for a [`Result`] with a [`RemoteError`] error. pub type RemoteResult = Result; - -/// An error from an RPC call, via -/// [`RpcSend::call_failable`](super::rpc::RpcSend::call_failable). -#[derive(Debug, Error)] -pub enum RpcError { - #[error(transparent)] - Call(E), - #[error(transparent)] - Channel(RecvError), -} - -/// Extension trait to [`Result`] for folding `Result, RecvError>` -/// to `RpcError`. -pub trait RemoteResultExt { - type Flattened; - - /// Flattens the result into `RpcError`. - fn flatten(self) -> Self::Flattened; -} - -impl RemoteResultExt for Result, RecvError> { - type Flattened = Result>; - - fn flatten(self) -> Self::Flattened { - match self { - Ok(Ok(t)) => Ok(t), - Ok(Err(e)) => Err(RpcError::Call(e)), - Err(e) => Err(RpcError::Channel(e)), - } - } -} diff --git a/support/mesh/mesh_channel/src/rpc.rs b/support/mesh/mesh_channel/src/rpc.rs index 589309f72..11f397ea9 100644 --- a/support/mesh/mesh_channel/src/rpc.rs +++ b/support/mesh/mesh_channel/src/rpc.rs @@ -5,17 +5,18 @@ use super::error::RemoteResult; use crate::error::RemoteError; -use crate::error::RemoteResultExt; -use crate::error::RpcError; use crate::oneshot; use crate::OneshotReceiver; use crate::OneshotSender; +use crate::RecvError; use mesh_node::message::MeshField; use mesh_protobuf::Protobuf; +use std::convert::Infallible; use std::future::Future; use std::pin::Pin; use std::task::ready; use std::task::Poll; +use thiserror::Error; /// An RPC message for a request with input of type `I` and output of type `R`. /// The receiver of the message should process the request and return results @@ -25,12 +26,31 @@ use std::task::Poll; bound = "I: 'static + MeshField + Send, R: 'static + MeshField + Send", resource = "mesh_node::resource::Resource" )] -pub struct Rpc(pub I, pub OneshotSender); +pub struct Rpc(I, OneshotSender); /// An RPC message with a failable result. pub type FailableRpc = Rpc>; impl Rpc { + /// Returns a new RPC message with `input` and no one listening for the + /// result. + pub fn detached(input: I) -> Self { + let (result_send, _) = oneshot(); + Rpc(input, result_send) + } + + /// Returns the input to the RPC. + pub fn input(&self) -> &I { + &self.0 + } + + /// Splits the RPC into its input and an input-less RPC. This is useful when + /// the input is needed in one place but the RPC will be completed in + /// another. + pub fn split(self) -> (I, Rpc<(), R>) { + (self.0, Rpc((), self.1)) + } + /// Handles an RPC request by calling `f` and sending the result to the /// initiator. pub fn handle_sync(self, f: F) @@ -101,12 +121,12 @@ impl Rpc> { } /// A trait implemented by objects that can send RPC requests. -pub trait RpcSend { +pub trait RpcSend: Sized { /// The message type for this sender. type Message; /// Send an RPC request. - fn send_rpc(&self, message: Self::Message); + fn send_rpc(self, message: Self::Message); /// Issues a request and returns a channel to receive the result. /// @@ -127,14 +147,14 @@ pub trait RpcSend { /// assert_eq!(send.call(Request::Add, (3, 4)).await.unwrap(), 7); /// } /// ``` - fn call(&self, f: F, input: I) -> OneshotReceiver + fn call(self, f: F, input: I) -> PendingRpc where F: FnOnce(Rpc) -> Self::Message, R: 'static + Send, { let (result_send, result_recv) = oneshot(); self.send_rpc(f(Rpc(input, result_send))); - result_recv + PendingRpc(result_recv) } /// Issues a request and returns an object to receive the result. @@ -142,55 +162,145 @@ pub trait RpcSend { /// This is like [`RpcSend::call`], but for RPCs that return a [`Result`]. /// The returned object combines the channel error and the call's error into /// a single [`RpcError`] type, which makes it easier to handle errors. - fn call_failable(&self, f: F, input: I) -> RpcResultReceiver> + fn call_failable(self, f: F, input: I) -> PendingFailableRpc where F: FnOnce(Rpc>) -> Self::Message, T: 'static + Send, E: 'static + Send, { - RpcResultReceiver(self.call(f, input)) + PendingFailableRpc(self.call(f, input)) + } +} + +/// A trait implemented by objects that can try to send RPC requests but may +/// fail. +pub trait TryRpcSend: Sized { + /// The message type for this sender. + type Message; + /// The error type returned when sending an RPC request fails. + type Error; + + /// Tries to send an RPC request. + fn try_send_rpc(self, message: Self::Message) -> Result<(), Self::Error>; + + /// Issues a request and returns a channel to receive the result. + /// + /// `f` maps an [`Rpc`] object to the message type and is often an enum + /// variant name. + /// + /// `input` is the input to the call. + /// + /// # Example + /// + /// ```rust + /// # use mesh_channel::rpc::{Rpc, RpcSend}; + /// # use mesh_channel::Sender; + /// enum Request { + /// Add(Rpc<(u32, u32), u32>), + /// } + /// async fn add(send: &Sender) { + /// assert_eq!(send.call(Request::Add, (3, 4)).await.unwrap(), 7); + /// } + /// ``` + fn try_call(self, f: F, input: I) -> Result, Self::Error> + where + F: FnOnce(Rpc) -> Self::Message, + R: 'static + Send, + { + let (result_send, result_recv) = oneshot(); + self.try_send_rpc(f(Rpc(input, result_send)))?; + Ok(PendingRpc(result_recv)) + } + + /// Issues a request and returns an object to receive the result. + /// + /// This is like [`TryRpcSend::try_call`], but for RPCs that return a + /// [`Result`]. The returned object combines the channel error and the + /// call's error into a single [`RpcError`] type, which makes it easier to + /// handle errors. + fn try_call_failable( + self, + f: F, + input: I, + ) -> Result, Self::Error> + where + F: FnOnce(Rpc>) -> Self::Message, + T: 'static + Send, + E: 'static + Send, + { + Ok(PendingFailableRpc(self.try_call(f, input)?)) + } +} + +/// An error from an RPC call, via +/// [`RpcSend::call_failable`] or [`RpcSend::call`]. +#[derive(Debug, Error)] +pub enum RpcError { + #[error(transparent)] + Call(E), + #[error(transparent)] + Channel(RecvError), +} + +/// The result future of an [`RpcSend::call`] call. +#[must_use] +#[derive(Debug)] +pub struct PendingRpc(OneshotReceiver); + +impl Future for PendingRpc { + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + Poll::Ready(ready!(Pin::new(&mut self.get_mut().0).poll(cx)).map_err(RpcError::Channel)) } } /// The result future of an [`RpcSend::call_failable`] call. #[must_use] -pub struct RpcResultReceiver(OneshotReceiver); +#[derive(Debug)] +pub struct PendingFailableRpc(PendingRpc>); -impl Future for RpcResultReceiver> { +impl Future for PendingFailableRpc { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - Poll::Ready(ready!(Pin::new(&mut self.get_mut().0).poll(cx)).flatten()) + let r = ready!(Pin::new(&mut self.get_mut().0).poll(cx)); + match r { + Ok(Ok(t)) => Ok(t), + Ok(Err(e)) => Err(RpcError::Call(e)), + Err(RpcError::Channel(e)) => Err(RpcError::Channel(e)), + } + .into() } } -#[cfg(feature = "newchan")] -impl RpcSend for mesh_channel_core::Sender { +impl RpcSend for OneshotSender { type Message = T; - fn send_rpc(&self, message: T) { + fn send_rpc(self, message: T) { self.send(message); } } -#[cfg(not(feature = "newchan_spsc"))] -impl RpcSend for crate::Sender { +#[cfg(feature = "newchan")] +impl RpcSend for &mesh_channel_core::Sender { type Message = T; - fn send_rpc(&self, message: T) { + fn send_rpc(self, message: T) { self.send(message); } } -#[cfg(not(feature = "newchan_mpsc"))] -impl RpcSend for crate::MpscSender { +#[cfg(not(feature = "newchan_spsc"))] +impl RpcSend for &crate::Sender { type Message = T; fn send_rpc(&self, message: T) { self.send(message); } } -impl RpcSend for &T { - type Message = T::Message; - fn send_rpc(&self, message: T::Message) { - (*self).send_rpc(message); +#[cfg(not(feature = "newchan_mpsc"))] +impl RpcSend for &crate::MpscSender { + type Message = T; + fn send_rpc(self, message: T) { + self.send(message); } } diff --git a/support/mesh/mesh_worker/src/worker.rs b/support/mesh/mesh_worker/src/worker.rs index 284009404..ca3c43e6c 100644 --- a/support/mesh/mesh_worker/src/worker.rs +++ b/support/mesh/mesh_worker/src/worker.rs @@ -11,8 +11,8 @@ use futures::StreamExt; use futures_concurrency::stream::Merge; use inspect::Inspect; use mesh::error::RemoteError; -use mesh::error::RemoteResult; -use mesh::error::RemoteResultExt; +use mesh::rpc::FailableRpc; +use mesh::rpc::RpcSend; use mesh::MeshPayload; use std::fmt; use std::marker::PhantomData; @@ -85,7 +85,7 @@ pub enum WorkerRpc { Stop, /// Tear down and send the state necessary to restart on the provided /// channel. - Restart(mesh::OneshotSender>), + Restart(FailableRpc<(), T>), /// Inspect the worker. Inspect(inspect::Deferred), } @@ -446,9 +446,8 @@ impl WorkerLaunchRequest { } } LaunchType::Restart { send, events } => { - let (state_send, state_recv) = mesh::oneshot(); - send.send(WorkerRpc::Restart(state_send)); - let state = match block_on(state_recv).flatten() { + let state_recv = send.call_failable(WorkerRpc::Restart, ()); + let state = match block_on(state_recv) { Ok(state) => state, Err(err) => { self.events @@ -763,8 +762,8 @@ mod tests { while let Ok(req) = recv.recv().await { match req { WorkerRpc::Stop => break, - WorkerRpc::Restart(state_send) => { - state_send.send(Ok(TestWorkerState { value: self.value })); + WorkerRpc::Restart(rpc) => { + rpc.complete(Ok(TestWorkerState { value: self.value })); break; } WorkerRpc::Inspect(_deferred) => (), @@ -795,8 +794,8 @@ mod tests { while let Ok(req) = recv.recv().await { match req { WorkerRpc::Stop => break, - WorkerRpc::Restart(state_send) => { - state_send.send(Ok(())); + WorkerRpc::Restart(rpc) => { + rpc.complete(Ok(())); break; } WorkerRpc::Inspect(_deferred) => (), diff --git a/vm/devices/get/guest_crash_device/src/lib.rs b/vm/devices/get/guest_crash_device/src/lib.rs index a9574976b..470961760 100644 --- a/vm/devices/get/guest_crash_device/src/lib.rs +++ b/vm/devices/get/guest_crash_device/src/lib.rs @@ -16,9 +16,8 @@ use get_protocol::crash::CRASHDUMP_GUID; use guid::Guid; use inspect::Inspect; use inspect::InspectMut; -use mesh::error::RemoteResult; -use mesh::error::RemoteResultExt; use mesh::rpc::FailableRpc; +use mesh::rpc::PendingFailableRpc; use mesh::rpc::RpcSend; use std::fs::File; use std::io::Seek; @@ -93,7 +92,7 @@ enum ProtocolState { enum DumpState { OpeningFile { - recv: mesh::OneshotReceiver>, + recv: PendingFailableRpc, }, Writing { file: File, @@ -227,7 +226,7 @@ impl GuestCrashDevice { } crash::MessageType::REQUEST_NIX_DUMP_START_V1 => { let (send, recv) = mesh::oneshot(); - let recv = self.request_dump.call(|x| x, recv); + let recv = self.request_dump.call_failable(|x| x, recv); channel.state = ProtocolState::DumpRequested { activity_id: header.activity_id, done: send, @@ -245,7 +244,7 @@ impl GuestCrashDevice { let DumpState::OpeningFile { recv } = state else { unreachable!() }; - let status = match recv.await.flatten() { + let status = match recv.await { Ok(file) => { *state = DumpState::Writing { file, diff --git a/vm/devices/get/guest_emulation_device/src/lib.rs b/vm/devices/get/guest_emulation_device/src/lib.rs index 827bbe6c4..6a2a8d530 100644 --- a/vm/devices/get/guest_emulation_device/src/lib.rs +++ b/vm/devices/get/guest_emulation_device/src/lib.rs @@ -189,7 +189,7 @@ pub struct GuestEmulationDevice { #[inspect(skip)] guest_request_recv: mesh::Receiver, #[inspect(skip)] - waiting_for_vtl0_start: Vec>>, + waiting_for_vtl0_start: Vec>>, vmgs: Option, @@ -311,7 +311,7 @@ pub struct GedChannel { #[inspect(with = "Option::is_some")] vtl0_start_report: Option>, #[inspect(with = "Option::is_some")] - modify: Option>>, + modify: Option>>, // TODO: allow unused temporarily as a follow up change will use it to // implement AK cert renewal. #[inspect(skip)] @@ -320,7 +320,7 @@ pub struct GedChannel { } struct InProgressSave { - response: mesh::OneshotSender>, + rpc: Rpc<(), Result<(), SaveRestoreError>>, buffer: Vec, } @@ -491,22 +491,23 @@ impl GedChannel { ) -> Result<(), Error> { match guest_request { GuestEmulationRequest::WaitForConnect(rpc) => rpc.handle_sync(|()| ()), - GuestEmulationRequest::WaitForVtl0Start(Rpc((), response)) => { + GuestEmulationRequest::WaitForVtl0Start(rpc) => { if let Some(result) = self.vtl0_start_report.clone() { - response.send(result); + rpc.complete(result); } else { - state.waiting_for_vtl0_start.push(response); + state.waiting_for_vtl0_start.push(rpc); } } - GuestEmulationRequest::ModifyVtl2Settings(Rpc(data, response)) => { + GuestEmulationRequest::ModifyVtl2Settings(rpc) => { + let (data, response) = rpc.split(); if self.modify.is_some() { - response.send(Err(ModifyVtl2SettingsError::OperationInProgress)); + response.complete(Err(ModifyVtl2SettingsError::OperationInProgress)); return Ok(()); } // TODO: support larger payloads. if data.len() > MAX_PAYLOAD_SIZE { - response.send(Err(ModifyVtl2SettingsError::LargeSettingsNotSupported)); + response.complete(Err(ModifyVtl2SettingsError::LargeSettingsNotSupported)); return Ok(()); } @@ -524,7 +525,7 @@ impl GedChannel { self.modify = Some(response); } - GuestEmulationRequest::SaveGuestVtl2State(Rpc((), response)) => { + GuestEmulationRequest::SaveGuestVtl2State(rpc) => { let r = (|| { if self.save.is_some() { return Err(SaveRestoreError::OperationInProgress); @@ -551,11 +552,11 @@ impl GedChannel { match r { Ok(()) => { self.save = Some(InProgressSave { - response, + rpc, buffer: Vec::new(), }) } - Err(err) => response.send(Err(err)), + Err(err) => rpc.complete(Err(err)), } } }; @@ -903,7 +904,7 @@ impl GedChannel { if r.is_ok() { state.save_restore_buf = Some(save.buffer); } - save.response.send(r); + save.rpc.complete(r); } Ok(()) } @@ -1121,7 +1122,7 @@ impl GedChannel { _ => return Err(Error::InvalidFieldValue), }; for response in state.waiting_for_vtl0_start.drain(..) { - response.send(result.clone()); + response.complete(result.clone()); } self.vtl0_start_report = Some(result); Ok(()) @@ -1177,7 +1178,7 @@ impl GedChannel { } _ => return Err(Error::InvalidFieldValue), }; - modify.send(r); + modify.complete(r); Ok(()) } diff --git a/vm/devices/get/guest_emulation_transport/src/process_loop.rs b/vm/devices/get/guest_emulation_transport/src/process_loop.rs index 22b69c19e..72b8c3b2e 100644 --- a/vm/devices/get/guest_emulation_transport/src/process_loop.rs +++ b/vm/devices/get/guest_emulation_transport/src/process_loop.rs @@ -19,6 +19,9 @@ use guid::Guid; use inspect::Inspect; use inspect::InspectMut; use inspect_counters::Counter; +use mesh::rpc::Rpc; +use mesh::rpc::RpcError; +use mesh::rpc::TryRpcSend; use mesh::RecvError; use parking_lot::Mutex; use std::cmp::min; @@ -398,6 +401,15 @@ impl BufferedSender { } } +impl TryRpcSend for &mut BufferedSender { + type Message = T; + type Error = BufferedSenderFull; + + fn try_send_rpc(self, message: Self::Message) -> Result<(), Self::Error> { + self.send(message) + } +} + /// A variant of `Option>` for late-bound guest notification /// consumers that buffers a fixed-number of messages during the window between /// GET init and worker startup. @@ -508,7 +520,7 @@ struct GuestNotificationListeners { // pair of notifications that don't really act like "notifications" for the // foreseeable future... enum GuestNotificationResponse { - ModifyVtl2Settings(Result>, RecvError>), + ModifyVtl2Settings(Result<(), RpcError>>), } #[derive(Default, Inspect)] @@ -553,7 +565,7 @@ struct PipeChannels { enum WriteRequest { Message(Vec), - Flush(mesh::OneshotSender<()>), + Flush(Rpc<(), ()>), } impl HostRequestPipeAccess { @@ -780,7 +792,7 @@ impl ProcessLoop { } match self.write_recv.recv().await.unwrap() { WriteRequest::Message(message) => outgoing = message, - WriteRequest::Flush(send) => send.send(()), + WriteRequest::Flush(send) => send.complete(()), } } } @@ -953,7 +965,7 @@ impl ProcessLoop { /// for its response. fn push_basic_host_request_handler( &mut self, - req: mesh::rpc::Rpc, + req: Rpc, f: impl 'static + Send + FnOnce(I) -> Req, ) where Req: AsBytes + 'static + Send + Sync, @@ -973,10 +985,10 @@ impl ProcessLoop { match message { // GET infrastructure - not part of the GET protocol itself. // No direct interaction with the host. - Msg::FlushWrites(mesh::rpc::Rpc((), response)) => { + Msg::FlushWrites(rpc) => { self.pipe_channels .message_send - .send(WriteRequest::Flush(response)); + .send(WriteRequest::Flush(rpc)); } Msg::Inspect(req) => { req.inspect(self); @@ -1132,9 +1144,10 @@ impl ProcessLoop { Msg::SendServicingState(req) => self.push_host_request_handler(move |access| { req.handle_must_succeed(|data| request_send_servicing_state(access, data)) }), - Msg::CompleteStartVtl0(mesh::rpc::Rpc(input, res)) => { + Msg::CompleteStartVtl0(rpc) => { + let (input, res) = rpc.split(); self.complete_start_vtl0(input)?; - res.send(()); + res.complete(()); } // Host Notifications (don't require a response) @@ -1368,18 +1381,14 @@ impl ProcessLoop { vtl2_settings_buf: Vec, kind: get_protocol::GuestNotifications, ) -> Result<(), FatalError> { - let (result_send, result_recv) = mesh::oneshot(); - - let req = ModifyVtl2SettingsRequest(mesh::rpc::Rpc(vtl2_settings_buf, result_send)); - let res = result_recv + let res = self + .guest_notification_listeners + .vtl2_settings + .try_call_failable(ModifyVtl2SettingsRequest, vtl2_settings_buf) + .map_err(|_| FatalError::TooManyGuestNotifications(kind))? .map(GuestNotificationResponse::ModifyVtl2Settings) .boxed(); - self.guest_notification_listeners - .vtl2_settings - .send(req) - .map_err(|_| FatalError::TooManyGuestNotifications(kind))?; - self.guest_notification_responses.push(res); Ok(()) } @@ -1439,13 +1448,14 @@ impl ProcessLoop { fn complete_modify_vtl2_settings( &mut self, - result: Result>, RecvError>, + result: Result<(), RpcError>>, ) -> Result<(), FatalError> { - let errors = result.unwrap_or_else(|err| { - Err(vec![Vtl2SettingsErrorInfo::new( + let errors = result.map_err(|err| match err { + RpcError::Call(err) => err, + RpcError::Channel(err) => vec![Vtl2SettingsErrorInfo::new( underhill_config::Vtl2SettingsErrorCode::InternalFailure, err.to_string(), - )]) + )], }); let (status, errors_json) = match errors { diff --git a/vm/devices/hyperv_ic/src/shutdown.rs b/vm/devices/hyperv_ic/src/shutdown.rs index 80e4002b3..43b91f2bf 100644 --- a/vm/devices/hyperv_ic/src/shutdown.rs +++ b/vm/devices/hyperv_ic/src/shutdown.rs @@ -53,7 +53,7 @@ pub struct ShutdownChannel { pipe: MessagePipe, state: ChannelState, #[inspect(with = "Option::is_some")] - pending_shutdown: Option>, + pending_shutdown: Option>, } #[derive(Inspect)] @@ -153,8 +153,9 @@ impl ShutdownChannel { } ChannelState::Ready { ref mut state, .. } => match state { ReadyState::Ready => { - self.pending_shutdown = Some(rpc.1); - *state = ReadyState::SendShutdown(rpc.0); + let (input, rpc) = rpc.split(); + self.pending_shutdown = Some(rpc); + *state = ReadyState::SendShutdown(input); } ReadyState::SendShutdown { .. } | ReadyState::WaitShutdown => { rpc.complete(ShutdownResult::AlreadyInProgress) @@ -278,7 +279,7 @@ impl ShutdownChannel { ShutdownResult::Failed(status) }; if let Some(send) = self.pending_shutdown.take() { - send.send(result); + send.complete(result); } *state = ReadyState::Ready; } diff --git a/vm/devices/net/net_consomme/consomme/src/lib.rs b/vm/devices/net/net_consomme/consomme/src/lib.rs index 1a67433a2..29e7f410d 100644 --- a/vm/devices/net/net_consomme/consomme/src/lib.rs +++ b/vm/devices/net/net_consomme/consomme/src/lib.rs @@ -27,6 +27,7 @@ mod windows; use inspect::InspectMut; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use pal_async::driver::Driver; use smoltcp::phy::Checksum; @@ -51,7 +52,7 @@ use thiserror::Error; pub enum ConsommeMessageError { /// Communication error with running instance. #[error("communication error")] - Mesh(mesh::RecvError), + Mesh(RpcError), /// Error executing request on current network instance. #[error("network err")] Network(DropReason), diff --git a/vm/devices/net/net_packet_capture/src/lib.rs b/vm/devices/net/net_packet_capture/src/lib.rs index b7dbb0ce1..ef94c9bb8 100644 --- a/vm/devices/net/net_packet_capture/src/lib.rs +++ b/vm/devices/net/net_packet_capture/src/lib.rs @@ -310,7 +310,7 @@ impl Endpoint for PacketCaptureEndpoint { Message::PacketCaptureEndpointCommand( PacketCaptureEndpointCommand::PacketCapture(rpc), ) => { - let options = rpc.0; + let (options, response) = rpc.split(); let result = async { let id = &self.id; let start = match options.operation { @@ -341,7 +341,7 @@ impl Endpoint for PacketCaptureEndpoint { Err(e) => (Err(e), false), Ok(value) => (Ok(()), value), }; - rpc.1.send(result.map_err(RemoteError::new)); + response.complete(result.map_err(RemoteError::new)); if restart_required { break EndpointAction::RestartRequired; } diff --git a/vm/devices/net/netvsp/src/test.rs b/vm/devices/net/netvsp/src/test.rs index 802b19c78..252f8413e 100644 --- a/vm/devices/net/netvsp/src/test.rs +++ b/vm/devices/net/netvsp/src/test.rs @@ -22,8 +22,8 @@ use hvdef::hypercall::HvGuestOsId; use hvdef::hypercall::HvGuestOsMicrosoft; use hvdef::hypercall::HvGuestOsMicrosoftIds; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; -use mesh::RecvError; use net_backend::null::NullEndpoint; use net_backend::DisconnectableEndpoint; use net_backend::Endpoint; @@ -393,12 +393,8 @@ impl TestNicDevice { req: impl FnOnce(Rpc) -> ChannelRequest, input: I, f: impl 'static + Send + FnOnce(R) -> ChannelResponse, - ) -> Result { - let (response, recv) = mesh::oneshot(); - self.offer_input - .request_send - .send((req)(Rpc(input, response))); - recv.await.map(f) + ) -> Result { + self.offer_input.request_send.call(req, input).await.map(f) } async fn connect_vmbus_channel(&mut self) -> TestNicChannel<'_> { @@ -1090,9 +1086,7 @@ impl TestVirtualFunctionState { pub async fn set_ready(&self, is_ready: bool) { let ready_callback = self.oneshot_ready_callback.lock().take(); if let Some(ready_callback) = ready_callback { - let (result_send, result_recv) = mesh::oneshot(); - ready_callback.send(Rpc(is_ready, result_send)); - result_recv.await.unwrap(); + ready_callback.call(|x| x, is_ready).await.unwrap(); } *self.is_ready.0.lock() = Some(is_ready); self.is_ready.1.notify(usize::MAX); diff --git a/vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs b/vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs index 07233e987..44d8f355a 100644 --- a/vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs +++ b/vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs @@ -22,6 +22,7 @@ use guestmem::GuestMemoryError; use inspect::Inspect; use inspect_counters::Counter; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use mesh::Cancel; use mesh::CancelContext; @@ -92,11 +93,7 @@ impl PendingCommands { } /// Inserts a command into the pending list, updating it with a new CID. - fn insert( - &mut self, - command: &mut spec::Command, - respond: mesh::OneshotSender, - ) { + fn insert(&mut self, command: &mut spec::Command, respond: Rpc<(), spec::Completion>) { let entry = self.commands.vacant_entry(); assert!(entry.key() < Self::MAX_CIDS); assert_eq!(self.next_cid_high_bits % Self::CID_SEQ_OFFSET, Wrapping(0)); @@ -109,7 +106,7 @@ impl PendingCommands { }); } - fn remove(&mut self, cid: u16) -> mesh::OneshotSender { + fn remove(&mut self, cid: u16) -> Rpc<(), spec::Completion> { let command = self .commands .try_remove((cid & Self::CID_KEY_MASK) as usize) @@ -152,7 +149,6 @@ impl PendingCommands { commands: commands .iter() .map(|state| { - let (send, mut _recv) = mesh::oneshot::(); // To correctly restore Slab we need both the command index, // inherited from command's CID, and the command itself. ( @@ -160,7 +156,7 @@ impl PendingCommands { (state.command.cdw0.cid() & Self::CID_KEY_MASK) as usize, PendingCommand { command: state.command, - respond: send, + respond: Rpc::detached(()), }, ) }) @@ -336,7 +332,7 @@ impl QueuePair { #[allow(missing_docs)] pub enum RequestError { #[error("queue pair is gone")] - Gone(#[source] mesh::RecvError), + Gone(#[source] RpcError), #[error("nvme error")] Nvme(#[source] NvmeError), #[error("memory error")] @@ -579,7 +575,7 @@ struct PendingCommand { // Keep the command around for diagnostics. command: spec::Command, #[inspect(skip)] - respond: mesh::OneshotSender, + respond: Rpc<(), spec::Completion>, } enum Req { @@ -659,7 +655,8 @@ impl QueueHandler { match event { Event::Request(req) => match req { - Req::Command(Rpc(mut command, respond)) => { + Req::Command(rpc) => { + let (mut command, respond) = rpc.split(); self.commands.insert(&mut command, respond); self.sq.write(command).unwrap(); self.stats.issued.increment(); @@ -679,7 +676,7 @@ impl QueueHandler { self.drain_after_restore = false; } self.sq.update_head(completion.sqhd); - respond.send(completion); + respond.complete(completion); self.stats.completed.increment(); } } diff --git a/vm/devices/storage/nvme/src/workers/coordinator.rs b/vm/devices/storage/nvme/src/workers/coordinator.rs index fb1c1b402..2ddb1ea15 100644 --- a/vm/devices/storage/nvme/src/workers/coordinator.rs +++ b/vm/devices/storage/nvme/src/workers/coordinator.rs @@ -17,6 +17,7 @@ use guestmem::GuestMemory; use guid::Guid; use inspect::Inspect; use inspect::InspectMut; +use mesh::rpc::PendingRpc; use mesh::rpc::Rpc; use mesh::rpc::RpcSend; use pal_async::task::Spawn; @@ -39,9 +40,9 @@ pub struct NvmeWorkers { #[derive(Debug)] enum EnableState { Disabled, - Enabling(mesh::OneshotReceiver<()>), + Enabling(PendingRpc<()>), Enabled, - Resetting(mesh::OneshotReceiver<()>), + Resetting(PendingRpc<()>), } impl InspectMut for NvmeWorkers { diff --git a/vm/devices/vmbus/vmbus_channel/src/channel.rs b/vm/devices/vmbus/vmbus_channel/src/channel.rs index 1bbb99cb4..a3f67f289 100644 --- a/vm/devices/vmbus/vmbus_channel/src/channel.rs +++ b/vm/devices/vmbus/vmbus_channel/src/channel.rs @@ -427,9 +427,6 @@ pub enum ChannelRestoreError { /// Failed to enable subchannels. #[error("failed to enable subchannels")] EnablingSubchannels(#[source] anyhow::Error), - /// Failed to send restore request. - #[error("failed to send restore request")] - SendingRequest(#[source] RecvError), /// Failed to restore vmbus channel. #[error("failed to restore vmbus channel")] RestoreError(#[source] anyhow::Error), @@ -570,8 +567,8 @@ impl Device { self.handle_gpadl(gpadl.id, gpadl.count, gpadl.buf, channel_idx); true }), - ChannelRequest::TeardownGpadl(Rpc(id, response_send)) => { - self.handle_teardown_gpadl(id, response_send, channel_idx); + ChannelRequest::TeardownGpadl(rpc) => { + self.handle_teardown_gpadl(rpc, channel_idx); } ChannelRequest::Modify(rpc) => { rpc.handle(|req| async { @@ -644,16 +641,12 @@ impl Device { } } - fn handle_teardown_gpadl( - &mut self, - id: GpadlId, - response_send: mesh::OneshotSender<()>, - channel_idx: usize, - ) { + fn handle_teardown_gpadl(&mut self, rpc: Rpc, channel_idx: usize) { + let id = *rpc.input(); if let Some(f) = self.gpadl_map.remove( id, Box::new(move || { - response_send.send(()); + rpc.complete(()); }), ) { f() @@ -787,9 +780,8 @@ impl Device { let mut results = Vec::with_capacity(states.len()); for (channel_idx, open) in states.iter().copied().enumerate() { let result = self.server_requests[channel_idx] - .call(ChannelServerRequest::Restore, open) + .call_failable(ChannelServerRequest::Restore, open) .await - .map_err(ChannelRestoreError::SendingRequest)? .map_err(|err| ChannelRestoreError::RestoreError(err.into()))?; assert!(open == result.open_request.is_some()); diff --git a/vm/devices/vmbus/vmbus_channel/src/offer.rs b/vm/devices/vmbus/vmbus_channel/src/offer.rs index df8340b38..7241ccfb0 100644 --- a/vm/devices/vmbus/vmbus_channel/src/offer.rs +++ b/vm/devices/vmbus/vmbus_channel/src/offer.rs @@ -104,7 +104,8 @@ impl Offer { let mut open_done = None; while let Ok(request) = request_recv.recv().await { match request { - ChannelRequest::Open(Rpc(open_request, response_send)) => { + ChannelRequest::Open(rpc) => { + let (open_request, response_send) = rpc.split(); let done = Arc::new(AtomicBool::new(false)); send.send(OpenMessage { open_request, @@ -113,7 +114,8 @@ impl Offer { }); open_done = Some(done); } - ChannelRequest::Close(Rpc((), _response_send)) => { + ChannelRequest::Close(rpc) => { + let _response_send = rpc; // TODO: figure out if we should really just drop this here. open_done .take() .expect("channel must be open") @@ -137,11 +139,12 @@ impl Offer { } }) } - ChannelRequest::TeardownGpadl(Rpc(id, response_send)) => { + ChannelRequest::TeardownGpadl(rpc) => { + let (id, response_send) = rpc.split(); if let Some(f) = gpadls.remove( id, Box::new(move || { - response_send.send(()); + response_send.complete(()); }), ) { f(); @@ -217,18 +220,18 @@ struct OpenMessage { response: OpenResponse, } -struct OpenResponse(Option>); +struct OpenResponse(Option>); impl OpenResponse { fn respond(mut self, open: bool) { - self.0.take().unwrap().send(open) + self.0.take().unwrap().complete(open) } } impl Drop for OpenResponse { fn drop(&mut self) { - if let Some(send) = self.0.take() { - send.send(false); + if let Some(rpc) = self.0.take() { + rpc.complete(false); } } } diff --git a/vm/devices/vmbus/vmbus_client/src/lib.rs b/vm/devices/vmbus/vmbus_client/src/lib.rs index c9d2e7d36..f51118052 100644 --- a/vm/devices/vmbus/vmbus_client/src/lib.rs +++ b/vm/devices/vmbus/vmbus_client/src/lib.rs @@ -383,7 +383,7 @@ enum ChannelState { /// The channel has been offered to the client. Offered, /// The channel has requested the server to be opened. - Opening(mesh::OneshotSender), + Opening(Rpc<(), bool>), /// The channel has been successfully opened. Opened, } @@ -402,7 +402,7 @@ struct Channel { offer: protocol::OfferChannel, response_send: mesh::Sender, state: ChannelState, - modify_response_send: Option>, + modify_response_send: Option>, } impl std::fmt::Debug for Channel { @@ -491,7 +491,7 @@ impl ClientTask { FeatureFlags::new() }; - let request = &rpc.0; + let request = rpc.input(); tracing::debug!(version = ?version, ?feature_flags, "VmBus client connecting"); let target_info = protocol::TargetInfo::new(SINT, VTL, feature_flags); @@ -552,7 +552,7 @@ impl ClientTask { return; } - let message = protocol::ModifyConnection::from(request.0); + let message = protocol::ModifyConnection::from(*request.input()); self.modify_request = Some(request); self.inner.send(&message); } @@ -781,7 +781,7 @@ impl ClientTask { unreachable!("validated above"); }; - sender.send(gpadl_created) + sender.complete(gpadl_created) } fn handle_open_result(&mut self, result: protocol::OpenResult) { @@ -811,7 +811,7 @@ impl ClientTask { return; }; - rpc.send(channel_opened); + rpc.complete(channel_opened); } fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) { @@ -893,7 +893,7 @@ impl ClientTask { return; }; - sender.send(response.status); + sender.complete(response.status); } fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) { @@ -981,7 +981,7 @@ impl ClientTask { } tracing::info!(channel_id = channel_id.0, "opening channel on host"); - let request = &rpc.0; + let request = rpc.input(); let open_data = &request.open_data; let open_channel = protocol::OpenChannel { @@ -1014,15 +1014,16 @@ impl ClientTask { self.inner.send(&open_channel); } - self.inner.channels.get_mut(&channel_id).unwrap().state = ChannelState::Opening(rpc.1); + self.inner.channels.get_mut(&channel_id).unwrap().state = + ChannelState::Opening(rpc.split().1); } fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: Rpc) { - let request = &rpc.0; + let (request, rpc) = rpc.split(); if self .inner .gpadls - .insert((channel_id, request.id), GpadlState::Offered(rpc.1)) + .insert((channel_id, request.id), GpadlState::Offered(rpc)) .is_some() { panic!( @@ -1129,12 +1130,12 @@ impl ClientTask { panic!("duplicate channel modify request {channel_id:?}"); } - channel.modify_response_send = Some(rpc.1); - let request = &rpc.0; + let (request, response) = rpc.split(); + channel.modify_response_send = Some(response); let payload = match request { ModifyRequest::TargetVp { target_vp } => protocol::ModifyChannel { channel_id, - target_vp: *target_vp, + target_vp, }, }; @@ -1311,7 +1312,7 @@ impl Inspect for ClientTask { #[derive(Debug)] enum GpadlState { /// GpadlHeader has been sent to the host. - Offered(mesh::OneshotSender), + Offered(Rpc<(), bool>), /// Host has responded with GpadlCreated. Created, /// GpadlTeardown message has been sent to the host. @@ -1795,8 +1796,8 @@ mod tests { let (server, mut client, _) = test_init(); let channel = server.get_channel(&mut client).await; - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Open(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Open, OpenRequest { open_data: OpenData { target_vp: 0, @@ -1808,8 +1809,7 @@ mod tests { }, flags: OpenChannelFlags::new(), }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -1846,8 +1846,8 @@ mod tests { let (server, mut client, _) = test_init(); let channel = server.get_channel(&mut client).await; - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Open(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Open, OpenRequest { open_data: OpenData { target_vp: 0, @@ -1859,8 +1859,7 @@ mod tests { }, flags: OpenChannelFlags::new(), }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -1899,11 +1898,10 @@ mod tests { // N.B. A real server requires the channel to be open before sending this, but the test // server doesn't care. - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Modify(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Modify, ModifyRequest::TargetVp { target_vp: 1 }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -2016,15 +2014,14 @@ mod tests { async fn test_gpadl_success() { let (server, mut client, _) = test_init(); let mut channel = server.get_channel(&mut client).await; - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Gpadl(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Gpadl, GpadlRequest { id: GpadlId(1), count: 1, buf: vec![5], }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -2079,15 +2076,14 @@ mod tests { async fn test_gpadl_fail() { let (server, mut client, _) = test_init(); let channel = server.get_channel(&mut client).await; - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Gpadl(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Gpadl, GpadlRequest { id: GpadlId(1), count: 1, buf: vec![7], }, - send, - ))); + ); assert_eq!( server.next().unwrap(), @@ -2121,15 +2117,14 @@ mod tests { let mut channel = server.get_channel(&mut client).await; let channel_id = ChannelId(0); let gpadl_id = GpadlId(1); - let (send, recv) = mesh::oneshot(); - channel.request_send.send(ChannelRequest::Gpadl(Rpc( + let recv = channel.request_send.call( + ChannelRequest::Gpadl, GpadlRequest { id: gpadl_id, count: 1, buf: vec![3], }, - send, - ))); + ); assert_eq!( server.next().unwrap(), diff --git a/vm/devices/vmbus/vmbus_relay/src/lib.rs b/vm/devices/vmbus/vmbus_relay/src/lib.rs index c8b26f79f..82ac48152 100644 --- a/vm/devices/vmbus/vmbus_relay/src/lib.rs +++ b/vm/devices/vmbus/vmbus_relay/src/lib.rs @@ -352,7 +352,7 @@ struct RelayChannel { /// State used to relay host-to-guest interrupts. interrupt_relay: Option, /// RPCs for gpadls that are waiting for a torndown message. - gpadls_tearing_down: HashMap>, + gpadls_tearing_down: HashMap>, } struct RelayChannelTask { @@ -434,7 +434,7 @@ impl RelayChannelTask { } fn handle_gpadl_teardown(&mut self, rpc: Rpc) { - let gpadl_id = rpc.0; + let (gpadl_id, rpc) = rpc.split(); tracing::trace!(gpadl_id = gpadl_id.0, "Tearing down GPADL"); let _ = &self diff --git a/vm/devices/vmbus/vmbus_server/src/lib.rs b/vm/devices/vmbus/vmbus_server/src/lib.rs index 4e9a7ce71..74828c496 100644 --- a/vm/devices/vmbus/vmbus_server/src/lib.rs +++ b/vm/devices/vmbus/vmbus_server/src/lib.rs @@ -33,13 +33,11 @@ use futures::StreamExt; use guestmem::GuestMemory; use hvdef::Vtl; use inspect::Inspect; -use mesh::error::RemoteError; -use mesh::error::RemoteResult; -use mesh::error::RemoteResultExt; use mesh::payload::Protobuf; +use mesh::rpc::FailableRpc; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; -use mesh::RecvError; use pal_async::task::Spawn; use pal_async::task::Task; use pal_event::Event; @@ -207,7 +205,7 @@ pub struct OfferInfo { #[derive(mesh::MeshPayload)] pub enum OfferRequest { - Offer(OfferInfo, mesh::OneshotSender>), + Offer(FailableRpc), } impl Inspect for VmbusServer { @@ -623,13 +621,13 @@ struct ServerTaskInner { hvsock_send: mesh::Sender, channels: HashMap, channel_responses: FuturesUnordered< - Pin)>>>, + Pin)>>>, >, external_server_send: Option>, relay_send: mesh::Sender, channel_bitmap: Option>, shared_event_port: Option>, - reset_done: Option>, + reset_done: Option>, enable_mnf: bool, } @@ -722,7 +720,7 @@ impl ServerTask { &mut self, offer_id: OfferId, seq: u64, - response: Result, + response: Result, ) { // Validate the sequence to ensure the response is not for a revoked channel. let channel = self @@ -847,9 +845,9 @@ impl ServerTask { fn handle_request(&mut self, request: VmbusRequest) { tracing::debug!(?request, "handle_request"); match request { - VmbusRequest::Reset(Rpc((), done)) => { + VmbusRequest::Reset(rpc) => { assert!(self.inner.reset_done.is_none()); - self.inner.reset_done = Some(done); + self.inner.reset_done = Some(rpc); self.server.with_notifier(&mut self.inner).reset(); // TODO: clear pending messages and other requests. } @@ -1000,8 +998,8 @@ impl ServerTask { } r = self.offer_recv.select_next_some() => { match r { - OfferRequest::Offer(request, response) => { - response.send(self.handle_offer(request).map_err(RemoteError::new)) + OfferRequest::Offer(rpc) => { + rpc.handle_failable_sync(|request| { self.handle_offer(request) }) }, } } @@ -1156,10 +1154,9 @@ impl channels::Notifier for ServerTaskInner { req: impl FnOnce(Rpc) -> ChannelRequest, input: I, f: impl 'static + Send + FnOnce(R) -> ChannelResponse, - ) -> Pin)>>> + ) -> Pin)>>> { - let (response, recv) = mesh::oneshot(); - channel.send.send((req)(Rpc(input, response))); + let recv = channel.send.call(req, input); let seq = channel.seq; Box::pin(async move { let r = recv.await.map(f); @@ -1337,7 +1334,7 @@ impl channels::Notifier for ServerTaskInner { } let done = self.reset_done.take().expect("must have requested reset"); - done.send(()); + done.complete(()); } } @@ -1503,9 +1500,9 @@ impl VmbusServerControl { /// This is used by the relay to forward the host's parameters. pub async fn offer_core(&self, offer_info: OfferInfo) -> anyhow::Result { let flags = offer_info.params.flags; - let (send, recv) = mesh::oneshot(); - self.send.send(OfferRequest::Offer(offer_info, send)); - recv.await.flatten()?; + self.send + .call_failable(OfferRequest::Offer, offer_info) + .await?; Ok(OfferResources::new( self.mem.clone(), if flags.confidential_ring_buffer() || flags.confidential_external_memory() { @@ -1713,7 +1710,7 @@ mod tests { panic!("Wrong request"); }; - f(&rpc.0); + f(rpc.input()); rpc.complete(true); } diff --git a/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs b/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs index 3d0643cc5..e1c1997a4 100644 --- a/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs +++ b/vm/devices/vmbus/vmbus_server/src/proxyintegration.rs @@ -17,8 +17,7 @@ use anyhow::Context; use futures::stream::SelectAll; use futures::StreamExt; use guestmem::GuestMemory; -use mesh::error::RemoteResultExt; -use mesh::rpc::Rpc; +use mesh::rpc::RpcSend; use mesh::Cancel; use mesh::CancelContext; use pal_async::driver::SpawnDriver; @@ -238,18 +237,17 @@ impl ProxyTask { }; let (request_send, request_recv) = mesh::channel(); let (server_request_send, server_request_recv) = mesh::channel(); - let (send, recv) = mesh::oneshot(); - self.server.send.send(OfferRequest::Offer( + let recv = self.server.send.call_failable( + OfferRequest::Offer, OfferInfo { params: offer.into(), event: Interrupt::from_event(incoming_event), request_send, server_request_recv, }, - send, - )); + ); - let (request_recv, server_request_send) = match recv.await.flatten() { + let (request_recv, server_request_send) = match recv.await { Ok(()) => (Some(request_recv), Some(server_request_send)), Err(err) => { // Currently there is no way to propagate this failure. @@ -356,7 +354,7 @@ impl ProxyTask { } // Modifying the target VP is handle by the server, there is nothing the proxy // driver needs to do. - ChannelRequest::Modify(Rpc(_, response)) => response.send(0), + ChannelRequest::Modify(rpc) => rpc.complete(0), } } None => { diff --git a/vm/vmgs/vmgs_broker/src/client.rs b/vm/vmgs/vmgs_broker/src/client.rs index ea713d08a..c0ca80a21 100644 --- a/vm/vmgs/vmgs_broker/src/client.rs +++ b/vm/vmgs/vmgs_broker/src/client.rs @@ -6,6 +6,7 @@ use crate::broker::VmgsBrokerRpc; use inspect::Inspect; +use mesh_channel::rpc::RpcError; use mesh_channel::rpc::RpcSend; use thiserror::Error; use tracing::instrument; @@ -18,12 +19,21 @@ use vmgs_format::FileId; pub enum VmgsClientError { /// VMGS broker is offline #[error("broker is offline")] - BrokerOffline(#[from] mesh_channel::RecvError), + BrokerOffline(#[from] RpcError), /// VMGS error #[error("vmgs error")] Vmgs(#[from] vmgs::Error), } +impl From> for VmgsClientError { + fn from(value: RpcError) -> Self { + match value { + RpcError::Call(e) => VmgsClientError::Vmgs(e), + RpcError::Channel(e) => VmgsClientError::BrokerOffline(RpcError::Channel(e)), + } + } +} + /// Client to interact with a backend-agnostic VMGS instance. #[derive(Clone)] pub struct VmgsClient { @@ -42,8 +52,8 @@ impl VmgsClient { pub async fn get_file_info(&self, file_id: FileId) -> Result { let res = self .control - .call(VmgsBrokerRpc::GetFileInfo, file_id) - .await??; + .call_failable(VmgsBrokerRpc::GetFileInfo, file_id) + .await?; Ok(res) } @@ -53,8 +63,8 @@ impl VmgsClient { pub async fn read_file(&self, file_id: FileId) -> Result, VmgsClientError> { let res = self .control - .call(VmgsBrokerRpc::ReadFile, file_id) - .await??; + .call_failable(VmgsBrokerRpc::ReadFile, file_id) + .await?; Ok(res) } @@ -66,8 +76,8 @@ impl VmgsClient { #[instrument(skip_all, fields(file_id))] pub async fn write_file(&self, file_id: FileId, buf: Vec) -> Result<(), VmgsClientError> { self.control - .call(VmgsBrokerRpc::WriteFile, (file_id, buf)) - .await??; + .call_failable(VmgsBrokerRpc::WriteFile, (file_id, buf)) + .await?; Ok(()) } @@ -83,8 +93,8 @@ impl VmgsClient { buf: Vec, ) -> Result<(), VmgsClientError> { self.control - .call(VmgsBrokerRpc::WriteFileEncrypted, (file_id, buf)) - .await??; + .call_failable(VmgsBrokerRpc::WriteFileEncrypted, (file_id, buf)) + .await?; Ok(()) } diff --git a/vmm_core/src/partition_unit/vp_set.rs b/vmm_core/src/partition_unit/vp_set.rs index db171a0de..92529683b 100644 --- a/vmm_core/src/partition_unit/vp_set.rs +++ b/vmm_core/src/partition_unit/vp_set.rs @@ -20,6 +20,7 @@ use guestmem::GuestMemory; use hvdef::Vtl; use inspect::Inspect; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; use mesh::rpc::RpcSend; use pal_async::local::block_with_io; use parking_lot::Mutex; @@ -890,7 +891,7 @@ pub struct RegisterSetError(&'static str, #[source] anyhow::Error); #[derive(Debug, Error)] #[error("the vp runner was dropped")] -struct RunnerGoneError(#[source] mesh::RecvError); +struct RunnerGoneError(#[source] RpcError); #[cfg(feature = "gdb")] impl VpSet { diff --git a/vmm_core/state_unit/src/lib.rs b/vmm_core/state_unit/src/lib.rs index db0ad3fad..b1a3561b3 100644 --- a/vmm_core/state_unit/src/lib.rs +++ b/vmm_core/state_unit/src/lib.rs @@ -36,10 +36,11 @@ use futures::StreamExt; use futures_concurrency::stream::Merge; use inspect::Inspect; use inspect::InspectMut; -use mesh::oneshot; use mesh::payload::Protobuf; use mesh::rpc::FailableRpc; use mesh::rpc::Rpc; +use mesh::rpc::RpcError; +use mesh::rpc::RpcSend; use mesh::MeshPayload; use mesh::Receiver; use mesh::Sender; @@ -257,7 +258,7 @@ pub struct NameInUse(Arc); struct UnitRecvError { name: Arc, #[source] - source: mesh::RecvError, + source: RpcError, } #[derive(Debug, Clone)] @@ -893,7 +894,6 @@ fn state_change( request: impl FnOnce(Rpc) -> StateRequest, input: Option, ) -> impl Future, UnitRecvError>> { - let (response_send, response_recv) = oneshot(); let send = unit.send.clone(); async move { @@ -901,8 +901,8 @@ fn state_change( let span = tracing::info_span!("device_state_change", device = name.as_ref()); async move { let start = Instant::now(); - send.send((request)(Rpc(input, response_send))); - let r = response_recv + let r = send + .call(request, input) .await .map_err(|err| UnitRecvError { name, source: err }); tracing::debug!(duration = ?Instant::now() - start, "device state change complete"); diff --git a/workers/debug_worker/src/gdb/mod.rs b/workers/debug_worker/src/gdb/mod.rs index 51801d8e0..99b54ad7b 100644 --- a/workers/debug_worker/src/gdb/mod.rs +++ b/workers/debug_worker/src/gdb/mod.rs @@ -4,7 +4,6 @@ use anyhow::Context; use futures::executor::block_on; use gdbstub::common::Tid; -use mesh::error::RemoteResultExt; use mesh::rpc::RpcSend; use std::num::NonZeroUsize; use vmm_core_defs::debug_rpc::DebugRequest; @@ -66,11 +65,10 @@ impl VmProxy { #[allow(dead_code)] // TODO: add monitor command to inspect physical memory? fn read_guest_physical_memory(&mut self, gpa: u64, data: &mut [u8]) -> anyhow::Result<()> { - let buf = block_on(self.req_chan.call( + let buf = block_on(self.req_chan.call_failable( DebugRequest::ReadMemory, (GuestAddress::Gpa(gpa), data.len()), )) - .flatten() .context("failed to read memory")?; data.copy_from_slice( buf.get(..data.len()) @@ -86,11 +84,10 @@ impl VmProxy { gva: u64, data: &mut [u8], ) -> anyhow::Result<()> { - let buf = block_on(self.req_chan.call( + let buf = block_on(self.req_chan.call_failable( DebugRequest::ReadMemory, (GuestAddress::Gva { vp: vp_index, gva }, data.len()), )) - .flatten() .context("failed to read memory")?; data.copy_from_slice( buf.get(..data.len()) @@ -106,11 +103,10 @@ impl VmProxy { gva: u64, data: &[u8], ) -> anyhow::Result<()> { - block_on(self.req_chan.call( + block_on(self.req_chan.call_failable( DebugRequest::WriteMemory, (GuestAddress::Gva { vp: vp_index, gva }, data.to_vec()), )) - .flatten() .context("failed to write memory")?; Ok(()) } diff --git a/workers/debug_worker/src/gdb/targets/base.rs b/workers/debug_worker/src/gdb/targets/base.rs index c88862754..d0ec2d126 100644 --- a/workers/debug_worker/src/gdb/targets/base.rs +++ b/workers/debug_worker/src/gdb/targets/base.rs @@ -14,7 +14,6 @@ use gdbstub::target::ext::base::multithread::MultiThreadSingleStep; use gdbstub::target::ext::base::multithread::MultiThreadSingleStepOps; use gdbstub::target::ext::base::single_register_access::SingleRegisterAccess; use gdbstub::target::TargetResult; -use mesh::error::RemoteResultExt; use mesh::rpc::RpcSend; use vmm_core_defs::debug_rpc::DebugRequest; use vmm_core_defs::debug_rpc::DebugState; @@ -23,9 +22,12 @@ impl MultiThreadBase for VmTarget<'_, T> { fn read_registers(&mut self, regs: &mut T::Registers, tid: Tid) -> TargetResult<(), Self> { let vp_index = self.0.tid_to_vp(tid).fatal()?; - let state = block_on(self.0.req_chan.call(DebugRequest::GetVpState, vp_index)) - .flatten() - .nonfatal()?; + let state = block_on( + self.0 + .req_chan + .call_failable(DebugRequest::GetVpState, vp_index), + ) + .nonfatal()?; T::registers(&state, regs)?; Ok(()) @@ -34,18 +36,20 @@ impl MultiThreadBase for VmTarget<'_, T> { fn write_registers(&mut self, regs: &T::Registers, tid: Tid) -> TargetResult<(), Self> { let vp_index = self.0.tid_to_vp(tid).fatal()?; - let mut state = block_on(self.0.req_chan.call(DebugRequest::GetVpState, vp_index)) - .flatten() - .nonfatal()?; + let mut state = block_on( + self.0 + .req_chan + .call_failable(DebugRequest::GetVpState, vp_index), + ) + .nonfatal()?; T::update_registers(&mut state, regs)?; block_on( self.0 .req_chan - .call(DebugRequest::SetVpState, (vp_index, state)), + .call_failable(DebugRequest::SetVpState, (vp_index, state)), ) - .flatten() .nonfatal()?; Ok(()) @@ -109,9 +113,12 @@ impl SingleRegisterAccess for VmTarget<'_, T> { ) -> TargetResult { let vp_index = self.0.tid_to_vp(tid).fatal()?; - let state = block_on(self.0.req_chan.call(DebugRequest::GetVpState, vp_index)) - .flatten() - .nonfatal()?; + let state = block_on( + self.0 + .req_chan + .call_failable(DebugRequest::GetVpState, vp_index), + ) + .nonfatal()?; Ok(T::register(&state, reg_id, buf)?) } @@ -119,18 +126,20 @@ impl SingleRegisterAccess for VmTarget<'_, T> { fn write_register(&mut self, tid: Tid, reg_id: T::RegId, val: &[u8]) -> TargetResult<(), Self> { let vp_index = self.0.tid_to_vp(tid).fatal()?; - let mut state = block_on(self.0.req_chan.call(DebugRequest::GetVpState, vp_index)) - .flatten() - .nonfatal()?; + let mut state = block_on( + self.0 + .req_chan + .call_failable(DebugRequest::GetVpState, vp_index), + ) + .nonfatal()?; T::update_register(&mut state, reg_id, val)?; block_on( self.0 .req_chan - .call(DebugRequest::SetVpState, (vp_index, state)), + .call_failable(DebugRequest::SetVpState, (vp_index, state)), ) - .flatten() .nonfatal()?; Ok(()) diff --git a/workers/debug_worker/src/lib.rs b/workers/debug_worker/src/lib.rs index 6166e200b..b5671ccdf 100644 --- a/workers/debug_worker/src/lib.rs +++ b/workers/debug_worker/src/lib.rs @@ -121,7 +121,7 @@ where Ok(message) => match message { WorkerRpc::Stop => return Ok(()), WorkerRpc::Inspect(deferred) => deferred.inspect(&mut server), - WorkerRpc::Restart(response) => { + WorkerRpc::Restart(rpc) => { let vm_proxy = match server.state { State::Listening { vm_proxy } => vm_proxy, State::Connected { task, abort, .. } => { @@ -148,7 +148,7 @@ where }, } }; - response.send(Ok(state)); + rpc.complete(Ok(state)); return Ok(()); } }, diff --git a/workers/vnc_worker/src/lib.rs b/workers/vnc_worker/src/lib.rs index f636f3487..afe15107a 100644 --- a/workers/vnc_worker/src/lib.rs +++ b/workers/vnc_worker/src/lib.rs @@ -116,7 +116,7 @@ impl VncWorker { state: self.state, }; - let response = loop { + let rpc = loop { let r = futures::select! { // merge semantics r = rpc_recv.recv().fuse() => r, r = server.process(&driver).fuse() => break r.map(|_| None)?, @@ -130,7 +130,7 @@ impl VncWorker { Err(_) => break None, } }; - if let Some(response) = response { + if let Some(rpc) = rpc { let (view, input) = match server.state { State::Listening { view, input } => (view, input), State::Connected { task, abort, .. } => { @@ -144,7 +144,7 @@ impl VncWorker { framebuffer: view.0.access(), input_send: input.send, }; - response.send(Ok(state)); + rpc.complete(Ok(state)); } Ok(()) })