Skip to content

Commit

Permalink
refactor: stabilize barrier
Browse files Browse the repository at this point in the history
  • Loading branch information
soehrl committed Sep 4, 2024
1 parent 9de34c0 commit d8ef856
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 398 deletions.
18 changes: 9 additions & 9 deletions examples/barrier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ fn client(name: &String) {
log::info!("[{name}] joined barrier group");

let sleep_distribution = rand::distributions::Uniform::new(
std::time::Duration::from_secs(1),
std::time::Duration::from_secs(2),
std::time::Duration::from_millis(1),
std::time::Duration::from_millis(2),
);
let mut rng = rand::thread_rng();

Expand All @@ -23,7 +23,7 @@ fn client(name: &String) {
barrier.wait().unwrap();
log::info!("[{name}] after");

std::thread::sleep(sleep_distribution.sample(&mut rng));
// std::thread::sleep(sleep_distribution.sample(&mut rng));
}
}

Expand All @@ -36,22 +36,22 @@ fn server() {

let mut barrier = p
.create_barrier_group(multicast::publisher::BarrierGroupDesc {
timeout: std::time::Duration::from_secs(1),
retries: 5,
retransmit_timeout: std::time::Duration::from_secs(1),
retransmit_count: 5,
})
.unwrap();
log::info!("[server] barrier group created");

let sleep_distribution = rand::distributions::Uniform::new(
std::time::Duration::from_secs(1),
std::time::Duration::from_secs(2),
std::time::Duration::from_millis(1),
std::time::Duration::from_millis(2),
);
let mut rng = rand::thread_rng();

loop {
std::thread::sleep(sleep_distribution.sample(&mut rng));
// std::thread::sleep(sleep_distribution.sample(&mut rng));

while let Some(client) = barrier.try_accept() {
while let Ok(client) = barrier.try_accept() {
log::info!("new client: {client}");
}

Expand Down
15 changes: 13 additions & 2 deletions src/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use std::{
use crossbeam::channel::{Receiver, Sender, TryRecvError};

use crate::protocol::{
kind, Ack, BarrierReached, BarrierReleased, ChunkKindData, ConfirmJoinChannel, Connect,
ConnectionInfo, JoinBarrierGroup, JoinChannel, Message,
kind, Ack, BarrierReached, BarrierReleased, ChannelDisconnected, ChunkKindData,
ConfirmJoinChannel, Connect, ConnectionInfo, JoinBarrierGroup, JoinChannel, LeaveChannel,
Message,
};

#[derive(Debug)]
Expand All @@ -21,6 +22,8 @@ pub enum Chunk<'a> {
JoinBarrierGroup(&'a JoinBarrierGroup),
BarrierReached(&'a BarrierReached),
BarrierReleased(&'a BarrierReleased),
LeaveChannel(&'a LeaveChannel),
ChannelDisconnected(&'a ChannelDisconnected),
}

impl Chunk<'_> {
Expand All @@ -33,6 +36,8 @@ impl Chunk<'_> {
Chunk::JoinBarrierGroup(join) => Some(join.0.into()),
Chunk::BarrierReached(b) => Some(b.0.channel_id.into()),
Chunk::BarrierReleased(b) => Some(b.0.channel_id.into()),
Chunk::LeaveChannel(c) => Some(c.0.into()),
Chunk::ChannelDisconnected(c) => Some(c.0.into()),
_ => None,
}
}
Expand Down Expand Up @@ -181,6 +186,12 @@ impl ChunkBuffer {
kind::BARRIER_RELEASED => Ok(Chunk::BarrierReleased(
self.get_kind_data_ref::<BarrierReleased>(packet_size)?,
)),
kind::LEAVE_CHANNEL => Ok(Chunk::LeaveChannel(
self.get_kind_data_ref::<LeaveChannel>(packet_size)?,
)),
kind::CHANNEL_DISCONNECTED => Ok(Chunk::ChannelDisconnected(
self.get_kind_data_ref::<ChannelDisconnected>(packet_size)?,
)),
kind => Err(ChunkValidationError::InvalidChunkKind(kind)),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ pub(crate) mod protocol;
pub mod publisher;
pub mod subscriber;

pub type OfferId = u16;
pub type ChannelId = u16;
pub(crate) type SequenceNumber = u16;
99 changes: 94 additions & 5 deletions src/multiplex_socket.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;
use std::{net::SocketAddr, sync::Arc};

use ahash::HashMap;
use crossbeam::channel::{Receiver, RecvError, RecvTimeoutError, Sender};
use crossbeam::channel::{Receiver, RecvError, RecvTimeoutError, Select, Sender};
use socket2::Socket;

use crate::{
Expand Down Expand Up @@ -168,7 +168,36 @@ impl MultiplexSocket {
}
}

pub fn wait_for_chunk<P: FnMut(Chunk) -> bool>(
pub fn wait_for_chunk2<T, P: FnMut(Chunk, SocketAddr) -> Option<T>>(
r: &[&ChunkReceiver],
timeout: std::time::Duration,
mut p: P,
) -> Result<T, RecvTimeoutError> {
let start = std::time::Instant::now();
let deadline = start + timeout;

let mut sel = Select::new();
for r in r {
sel.recv(r);
}

loop {
match sel.select_timeout(deadline - std::time::Instant::now()) {
Ok(recv) => {
let index = recv.index();
let chunk = recv.recv(r[index])?;
if let (Ok(chunk), Some(addr)) = (chunk.validate(), chunk.addr().as_socket()) {
if let Some(v) = p(chunk, addr) {
return Ok(v);
}
}
}
Err(_) => return Err(RecvTimeoutError::Timeout),
}
}
}

pub fn wait_for_chunk<P: FnMut(Chunk, SocketAddr) -> bool>(
r: &ChunkReceiver,
timeout: std::time::Duration,
mut p: P,
Expand All @@ -179,8 +208,8 @@ pub fn wait_for_chunk<P: FnMut(Chunk) -> bool>(
loop {
match r.recv_deadline(deadline) {
Ok(chunk) => {
if let Ok(chunk) = chunk.validate() {
if p(chunk) {
if let (Ok(chunk), Some(addr)) = (chunk.validate(), chunk.addr().as_socket()) {
if p(chunk, addr) {
return Ok(());
}
}
Expand All @@ -189,3 +218,63 @@ pub fn wait_for_chunk<P: FnMut(Chunk) -> bool>(
}
}
}

#[derive(thiserror::Error, Debug)]
pub enum TransmitAndWaitError {
#[error("receive error: {0}")]
RecvError(#[from] RecvTimeoutError),

#[error("send error: {0}")]
SendError(#[from] std::io::Error),
}

pub fn transmit_and_wait<C: ChunkKindData, T, P: FnMut(Chunk, SocketAddr) -> Option<T>>(
socket: &ChunkSocket,
kind_data: &C,
retransmit_timeout: std::time::Duration,
retransmit_count: usize,
r: &[&ChunkReceiver],
mut p: P,
) -> Result<T, TransmitAndWaitError> {
for _ in 0..retransmit_count + 1 {
socket.send_chunk(kind_data)?;

match wait_for_chunk2(r, retransmit_timeout, &mut p) {
Ok(v) => return Ok(v),
Err(RecvTimeoutError::Timeout) => continue,
Err(RecvTimeoutError::Disconnected) => {
return Err(TransmitAndWaitError::RecvError(
RecvTimeoutError::Disconnected,
))
}
}
}

Err(TransmitAndWaitError::RecvError(RecvTimeoutError::Timeout))
}

pub fn transmit_to_and_wait<T: ChunkKindData, P: FnMut(Chunk, SocketAddr) -> bool>(
socket: &ChunkSocket,
addr: &SocketAddr,
kind_data: &T,
retransmit_timeout: std::time::Duration,
retransmit_count: usize,
r: &ChunkReceiver,
mut p: P,
) -> Result<(), TransmitAndWaitError> {
for _ in 0..retransmit_count + 1 {
socket.send_chunk_to(kind_data, &(*addr).into())?;

match wait_for_chunk(r, retransmit_timeout, &mut p) {
Ok(_) => return Ok(()),
Err(RecvTimeoutError::Timeout) => continue,
Err(RecvTimeoutError::Disconnected) => {
return Err(TransmitAndWaitError::RecvError(
RecvTimeoutError::Disconnected,
))
}
}
}

Err(TransmitAndWaitError::RecvError(RecvTimeoutError::Timeout))
}
12 changes: 12 additions & 0 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub mod kind {
pub const JOIN_BARRIER_GROUP: ChunkKind = 8;
pub const BARRIER_REACHED: ChunkKind = 9;
pub const BARRIER_RELEASED: ChunkKind = 10;
pub const LEAVE_CHANNEL: ChunkKind = 11;
pub const CHANNEL_DISCONNECTED: ChunkKind = 12;
}

pub const MESSAGE_PAYLOAD_OFFSET: usize = 1 + std::mem::size_of::<Message>();
Expand Down Expand Up @@ -102,3 +104,13 @@ impl_chunk_data!(BarrierReached);
#[repr(C)]
pub struct BarrierReleased(pub ChannelHeader);
impl_chunk_data!(BarrierReleased);

#[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)]
#[repr(C)]
pub struct LeaveChannel(pub ChannelId);
impl_chunk_data!(LeaveChannel);

#[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)]
#[repr(C)]
pub struct ChannelDisconnected(pub ChannelId);
impl_chunk_data!(ChannelDisconnected);
Loading

0 comments on commit d8ef856

Please sign in to comment.