diff --git a/CHANGES.md b/CHANGES.md index 78a84fe..999ecb2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [4.4.0] - 2024-11-10 + +* Check service readiness once per decoded item + +* Run un-readiness check in separate task + ## [4.3.1] - 2024-11-05 * Do not rely on not_ready(), always check service readiness diff --git a/Cargo.toml b/Cargo.toml index b5b9172..1f929c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "4.3.1" +version = "4.4.0" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" @@ -18,10 +18,11 @@ features = ["ntex/tokio"] ntex-io = "2" ntex-net = "2" ntex-util = "2.5" -ntex-service = "3.3" +ntex-service = "3.3.3" ntex-bytes = "0.1" ntex-codec = "0.6" ntex-router = "0.5" +ntex-rt = "0.4" bitflags = "2" log = "0.4" pin-project-lite = "0.2" diff --git a/src/io.rs b/src/io.rs index 271779d..1cb3c66 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1,13 +1,14 @@ //! Framed transport dispatcher +use std::future::{poll_fn, Future}; use std::task::{ready, Context, Poll}; -use std::{cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc}; +use std::{cell::Cell, cell::RefCell, collections::VecDeque, pin::Pin, rc::Rc}; use ntex_codec::{Decoder, Encoder}; use ntex_io::{ Decoded, DispatchItem, DispatcherConfig, IoBoxed, IoRef, IoStatusUpdate, RecvError, }; use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service}; -use ntex_util::time::Seconds; +use ntex_util::{task::LocalWaker, time::Seconds}; type Response = ::Item; @@ -28,12 +29,13 @@ pin_project_lite::pin_project! { bitflags::bitflags! { #[derive(Copy, Clone, Eq, PartialEq, Debug)] struct Flags: u8 { - const READY_ERR = 0b000001; - const IO_ERR = 0b000010; - const KA_ENABLED = 0b000100; - const KA_TIMEOUT = 0b001000; - const READ_TIMEOUT = 0b010000; - const READY = 0b100000; + const READY_ERR = 0b0000001; + const IO_ERR = 0b0000010; + const KA_ENABLED = 0b0000100; + const KA_TIMEOUT = 0b0001000; + const READ_TIMEOUT = 0b0010000; + const READY = 0b0100000; + const READY_TASK = 0b1000000; } } @@ -43,7 +45,7 @@ struct DispatcherInner>, U: Encoder + Decoder + 'stat codec: U, service: PipelineBinding>, st: IoDispatcherState, - state: Rc>>, + state: Rc>, config: DispatcherConfig, read_remains: u32, read_remains_prev: u32, @@ -55,9 +57,11 @@ struct DispatcherInner>, U: Encoder + Decoder + 'stat } struct DispatcherState>, U: Encoder + Decoder> { - error: Option::Error>>, - base: usize, - queue: VecDeque>>, + error: Cell::Error>>>, + base: Cell, + ready: Cell, + queue: RefCell>>>, + waker: LocalWaker, } enum ServiceResult { @@ -116,11 +120,13 @@ where // register keepalive timer io.set_disconnect_timeout(config.disconnect_timeout()); - let state = Rc::new(RefCell::new(DispatcherState { - error: None, - base: 0, - queue: VecDeque::new(), - })); + let state = Rc::new(DispatcherState { + error: Cell::new(None), + base: Cell::new(0), + ready: Cell::new(false), + queue: RefCell::new(VecDeque::new()), + waker: LocalWaker::default(), + }); let keepalive_timeout = config.keepalive_timeout(); Dispatcher { @@ -169,53 +175,54 @@ where ::Item: 'static, { fn handle_result( - &mut self, + &self, item: Result, response_idx: usize, io: &IoRef, codec: &U, wake: bool, ) { - let idx = response_idx.wrapping_sub(self.base); + let mut queue = self.queue.borrow_mut(); + let idx = response_idx.wrapping_sub(self.base.get()); // handle first response if idx == 0 { - let _ = self.queue.pop_front(); - self.base = self.base.wrapping_add(1); + let _ = queue.pop_front(); + self.base.set(self.base.get().wrapping_add(1)); match item { Err(err) => { - self.error = Some(err.into()); + self.error.set(Some(err.into())); } Ok(Some(item)) => { if let Err(err) = io.encode(item, codec) { - self.error = Some(IoDispatcherError::Encoder(err)); + self.error.set(Some(IoDispatcherError::Encoder(err))); } } Ok(None) => (), } // check remaining response - while let Some(item) = self.queue.front_mut().and_then(|v| v.take()) { - let _ = self.queue.pop_front(); - self.base = self.base.wrapping_add(1); + while let Some(item) = queue.front_mut().and_then(|v| v.take()) { + let _ = queue.pop_front(); + self.base.set(self.base.get().wrapping_add(1)); match item { Err(err) => { - self.error = Some(err.into()); + self.error.set(Some(err.into())); } Ok(Some(item)) => { if let Err(err) = io.encode(item, codec) { - self.error = Some(IoDispatcherError::Encoder(err)); + self.error.set(Some(IoDispatcherError::Encoder(err))); } } Ok(None) => (), } } - if wake && self.queue.is_empty() { + if wake && queue.is_empty() { io.wake() } } else { - self.queue[idx] = ServiceResult::Ready(item); + queue[idx] = ServiceResult::Ready(item); } } } @@ -232,10 +239,12 @@ where let mut this = self.as_mut().project(); let inner = &mut this.inner; + inner.state.waker.register(cx.waker()); + // handle service response future if let Some(fut) = inner.response.as_mut() { if let Poll::Ready(item) = Pin::new(fut).poll(cx) { - inner.state.borrow_mut().handle_result( + inner.state.handle_result( item, inner.response_idx, inner.io.as_ref(), @@ -246,6 +255,12 @@ where } } + // start ready task + if inner.flags.contains(Flags::READY_TASK) { + inner.flags.insert(Flags::READY_TASK); + ntex_rt::spawn(not_ready(inner.state.clone(), inner.service.clone())); + } + loop { match inner.st { IoDispatcherState::Processing => { @@ -295,6 +310,7 @@ where PollService::Continue => continue, }; + inner.state.ready.set(false); inner.call_service(cx, item); } // handle write back-pressure @@ -328,7 +344,7 @@ where } } - if inner.state.borrow().queue.is_empty() { + if inner.state.queue.borrow().is_empty() { if inner.io.poll_shutdown(cx).is_ready() { log::trace!("{}: io shutdown completed", inner.io.tag()); inner.st = IoDispatcherState::Shutdown; @@ -361,7 +377,7 @@ where Poll::Ready( if let Some(IoDispatcherError::Service(err)) = - inner.state.borrow_mut().error.take() + inner.state.error.take() { Err(err) } else { @@ -384,37 +400,37 @@ where ::Item: 'static, { fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem) { - let mut state = self.state.borrow_mut(); let mut fut = self.service.call_nowait(item); + let mut queue = self.state.queue.borrow_mut(); // optimize first call if self.response.is_none() { if let Poll::Ready(res) = Pin::new(&mut fut).poll(cx) { // check if current result is only response - if state.queue.is_empty() { + if queue.is_empty() { match res { Err(err) => { - state.error = Some(err.into()); + self.state.error.set(Some(err.into())); } Ok(Some(item)) => { if let Err(err) = self.io.encode(item, &self.codec) { - state.error = Some(IoDispatcherError::Encoder(err)); + self.state.error.set(Some(IoDispatcherError::Encoder(err))); } } Ok(None) => (), } } else { - self.response_idx = state.base.wrapping_add(state.queue.len()); - state.queue.push_back(ServiceResult::Ready(res)); + queue.push_back(ServiceResult::Ready(res)); + self.response_idx = self.state.base.get().wrapping_add(queue.len()); } } else { self.response = Some(fut); - self.response_idx = state.base.wrapping_add(state.queue.len()); - state.queue.push_back(ServiceResult::Pending); + self.response_idx = self.state.base.get().wrapping_add(queue.len()); + queue.push_back(ServiceResult::Pending); } } else { - let response_idx = state.base.wrapping_add(state.queue.len()); - state.queue.push_back(ServiceResult::Pending); + let response_idx = self.state.base.get().wrapping_add(queue.len()); + queue.push_back(ServiceResult::Pending); let st = self.io.get_ref(); let codec = self.codec.clone(); @@ -422,15 +438,14 @@ where ntex_util::spawn(async move { let item = fut.await; - state.borrow_mut().handle_result(item, response_idx, &st, &codec, true); + state.handle_result(item, response_idx, &st, &codec, true); }); } } fn check_error(&mut self) -> PollService { // check for errors - let mut state = self.state.borrow_mut(); - if let Some(err) = state.error.take() { + if let Some(err) = self.state.error.take() { log::trace!("{}: Error occured, stopping dispatcher", self.io.tag()); self.st = IoDispatcherState::Stop; match err { @@ -438,7 +453,7 @@ where PollService::Item(DispatchItem::EncoderError(err)) } IoDispatcherError::Service(err) => { - state.error = Some(IoDispatcherError::Service(err)); + self.state.error.set(Some(IoDispatcherError::Service(err))); PollService::Continue } } @@ -448,9 +463,13 @@ where } fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.state.ready.get() { + return Poll::Ready(self.check_error()); + } + match self.service.poll_ready(cx) { Poll::Ready(Ok(_)) => { - let _ = self.service.poll_not_ready(cx); + self.state.ready.set(true); Poll::Ready(self.check_error()) } // pause io read task @@ -498,7 +517,7 @@ where log::error!("{}: Service readiness check failed, stopping", self.io.tag()); self.st = IoDispatcherState::Stop; self.flags.insert(Flags::READY_ERR); - self.state.borrow_mut().error = Some(IoDispatcherError::Service(err)); + self.state.error.set(Some(IoDispatcherError::Service(err))); Poll::Ready(PollService::Item(DispatchItem::Disconnect(None))) } } @@ -576,6 +595,30 @@ where } } +async fn not_ready( + slf: Rc>, + pl: PipelineBinding>, +) where + S: Service, Response = Option>> + 'static, + U: Encoder + Decoder + 'static, +{ + loop { + if !pl.is_shutdown() { + if let Err(err) = poll_fn(|cx| pl.poll_ready(cx)).await { + slf.error.set(Some(IoDispatcherError::Service(err))); + break; + } + if !pl.is_shutdown() { + poll_fn(|cx| pl.poll_not_ready(cx)).await; + slf.ready.set(false); + slf.waker.wake(); + continue; + } + } + break; + } +} + #[cfg(test)] mod tests { use std::{cell::Cell, io, sync::Arc, sync::Mutex}; @@ -616,11 +659,13 @@ mod tests { let keepalive_timeout = config.keepalive_timeout(); let rio = io.get_ref(); - let state = Rc::new(RefCell::new(DispatcherState { - error: None, - base: 0, - queue: VecDeque::new(), - })); + let state = Rc::new(DispatcherState { + error: Cell::new(None), + base: Cell::new(0), + ready: Cell::new(false), + waker: LocalWaker::default(), + queue: RefCell::new(VecDeque::new()), + }); ( Dispatcher {