Skip to content

Commit

Permalink
fix(hydroflow): cleanup temp tcp networking code, fix race condition fix
Browse files Browse the repository at this point in the history
 #1458 (#1446)

consolidate into one task to prevent races
  • Loading branch information
MingweiSamuel committed Oct 28, 2024
1 parent 47cb703 commit b961233
Showing 1 changed file with 130 additions and 77 deletions.
207 changes: 130 additions & 77 deletions hydroflow/src/util/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#![cfg(not(target_arch = "wasm32"))]

use std::cell::RefCell;
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::HashMap;
use std::fmt::Debug;
use std::net::SocketAddr;
use std::pin::pin;
use std::rc::Rc;

use futures::{SinkExt, StreamExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::select;
use tokio::task::spawn_local;
use tokio_stream::StreamMap;
use tokio_util::codec::{
BytesCodec, Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec, LinesCodec,
};
Expand Down Expand Up @@ -74,107 +74,160 @@ pub type TcpFramedSink<T> = Sender<(T, SocketAddr)>;
pub type TcpFramedStream<Codec: Decoder> =
Receiver<Result<(<Codec as Decoder>::Item, SocketAddr), <Codec as Decoder>::Error>>;

// TODO(mingwei): this temporary code should be replaced with a properly thought out networking system.
/// Create a listening tcp socket, and then as new connections come in, receive their data and forward it to a queue.
pub async fn bind_tcp<T: 'static, Codec: 'static + Clone + Decoder + Encoder<T>>(
pub async fn bind_tcp<Item, Codec>(
endpoint: SocketAddr,
codec: Codec,
) -> Result<(TcpFramedSink<T>, TcpFramedStream<Codec>, SocketAddr), std::io::Error> {
) -> Result<(TcpFramedSink<Item>, TcpFramedStream<Codec>, SocketAddr), std::io::Error>
where
Item: 'static,
Codec: 'static + Clone + Decoder + Encoder<Item>,
<Codec as Encoder<Item>>::Error: Debug,
{
let listener = TcpListener::bind(endpoint).await?;

let bound_endpoint = listener.local_addr()?;

let (tx_egress, mut rx_egress) = unsync_channel(None);
let (tx_ingress, rx_ingress) = unsync_channel(None);

let clients = Rc::new(RefCell::new(HashMap::new()));

spawn_local({
let clients = clients.clone();

async move {
while let Some((payload, addr)) = rx_egress.next().await {
let client = clients.borrow_mut().remove(&addr);

if let Some(mut sender) = client {
let _ = SinkExt::send(&mut sender, payload).await;
clients.borrow_mut().insert(addr, sender);
}
}
}
});
let (send_egress, mut recv_egress) = unsync_channel::<(Item, SocketAddr)>(None);
let (send_ingres, recv_ingres) = unsync_channel(None);

spawn_local(async move {
let send_ingress = send_ingres;
// Map of `addr -> peers`, to send messages to.
let mut peers_send = HashMap::new();
// `StreamMap` of `addr -> peers`, to receive messages from. Automatically removes streams
// when they disconnect.
let mut peers_recv = StreamMap::<SocketAddr, FramedRead<OwnedReadHalf, Codec>>::new();

loop {
let (stream, peer_addr) = if let Ok((stream, _)) = listener.accept().await {
if let Ok(peer_addr) = stream.peer_addr() {
(stream, peer_addr)
} else {
continue;
// Calling methods in a loop, futures must be cancel-safe.
select! {
// `biased` means the cases will be prioritized in the order they are listed.
// First we accept any new connections
// This is not strictly neccessary, but lets us do our internal work (send outgoing
// messages) before accepting more work (receiving more messages, accepting new
// clients).
biased;
// Send outgoing messages.
msg_send = recv_egress.next() => {
let Some((payload, peer_addr)) = msg_send else {
// `None` if the send side has been dropped (no more send messages will ever come).
continue;
};
let Some(stream) = peers_send.get_mut(&peer_addr) else {
tracing::warn!("Dropping message to non-connected peer: {}", peer_addr);
continue;
};
if let Err(err) = SinkExt::send(stream, payload).await {
tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
peers_send.remove(&peer_addr); // `Drop` disconnects.
};
}
} else {
continue;
};

let mut tx_ingress = tx_ingress.clone();

let (send, recv) = tcp_framed(stream, codec.clone());

// TODO: Using peer_addr here as the key is a little bit sketchy.
// It's possible that a client could send a message, disconnect, then another client connects from the same IP address (and the same src port), and then the response could be sent to that new client.
// This can be solved by using monotonically increasing IDs for each new client, but would break the similarity with the UDP versions of this function.
clients.borrow_mut().insert(peer_addr, send);

spawn_local({
let clients = clients.clone();
async move {
let mapped = recv.map(|x| Ok(x.map(|x| (x, peer_addr))));
let _ = tx_ingress.send_all(&mut pin!(mapped)).await;

clients.borrow_mut().remove(&peer_addr);
// Receive incoming messages.
msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
// If `peers_recv` is empty then `next()` will immediately return `None` which
// would cause the loop to spin.
let Some((peer_addr, payload_result)) = msg_recv else {
continue; // => `peers_recv.is_empty()`.
};
if let Err(err) = send_ingress.send(payload_result.map(|payload| (payload, peer_addr))).await {
tracing::error!("Error passing along received message: {:?}", err);
}
}
});
// Accept new clients.
new_peer = listener.accept() => {
let Ok((stream, _addr)) = new_peer else {
continue;
};
let Ok(peer_addr) = stream.peer_addr() else {
continue;
};
let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());

// TODO: Using peer_addr here as the key is a little bit sketchy.
// It's possible that a peer could send a message, disconnect, then another peer connects from the
// same IP address (and the same src port), and then the response could be sent to that new client.
// This can be solved by using monotonically increasing IDs for each new peer, but would break the
// similarity with the UDP versions of this function.
peers_send.insert(peer_addr, peer_send);
peers_recv.insert(peer_addr, peer_recv);
}
}
}
});

Ok((tx_egress, rx_ingress, bound_endpoint))
Ok((send_egress, recv_ingres, bound_endpoint))
}

/// The inverse of [`bind_tcp`].
///
/// When messages enqueued into the returned sender, tcp sockets will be created and connected as
/// necessary to send out the requests. As the responses come back, they will be forwarded to the
/// returned receiver.
pub fn connect_tcp<T: 'static, Codec: 'static + Clone + Decoder + Encoder<T>>(
codec: Codec,
) -> (TcpFramedSink<T>, TcpFramedStream<Codec>) {
let (tx_egress, mut rx_egress) = unsync_channel(None);
let (tx_ingress, rx_ingress) = unsync_channel(None);
pub fn connect_tcp<Item, Codec>(codec: Codec) -> (TcpFramedSink<Item>, TcpFramedStream<Codec>)
where
Item: 'static,
Codec: 'static + Clone + Decoder + Encoder<Item>,
<Codec as Encoder<Item>>::Error: Debug,
{
let (send_egress, mut recv_egress) = unsync_channel(None);
let (send_ingres, recv_ingres) = unsync_channel(None);

spawn_local(async move {
let mut streams = HashMap::new();
let send_ingres = send_ingres;
// Map of `addr -> peers`, to send messages to.
let mut peers_send = HashMap::new();
// `StreamMap` of `addr -> peers`, to receive messages from. Automatically removes streams
// when they disconnect.
let mut peers_recv = StreamMap::new();

while let Some((payload, addr)) = rx_egress.next().await {
let stream = match streams.entry(addr) {
Occupied(entry) => entry.into_mut(),
Vacant(entry) => {
let socket = TcpSocket::new_v4().unwrap();
let stream = socket.connect(addr).await.unwrap();

let (send, recv) = tcp_framed(stream, codec.clone());

let mut tx_ingress = tx_ingress.clone();
spawn_local(async move {
let mapped = recv.map(|x| Ok(x.map(|x| (x, addr))));
let _ = tx_ingress.send_all(&mut pin!(mapped)).await;
});

entry.insert(send)
loop {
// Calling methods in a loop, futures must be cancel-safe.
select! {
// `biased` means the cases will be prioritized in the order they are listed.
// This is not strictly neccessary, but lets us do our internal work (send outgoing
// messages) before accepting more work (receiving more messages).
biased;
// Send outgoing messages.
msg_send = recv_egress.next() => {
let Some((payload, peer_addr)) = msg_send else {
// `None` if the send side has been dropped (no more send messages will ever come).
continue;
};

let stream = match peers_send.entry(peer_addr) {
Occupied(entry) => entry.into_mut(),
Vacant(entry) => {
let socket = TcpSocket::new_v4().unwrap();
let stream = socket.connect(peer_addr).await.unwrap();

let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());

peers_recv.insert(peer_addr, peer_recv);
entry.insert(peer_send)
}
};

if let Err(err) = stream.send(payload).await {
tracing::error!("IO or codec error sending message to peer {}, disconnecting: {:?}", peer_addr, err);
peers_send.remove(&peer_addr); // `Drop` disconnects.
}
}
};

let _ = stream.send(payload).await;
// Receive incoming messages.
msg_recv = peers_recv.next(), if !peers_recv.is_empty() => {
// If `peers_recv` is empty then `next()` will immediately return `None` which
// would cause the loop to spin.
let Some((peer_addr, payload_result)) = msg_recv else {
continue; // => `peers_recv.is_empty()`.
};
if let Err(err) = send_ingres.send(payload_result.map(|payload| (payload, peer_addr))).await {
tracing::error!("Error passing along received message: {:?}", err);
}
}
}
}
});

(tx_egress, rx_ingress)
(send_egress, recv_ingres)
}

0 comments on commit b961233

Please sign in to comment.