diff --git a/src/v5/client.rs b/src/v5/client.rs index d41ee6d8..5e7bf748 100644 --- a/src/v5/client.rs +++ b/src/v5/client.rs @@ -1,6 +1,6 @@ use std::task::{Context, Poll}; use std::time::Duration; -use std::{fmt, io, marker::PhantomData, pin::Pin, rc::Rc}; +use std::{fmt, io, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc}; use bytes::Bytes; use bytestring::ByteString; @@ -11,7 +11,7 @@ use ntex::channel::mpsc; use ntex::service::{boxed, IntoService, IntoServiceFactory, Service, ServiceFactory}; use ntex_codec::{AsyncRead, AsyncWrite}; -use crate::error::{DecodeError, EncodeError, MqttError}; +use crate::error::{DecodeError, EncodeError, MqttError, ProtocolError}; use crate::handshake::{Handshake, HandshakeResult}; use crate::service::Builder; @@ -26,7 +26,7 @@ use super::{codec, Session}; #[derive(Clone)] pub struct Client { keep_alive: u16, - inflight: usize, + max_receive: u16, connect: codec::Connect, _t: PhantomData<(Io, St)>, } @@ -39,7 +39,7 @@ where pub fn new(client_id: ByteString) -> Self { Client { keep_alive: 30, - inflight: 15, + max_receive: 16, connect: codec::Connect { client_id, ..Default::default() }, _t: PhantomData, } @@ -79,11 +79,12 @@ where self } - /// Number of in-flight concurrent messages. + /// Set `receive max` /// - /// in-flight is set to 15 messages - pub fn inflight(mut self, val: usize) -> Self { - self.inflight = val; + /// Number of in-flight publish packets. By default receive max is set to 15 packets. + /// To disable timeout set value to 0. + pub fn receive_max(mut self, val: u16) -> Self { + self.max_receive = val; self } @@ -106,7 +107,7 @@ where .map_err(|_| unreachable!()), ), keep_alive: self.keep_alive.into(), - inflight: self.inflight, + max_receive: self.max_receive, _t: PhantomData, } } @@ -116,7 +117,7 @@ pub struct ServiceBuilder { state: Rc, packet: codec::Connect, keep_alive: u64, - inflight: usize, + max_receive: u16, control: boxed::BoxServiceFactory< Session, ControlPacket, @@ -153,13 +154,16 @@ where connect: self.state, packet: self.packet, keep_alive: self.keep_alive, - inflight: self.inflight, + max_receive: self.max_receive, _t: PhantomData, }) .build(factory( - service.into_factory().map_err(MqttError::Service).map_init_err(MqttError::Service), + service.into_factory() + .map_err(::from) + .map_init_err(MqttError::Service), self.control, 0, + 16, )) } } @@ -168,7 +172,7 @@ struct ConnectService { connect: Rc, packet: codec::Connect, keep_alive: u64, - inflight: usize, + max_receive: u16, _t: PhantomData<(Io, St)>, } @@ -199,7 +203,10 @@ where let srv = self.connect.clone(); let packet = self.packet.clone(); let keep_alive = Duration::from_secs(self.keep_alive as u64); - let inflight = self.inflight; + let max_receive = self.max_receive; + if max_receive > 0 { + packet.receive_max = Some(NonZeroU16::new(max_receive).unwrap()) + } // send Connect packet async move { @@ -219,8 +226,11 @@ where match packet { codec::Packet::ConnectAck(packet) => { let (tx, rx) = mpsc::channel(); - let sink = MqttSink::new(tx); - let ack = ConnectAck { sink, inflight, packet, io: framed }; + let sink = MqttSink::new( + tx, + packet.receive_max.map(|v| v.get()).unwrap_or(16) as usize, + ); + let ack = ConnectAck { sink, packet, io: framed }; Ok(srv .as_ref() .call(ack) @@ -228,7 +238,10 @@ where .map_err(MqttError::Service) .map(move |ack| ack.io.out(rx).state(ack.state))?) } - p => Err(MqttError::Unexpected(p.packet_type(), "Expected CONNECT-ACK packet")), + p => Err(MqttError::Protocol(ProtocolError::Unexpected( + p.packet_type(), + "Expected CONNECT-ACK packet", + ))), } } .boxed_local() @@ -238,7 +251,6 @@ where pub struct ConnectAck { io: HandshakeResult>, sink: MqttSink, - inflight: usize, packet: codec::ConnectAck, } @@ -264,7 +276,7 @@ impl ConnectAck { #[inline] /// Set connection state and create result object pub fn state(self, state: St) -> ConnectAckResult { - ConnectAckResult { io: self.io, state: Session::new(state, self.sink, self.inflight) } + ConnectAckResult { io: self.io, state: Session::new(state, self.sink) } } } diff --git a/src/v5/mod.rs b/src/v5/mod.rs index c862f0d3..1a96946f 100644 --- a/src/v5/mod.rs +++ b/src/v5/mod.rs @@ -1,6 +1,6 @@ //! MQTT 3.1.1 Client/Server framework -//pub mod client; +// pub mod client; pub mod codec; mod connect; pub mod control; @@ -13,7 +13,7 @@ mod sink; pub type Session = crate::Session; -//pub use self::client::Client; +// pub use self::client::Client; pub use self::connect::{Connect, ConnectAck}; pub use self::control::{ControlPacket, ControlResult}; pub use self::publish::{Publish, PublishAck}; diff --git a/src/v5/sink.rs b/src/v5/sink.rs index 621711c1..be4a698a 100644 --- a/src/v5/sink.rs +++ b/src/v5/sink.rs @@ -1,5 +1,4 @@ -use std::task::{Context, Poll}; -use std::{cell::RefCell, collections::VecDeque, fmt, num::NonZeroU16, pin::Pin, rc::Rc}; +use std::{cell::RefCell, collections::VecDeque, fmt, num::NonZeroU16, rc::Rc}; use bytes::Bytes; use bytestring::ByteString;