Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various cleanups #155

Merged
merged 2 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Changes

## [0.12.8] - 2023-11-11
## [0.12.8] - 2023-11-12

* Use new ntex-io apis

Expand Down
6 changes: 3 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub enum HandshakeError<E> {
}

/// Protocol level errors
#[derive(Debug, thiserror::Error)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ProtocolError {
/// MQTT decoding error
#[error("Decoding error: {0:?}")]
Expand All @@ -52,13 +52,13 @@ pub enum ProtocolError {
ReadTimeout,
}

#[derive(Debug, thiserror::Error)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
#[error(transparent)]
pub struct ProtocolViolationError {
inner: ViolationInner,
}

#[derive(Debug, thiserror::Error)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
enum ViolationInner {
#[error("{message}")]
Common { reason: DisconnectReasonCode, message: &'static str },
Expand Down
114 changes: 51 additions & 63 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ enum IoDispatcherState {
}

pub(crate) enum IoDispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
}
Expand Down Expand Up @@ -257,106 +256,98 @@ where
// decode incoming bytes stream
match inner.io.poll_recv_decode(this.codec, cx) {
Ok(decoded) => {
// update keep-alive timer
inner.update_timer(&decoded);
if let Some(el) = decoded.item {
Some(DispatchItem::Item(el))
DispatchItem::Item(el)
} else {
return Poll::Pending;
}
}
Err(RecvError::Stop) => {
log::trace!("dispatcher is instructed to stop");
inner.st = IoDispatcherState::Stop;
None
continue;
}
Err(RecvError::KeepAlive) => {
// check keepalive timeout
log::trace!("keepalive timeout");
log::trace!("keep-alive error, stopping dispatcher");
inner.st = IoDispatcherState::Stop;
let mut state = inner.state.borrow_mut();
if state.error.is_none() {
state.error = Some(IoDispatcherError::KeepAlive);
if inner.flags.contains(Flags::READ_TIMEOUT) {
DispatchItem::ReadTimeout
} else {
DispatchItem::KeepAliveTimeout
}
Some(DispatchItem::KeepAliveTimeout)
}
Err(RecvError::WriteBackpressure) => {
if let Err(err) = ready!(inner.io.poll_flush(cx, false)) {
inner.st = IoDispatcherState::Stop;
Some(DispatchItem::Disconnect(Some(err)))
DispatchItem::Disconnect(Some(err))
} else {
continue;
}
}
Err(RecvError::Decoder(err)) => {
inner.st = IoDispatcherState::Stop;
Some(DispatchItem::DecoderError(err))
DispatchItem::DecoderError(err)
}
Err(RecvError::PeerGone(err)) => {
inner.st = IoDispatcherState::Stop;
Some(DispatchItem::Disconnect(err))
DispatchItem::Disconnect(err)
}
}
}
PollService::Item(item) => Some(item),
PollService::Item(item) => item,
PollService::Continue => continue,
};

// call service
if let Some(item) = item {
// optimize first call
if this.response.is_none() {
this.response.set(Some(this.service.call_static(item)));
let res = this.response.as_mut().as_pin_mut().unwrap().poll(cx);

let mut state = inner.state.borrow_mut();

if let Poll::Ready(res) = res {
// check if current result is only response
if state.queue.is_empty() {
match res {
Err(err) => {
state.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) = inner.io.encode(item, this.codec)
{
state.error =
Some(IoDispatcherError::Encoder(err));
}
// optimize first call
if this.response.is_none() {
this.response.set(Some(this.service.call_static(item)));

let res = this.response.as_mut().as_pin_mut().unwrap().poll(cx);
let mut state = inner.state.borrow_mut();

if let Poll::Ready(res) = res {
// check if current result is only response
if state.queue.is_empty() {
match res {
Err(err) => {
state.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) = inner.io.encode(item, this.codec) {
state.error = Some(IoDispatcherError::Encoder(err));
}
Ok(None) => (),
}
} else {
*this.response_idx =
state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Ready(res));
Ok(None) => (),
}
this.response.set(None);
} else {
*this.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);
state.queue.push_back(ServiceResult::Ready(res));
}
this.response.set(None);
} else {
let mut state = inner.state.borrow_mut();
let response_idx = state.base.wrapping_add(state.queue.len());
*this.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);

let st = inner.io.get_ref();
let codec = this.codec.clone();
let state = inner.state.clone();
let fut = this.service.call_static(item);
ntex::rt::spawn(async move {
let item = fut.await;
state.borrow_mut().handle_result(
item,
response_idx,
&st,
&codec,
true,
);
});
}
} else {
let mut state = inner.state.borrow_mut();
let response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);

let st = inner.io.get_ref();
let codec = this.codec.clone();
let state = inner.state.clone();
let fut = this.service.call_static(item);
ntex::rt::spawn(async move {
let item = fut.await;
state.borrow_mut().handle_result(
item,
response_idx,
&st,
&codec,
true,
);
});
}
}
// drain service responses and shutdown io
Expand Down Expand Up @@ -443,9 +434,6 @@ where
state.error = Some(IoDispatcherError::Service(err));
PollService::Continue
}
IoDispatcherError::KeepAlive => {
PollService::Item(DispatchItem::KeepAliveTimeout)
}
}
} else {
PollService::Ready
Expand Down
16 changes: 7 additions & 9 deletions src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cell::RefCell;
use std::task::{Context, Poll};
use std::task::{ready, Context, Poll};
use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc};

use ntex::io::DispatchItem;
Expand Down Expand Up @@ -386,8 +386,8 @@ where
let mut this = self.as_mut().project();

match this.state.as_mut().project() {
PublishResponseStateProject::Publish { fut } => match fut.poll(cx) {
Poll::Ready(Ok(_)) => {
PublishResponseStateProject::Publish { fut } => match ready!(fut.poll(cx)) {
Ok(_) => {
log::trace!("Publish result for packet {:?} is ready", this.packet_id);

if let Some(packet_id) = this.packet_id {
Expand All @@ -399,7 +399,7 @@ where
Poll::Ready(Ok(None))
}
}
Poll::Ready(Err(e)) => {
Err(e) => {
this.state.set(PublishResponseState::Control {
fut: ControlResponse::new(
ControlMessage::error(e.into()),
Expand All @@ -409,7 +409,6 @@ where
});
self.poll(cx)
}
Poll::Pending => Poll::Pending,
},
PublishResponseStateProject::Control { fut } => fut.poll(cx),
}
Expand Down Expand Up @@ -453,8 +452,8 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project();

match this.fut.poll(cx) {
Poll::Ready(Ok(item)) => {
match ready!(this.fut.poll(cx)) {
Ok(item) => {
let packet = match item.result {
ControlResultKind::Ping => Some(codec::Packet::PingResponse),
ControlResultKind::Subscribe(res) => {
Expand All @@ -478,7 +477,7 @@ where
};
Poll::Ready(Ok(packet))
}
Poll::Ready(Err(err)) => {
Err(err) => {
// do not handle nested error
if *this.error {
Poll::Ready(Err(err))
Expand All @@ -496,7 +495,6 @@ where
}
}
}
Poll::Pending => Poll::Pending,
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/v3/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ where
self
}

/// Set read rate parameters for single frame.
///
/// Set max timeout for reading single frame. If the client
/// sends `rate` amount of data, increase the timeout by 1 second for every.
/// But no more than `max_timeout` timeout.
///
/// By default frame read rate is disabled.
pub fn frame_read_rate(self, timeout: Seconds, max_timeout: Seconds, rate: u16) -> Self {
self.config.set_frame_read_rate(timeout, max_timeout, rate);
self
}

/// Set max allowed QoS.
///
/// If peer sends publish with higher qos then ProtocolError::MaxQoSViolated(..)
Expand Down
12 changes: 12 additions & 0 deletions src/v5/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ where
self
}

/// Set read rate parameters for single frame.
///
/// Set max timeout for reading single frame. If the client
/// sends `rate` amount of data, increase the timeout by 1 second for every.
/// But no more than `max_timeout` timeout.
///
/// By default frame read rate is disabled.
pub fn frame_read_rate(self, timeout: Seconds, max_timeout: Seconds, rate: u16) -> Self {
self.config.set_frame_read_rate(timeout, max_timeout, rate);
self
}

/// Set max inbound frame size.
///
/// If max size is set to `0`, size is unlimited.
Expand Down
72 changes: 69 additions & 3 deletions tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{cell::RefCell, future::Future, num::NonZeroU16, pin::Pin, rc::Rc, time

use ntex::service::{fn_service, Pipeline, ServiceFactory};
use ntex::time::{sleep, Millis, Seconds};
use ntex::util::{join_all, lazy, ByteString, Bytes, Ready};
use ntex::{server, service::chain_factory};
use ntex::util::{join_all, lazy, ByteString, Bytes, BytesMut, Ready};
use ntex::{codec::Encoder, server, service::chain_factory};

use ntex_mqtt::v3::{
client, codec, ControlMessage, Handshake, HandshakeAck, MqttServer, Publish, Session,
Expand Down Expand Up @@ -447,7 +447,6 @@ fn ssl_acceptor() -> openssl::ssl::SslAcceptor {
#[ntex::test]
async fn test_large_publish_openssl() -> std::io::Result<()> {
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
env_logger::init();

let srv = server::test_server(move || {
chain_factory(server::openssl::Acceptor::new(ssl_acceptor()).map_err(|_| ())).and_then(
Expand Down Expand Up @@ -613,3 +612,70 @@ async fn test_sink_publish_noblock() -> std::io::Result<()> {
sink.close();
Ok(())
}

// Slow frame rate
#[ntex::test]
async fn test_frame_read_rate() -> std::io::Result<()> {
let _ = env_logger::try_init();
let check = Arc::new(AtomicBool::new(false));
let check2 = check.clone();

let srv = server::test_server(move || {
let check = check2.clone();

MqttServer::new(handshake)
.frame_read_rate(Seconds(1), Seconds(2), 10)
.publish(|_| Ready::Ok(()))
.control(move |msg| {
let check = check.clone();
match msg {
ControlMessage::ProtocolError(msg) => {
if msg.get_ref() == &ProtocolError::ReadTimeout {
check.store(true, Relaxed);
}
Ready::Ok(msg.ack())
}
_ => Ready::Ok(msg.disconnect()),
}
})
.finish()
.map_err(|_| ())
.map_init_err(|_| ())
});

let io = srv.connect().await.unwrap();
let codec = codec::Codec::default();
io.encode(codec::Connect::default().client_id("user").into(), &codec).unwrap();
io.recv(&codec).await.unwrap();

let p = codec::Publish {
dup: false,
retain: false,
qos: codec::QoS::AtLeastOnce,
topic: ByteString::from("test"),
packet_id: Some(NonZeroU16::new(3).unwrap()),
payload: Bytes::from(vec![b'*'; 270 * 1024]),
}
.into();

let mut buf = BytesMut::new();
codec.encode(p, &mut buf).unwrap();

io.write(&buf[..5]).unwrap();
buf.split_to(5);
sleep(Millis(100)).await;
io.write(&buf[..10]).unwrap();
buf.split_to(10);
sleep(Millis(500)).await;
assert!(!check.load(Relaxed));

io.write(&buf[..12]).unwrap();
buf.split_to(12);
sleep(Millis(500)).await;
assert!(!check.load(Relaxed));

sleep(Millis(1200)).await;
assert!(check.load(Relaxed));

Ok(())
}
Loading