From 508e0b835ac9e36c1a6e015acdd09925644168e5 Mon Sep 17 00:00:00 2001 From: Simon Oehrl Date: Wed, 11 Sep 2024 13:51:53 +0200 Subject: [PATCH] refactor: API & internals --- Cargo.toml | 2 +- examples/barrier.rs | 85 --- examples/barrier_group_coordinator.rs | 33 ++ examples/barrier_group_member.rs | 30 + examples/broadcast_group_receiver.rs | 33 ++ examples/broadcast_group_sender.rs | 64 +++ examples/broadcast_group_sender2.rs | 59 ++ examples/publish.rs | 30 - examples/subscribe.rs | 22 - rustfmt.toml | 1 + src/barrier.rs | 557 +++++++++++++++++++ src/broadcast.rs | 595 ++++++++++++++++++++ src/chunk.rs | 195 ++++--- src/chunk_socket.rs | 110 ++-- src/group.rs | 766 ++++++++++++++++++++++++++ src/lib.rs | 51 +- src/multiplex_socket.rs | 521 ++++++++++++++---- src/protocol.rs | 252 +++++++-- src/publisher.rs | 718 ------------------------ src/session.rs | 562 +++++++++++++++++++ src/subscriber.rs | 395 ------------- src/test.rs | 18 + src/utils.rs | 57 ++ 23 files changed, 3631 insertions(+), 1525 deletions(-) delete mode 100644 examples/barrier.rs create mode 100644 examples/barrier_group_coordinator.rs create mode 100644 examples/barrier_group_member.rs create mode 100644 examples/broadcast_group_receiver.rs create mode 100644 examples/broadcast_group_sender.rs create mode 100644 examples/broadcast_group_sender2.rs delete mode 100644 examples/publish.rs delete mode 100644 examples/subscribe.rs create mode 100644 rustfmt.toml create mode 100644 src/barrier.rs create mode 100644 src/broadcast.rs create mode 100644 src/group.rs delete mode 100644 src/publisher.rs create mode 100644 src/session.rs delete mode 100644 src/subscriber.rs create mode 100644 src/test.rs create mode 100644 src/utils.rs diff --git a/Cargo.toml b/Cargo.toml index 0fddb21..83ec614 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" ahash = "0.8.11" crossbeam = { version = "0.8.4", features = ["crossbeam-channel"] } dashmap = "6.0.1" -env_logger = "0.11.5" log = "0.4.22" paste = "1.0.15" socket2 = "0.5.7" @@ -15,4 +14,5 @@ thiserror = "1.0.63" zerocopy = { version = "0.7.35", features = ["derive"] } [dev-dependencies] +env_logger = "0.11.5" rand = "0.8.5" diff --git a/examples/barrier.rs b/examples/barrier.rs deleted file mode 100644 index 211f588..0000000 --- a/examples/barrier.rs +++ /dev/null @@ -1,85 +0,0 @@ -use multicast::{ - publisher::{Publisher, PublisherConfig}, - subscriber::Subscriber, -}; -use rand::distributions::Distribution; - -fn client(name: &String) { - let mut s = Subscriber::connect("127.0.0.1:12345".parse().unwrap()).unwrap(); - - let mut barrier = s - .join_barrier_group(0) - .unwrap(); - log::info!("[{name}] joined barrier group"); - - let sleep_distribution = rand::distributions::Uniform::new( - std::time::Duration::from_millis(1), - std::time::Duration::from_millis(2), - ); - let mut rng = rand::thread_rng(); - - loop { - log::info!("[{name}] before"); - barrier.wait().unwrap(); - log::info!("[{name}] after"); - - // std::thread::sleep(sleep_distribution.sample(&mut rng)); - } -} - -fn server() { - let mut p = Publisher::new(PublisherConfig { - addr: "0.0.0.0:12345".parse().unwrap(), - multicast_addr: "224.0.0.0:5555".parse().unwrap(), - chunk_size: 1024, - }); - - let mut barrier = p - .create_barrier_group(multicast::publisher::BarrierGroupDesc { - retransmit_timeout: std::time::Duration::from_secs(1), - retransmit_count: 5, - }) - .unwrap(); - log::info!("[server] barrier group created"); - - let sleep_distribution = rand::distributions::Uniform::new( - std::time::Duration::from_millis(1), - std::time::Duration::from_millis(2), - ); - let mut rng = rand::thread_rng(); - - loop { - // std::thread::sleep(sleep_distribution.sample(&mut rng)); - - while let Ok(client) = barrier.try_accept() { - log::info!("new client: {client}"); - } - - log::info!("[server] before"); - barrier.wait(); - log::info!("[server] after"); - } -} - -fn main() { - env_logger::builder() - .filter_level(log::LevelFilter::Debug) - .init(); - - let args = std::env::args().collect::>(); - - if args.len() > 1 { - client(&args[1]); - } else { - let executable = std::env::current_exe().unwrap_or(format!("./{}", args[0]).into()); - - for client in ["alpha", "bravo", "charlie", "delta", "echo", "foxtrot"] { - std::process::Command::new(&executable) - .arg(client) - .spawn() - .unwrap(); - } - server(); - } -} - diff --git a/examples/barrier_group_coordinator.rs b/examples/barrier_group_coordinator.rs new file mode 100644 index 0000000..f1906eb --- /dev/null +++ b/examples/barrier_group_coordinator.rs @@ -0,0 +1,33 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Instant, +}; + +use multicast::session::{Coordinator, GroupId}; + +fn main() { + env_logger::init(); + + let mut args = std::env::args(); + let _ = args.next().unwrap(); + let bind_addr: SocketAddr = args.next().unwrap().parse().unwrap(); + let multicast_addr: SocketAddr = args.next().unwrap().parse().unwrap(); + let group_id: Option = args.next().map(|s| s.parse().unwrap()); + + let coordinator = Coordinator::start_session(bind_addr, multicast_addr).unwrap(); + + let vrm = coordinator.create_barrier_group(Some(0)).unwrap(); + let mut barrier_group_coordinator = coordinator.create_barrier_group(group_id).unwrap(); + barrier_group_coordinator.accept().unwrap(); + + for _ in 0..1000 { + barrier_group_coordinator.wait().unwrap(); + } + + let before = Instant::now(); + for _ in 0..1000 { + barrier_group_coordinator.wait().unwrap(); + } + let after = Instant::now(); + println!("1000 barriers took {:?}", after - before); +} diff --git a/examples/barrier_group_member.rs b/examples/barrier_group_member.rs new file mode 100644 index 0000000..cdb43f9 --- /dev/null +++ b/examples/barrier_group_member.rs @@ -0,0 +1,30 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Instant, +}; + +use multicast::session::{GroupId, Member}; + +fn main() { + env_logger::init(); + + let mut args = std::env::args(); + let _ = args.next().unwrap(); + let connect_addr: SocketAddr = args.next().unwrap().parse().unwrap(); + let group_id: GroupId = args.next().map(|s| s.parse().unwrap()).unwrap_or(0); + + let member = Member::join_session(connect_addr).unwrap(); + + let mut barrier_group_member = member.join_barrier_group(group_id).unwrap(); + + for _ in 0..1000 { + barrier_group_member.wait().unwrap(); + } + + let before = Instant::now(); + for _ in 0..1000 { + barrier_group_member.wait().unwrap(); + } + let after = Instant::now(); + println!("1000 barriers took {:?}", after - before); +} diff --git a/examples/broadcast_group_receiver.rs b/examples/broadcast_group_receiver.rs new file mode 100644 index 0000000..6e29d07 --- /dev/null +++ b/examples/broadcast_group_receiver.rs @@ -0,0 +1,33 @@ +use std::{ + io::Read, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Instant, +}; + +use multicast::session::{GroupId, Member}; + +fn main() { + env_logger::init(); + + let mut args = std::env::args(); + let _ = args.next().unwrap(); + let connect_addr: SocketAddr = args.next().unwrap().parse().unwrap(); + let group_id: GroupId = args.next().map(|s| s.parse().unwrap()).unwrap_or(0); + + let member = Member::join_session(connect_addr).unwrap(); + + let mut receiver = member.join_broadcast_group(group_id).unwrap(); + + for _ in 0..1000 { + // let mut buf = String::new(); + receiver.recv().unwrap().read(); + } + + let before = Instant::now(); + for _ in 0..1000 { + // let mut buf = String::new(); + receiver.recv().unwrap().read(); + } + let after = Instant::now(); + println!("received in {:?}", after - before); +} diff --git a/examples/broadcast_group_sender.rs b/examples/broadcast_group_sender.rs new file mode 100644 index 0000000..d3f2a8f --- /dev/null +++ b/examples/broadcast_group_sender.rs @@ -0,0 +1,64 @@ +use std::{ + io::Write, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Instant, +}; + +use multicast::session::{Coordinator, GroupId}; + +const LOREM_IPSUM: &[u8; 5992] = br" +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam posuere hendrerit sem, id malesuada nisl pulvinar et. Maecenas venenatis nisl at nibh faucibus, vitae tempor magna auctor. Vestibulum interdum mi diam, vel molestie justo condimentum eu. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Nam odio sem, gravida sed dignissim a, tempor dignissim magna. Duis a lacus a magna gravida fermentum sit amet vitae justo. Pellentesque gravida lacus eget ante ultrices dapibus. Curabitur et iaculis felis. Etiam mollis diam a est fermentum, id vestibulum arcu rutrum. Aenean lobortis fermentum dolor, a consectetur dui ultrices vitae. Suspendisse id ultrices diam. Praesent eget varius lorem. Sed dignissim, libero vel rhoncus cursus, nulla massa bibendum nulla, nec fermentum nisi risus bibendum magna. Integer quis mauris in odio vulputate dictum ut id lacus. + +Nam posuere quam metus, ac blandit odio feugiat eget. Etiam congue id risus eu sagittis. Etiam iaculis imperdiet odio, varius vestibulum dolor ullamcorper id. Praesent at ante vitae metus pretium bibendum quis eget orci. Duis quis velit luctus, pulvinar libero a, finibus augue. Integer elementum rhoncus urna consectetur finibus. Praesent vel urna nec mi feugiat rutrum. Vivamus at nunc metus. Morbi euismod mi condimentum ex venenatis luctus a ac velit. Donec rhoncus tortor nec augue feugiat feugiat. Donec et ultrices nibh, a molestie turpis. + +Proin a feugiat nisl, eu suscipit eros. Sed id elementum dolor. Sed sagittis enim ipsum, non elementum massa lobortis in. Curabitur et viverra risus. Maecenas ultrices tristique gravida. Quisque tempus, ex id fermentum fermentum, lacus lacus varius purus, a rhoncus augue diam eget metus. Nunc non convallis neque. Fusce ac semper metus. Sed blandit diam quis est porttitor, mattis pulvinar nulla malesuada. Nullam posuere nunc tincidunt imperdiet efficitur. Nunc laoreet maximus purus in luctus. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Sed a venenatis nisi. + +In dignissim vitae purus in mattis. In congue augue at elit molestie, in consectetur tellus maximus. In sodales et massa ac ultrices. Phasellus congue imperdiet arcu, vel cursus enim. Duis at justo tellus. Praesent non nisl sem. Duis sodales velit in felis faucibus, vel suscipit eros interdum. Aenean semper ante nec sapien condimentum luctus. Vivamus ac elit at quam maximus ultrices. Ut varius nisl sed ex posuere, nec tristique eros varius. + +Curabitur non dapibus est, sit amet dignissim velit. Mauris augue nisi, facilisis non tincidunt in, lacinia ut dui. Fusce fermentum ultrices orci, et commodo sapien congue quis. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Pellentesque sit amet ante velit. Etiam arcu nulla, varius id leo sit amet, bibendum mollis eros. Quisque blandit mattis tellus, a egestas nulla suscipit in. Nulla fringilla interdum suscipit. Duis ut nisi id neque efficitur lacinia. Sed iaculis eu ligula vitae dictum. + +Aliquam sit amet pellentesque lacus, sed mattis neque. Nulla elementum eros eget lorem efficitur scelerisque. In a nunc lectus. Vestibulum dui ante, fringilla quis ullamcorper quis, scelerisque non purus. Sed sit amet bibendum arcu, eu feugiat arcu. Nulla eu nisi a nisl lobortis porttitor nec sit amet risus. Nam et euismod nunc. Nullam vitae est eleifend, facilisis eros consequat, bibendum enim. Curabitur orci turpis, convallis ut purus ac, fringilla tempor massa. Pellentesque eu posuere felis, a auctor tellus. + +Praesent vel tempor enim, a semper eros. Donec tortor quam, pulvinar quis arcu ac, mollis blandit ipsum. Sed volutpat ante in mattis lobortis. Integer eget massa euismod, vulputate sapien ac, egestas est. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Integer congue nec enim at blandit. Nam ornare erat quam, in ultricies justo euismod non. Cras rhoncus sapien eget ligula tristique elementum. Nullam ullamcorper, odio non eleifend maximus, magna turpis convallis augue, vel pellentesque orci sem eu sapien. + +Vestibulum semper enim quis tincidunt fermentum. Nunc orci lectus, rutrum in urna vitae, commodo egestas urna. Aliquam erat volutpat. Mauris eu ante sed ex ultricies efficitur a ut nibh. Suspendisse potenti. Aenean nec odio vel dui vehicula suscipit id aliquam lectus. Praesent a metus maximus, pharetra felis id, varius nibh. Aliquam aliquet orci arcu, non interdum sem auctor sit amet. Praesent convallis nibh ut leo posuere, sed laoreet turpis convallis. + +Aliquam lobortis nunc sed orci posuere, ut tincidunt justo porttitor. Nam turpis turpis, varius in mi in, volutpat bibendum sem. Praesent lobortis nulla ut sollicitudin interdum. Praesent et arcu rutrum libero rutrum auctor vulputate sit amet turpis. Vestibulum ut lacinia massa. Maecenas at tempor massa. Aenean convallis nunc quis maximus malesuada. Praesent ac condimentum dui. Nullam quis iaculis turpis, ac pellentesque dui. Donec semper ligula in augue posuere convallis. Vestibulum vel velit scelerisque, facilisis justo at, tempor enim. Fusce eu ornare augue. + +Etiam non augue dapibus, pellentesque felis eu, efficitur leo. Praesent vehicula risus sed nunc aliquam lobortis. Sed eget efficitur augue. Donec mollis laoreet vehicula. Duis in finibus lorem. Duis sit amet risus sem. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Nulla consequat, enim vel dictum blandit, turpis turpis facilisis ligula, a porta nisi eros id quam. Donec eu finibus erat. Nullam tincidunt pellentesque orci, at placerat tortor laoreet vitae. Nulla facilisis suscipit est blandit venenatis. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Vivamus id ipsum condimentum, ornare turpis id, convallis dolor. Quisque porta a diam quis sollicitudin. +"; + +fn main() { + env_logger::init(); + + let mut args = std::env::args(); + let _ = args.next().unwrap(); + let bind_addr: SocketAddr = args.next().unwrap().parse().unwrap(); + let multicast_addr: SocketAddr = args.next().unwrap().parse().unwrap(); + let group_id: Option = args.next().map(|s| s.parse().unwrap()); + + let coordinator = Coordinator::start_session(bind_addr, multicast_addr).unwrap(); + + let mut sender = coordinator.create_broadcast_group(group_id).unwrap(); + sender.accept().unwrap(); + + for _ in 0..1000 { + sender.write_message().write_all(LOREM_IPSUM).unwrap(); + } + sender.wait().unwrap(); + + let before = Instant::now(); + for _ in 0..1000 { + sender.write_message().write_all(LOREM_IPSUM).unwrap(); + } + sender.wait().unwrap(); + let after = Instant::now(); + println!( + "broadcasted {} bytes in {:?} ({:.3} MB/s)", + LOREM_IPSUM.len() * 1000, + after - before, + LOREM_IPSUM.len() as f64 * 1000.0 / (after - before).as_secs_f64() / 1000.0 / 1000.0 + ); + + std::thread::sleep(std::time::Duration::from_secs(10)); +} diff --git a/examples/broadcast_group_sender2.rs b/examples/broadcast_group_sender2.rs new file mode 100644 index 0000000..c030dd0 --- /dev/null +++ b/examples/broadcast_group_sender2.rs @@ -0,0 +1,59 @@ +use std::{ + io::Write, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Instant, +}; + +use multicast::session::{Coordinator, GroupId}; + +const LOREM_IPSUM: &[u8; 5992] = br" +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam posuere hendrerit sem, id malesuada nisl pulvinar et. Maecenas venenatis nisl at nibh faucibus, vitae tempor magna auctor. Vestibulum interdum mi diam, vel molestie justo condimentum eu. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Nam odio sem, gravida sed dignissim a, tempor dignissim magna. Duis a lacus a magna gravida fermentum sit amet vitae justo. Pellentesque gravida lacus eget ante ultrices dapibus. Curabitur et iaculis felis. Etiam mollis diam a est fermentum, id vestibulum arcu rutrum. Aenean lobortis fermentum dolor, a consectetur dui ultrices vitae. Suspendisse id ultrices diam. Praesent eget varius lorem. Sed dignissim, libero vel rhoncus cursus, nulla massa bibendum nulla, nec fermentum nisi risus bibendum magna. Integer quis mauris in odio vulputate dictum ut id lacus. + +Nam posuere quam metus, ac blandit odio feugiat eget. Etiam congue id risus eu sagittis. Etiam iaculis imperdiet odio, varius vestibulum dolor ullamcorper id. Praesent at ante vitae metus pretium bibendum quis eget orci. Duis quis velit luctus, pulvinar libero a, finibus augue. Integer elementum rhoncus urna consectetur finibus. Praesent vel urna nec mi feugiat rutrum. Vivamus at nunc metus. Morbi euismod mi condimentum ex venenatis luctus a ac velit. Donec rhoncus tortor nec augue feugiat feugiat. Donec et ultrices nibh, a molestie turpis. + +Proin a feugiat nisl, eu suscipit eros. Sed id elementum dolor. Sed sagittis enim ipsum, non elementum massa lobortis in. Curabitur et viverra risus. Maecenas ultrices tristique gravida. Quisque tempus, ex id fermentum fermentum, lacus lacus varius purus, a rhoncus augue diam eget metus. Nunc non convallis neque. Fusce ac semper metus. Sed blandit diam quis est porttitor, mattis pulvinar nulla malesuada. Nullam posuere nunc tincidunt imperdiet efficitur. Nunc laoreet maximus purus in luctus. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Sed a venenatis nisi. + +In dignissim vitae purus in mattis. In congue augue at elit molestie, in consectetur tellus maximus. In sodales et massa ac ultrices. Phasellus congue imperdiet arcu, vel cursus enim. Duis at justo tellus. Praesent non nisl sem. Duis sodales velit in felis faucibus, vel suscipit eros interdum. Aenean semper ante nec sapien condimentum luctus. Vivamus ac elit at quam maximus ultrices. Ut varius nisl sed ex posuere, nec tristique eros varius. + +Curabitur non dapibus est, sit amet dignissim velit. Mauris augue nisi, facilisis non tincidunt in, lacinia ut dui. Fusce fermentum ultrices orci, et commodo sapien congue quis. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Pellentesque sit amet ante velit. Etiam arcu nulla, varius id leo sit amet, bibendum mollis eros. Quisque blandit mattis tellus, a egestas nulla suscipit in. Nulla fringilla interdum suscipit. Duis ut nisi id neque efficitur lacinia. Sed iaculis eu ligula vitae dictum. + +Aliquam sit amet pellentesque lacus, sed mattis neque. Nulla elementum eros eget lorem efficitur scelerisque. In a nunc lectus. Vestibulum dui ante, fringilla quis ullamcorper quis, scelerisque non purus. Sed sit amet bibendum arcu, eu feugiat arcu. Nulla eu nisi a nisl lobortis porttitor nec sit amet risus. Nam et euismod nunc. Nullam vitae est eleifend, facilisis eros consequat, bibendum enim. Curabitur orci turpis, convallis ut purus ac, fringilla tempor massa. Pellentesque eu posuere felis, a auctor tellus. + +Praesent vel tempor enim, a semper eros. Donec tortor quam, pulvinar quis arcu ac, mollis blandit ipsum. Sed volutpat ante in mattis lobortis. Integer eget massa euismod, vulputate sapien ac, egestas est. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Integer congue nec enim at blandit. Nam ornare erat quam, in ultricies justo euismod non. Cras rhoncus sapien eget ligula tristique elementum. Nullam ullamcorper, odio non eleifend maximus, magna turpis convallis augue, vel pellentesque orci sem eu sapien. + +Vestibulum semper enim quis tincidunt fermentum. Nunc orci lectus, rutrum in urna vitae, commodo egestas urna. Aliquam erat volutpat. Mauris eu ante sed ex ultricies efficitur a ut nibh. Suspendisse potenti. Aenean nec odio vel dui vehicula suscipit id aliquam lectus. Praesent a metus maximus, pharetra felis id, varius nibh. Aliquam aliquet orci arcu, non interdum sem auctor sit amet. Praesent convallis nibh ut leo posuere, sed laoreet turpis convallis. + +Aliquam lobortis nunc sed orci posuere, ut tincidunt justo porttitor. Nam turpis turpis, varius in mi in, volutpat bibendum sem. Praesent lobortis nulla ut sollicitudin interdum. Praesent et arcu rutrum libero rutrum auctor vulputate sit amet turpis. Vestibulum ut lacinia massa. Maecenas at tempor massa. Aenean convallis nunc quis maximus malesuada. Praesent ac condimentum dui. Nullam quis iaculis turpis, ac pellentesque dui. Donec semper ligula in augue posuere convallis. Vestibulum vel velit scelerisque, facilisis justo at, tempor enim. Fusce eu ornare augue. + +Etiam non augue dapibus, pellentesque felis eu, efficitur leo. Praesent vehicula risus sed nunc aliquam lobortis. Sed eget efficitur augue. Donec mollis laoreet vehicula. Duis in finibus lorem. Duis sit amet risus sem. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Nulla consequat, enim vel dictum blandit, turpis turpis facilisis ligula, a porta nisi eros id quam. Donec eu finibus erat. Nullam tincidunt pellentesque orci, at placerat tortor laoreet vitae. Nulla facilisis suscipit est blandit venenatis. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Vivamus id ipsum condimentum, ornare turpis id, convallis dolor. Quisque porta a diam quis sollicitudin. +"; + +fn main() { + env_logger::init(); + + let mut args = std::env::args(); + let _ = args.next().unwrap(); + let bind_addr: SocketAddr = args.next().unwrap().parse().unwrap(); + let multicast_addr: SocketAddr = args.next().unwrap().parse().unwrap(); + let group_id: Option = args.next().map(|s| s.parse().unwrap()); + + let coordinator = Coordinator::start_session(bind_addr, multicast_addr).unwrap(); + + let mut sender = coordinator.create_broadcast_group(group_id).unwrap(); + + loop { + if let Some(addr) = sender.try_accept().unwrap() { + println!("accepted connection from {}", addr); + } + // while let Some(addr) = sender.try_accept().unwrap() { + // println!("accepted connection from {}", addr); + // } + + if sender.has_members() { + println!("broadcasting to members"); + sender.write_message().write_all(LOREM_IPSUM).unwrap(); + } else { + println!("no members to broadcast to"); + } + } +} diff --git a/examples/publish.rs b/examples/publish.rs deleted file mode 100644 index 96ab231..0000000 --- a/examples/publish.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::io::Write; - -use subscriptions::publisher::{Publisher, PublisherConfig}; - -fn main() { - env_logger::builder() - .filter_level(log::LevelFilter::Debug) - .init(); - - let publisher = Publisher::new(PublisherConfig { - addr: "0.0.0.0:12345".parse().unwrap(), - multicast_addr: "224.0.0.0:5555".parse().unwrap(), - chunk_size: 1024, - }); - log::info!("publisher created"); - - let mut offer = publisher.create_offer().unwrap(); - log::info!("created offer with id {}", offer.id()); - - loop { - if let Some(addr) = offer.accept() { - log::info!("{:?} subscribed to offer {}", addr, offer.id()); - offer.write_message().write_all(b"welcome").unwrap(); - } - - if offer.has_subscribers() { - offer.write_message().write_all(b"nice that you are still here").unwrap(); - } - } -} diff --git a/examples/subscribe.rs b/examples/subscribe.rs deleted file mode 100644 index ea5db33..0000000 --- a/examples/subscribe.rs +++ /dev/null @@ -1,22 +0,0 @@ -use std::io::Read; - -use subscriptions::subscriber::Subscriber; - -fn main() { - env_logger::builder() - .filter_level(log::LevelFilter::Debug) - .init(); - - let mut subscriber = Subscriber::connect("127.0.0.1:12345".parse().unwrap()).unwrap(); - log::info!("subscriber connected"); - - let mut subscription = subscriber.subscribe(0).unwrap(); - log::info!("subscribed to channel 0"); - - while let Ok(msg) = subscription.recv() { - let mut msg = msg.read(); - let mut s = String::new(); - msg.read_to_string(&mut s).unwrap(); - log::info!("received: {}", s); - } -} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..b2715b2 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +wrap_comments = true diff --git a/src/barrier.rs b/src/barrier.rs new file mode 100644 index 0000000..0cc1d7c --- /dev/null +++ b/src/barrier.rs @@ -0,0 +1,557 @@ +use std::net::SocketAddr; + +use ahash::HashSet; +use socket2::SockAddr; + +use crate::{ + chunk::Chunk, + group::{ + GroupCoordinator, GroupCoordinatorState, GroupCoordinatorTypeImpl, GroupMember, + GroupMemberState, GroupMemberTypeImpl, + }, + protocol::{self, BarrierReleased, SequenceNumber}, + utils::{display_addr, sock_addr_to_socket_addr, ExponentialBackoff}, +}; + +#[derive(Debug, Default)] +pub(crate) struct BarrierGroupCoordinatorState { + arrived: HashSet, + ack_required: HashSet, +} + +impl BarrierGroupCoordinatorState { + fn swap_arrived_and_ack_required(&mut self) { + std::mem::swap(&mut self.arrived, &mut self.ack_required); + } +} + +impl GroupCoordinatorTypeImpl for BarrierGroupCoordinatorState { + const GROUP_TYPE: protocol::GroupType = protocol::GROUP_TYPE_BARRIER; + + fn process_join_cancelled(&mut self, addr: &SockAddr, _: &GroupCoordinatorState) { + self.arrived.remove(addr); + } + + fn process_member_disconnected(&mut self, addr: &SockAddr, _: &GroupCoordinatorState) { + self.arrived.remove(addr); + self.ack_required.remove(addr); + } + + fn process_chunk(&mut self, chunk: Chunk, addr: &SockAddr, group: &GroupCoordinatorState) { + match chunk { + Chunk::SessionJoin(_) + | Chunk::SessionWelcome(_) + | Chunk::SessionHeartbeat(_) + | Chunk::GroupJoin(_) + | Chunk::GroupWelcome(_) + | Chunk::GroupLeave(_) + | Chunk::GroupDisconnected(_) => { + unreachable!(); + } + Chunk::BroadcastMessage(_) + | Chunk::BroadcastFirstMessageFragment(_) + | Chunk::BroadcastMessageFragment(_) + | Chunk::BroadcastFinalMessageFragment(_) => { + log::trace!( + "IGNORED: broadcast message from {} received in barrier group", + display_addr(addr) + ); + } + Chunk::GroupAck(ack) => { + if ack.seq == group.seq.prev() { + if group.members.contains(addr) || group.member_requests.contains(addr) { + self.ack_required.remove(addr); + + log::trace!( + "received ack from {} with seq {}", + display_addr(addr), + ack.seq, + ); + } else { + log::trace!( + "IGNORED: received ack from non-member {} with seq {}", + display_addr(addr), + ack.seq, + ); + } + } else { + log::trace!( + "IGNORED: received ack from {} with unexpected seq {} (expected seq: {})", + display_addr(addr), + ack.seq, + group.seq.prev(), + ); + } + } + Chunk::BarrierReached(reached) => { + if reached.seq == group.seq { + if group.members.contains(addr) || group.member_requests.contains(addr) { + self.arrived.insert(addr.clone()); + + log::trace!( + "received barrier reached from {} with seq {}", + display_addr(addr), + reached.seq, + ); + } else { + log::trace!( + "IGNORED: received barrier reached from non-member {} with seq {}", + display_addr(addr), + reached.seq, + ); + } + } else { + log::trace!( + "IGNORED: received barrier reached from {} with non-current seq {} (current seq: {})", + display_addr(addr), + reached.seq, + group.seq, + ); + } + } + Chunk::BarrierReleased(released) => log::trace!( + "IGNORED: received barrier released from {} with seq {}", + display_addr(addr), + released.seq, + ), + } + } + + fn process_join_request(&mut self, _addr: &SockAddr, _group: &GroupCoordinatorState) {} +} + +pub struct BarrierGroupCoordinator { + pub(crate) group: GroupCoordinator, + // pub(crate) state: BarrierGroupCoordinatorState, +} + +impl BarrierGroupCoordinator { + // pub fn new(channel: GroupCoordinator) -> Self { + // Self { + // channel, + // arrived: HashSet::new(), + // } + // } + + fn all_members_arrived(&self) -> bool { + // Assert that arrived.len() == members.len() => arrived == members + debug_assert!( + self.group.state.members.len() != self.group.inner.arrived.len() + || self.group.state.members == self.group.inner.arrived + ); + self.group.state.members.len() == self.group.inner.arrived.len() + } + + pub fn has_members(&self) -> bool { + !self.group.state.members.is_empty() + } + + pub fn accept(&mut self) -> std::io::Result { + self.group.accept().and_then(sock_addr_to_socket_addr) + } + + pub fn try_accept(&mut self) -> std::io::Result> { + let addr = self.group.try_accept()?; + if let Some(addr) = addr { + Ok(Some(sock_addr_to_socket_addr(addr)?)) + } else { + Ok(None) + } + } + + pub fn wait(&mut self) -> std::io::Result<()> { + if self.group.state.members.is_empty() { + log::trace!("wait: no members in group"); + return Ok(()); + } + + // Wait until everyone has arrived + log::trace!("waiting for all members to arrive"); + if !self.all_members_arrived() { + self.group.recv()?; + } + log::trace!("all members have arrived"); + + let release_seq = self.group.state.seq; + self.group.state.seq = self.group.state.seq.next(); + + debug_assert!(self.group.inner.ack_required.is_empty()); + self.group.inner.swap_arrived_and_ack_required(); + + // Send barrier released and wait for acks + for deadline in ExponentialBackoff::new() { + log::trace!("sending barrier released with seq {}", release_seq); + self.group.send_chunk_to_group(&BarrierReleased { + seq: release_seq, + group_id: self.group.channel.id(), + })?; + log::trace!("waiting for acks"); + + loop { + match self.group.recv_until(deadline) { + Ok(_) => { + if self.group.inner.ack_required.is_empty() { + log::trace!("all members have acknowledged the barrier release"); + return Ok(()); + } else { + log::trace!( + "still waiting for acks from {:?}", + self.group.inner.ack_required.iter().map(display_addr), + ); + } + } + Err(err) if err.kind() == std::io::ErrorKind::TimedOut => break, + Err(err) => return Err(err), + } + } + + let mut members_to_remove = HashSet::default(); + log::trace!("still wating for:"); + for addr in &self.group.inner.ack_required { + if !self.group.session_members.is_alive(addr) { + members_to_remove.insert(addr.clone()); + log::trace!("{}: dead", display_addr(addr)); + } else { + log::trace!("{}: still alive", display_addr(addr)); + } + } + for addr in members_to_remove { + self.group.remove(&addr)?; + } + } + unreachable!(); + } +} + +#[derive(Debug, Default)] +pub(crate) struct BarrierGroupMemberState { + next: protocol::SequenceNumber, + released: protocol::SequenceNumber, +} + +impl BarrierGroupMemberState {} + +impl GroupMemberTypeImpl for BarrierGroupMemberState { + const GROUP_TYPE: protocol::GroupType = protocol::GROUP_TYPE_BARRIER; + + fn process_group_join(&mut self, seq: SequenceNumber, _group: &GroupMemberState) { + self.next = seq; + self.released = seq.prev(); + } + + fn process_chunk(&mut self, chunk: Chunk, addr: &SockAddr, _group: &GroupMemberState) -> bool { + match chunk { + Chunk::SessionJoin(_) + | Chunk::SessionWelcome(_) + | Chunk::SessionHeartbeat(_) + | Chunk::GroupJoin(_) + | Chunk::GroupWelcome(_) + | Chunk::GroupLeave(_) + | Chunk::GroupDisconnected(_) => { + unreachable!(); + } + Chunk::BroadcastMessage(_) + | Chunk::BroadcastFirstMessageFragment(_) + | Chunk::BroadcastMessageFragment(_) + | Chunk::BroadcastFinalMessageFragment(_) => { + log::trace!( + "IGNORED: broadcast message from {} received in barrier group", + display_addr(addr) + ); + } + Chunk::BarrierReached(reached) => log::trace!( + "IGNORED: received barrier released from {} with seq {}", + display_addr(addr), + reached.seq, + ), + Chunk::GroupAck(ack) => { + if ack.seq == self.next { + log::trace!( + "received ack from {} with seq {}", + display_addr(addr), + ack.seq, + ); + self.next = self.next.next(); + } else { + log::trace!( + "IGNORED: received ack from {} with seq {} (expected seq: {})", + display_addr(addr), + ack.seq, + self.next, + ); + } + } + Chunk::BarrierReleased(released) => { + if released.seq == self.released.next() { + log::trace!( + "received barrier released from {} with seq {}", + display_addr(addr), + released.seq, + ); + self.released = released.seq; + } else { + log::trace!( + "IGNORED: received barrier released from {} with unexpected seq {} (expected seq {})", + display_addr(addr), + released.seq, + self.released.next(), + ); + } + if self.next == self.released { + log::trace!("reached ack from server got lost for seq {}", self.next); + self.next = self.next.next(); + } + } + } + + false + } +} + +pub struct BarrierGroupMember { + pub(crate) group: GroupMember, +} + +impl BarrierGroupMember { + fn send_reached(&mut self) -> std::io::Result { + let reached = self.group.inner()?.next; + + for deadline in ExponentialBackoff::new() { + self.group.send_chunk(&protocol::BarrierReached { + seq: reached, + group_id: self.group.id()?, + })?; + + // return Ok(reached); + loop { + match self.group.recv_until(deadline) { + Ok(_) => { + if self.group.inner()?.next == reached.next() { + return Ok(reached); + } + } + Err(err) if err.kind() == std::io::ErrorKind::TimedOut => break, + Err(err) => return Err(err), + } + } + } + + unreachable!(); + } + + pub fn wait(&mut self) -> std::io::Result<()> { + let reached = self.send_reached()?; + + while self.group.inner()?.released != reached { + self.group.recv()?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use crate::session::{Coordinator, Member}; + use crate::test::*; + use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + thread, + }; + + #[test] + fn test_barrier_group() -> Result<()> { + init_logger(); + + let port = crate::test::get_port(); + let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); + let connect_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port); + let multicast_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(234, 0, 0, 0)), port); + + let coordinator = Coordinator::start_session(bind_addr, multicast_addr)?; + let member = Member::join_session(connect_addr)?; + + thread::scope(|s| { + s.spawn(|| { + let mut barrier_group_coordinator = + coordinator.create_barrier_group(Some(0)).unwrap(); + barrier_group_coordinator.accept().unwrap(); + + for _ in 0..10 { + barrier_group_coordinator.wait().unwrap(); + } + }); + + s.spawn(|| { + let mut barrier_group_member = member.join_barrier_group(0).unwrap(); + + for _ in 0..10 { + barrier_group_member.wait().unwrap(); + } + }); + }); + + Ok(()) + } +} + +// #[derive(Debug, Default)] +// struct BarrierGroupState { +// new_clients: HashSet, +// clients: HashSet, +// arrived: HashSet, +// seq: SequenceNumber, +// } + +// impl BarrierGroupState { +// /// Returns true if the client is connected to the barrier group and +// false otherwise. fn client_reached_barrier(&mut self, client: SocketAddr, +// seq: SequenceNumber) -> bool { if self.clients.contains(&client) { +// if self.seq == seq { +// self.arrived.insert(client); +// } +// true +// } else { +// false +// } +// } + +// /// Returns if all remotes have arrived at the barrier. +// fn all_remotes_arrived(&self) -> bool { +// } + +// /// Processes a single chunk +// fn process_chunk(&mut self, chunk: Chunk, addr: SocketAddr) -> bool { +// match chunk { +// Chunk::JoinBarrierGroup(_) => { +// self.new_clients.insert(addr); +// true +// } +// Chunk::BarrierReached(reached) => { +// if self.clients.contains(&addr) { +// self.client_reached_barrier(addr, reached.0.seq.into()); +// true +// } else { +// log::warn!("received barrier reached from non-client"); +// false +// } +// } +// Chunk::LeaveChannel(_) => { +// if self.clients.contains(&addr) { +// self.clients.remove(&addr); +// self.arrived.remove(&addr); +// } +// false +// } +// _ => { +// log::warn!("received invalid chunk: {chunk:?}"); +// self.clients.contains(&addr) +// } +// } +// } +// } + +// pub struct BarrierGroup { +// channel_id: ChannelId, +// desc: BarrierGroupDesc, +// state: BarrierGroupState, +// receiver: ChunkReceiver, +// socket: ChunkSocket, +// multicast_addr: SocketAddr, +// } + +// impl BarrierGroup { +// fn try_process(&mut self) -> bool { +// let mut processed = false; +// while let Ok(chunk) = self.receiver.try_recv() { +// if let (Ok(chunk), Some(addr)) = (chunk.validate(), +// chunk.addr().as_socket()) { if +// !self.state.process_chunk(chunk, addr) { let _ = self +// .socket +// +// .send_chunk_to(&ChannelDisconnected(self.channel_id.into()), &addr.into()); +// } +// processed = true; +// } +// } +// processed +// } + +// fn process(&mut self) { +// if let Ok(chunk) = self.receiver.recv() { +// if let (Ok(chunk), Some(addr)) = (chunk.validate(), +// chunk.addr().as_socket()) { if +// !self.state.process_chunk(chunk, addr) { let _ = self +// .socket +// +// .send_chunk_to(&ChannelDisconnected(self.channel_id.into()), &addr.into()); +// } +// } +// } +// self.try_process(); +// } + +// pub fn accept_client(&mut self, client: SocketAddr) -> Result<(), +// TransmitAndWaitError> { transmit_to_and_wait( +// &self.socket, +// &client, +// &ConfirmJoinChannel { +// header: ChannelHeader { +// channel_id: self.channel_id.into(), +// seq: self.state.seq.into(), +// }, +// }, +// self.desc.retransmit_timeout, +// self.desc.retransmit_count, +// &self.receiver, +// |chunk, addr| { +// if let Chunk::Ack(ack) = chunk { +// let ack_seq: u16 = ack.header.seq.into(); +// if ack_seq == self.state.seq && addr == client { +// log::debug!("client {} joined barrier group", +// client); self.state.clients.insert(addr); +// return true; +// } +// } else { +// self.state.process_chunk(chunk, addr); +// } +// false +// }, +// ) +// } + +// pub fn try_accept(&mut self) -> Result +// { self.try_process(); + +// if let Some(client) = self +// .state +// .new_clients +// .iter() +// .next() +// .copied() +// .and_then(|q| self.state.new_clients.take(&q)) +// { +// log::debug!("accepting client {}", client); +// self.accept_client(client)?; +// Ok(client) +// } else { +// Err(TransmitAndWaitError::RecvError(RecvTimeoutError::Timeout)) +// } +// } + +// pub fn has_remotes(&self) -> bool { +// !self.state.clients.is_empty() +// } + +// pub fn try_wait(&mut self) -> bool { +// self.try_process(); +// if self.state.all_remotes_arrived() { +// self.wait(); +// true +// } else { +// false +// } +// } + +// pub fn wait(&mut self) { +// } diff --git a/src/broadcast.rs b/src/broadcast.rs new file mode 100644 index 0000000..79604a0 --- /dev/null +++ b/src/broadcast.rs @@ -0,0 +1,595 @@ +use std::{ + collections::VecDeque, + io::{Read, Write}, + mem::ManuallyDrop, + net::SocketAddr, +}; + +use ahash::HashSet; +use socket2::SockAddr; +use zerocopy::FromBytes; + +use crate::{ + chunk::{Chunk, ChunkBuffer}, chunk_socket::ReceivedChunk, group::{ + GroupCoordinator, GroupCoordinatorState, GroupCoordinatorTypeImpl, GroupMember, + GroupMemberState, GroupMemberTypeImpl, + }, protocol::{ + self, BroadcastFinalMessageFragment, BroadcastFirstMessageFragment, BroadcastMessage, BroadcastMessageFragment, SequenceNumber, MESSAGE_PAYLOAD_OFFSET + }, session::GroupId, utils::{display_addr, sock_addr_to_socket_addr} +}; + +/// A chunk that has been sent but not yet acknowledged by all subscribers. +#[derive(Debug)] +struct UnacknowledgedChunk { + retransmit_time: std::time::Instant, + buffer: ChunkBuffer, + packet_size: usize, + missing_acks: HashSet, + retransmit_count: u32, +} + +#[derive(Debug, Default)] +pub(crate) struct BroadcastGroupSenderState { + // The first sequence number not acknowledged by all members. + seq_not_ack: SequenceNumber, + + // Chunks, that have been sent but not yet acknowledged by all members. + // + // The first element corresponds to self.seq_not_ack and the last element corresponds to + // self.seq_sent. + unacknowledged_chunks: VecDeque>, + + // The number of packets that are currently in flight. + packets_in_flight: usize, +} + +impl BroadcastGroupSenderState { + fn remove_addr(&mut self, addr: &SockAddr) { + for chunk in &mut self.unacknowledged_chunks { + if let Some(chunk) = chunk { + chunk.missing_acks.remove(addr); + } + } + } +} + +impl GroupCoordinatorTypeImpl for BroadcastGroupSenderState { + const GROUP_TYPE: protocol::GroupType = protocol::GROUP_TYPE_BARRIER; + + fn process_join_cancelled(&mut self, addr: &SockAddr, _: &GroupCoordinatorState) { + self.remove_addr(addr); + } + + fn process_member_disconnected(&mut self, addr: &SockAddr, _: &GroupCoordinatorState) { + self.remove_addr(addr); + } + + fn process_chunk(&mut self, chunk: Chunk, addr: &SockAddr, _: &GroupCoordinatorState) { + match chunk { + Chunk::SessionJoin(_) + | Chunk::SessionWelcome(_) + | Chunk::SessionHeartbeat(_) + | Chunk::GroupJoin(_) + | Chunk::GroupWelcome(_) + | Chunk::GroupLeave(_) + | Chunk::GroupDisconnected(_) => { + unreachable!(); + } + Chunk::BarrierReached(_) + | Chunk::BarrierReleased(_) + | Chunk::BroadcastMessage(_) + | Chunk::BroadcastFirstMessageFragment(_) + | Chunk::BroadcastMessageFragment(_) + | Chunk::BroadcastFinalMessageFragment(_) => { + log::trace!("IGNORED: {:?}", chunk); + } + Chunk::GroupAck(ack) => { + let offset = ack.seq - self.seq_not_ack; + if offset < self.unacknowledged_chunks.len() { + // TODO: check if this is actually the correct chunk: checksum? + if let Some(chunk) = &mut self.unacknowledged_chunks[offset as usize] { + if chunk.missing_acks.remove(addr) { + log::trace!( + "received ack for seq {} from {}", + ack.seq, + display_addr(addr) + ); + } else { + log::trace!( + "IGNORED: ack for seq {} from {} (already acked)", + ack.seq, + display_addr(addr) + ); + } + } else { + log::trace!( + "IGNORED: ack for seq {} from {} (already acked)", + ack.seq, + display_addr(addr) + ); + } + } else { + log::trace!( + "IGNORED: ack for seq {} from {} (seq_ack: {})", + ack.seq, + display_addr(addr), + self.seq_not_ack + ); + } + } + } + } +} + +pub struct BroadcastGroupSender { + pub(crate) group: GroupCoordinator, + pub(crate) initial_retransmit_delay: std::time::Duration, + pub(crate) max_retransmit_delay: std::time::Duration, + pub(crate) max_packets_in_flight: usize, +} + +impl BroadcastGroupSender { + pub(crate) fn new(group: GroupCoordinator) -> Self { + Self { + group, + initial_retransmit_delay: std::time::Duration::from_secs(1), + max_retransmit_delay: std::time::Duration::from_secs(10), + max_packets_in_flight: 10, + } + } + + fn process_unacknlowedged_chunks(&mut self) -> std::io::Result<()> { + let inner = &mut self.group.inner; + for maybe_chunk in &mut inner.unacknowledged_chunks { + if let Some(chunk) = maybe_chunk { + if chunk.missing_acks.is_empty() { + log::trace!("all members have acknowledged seq {}", inner.seq_not_ack); + // TODO: use take_if when it becomes stable + maybe_chunk.take(); + inner.packets_in_flight -= 1; + } else { + if chunk.retransmit_time < std::time::Instant::now() { + log::trace!("retransmitting packet"); + self.group.channel.send_chunk_buffer_to( + &chunk.buffer, + chunk.packet_size, + &self.group.multicast_addr, + )?; + + let retransmit_delay = + self.initial_retransmit_delay * 2u32.pow(chunk.retransmit_count); + let retransmit_delay = + std::cmp::min(retransmit_delay, self.max_retransmit_delay); + chunk.retransmit_time = std::time::Instant::now() + retransmit_delay; + } + } + } + } + + // Remove acknowledged chunks from the front + while let Some(None) = inner.unacknowledged_chunks.front() { + log::debug!("removing acknowledged chunk"); + inner.unacknowledged_chunks.pop_front(); + inner.seq_not_ack = inner.seq_not_ack.next(); + } + + Ok(()) + } + + pub fn id(&self) -> GroupId { + self.group.id().into() + } + + pub fn has_members(&self) -> bool { + !self.group.state.members.is_empty() + } + + /// Processes incoming messages and sends retransmissions if necessary. + /// + /// This method must be called periodically to ensure that messages are sent + /// and received. It is also called by the `accept` and `try_accept` + /// methods and when writing a message. + pub fn process(&mut self) -> std::io::Result<()> { + self.group.recv()?; + self.process_unacknlowedged_chunks() + } + + /// Processes incoming messages and sends retransmissions if necessary. + /// + /// This method must be called periodically to ensure that messages are sent + /// and received. It is also called by the `accept` and `try_accept` + /// methods and when writing a message. + pub fn try_process(&mut self) -> std::io::Result<()> { + self.group.try_recv()?; + self.process_unacknlowedged_chunks() + } + + pub fn accept(&mut self) -> std::io::Result { + self.process()?; + self.group.accept().and_then(sock_addr_to_socket_addr) + } + + pub fn try_accept(&mut self) -> std::io::Result> { + self.process_unacknlowedged_chunks()?; + + let addr = self.group.try_accept()?; + if let Some(addr) = addr { + let addr = sock_addr_to_socket_addr(addr)?; + Ok(Some(addr)) + } else { + Ok(None) + } + } + + fn send_message_buffer( + &mut self, + mut buffer: ChunkBuffer, + fragment_index: usize, + last: bool, + packet_size: usize, + ) -> std::io::Result<()> { + buffer.init(&BroadcastMessage { + seq: self.group.state.seq, + group_id: self.group.id(), + }); + + log::trace!("sending broadcast message {fragment_index} {last}"); + + match (fragment_index, last) { + (0, true) => {} + (0, false) => { + buffer[0] = protocol::CHUNK_ID_BROADCAST_FIRST_MESSAGE_FRAGMENT; + } + (_, false) => { + buffer[0] = protocol::CHUNK_ID_BROADCAST_MESSAGE_FRAGMENT; + } + (_, true) => { + buffer[0] = protocol::CHUNK_ID_BROADCAST_FINAL_MESSAGE_FRAGMENT; + } + } + + while self.group.inner.packets_in_flight >= self.max_packets_in_flight { + log::trace!( + "too many packets ({}) in flight (max: {})", + self.group.inner.packets_in_flight, + self.max_packets_in_flight + ); + self.process()?; + } + + self.group + .send_chunk_buffer_to_group(&buffer, packet_size)?; + self.group + .inner + .unacknowledged_chunks + .push_back(Some(UnacknowledgedChunk { + retransmit_time: std::time::Instant::now() + self.initial_retransmit_delay, + buffer, + packet_size, + missing_acks: self.group.state.members.clone(), + retransmit_count: 0, + })); + self.group.inner.packets_in_flight += 1; + self.group.state.seq = self.group.state.seq.next(); + + Ok(()) + } + + pub fn write_message(&mut self) -> MessageWriter { + MessageWriter { + buffer: ManuallyDrop::new(self.group.buffer_allocator().allocate()), + sender: self, + fragment_count: 0, + cursor: MESSAGE_PAYLOAD_OFFSET, + } + } + + pub fn wait(&mut self) -> Result<(), std::io::Error> { + while self.group.inner.packets_in_flight > 0 { + self.process()?; + } + Ok(()) + } +} + +pub struct MessageWriter<'a> { + sender: &'a mut BroadcastGroupSender, + buffer: ManuallyDrop, + fragment_count: usize, + cursor: usize, +} + +impl Drop for MessageWriter<'_> { + fn drop(&mut self) { + log::trace!("sending"); + let buffer = unsafe { ManuallyDrop::take(&mut self.buffer) }; + self.sender + .send_message_buffer(buffer, self.fragment_count, true, self.cursor) + .unwrap(); + } +} + +impl Write for MessageWriter<'_> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut src_bytes = buf; + + while !src_bytes.is_empty() { + let mut remaining_buffer = &mut self.buffer[self.cursor..]; + if remaining_buffer.is_empty() { + let new_buffer = self.sender.group.buffer_allocator().allocate(); + self.sender.send_message_buffer( + std::mem::replace(&mut self.buffer, new_buffer), + self.fragment_count, + false, + self.cursor, + )?; + self.fragment_count += 1; + self.cursor = MESSAGE_PAYLOAD_OFFSET; + remaining_buffer = &mut self.buffer[self.cursor..]; + } + + let len = remaining_buffer.len().min(src_bytes.len()); + remaining_buffer[..len].copy_from_slice(&src_bytes[..len]); + self.cursor += len; + src_bytes = &src_bytes[len..]; + } + + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +pub struct Message { + chunks: Vec, +} + +impl Message { + pub fn read(&self) -> MessageReader { + MessageReader { + chunks: &self.chunks, + cursor: MESSAGE_PAYLOAD_OFFSET, + } + } +} + +pub struct MessageReader<'a> { + chunks: &'a [ReceivedChunk], + cursor: usize, +} + +impl Read for MessageReader<'_> { + fn read(&mut self, buf: &mut [u8]) -> Result { + let total_len = buf.len(); + let mut dst_bytes = buf; + while !dst_bytes.is_empty() { + let mut remaining_buffer = + &self.chunks[0].buffer()[self.cursor..self.chunks[0].packet_size()]; + if remaining_buffer.is_empty() { + if self.chunks.len() == 1 { + return Ok(total_len - dst_bytes.len()); + } + self.chunks = &self.chunks[1..]; + self.cursor = MESSAGE_PAYLOAD_OFFSET; + remaining_buffer = + &self.chunks[0].buffer()[self.cursor..self.chunks[0].packet_size()]; + } + + let len = remaining_buffer.len().min(dst_bytes.len()); + dst_bytes[..len].copy_from_slice(&remaining_buffer[..len]); + self.cursor += len; + dst_bytes = &mut dst_bytes[len..]; + } + Ok(total_len) + } +} + +#[derive(Debug, Default)] +pub(crate) struct BroadcastGroupReceiverState { + next_seq: SequenceNumber, + chunks: VecDeque>, +} + +impl BroadcastGroupReceiverState {} + +impl GroupMemberTypeImpl for BroadcastGroupReceiverState { + const GROUP_TYPE: protocol::GroupType = protocol::GROUP_TYPE_BARRIER; + + fn process_group_join(&mut self, seq: SequenceNumber, _group: &GroupMemberState) { + self.next_seq = seq; + } + + fn process_chunk(&mut self, chunk: Chunk, addr: &SockAddr, _group: &GroupMemberState) -> bool { + match chunk { + Chunk::SessionJoin(_) + | Chunk::SessionWelcome(_) + | Chunk::SessionHeartbeat(_) + | Chunk::GroupJoin(_) + | Chunk::GroupWelcome(_) + | Chunk::GroupLeave(_) + | Chunk::GroupDisconnected(_) => { + unreachable!(); + } + Chunk::BarrierReached(_) | Chunk::BarrierReleased(_) | Chunk::GroupAck(_) => { + log::trace!( + "IGNORED: chunk from {} received in broadcast group", + display_addr(addr), + ); + false + } + Chunk::BroadcastMessage(BroadcastMessage { seq, .. }) + | Chunk::BroadcastFirstMessageFragment(BroadcastFirstMessageFragment { seq, .. }) + | Chunk::BroadcastMessageFragment(BroadcastMessageFragment { seq, .. }) + | Chunk::BroadcastFinalMessageFragment(BroadcastFinalMessageFragment { seq, .. }) => { + let offset = *seq - self.next_seq; + if offset > u16::MAX as usize / 2 { + // This is most likely an old packet, just ignore it + log::trace!( + "IGNORED: chunk from {} received in broadcast group (old packet)", + display_addr(addr), + ); + false + } else { + log::trace!( + "received chunk from {} in broadcast group: seq: {}", + display_addr(addr), + seq, + ); + + if offset as usize >= self.chunks.len() { + self.chunks.resize_with(offset as usize + 1, || None); + true + } else { + self.chunks[offset as usize].is_none() + } + } + } + } + } + + fn take_chunk(&mut self, chunk: ReceivedChunk, _: &GroupMemberState) { + let msg_seq = match BroadcastMessage::ref_from_prefix(&chunk.buffer()[1..]) { + Some(msg) => msg.seq, + None => { + unreachable!(); + } + }; + let offset = msg_seq - self.next_seq; + debug_assert!(offset < u16::MAX as usize / 2); + debug_assert!(offset < self.chunks.len()); + debug_assert!(self.chunks[offset].is_none()); + log::trace!("inserting chunk with seq {msg_seq} at offset {offset}"); + self.chunks[offset] = Some(chunk); + } +} + +pub struct BroadcastGroupReceiver { + pub(crate) group: GroupMember, +} + +impl BroadcastGroupReceiver { + pub fn recv(&mut self) -> std::io::Result { + loop { + let chunks = &mut self.group.inner_mut()?.chunks; + let mut chunk_count = 0; + + for (index, chunk) in chunks.iter().enumerate() { + match chunk { + Some(chunk) => match chunk.buffer()[0] { + protocol::CHUNK_ID_BROADCAST_MESSAGE => { + if index == 0 { + chunk_count = 1; + break; + } else { + panic!("unexpected chunk: {:?}", chunk); + } + } + protocol::CHUNK_ID_BROADCAST_FIRST_MESSAGE_FRAGMENT => { + if index == 0 { + } else { + panic!("unexpected chunk: {:?}", chunk); + } + } + protocol::CHUNK_ID_BROADCAST_MESSAGE_FRAGMENT => { + if index == 0 { + panic!("unexpected chunk: {:?}", chunk); + } else { + } + } + protocol::CHUNK_ID_BROADCAST_FINAL_MESSAGE_FRAGMENT => { + if index == 0 { + panic!("unexpected chunk: {:?}", chunk); + } else { + chunk_count = index + 1; + break; + } + } + _ => { + panic!("unexpected chunk: {:?}", chunk); + } + }, + None => { + break; + } + } + } + + if chunk_count == 0 { + self.group.recv()?; + } else { + let chunks = chunks + .drain(..chunk_count) + .map(|c| c.unwrap()) + .collect::>(); + + let seq = &mut self.group.inner_mut()?.next_seq; + for _ in 0..chunk_count { + *seq = seq.next() + } + return Ok(Message { chunks }); + } + } + } +} + +// pub struct Subscription { +// control_receiver: ChunkReceiver, +// multicast_receiver: ChunkReceiver, +// buffer_allocator: Arc, +// sequence: SequenceNumber, +// control_socket: ChunkSocket, + +// /// stores the chunks starting from the last received sequence number. +// chunks: VecDeque>, +// } + +// #[derive(thiserror::Error, Debug)] +// pub enum RecvError { +// #[error("I/O error: {0}")] +// Io(#[from] std::io::Error), + +// #[error("Disconnected")] +// Recv(#[from] crossbeam::channel::RecvError), +// } + +// impl Subscription { +// pub fn recv(&mut self) -> Result { +// loop { +// let chunk = self.multicast_receiver.recv()?; +// match chunk.validate() { +// Ok(Chunk::Message(msg, _)) => { +// self.control_socket.send_chunk(&Ack { +// header: ChannelHeader { +// seq: msg.header.seq, +// channel_id: msg.header.channel_id, +// }, +// })?; + +// let seq: u16 = msg.header.seq.into(); +// let offset = +// seq.wrapping_sub(self.sequence.wrapping_add(1)); if +// offset > u16::MAX / 2 { // This is most likely an old +// packet, just ignore it } else if seq != +// self.sequence.wrapping_add(1) { panic!( +// "unexpected sequence number: expected {}, got +// {}", self.sequence.wrapping_add(1), +// seq +// ); +// } else { +// log::debug!("received message: {:?}", msg); +// self.sequence = seq; +// return Ok(Message::SingleChunk(chunk)); +// } +// } +// Ok(chunk) => { +// log::debug!("ignore unexpected chunk: {:?}", chunk); +// } +// Err(err) => { +// log::error!("received invalid chunk: {}", err); +// } +// } +// } +// } +// } diff --git a/src/chunk.rs b/src/chunk.rs index fb4d94e..a538408 100644 --- a/src/chunk.rs +++ b/src/chunk.rs @@ -4,54 +4,97 @@ use std::{ }; use crossbeam::channel::{Receiver, Sender, TryRecvError}; +use zerocopy::network_endian::U16; use crate::protocol::{ - kind, Ack, BarrierReached, BarrierReleased, ChannelDisconnected, ChunkKindData, - ConfirmJoinChannel, Connect, ConnectionInfo, JoinBarrierGroup, JoinChannel, LeaveChannel, - Message, + self, BarrierReached, BarrierReleased, BroadcastFinalMessageFragment, BroadcastFirstMessageFragment, BroadcastMessage, BroadcastMessageFragment, ChunkHeader, ChunkIdentifier, GroupAck, GroupDisconnected, GroupJoin, GroupLeave, GroupWelcome, SequenceNumber, SessionHeartbeat, SessionJoin, SessionWelcome }; #[derive(Debug)] pub enum Chunk<'a> { - Connect(&'a Connect), - ConnectionInfo(&'a ConnectionInfo), - JoinChannel(&'a JoinChannel), - ConfirmJoinChannel(&'a ConfirmJoinChannel), - Ack(&'a Ack), - Message(&'a Message, &'a [u8]), - JoinBarrierGroup(&'a JoinBarrierGroup), + // Session related chunks + SessionJoin(&'a SessionJoin), + #[allow(unused)] // TODO: use #[expect(unused)] when it is stable + SessionWelcome(&'a SessionWelcome), + SessionHeartbeat(&'a SessionHeartbeat), + + // Group related chunks + GroupJoin(&'a GroupJoin), + GroupWelcome(&'a GroupWelcome), + GroupAck(&'a GroupAck), + GroupLeave(&'a GroupLeave), + GroupDisconnected(&'a GroupDisconnected), + + // Broadcast related chunks + BroadcastMessage(&'a BroadcastMessage), + BroadcastFirstMessageFragment(&'a BroadcastFirstMessageFragment), + BroadcastMessageFragment(&'a BroadcastMessageFragment), + BroadcastFinalMessageFragment(&'a BroadcastFinalMessageFragment), + + + // Barrier related chunk BarrierReached(&'a BarrierReached), BarrierReleased(&'a BarrierReleased), - LeaveChannel(&'a LeaveChannel), - ChannelDisconnected(&'a ChannelDisconnected), + // ConfirmJoinChannel(&'a ConfirmJoinChannel), + // Ack(&'a Ack), + // Message(&'a Message, &'a [u8]), + // JoinBarrierGroup(&'a JoinBarrierGroup), + // BarrierReached(&'a BarrierReached), + // BarrierReleased(&'a BarrierReleased), + // LeaveChannel(&'a LeaveChannel), } impl Chunk<'_> { - pub fn channel_id(&self) -> Option { + pub fn channel_id(&self) -> Option { match self { - Chunk::JoinChannel(join) => Some(join.channel_id.into()), - Chunk::ConfirmJoinChannel(confirm) => Some(confirm.header.channel_id.into()), - Chunk::Ack(ack) => Some(ack.header.channel_id.into()), - Chunk::Message(msg, _) => Some(msg.header.channel_id.into()), - Chunk::JoinBarrierGroup(join) => Some(join.0.into()), - Chunk::BarrierReached(b) => Some(b.0.channel_id.into()), - Chunk::BarrierReleased(b) => Some(b.0.channel_id.into()), - Chunk::LeaveChannel(c) => Some(c.0.into()), - Chunk::ChannelDisconnected(c) => Some(c.0.into()), - _ => None, + Chunk::GroupJoin(GroupJoin { group_id, .. }) => Some(*group_id), + Chunk::GroupDisconnected(GroupDisconnected(group_id)) => Some(*group_id), + Chunk::GroupWelcome(GroupWelcome { group_id, .. }) => Some(*group_id), + Chunk::GroupAck(GroupAck { group_id, .. }) => Some(*group_id), + Chunk::GroupLeave(GroupLeave(group_id)) => Some(*group_id), + Chunk::BarrierReached(BarrierReached { group_id, .. }) => Some(*group_id), + Chunk::BarrierReleased(BarrierReleased { group_id, .. }) => Some(*group_id), + Chunk::BroadcastMessage(BroadcastMessage { group_id, .. }) => Some(*group_id), + Chunk::BroadcastFirstMessageFragment(BroadcastFirstMessageFragment { group_id, .. }) => Some(*group_id), + Chunk::BroadcastMessageFragment(BroadcastMessageFragment { group_id, .. }) => Some(*group_id), + Chunk::BroadcastFinalMessageFragment(BroadcastFinalMessageFragment { group_id, .. }) => Some(*group_id), + + Chunk::SessionJoin(_) | Chunk::SessionWelcome(_) | Chunk::SessionHeartbeat(_) => None, + } + } + + pub fn requires_ack(&self) -> Option { + match self { + Chunk::GroupWelcome(GroupWelcome { seq, .. }) => Some(*seq), + Chunk::BarrierReached(BarrierReached { seq, .. }) => Some(*seq), + Chunk::BarrierReleased(BarrierReleased { seq, .. }) => Some(*seq), + Chunk::BroadcastMessage(BroadcastMessage { seq, .. }) => Some(*seq), + Chunk::BroadcastFirstMessageFragment(BroadcastFirstMessageFragment { seq, .. }) => Some(*seq), + Chunk::BroadcastMessageFragment(BroadcastMessageFragment { seq, .. }) => Some(*seq), + Chunk::BroadcastFinalMessageFragment(BroadcastFinalMessageFragment { seq, .. }) => Some(*seq), + + Chunk::GroupLeave(_) + | Chunk::GroupAck(_) + | Chunk::GroupDisconnected(_) + | Chunk::GroupJoin(_) + | Chunk::SessionJoin(_) + | Chunk::SessionWelcome(_) + | Chunk::SessionHeartbeat(_) => None, } } } /// A buffer for a single chunk. /// -/// This buffer is used to store the bytes of a single chunk. It is created using -/// `ChunkBufferAllocator::allocate` and is automatically returned to the allocator when it is -/// dropped. This avoids unnecessary allocations and deallocations. +/// This buffer is used to store the bytes of a single chunk. It is created +/// using `ChunkBufferAllocator::allocate` and is automatically returned to the +/// allocator when it is dropped. This avoids unnecessary allocations and +/// deallocations. /// -/// The first buffer indicates the kind of the chunk as defined in `crate::protocol::kind`. It is -/// followed by the data specific to its kind which is implemented by the `ChunkKindData` trait. -/// Optionally, it is followed by a payload. +/// The first buffer indicates the kind of the chunk as defined in +/// `crate::protocol::kind`. It is followed by the data specific to its kind +/// which is implemented by the `ChunkKindData` trait. Optionally, it is +/// followed by a payload. #[derive(Debug)] pub struct ChunkBuffer { bytes: ManuallyDrop>, @@ -60,8 +103,8 @@ pub struct ChunkBuffer { impl Drop for ChunkBuffer { fn drop(&mut self) { - // If send() returns an error, the allocator has been dropped. In this case we can also - // drop the buffer. + // If send() returns an error, the allocator has been dropped. In this case we + // can also drop the buffer. let _ = self .allocator .send(unsafe { ManuallyDrop::take(&mut self.bytes) }); @@ -96,21 +139,21 @@ impl DerefMut for ChunkBuffer { #[derive(thiserror::Error, Debug)] pub enum ChunkValidationError { - #[error("Invalid chunk kind: {0}")] - InvalidChunkKind(u8), + #[error("Invalid chunk identifier: {0}")] + InvalidChunkId(ChunkIdentifier), #[error("Invalid packet size: expected {expected}, got {actual}")] InvalidPacketSize { expected: usize, actual: usize }, } impl ChunkBuffer { - pub fn init(&mut self, kind_data: &T) { - self.bytes[0] = T::kind(); + pub fn init(&mut self, kind_data: &T) { + self.bytes[0] = T::id().into(); kind_data.write_to_prefix(&mut self.bytes[1..]); } - pub fn init_with_payload(&mut self, kind_data: &T, payload: &[u8]) { - self.bytes[0] = T::kind(); + pub fn init_with_payload(&mut self, kind_data: &T, payload: &[u8]) { + self.bytes[0] = T::id().into(); kind_data.write_to_prefix(&mut self.bytes[1..]); let payload_offset = 1 + std::mem::size_of::(); self.bytes[payload_offset..payload_offset + payload.len()].copy_from_slice(payload); @@ -124,7 +167,7 @@ impl ChunkBuffer { &self.bytes[1 + kind_size..packet_size] } - fn get_kind_data_ref( + fn get_kind_data_ref( &self, packet_size: usize, ) -> Result<&T, ChunkValidationError> { @@ -139,7 +182,7 @@ impl ChunkBuffer { } } - fn get_kind_data_and_payload_ref( + fn get_kind_data_and_payload_ref( &self, packet_size: usize, ) -> Result<(&T, &[u8]), ChunkValidationError> { @@ -156,43 +199,49 @@ impl ChunkBuffer { pub fn validate(&self, packet_size: usize) -> Result { match self.bytes[0] { - kind::CONNECT => Ok(Chunk::Connect( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_SESSION_JOIN => Ok(Chunk::SessionJoin( + self.get_kind_data_ref::(packet_size)?, )), - kind::CONNECTION_INFO => Ok(Chunk::ConnectionInfo( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_SESSION_WELCOME => Ok(Chunk::SessionWelcome( + self.get_kind_data_ref::(packet_size)?, )), - kind::JOIN_CHANNEL => Ok(Chunk::JoinChannel( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_SESSION_HEARTBEAT => Ok(Chunk::SessionHeartbeat( + self.get_kind_data_ref::(packet_size)?, )), - kind::CONFIRM_JOIN_CHANNEL => Ok(Chunk::ConfirmJoinChannel( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_GROUP_JOIN => Ok(Chunk::GroupJoin( + self.get_kind_data_ref::(packet_size)?, )), - kind::ACK => Ok(Chunk::Ack(self.get_kind_data_ref::(packet_size)?)), - kind::MESSAGE => { - let (data, payload) = self.get_kind_data_and_payload_ref::(packet_size)?; - // return Err(ChunkValidationError::InvalidPacketSize { - // expected: 1 + std::mem::size_of::() + data_len as usize, - // actual: packet_size, - // }); - Ok(Chunk::Message(data, payload)) - } - kind::JOIN_BARRIER_GROUP => Ok(Chunk::JoinBarrierGroup( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_GROUP_WELCOME => Ok(Chunk::GroupWelcome( + self.get_kind_data_ref::(packet_size)?, )), - kind::BARRIER_REACHED => Ok(Chunk::BarrierReached( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_GROUP_ACK => Ok(Chunk::GroupAck( + self.get_kind_data_ref::(packet_size)?, )), - kind::BARRIER_RELEASED => Ok(Chunk::BarrierReleased( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_GROUP_LEAVE => Ok(Chunk::GroupLeave( + self.get_kind_data_ref::(packet_size)?, + )), + protocol::CHUNK_ID_GROUP_DISCONNECTED => Ok(Chunk::GroupDisconnected( + self.get_kind_data_ref::(packet_size)?, )), - kind::LEAVE_CHANNEL => Ok(Chunk::LeaveChannel( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_BROADCAST_MESSAGE => Ok(Chunk::BroadcastMessage( + self.get_kind_data_ref::(packet_size)?, )), - kind::CHANNEL_DISCONNECTED => Ok(Chunk::ChannelDisconnected( - self.get_kind_data_ref::(packet_size)?, + protocol::CHUNK_ID_BROADCAST_FIRST_MESSAGE_FRAGMENT => Ok(Chunk::BroadcastFirstMessageFragment( + self.get_kind_data_ref::(packet_size)?, )), - kind => Err(ChunkValidationError::InvalidChunkKind(kind)), + protocol::CHUNK_ID_BROADCAST_MESSAGE_FRAGMENT => Ok(Chunk::BroadcastMessageFragment( + self.get_kind_data_ref::(packet_size)?, + )), + protocol::CHUNK_ID_BROADCAST_FINAL_MESSAGE_FRAGMENT => Ok(Chunk::BroadcastFinalMessageFragment( + self.get_kind_data_ref::(packet_size)?, + )), + protocol::CHUNK_ID_BARRIER_REACHED => Ok(Chunk::BarrierReached( + self.get_kind_data_ref::(packet_size)?, + )), + protocol::CHUNK_ID_BARRIER_RELEASED => Ok(Chunk::BarrierReleased( + self.get_kind_data_ref::(packet_size)?, + )), + id => Err(ChunkValidationError::InvalidChunkId(id)), } } } @@ -214,13 +263,17 @@ impl ChunkBufferAllocator { } } + pub fn chunk_size(&self) -> usize { + self.chunk_size + } + pub fn with_initial_capacity(chunk_size: usize, capacity: usize) -> Self { let allocator = Self::new(chunk_size); for _ in 0..capacity { match allocator.sender.send(allocator.allocate_bytes()) { Ok(_) => (), - Err(_) => unreachable!(), // This cannot happen, as we on both, the sender and - // receiver. + Err(_) => unreachable!(), /* This cannot happen, as we on both, the sender and + * receiver. */ } } allocator @@ -240,8 +293,8 @@ impl ChunkBufferAllocator { bytes: ManuallyDrop::new(self.allocate_bytes()), allocator: self.sender.clone(), }, - Err(TryRecvError::Disconnected) => unreachable!(), // This cannot happen, as we own - // both, the sender and receiver. + Err(TryRecvError::Disconnected) => unreachable!(), /* This cannot happen, as we own + * both, the sender and receiver. */ } } } diff --git a/src/chunk_socket.rs b/src/chunk_socket.rs index 1cdbfb8..d8169e6 100644 --- a/src/chunk_socket.rs +++ b/src/chunk_socket.rs @@ -1,10 +1,10 @@ -use std::sync::Arc; +use std::{io::IoSlice, sync::Arc}; use socket2::{SockAddr, Socket}; use crate::{ chunk::{Chunk, ChunkBuffer, ChunkBufferAllocator, ChunkValidationError}, - protocol::ChunkKindData, + protocol::ChunkHeader, }; #[derive(Debug)] @@ -41,19 +41,26 @@ impl ReceivedChunk { } /// A socket that sends and receives chunks. +#[derive(Debug, Clone)] pub struct ChunkSocket { - socket: Socket, + socket: Arc, buffer_allocator: Arc, + chunk_size: usize, } impl ChunkSocket { - pub fn new(socket: Socket, buffer_allocator: Arc) -> Self { + pub fn new(socket: Arc, buffer_allocator: Arc) -> Self { Self { socket, + chunk_size: buffer_allocator.chunk_size(), buffer_allocator, } } + pub fn buffer_allocator(&self) -> &Arc { + &self.buffer_allocator + } + pub fn receive_chunk(&self) -> Result { let mut buffer = self.buffer_allocator.allocate(); @@ -68,86 +75,119 @@ impl ChunkSocket { Ok(ReceivedChunk::new(buffer, addr, size)) } + #[inline] pub fn send_chunk_buffer( &self, buffer: &ChunkBuffer, packet_size: usize, ) -> Result<(), std::io::Error> { - self.socket.send(&buffer[..packet_size])?; + let sent_bytes = self.socket.send(&buffer[..packet_size])?; + debug_assert_eq!(sent_bytes, packet_size); Ok(()) } + #[inline] pub fn send_chunk_buffer_to( &self, buffer: &ChunkBuffer, packet_size: usize, addr: &SockAddr, ) -> Result<(), std::io::Error> { - self.socket.send_to(&buffer[..packet_size], addr)?; + let sent_bytes = self.socket.send_to(&buffer[..packet_size], addr)?; + debug_assert_eq!(sent_bytes, packet_size); Ok(()) } - pub fn send_chunk(&self, kind_data: &T) -> Result<(), std::io::Error> { - let mut buffer = self.buffer_allocator.allocate(); - buffer.init(kind_data); - let size = 1 + std::mem::size_of::(); - self.send_chunk_buffer(&buffer, size)?; + pub fn send_chunk(&self, kind_data: &T) -> Result<(), std::io::Error> { + let data_size = 1 + std::mem::size_of::(); + debug_assert!(data_size <= self.chunk_size); + + let kind_id = T::id(); + let kind_id_buf = [kind_id]; + let bufs = [ + IoSlice::new(&kind_id_buf), + IoSlice::new(kind_data.as_bytes()), + ]; + + let sent_bytes = self.socket.send_vectored(&bufs)?; + debug_assert_eq!(sent_bytes, data_size); + Ok(()) } - pub fn send_chunk_to( + pub fn send_chunk_to( &self, kind_data: &T, addr: &SockAddr, ) -> Result<(), std::io::Error> { - let mut buffer = self.buffer_allocator.allocate(); - buffer.init(kind_data); - let size = 1 + std::mem::size_of::(); - self.send_chunk_buffer_to(&buffer, size, addr)?; + let data_size = 1 + std::mem::size_of::(); + debug_assert!(data_size <= self.chunk_size); + + let kind_id = T::id(); + let kind_id_buf = [kind_id]; + let bufs = [ + IoSlice::new(&kind_id_buf), + IoSlice::new(kind_data.as_bytes()), + ]; + + let sent_bytes = self.socket.send_to_vectored(&bufs, addr)?; + debug_assert_eq!(sent_bytes, data_size); + Ok(()) } - pub fn send_chunk_with_payload( + pub fn send_chunk_with_payload( &self, kind_data: &T, payload: &[u8], ) -> Result<(), std::io::Error> { - let mut buffer = self.buffer_allocator.allocate(); - buffer.init(kind_data); - let size = 1 + std::mem::size_of::() + payload.len(); - if size > buffer.len() { + let data_size = 1 + std::mem::size_of::() + payload.len(); + if data_size > self.chunk_size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "payload too large", )); } - self.send_chunk_buffer(&buffer, size)?; + + let kind_id = T::id(); + let kind_id_buf = [kind_id]; + let bufs = [ + IoSlice::new(&kind_id_buf), + IoSlice::new(kind_data.as_bytes()), + IoSlice::new(payload), + ]; + + let sent_bytes = self.socket.send_vectored(&bufs)?; + debug_assert_eq!(sent_bytes, data_size); + Ok(()) } - pub fn send_chunk_with_payload_to( + pub fn send_chunk_with_payload_to( &self, kind_data: &T, payload: &[u8], addr: &SockAddr, ) -> Result<(), std::io::Error> { - let mut buffer = self.buffer_allocator.allocate(); - buffer.init_with_payload(kind_data, payload); - let size = 1 + std::mem::size_of::() + payload.len(); - if size > buffer.len() { + let data_size = 1 + std::mem::size_of::() + payload.len(); + if data_size > self.chunk_size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "payload too large", )); } - self.send_chunk_buffer_to(&buffer, size, addr)?; - Ok(()) - } - pub fn try_clone(&self) -> Result { - Ok(Self { - socket: self.socket.try_clone()?, - buffer_allocator: self.buffer_allocator.clone(), - }) + let kind_id = T::id(); + let kind_id_buf = [kind_id]; + let bufs = [ + IoSlice::new(&kind_id_buf), + IoSlice::new(kind_data.as_bytes()), + IoSlice::new(payload), + ]; + + let sent_bytes = self.socket.send_to_vectored(&bufs, addr)?; + debug_assert_eq!(sent_bytes, data_size); + + Ok(()) } } diff --git a/src/group.rs b/src/group.rs new file mode 100644 index 0000000..a38dc91 --- /dev/null +++ b/src/group.rs @@ -0,0 +1,766 @@ +use std::{sync::Arc, time::Instant}; + +use ahash::HashSet; +use crossbeam::{channel::RecvTimeoutError, select}; +use socket2::SockAddr; + +use crate::{ + chunk::{Chunk, ChunkBuffer, ChunkBufferAllocator}, + chunk_socket::ReceivedChunk, + multiplex_socket::{Channel, ProcessError}, + protocol::{ + self, ChunkHeader, GroupAck, GroupDisconnected, GroupId, GroupJoin, GroupLeave, GroupType, + GroupWelcome, + }, + session::MemberVitals, + utils::{display_addr, ExponentialBackoff}, +}; + +#[derive(Debug, Default)] +pub struct GroupCoordinatorState { + pub members: HashSet, + pub member_requests: HashSet, + pub seq: protocol::SequenceNumber, +} + +impl GroupCoordinatorState { + fn get_request(&mut self) -> Option<&SockAddr> { + self.member_requests.iter().next() + } +} + +pub trait GroupCoordinatorTypeImpl { + const GROUP_TYPE: GroupType; + + #[allow(unused_variables)] + fn process_join_request(&mut self, addr: &SockAddr, group: &GroupCoordinatorState) {} + + #[allow(unused_variables)] + fn process_join_cancelled(&mut self, addr: &SockAddr, group: &GroupCoordinatorState) {} + + #[allow(unused_variables)] + fn process_member_joined(&mut self, addr: &SockAddr, group: &GroupCoordinatorState) {} + + #[allow(unused_variables)] + fn process_member_disconnected(&mut self, addr: &SockAddr, group: &GroupCoordinatorState) {} + + #[allow(unused_variables)] + fn process_chunk(&mut self, chunk: Chunk, addr: &SockAddr, group: &GroupCoordinatorState) {} +} + +pub struct GroupCoordinator { + pub channel: Channel, + pub multicast_addr: SockAddr, + pub session_members: Arc, + pub state: GroupCoordinatorState, + pub inner: I, +} + +impl Drop for GroupCoordinator { + fn drop(&mut self) { + for addr in &self.state.members { + let _ = self + .channel + .send_chunk_to(&GroupDisconnected(self.channel.id().into()), addr); + } + } +} + +impl GroupCoordinator { + // This function was not very useful as all member removal events need slightly + // different handling of the channel disconnect message + /* + /// Removes a member from the channel. + /// + /// If it is present in the channel, the member is removed and a `MemberDisconnected` event is + /// emitted. If `disconnect_message` is provided, a `ChannelDisconnected` message is sent to + /// the member. If the member is not present in the channel, this is a no-op. + fn remove_member( + &mut self, + addr: &SockAddr, + disconnect_message: Option<&[u8]>, + event_handler: F, + ) { + if self.state.members.remove(addr) { + event_handler(CoordinatorChannelEvent::MemberDisconnected(addr)); + if let Some(message) = disconnect_message { + let _ = self.channel.send_chunk_with_payload_to( + &ChannelDisconnected(self.channel.id().into()), + message, + addr, + ); + } + } + } + */ + + pub fn id(&self) -> GroupId { + self.channel.id() + } + + pub fn buffer_allocator(&self) -> &Arc { + self.channel.buffer_allocator() + } + + pub fn accept(&mut self) -> Result { + let addr = loop { + if let Some(addr) = self.state.get_request() { + break addr.clone(); + } else { + self.recv()?; + } + }; + + log::trace!( + "accept group join request for group {} from {}", + self.channel.id(), + display_addr(&addr), + ); + + 'outer: for deadline in ExponentialBackoff::new() { + log::trace!( + "send group welcome for group {} to {}", + self.channel.id(), + display_addr(&addr), + ); + self.channel.send_chunk_to( + &GroupWelcome { + group_id: self.channel.id(), + seq: self.state.seq, + }, + &addr, + )?; + + while !self.state.members.contains(&addr) { + if !self.session_members.is_alive(&addr) { + log::trace!( + "member {} is not alive anymore, ignore join request", + display_addr(&addr), + ); + self.state.member_requests.remove(&addr); + self.inner.process_join_cancelled(&addr, &self.state); + return Err(std::io::ErrorKind::ConnectionAborted.into()); + } + + match self.recv_until(deadline) { + Ok(_) => {} + Err(err) if err.kind() == std::io::ErrorKind::TimedOut => continue 'outer, + Err(err) => { + self.state.member_requests.remove(&addr); + self.inner.process_join_cancelled(&addr, &self.state); + return Err(err); + } + } + } + + log::trace!( + "join group({}) for {} was successful", + self.channel.id(), + display_addr(&addr), + ); + + debug_assert!(self.state.members.contains(&addr)); + self.state.member_requests.remove(&addr); + self.inner.process_member_joined(&addr, &self.state); + return Ok(addr); + } + + unreachable!(); + } + + pub fn try_accept(&mut self) -> Result, std::io::Error> { + self.try_recv()?; + if self.state.member_requests.is_empty() { + Ok(None) + } else { + self.accept().map(Some) + } + } + + pub fn remove(&mut self, addr: &SockAddr) -> Result<(), std::io::Error> { + if let Some(addr) = self.state.members.take(addr) { + self.inner.process_member_disconnected(&addr, &self.state); + self.channel.send_chunk_with_payload_to( + &GroupDisconnected(self.channel.id()), + b"timeout", + &addr, + )?; + } + Ok(()) + } + + fn process_join(&mut self, join: &GroupJoin, addr: &SockAddr) { + debug_assert_eq!( + self.channel.id(), + join.group_id, + "this is enforced by the socket" + ); + + log::trace!( + "received group join for group {} from {}", + self.channel.id(), + display_addr(addr), + ); + + if I::GROUP_TYPE != join.group_type { + // TODO: what to do in case the send fails? + let _ = self.channel.send_chunk_with_payload_to( + &GroupDisconnected(self.channel.id().into()), + b"invalid channel type", + addr, + ); + return; + } + + if self.state.members.contains(addr) { + // We received a join request from a member that already joined the channel. + // This can happen in two cases: + // 1. The member ran into a timeout after requesting to join and resend the join + // request, but ended up receiving both. In this case we can ignore this. + // 2. The member disconnected and we were not able to receive the leave message + // for some reason. In this case we should first handle a member leave event + // and then treating this as a new join. + // Currently, the two cases cannot be distinguished and just disconnect the + // member for simplicity. + self.state.members.remove(addr); + self.inner.process_member_disconnected(addr, &self.state); + let _ = self.channel.send_chunk_with_payload_to( + &GroupDisconnected(self.channel.id().into()), + b"duplicate join", + addr, + ); + return; + } + + self.state.member_requests.insert(addr.clone()); + self.inner.process_join_request(addr, &self.state); + } + + fn process_ack(&mut self, ack: &GroupAck, addr: &SockAddr) { + if self.state.member_requests.contains(addr) { + if ack.seq == self.state.seq { + log::trace!( + "received ack for group welcome {} from {}", + self.channel.id(), + display_addr(addr), + ); + self.state.members.insert(addr.clone()); + } else { + log::trace!( + "INGORED: received ack for group welcome {} from {} with invalid seq {} (expected {})", + self.channel.id(), + display_addr(addr), + ack.seq, + self.state.seq, + ); + } + } else { + self.inner + .process_chunk(Chunk::GroupAck(ack), addr, &self.state); + } + } + + fn process_leave(&mut self, GroupLeave(channel_id): &GroupLeave, addr: &SockAddr) { + debug_assert_eq!( + self.channel.id(), + *channel_id, + "this is enforced by the socket" + ); + + if let Some(addr) = self.state.members.take(addr) { + self.inner.process_member_disconnected(&addr, &self.state); + } + let _ = self.channel.send_chunk_with_payload_to( + &GroupDisconnected(self.channel.id().into()), + b"leave", + addr, + ); + } + + fn process_disconnected( + &mut self, + GroupDisconnected(channel_id): &GroupDisconnected, + addr: &SockAddr, + ) { + debug_assert_eq!( + self.channel.id(), + *channel_id, + "this is enforced by the socket" + ); + + if let Some(addr) = self.state.members.take(addr) { + self.inner.process_member_disconnected(&addr, &self.state); + } + } + + /// Process all CHUNK_ID_CHANNEL_* chunks. + #[inline] + fn process_chunk(&mut self, chunk: Chunk, addr: &SockAddr) { + match chunk { + // Session chunks should not be forwarded to the channel. + Chunk::SessionJoin(_) => unreachable!("received session join in channel"), + Chunk::SessionWelcome(_) => unreachable!("received session welcome in channel"), + Chunk::SessionHeartbeat(_) => unreachable!("received session heartbeat in channel"), + Chunk::GroupJoin(join) => self.process_join(join, &addr), + Chunk::GroupWelcome(welcome) => log::trace!( + "coordinator received channel welcome from {:?}: {:?}", + addr, + welcome + ), + Chunk::GroupAck(ack) => self.process_ack(ack, &addr), + Chunk::GroupLeave(leave) => self.process_leave(leave, &addr), + Chunk::GroupDisconnected(disconnected) => self.process_disconnected(disconnected, addr), + Chunk::BarrierReached(_) + | Chunk::BarrierReleased(_) + | Chunk::BroadcastMessage(_) + | Chunk::BroadcastFirstMessageFragment(_) + | Chunk::BroadcastMessageFragment(_) + | Chunk::BroadcastFinalMessageFragment(_) => { + self.inner.process_chunk(chunk, addr, &self.state); + } + } + } + + #[inline] + pub fn process_received_chunk(&mut self, chunk: ReceivedChunk) { + match chunk.validate() { + Ok(c) => { + self.process_chunk(c, chunk.addr()); + } + Err(e) => { + // Those should be filtered out by the socket. + unreachable!("received invalid chunk: {:?}", e); + } + } + } + + /// Waits for a chunk and processes it. + #[inline] + pub fn recv(&mut self) -> Result<(), std::io::Error> { + let chunk = self.channel.recv()?; + self.process_received_chunk(chunk); + Ok(()) + } + + /// + #[inline] + pub fn try_recv(&mut self) -> Result<(), std::io::Error> { + if let Some(chunk) = self.channel.try_recv()? { + self.process_received_chunk(chunk); + } + Ok(()) + } + + /// Waits for a chunk until the deadline and processes it if one is + /// received. + #[inline] + pub fn recv_until(&mut self, deadline: std::time::Instant) -> Result<(), std::io::Error> { + let chunk = self.channel.recv_until(deadline)?; + self.process_received_chunk(chunk); + Ok(()) + } + + #[inline] + pub fn send_chunk_buffer_to_group( + &mut self, + buffer: &ChunkBuffer, + packet_size: usize, + ) -> Result<(), std::io::Error> { + log::trace!( + "sending chunk buffer to {}", + display_addr(&self.multicast_addr) + ); + self.channel + .send_chunk_buffer_to(buffer, packet_size, &self.multicast_addr) + } + + #[inline] + pub fn send_chunk_to_group( + &mut self, + header: &H, + ) -> Result<(), std::io::Error> { + self.channel.send_chunk_to(header, &self.multicast_addr) + } + + #[inline] + pub fn send_chunk_with_payload_to_group( + &mut self, + header: &H, + payload: &[u8], + ) -> Result<(), std::io::Error> { + self.channel + .send_chunk_with_payload_to(header, payload, &self.multicast_addr) + } +} + +#[derive(Debug, Default)] +pub struct GroupMemberState { + // pub seq: protocol::SequenceNumber, +} + +pub trait GroupMemberTypeImpl { + const GROUP_TYPE: GroupType; + + #[allow(unused_variables)] + fn process_group_join(&mut self, seq: protocol::SequenceNumber, group: &GroupMemberState) {} + + #[allow(unused_variables)] + fn process_disconnected(&mut self, group: &GroupMemberState) {} + + #[allow(unused_variables)] + fn process_chunk(&mut self, chunk: Chunk, addr: &SockAddr, group: &GroupMemberState) -> bool { + false + } + + #[allow(unused_variables)] + fn take_chunk(&mut self, chunk: ReceivedChunk, group: &GroupMemberState) {} +} + +pub struct ConnectedGroupMember { + group_id: protocol::GroupId, + coordinator_channel: Channel, + multicast_channel: Channel, + inner: I, + state: GroupMemberState, + coordinator_addr: SockAddr, +} + +impl Drop for ConnectedGroupMember { + fn drop(&mut self) { + let _ = self.leave(); + } +} + +impl ConnectedGroupMember { + // Attempts to leave the channel. + // + // It will send a LeaveChannel request and waits for a response. However, it may + // time out while waiting for the response. + fn leave(&mut self) -> Result<(), std::io::Error> { + for _ in 0..3 { + self.coordinator_channel.send_chunk_to( + &GroupLeave(self.coordinator_channel.id()), + &self.coordinator_addr, + )?; + + match self.coordinator_channel.process_for( + std::time::Duration::from_secs(5), + |chunk, _| match chunk { + Chunk::GroupDisconnected(_) => Ok(Some(())), + _ => Ok(None), + }, + ) { + Ok(_) => return Ok(()), + Err(ProcessError::RecvError(RecvTimeoutError::Timeout)) => { + // The channel did not respond in time. + continue; + } + Err(ProcessError::RecvError(RecvTimeoutError::Disconnected)) => { + return Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "channel disconnected", + )); + } + Err(ProcessError::Callback(())) => { + // We do not return an error + unreachable!(); + } + } + } + Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "connection timed out", + )) + } + + fn process_disconnected( + &mut self, + GroupDisconnected(group_id): &GroupDisconnected, + _: &SockAddr, + ) -> std::io::Result<()> { + debug_assert_eq!(self.group_id, *group_id, "this is enforced by the socket"); + Err(std::io::ErrorKind::ConnectionAborted.into()) + } + + /// Process all CHUNK_ID_CHANNEL_* chunks. + #[inline] + fn process_chunk(&mut self, chunk: Chunk, addr: &SockAddr) -> std::io::Result { + match chunk { + // Session chunks should not be forwarded to the channel. + Chunk::SessionJoin(_) => unreachable!("received session join in channel"), + Chunk::SessionWelcome(_) => unreachable!("received session welcome in channel"), + Chunk::SessionHeartbeat(_) => unreachable!("received session heartbeat in channel"), + Chunk::GroupJoin(_) | Chunk::GroupWelcome(_) | Chunk::GroupLeave(_) => { + log::trace!( + "INGORED: received unexpected chunk from group {}: {:?}", + display_addr(addr), + chunk + ); + Ok(false) + } + Chunk::GroupDisconnected(disconnected) => { + self.process_disconnected(disconnected, addr).map(|_| false) + } + Chunk::GroupAck(_) + | Chunk::BroadcastMessage(_) + | Chunk::BroadcastFirstMessageFragment(_) + | Chunk::BroadcastMessageFragment(_) + | Chunk::BroadcastFinalMessageFragment(_) + | Chunk::BarrierReached(_) + | Chunk::BarrierReleased(_) => Ok(self.inner.process_chunk(chunk, addr, &self.state)), + } + } + + #[inline] + pub fn process_received_chunk(&mut self, chunk: ReceivedChunk) -> std::io::Result<()> { + match chunk.validate() { + Ok(c) => { + if self.process_chunk(c, chunk.addr())? { + self.inner.take_chunk(chunk, &self.state); + } + Ok(()) + } + Err(e) => { + // Those should be filtered out by the socket. + unreachable!("received invalid chunk: {:?}", e); + } + } + } + + /// Waits for a chunk and processes it. + #[inline] + pub fn recv(&mut self) -> Result<(), std::io::Error> { + select! { + recv(self.coordinator_channel.receiver()) -> chunk => { + let chunk = chunk.map_err(|_| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))?; + self.process_received_chunk(chunk) + } + recv(self.multicast_channel.receiver()) -> chunk => { + let chunk = chunk.map_err(|_| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))?; + self.process_received_chunk(chunk) + } + } + } + + /// Waits for a chunk and processes it. + #[inline] + pub fn try_recv(&mut self) -> Result<(), std::io::Error> { + while let Some(chunk) = self.coordinator_channel.try_recv()? { + self.process_received_chunk(chunk)?; + } + + while let Some(chunk) = self.multicast_channel.try_recv()? { + self.process_received_chunk(chunk)?; + } + + Ok(()) + } + + /// Waits for a chunk until the deadline and processes it if one is + /// received. + #[inline] + pub fn recv_until(&mut self, deadline: std::time::Instant) -> Result<(), std::io::Error> { + let timeout = deadline - std::time::Instant::now(); + select! { + recv(self.coordinator_channel.receiver()) -> chunk => { + let chunk = chunk.map_err(|_| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))?; + self.process_received_chunk(chunk) + } + recv(self.multicast_channel.receiver()) -> chunk => { + let chunk = chunk.map_err(|_| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))?; + self.process_received_chunk(chunk) + } + default(timeout) => { + Err(std::io::Error::from(std::io::ErrorKind::TimedOut)) + } + } + } +} + +pub enum GroupMember { + Connected(ConnectedGroupMember), + Disconnected, +} + +impl GroupMember { + pub fn join( + coordinator_addr: SockAddr, + coordinator_channel: Channel, + multicast_channel: Channel, + inner: I, + ) -> Result { + let group_id = coordinator_channel.id(); + + for deadline in ExponentialBackoff::new() { + log::trace!( + "send join request with type {} to group {}", + I::GROUP_TYPE, + group_id, + ); + coordinator_channel.send_chunk_to( + &GroupJoin { + group_id, + group_type: I::GROUP_TYPE, + }, + &coordinator_addr, + )?; + + loop { + match coordinator_channel.recv_until(deadline) { + Ok(chunk) => match chunk.validate() { + Ok(Chunk::GroupWelcome(welcome)) => { + log::trace!("received welcome from group {}: {:?}", group_id, welcome); + let state = GroupMemberState {}; + let mut inner = inner; + inner.process_group_join(welcome.seq, &state); + return Ok(Self::Connected(ConnectedGroupMember { + group_id: group_id, + coordinator_channel, + multicast_channel, + inner: inner, + state, + coordinator_addr, + })); + } + Ok(Chunk::GroupDisconnected(_)) => { + log::trace!("received disconnect from group {}", group_id,); + return Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "channel disconnected", + )); + } + Ok( + Chunk::SessionJoin(_) + | Chunk::SessionWelcome(_) + | Chunk::SessionHeartbeat(_) + | Chunk::GroupJoin(_) + | Chunk::GroupAck(_) + | Chunk::GroupLeave(_) + | Chunk::BroadcastMessage(_) + | Chunk::BroadcastFirstMessageFragment(_) + | Chunk::BroadcastMessageFragment(_) + | Chunk::BroadcastFinalMessageFragment(_) + | Chunk::BarrierReached(_) + | Chunk::BarrierReleased(_), + ) => { + log::trace!( + "IGNORED: received unexpected chunk from group {}: {:?}", + group_id, + chunk + ); + } + Err(err) => { + log::trace!( + "IGNORED: received invalid chunk as join resquest response: {:?}", + err + ); + } + }, + Err(err) if err.kind() == std::io::ErrorKind::TimedOut => break, + Err(err) => return Err(err), + } + } + } + + Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "connection timed out", + )) + } + + #[inline] + pub fn id(&self) -> std::io::Result { + match self { + Self::Connected(inner) => Ok(inner.group_id), + Self::Disconnected => Err(std::io::ErrorKind::NotConnected.into()), + } + } + + #[inline] + pub fn inner(&self) -> std::io::Result<&I> { + match self { + Self::Connected(inner) => Ok(&inner.inner), + Self::Disconnected => Err(std::io::ErrorKind::NotConnected.into()), + } + } + + #[inline] + pub fn inner_mut(&mut self) -> std::io::Result<&mut I> { + match self { + Self::Connected(inner) => Ok(&mut inner.inner), + Self::Disconnected => Err(std::io::ErrorKind::NotConnected.into()), + } + } + + #[inline] + pub fn recv(&mut self) -> std::io::Result<()> { + match self { + Self::Connected(inner) => inner.recv(), + Self::Disconnected => Err(std::io::ErrorKind::NotConnected.into()), + } + } + + #[inline] + pub fn recv_until(&mut self, dealine: Instant) -> std::io::Result<()> { + match self { + Self::Connected(inner) => inner.recv_until(dealine), + Self::Disconnected => Err(std::io::ErrorKind::NotConnected.into()), + } + } + + #[inline] + pub fn send_chunk(&self, header: &H) -> Result<(), std::io::Error> { + match self { + Self::Connected(inner) => inner + .coordinator_channel + .send_chunk_to(header, &inner.coordinator_addr), + Self::Disconnected => Err(std::io::ErrorKind::NotConnected.into()), + } + } +} + +// #[cfg(test)] +// mod test { +// use crate::session::{Coordinator, Member}; +// use crate::test::*; +// use std::{ +// net::{IpAddr, Ipv4Addr, SocketAddr}, +// thread, +// }; + +// #[test] +// fn immediate_session_close_after_group_join() -> Result<()> { +// init_logger(); + +// let port = crate::test::get_port(); +// let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, +// 1)), port); let multicast_addr = +// SocketAddr::new(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)), port); + +// let coordinator = Coordinator::start_session(bind_addr, +// multicast_addr)?; let member = Member::join_session(bind_addr)?; + +// thread::scope(|s| { +// s.spawn(|| { +// let mut barrier_group_coordinator = +// coordinator.create_group(). +// barrier_group_coordinator.accept().unwrap(); + +// for _ in 0..10 { +// barrier_group_coordinator.wait().unwrap(); +// } +// }); + +// s.spawn(|| { +// let mut barrier_group_member = +// member.join_barrier_group(0).unwrap(); + +// for _ in 0..10 { +// barrier_group_member.wait().unwrap(); +// } +// }); +// }); + +// Ok(()) +// } +// } diff --git a/src/lib.rs b/src/lib.rs index 07ef91e..34ec96e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,51 @@ -#![allow(dead_code)] +//! High-performance 1-to-many communication and synchronization primitives +//! using UDP multicast. +//! +//! This crates provides a set of communication and synchronization primitives +//! similar to the collective communication routines found in MPI. In contrast +//! to MPI, this crates allows for more flexible communication patterns without +//! sacrificing performance and is designed to be used in mid-sized distributed +//! systems. The underlying protocol is session based which allows nodes to join +//! and leave at any time. One node explicitly takes over the role of the +//! session [coordinator](session::Coordinator) and is responsible for creating +//! the session. All other nodes must join the session as a +//! [member](session::Member). For more information on how to establish and join +//! sessions, see the [session] module. +//! +//! # Groups +//! Each session is divided into multiple groups that operate independently of +//! each other. Groups are always created by the coordinator and must be +//! explicitly joined by members. There are currently two different types of +//! groups. +//! +//! ## Barrier +//! +//! ## Broadcast +//! +//! # Debugging +//! If you encounter a lot of timeouts, hangs, or other issues, you can enable +//! logging to get more information about what is happening. This crate uses the +//! [log](https://crates.io/crates/log) crate for logging. So, choose a +//! logging implementation according to their +//! [documentation](https://docs.rs/log/latest/log/) and make sure to set the +//! log level to `trace` for this crate. In order to avoid any overhead caused +//! by logging ensure that the compile time filter is set to debug or higher +//! (see [Compile Time Filters](https://docs.rs/log/latest/log/#compile-time-filters) for more +//! information). +//! +//! # Future Work +//! - **Improve Member Management**: Currrently, the chunk allocation strategy +//! is suboptimal and can be improved to increase performance. pub(crate) mod chunk; pub(crate) mod chunk_socket; +pub(crate) mod group; pub(crate) mod multiplex_socket; pub(crate) mod protocol; +#[cfg(test)] +pub(crate) mod test; +pub(crate) mod utils; -pub mod publisher; -pub mod subscriber; - -pub type ChannelId = u16; -pub(crate) type SequenceNumber = u16; +pub mod barrier; +pub mod broadcast; +pub mod session; diff --git a/src/multiplex_socket.rs b/src/multiplex_socket.rs index 800a213..db67373 100644 --- a/src/multiplex_socket.rs +++ b/src/multiplex_socket.rs @@ -1,170 +1,481 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, num::NonZeroUsize, sync::Arc, time::Duration}; use ahash::HashMap; -use crossbeam::channel::{Receiver, RecvError, RecvTimeoutError, Select, Sender}; -use socket2::Socket; +use crossbeam::channel::{Receiver, RecvTimeoutError, Select, Sender, TryRecvError, TrySendError}; +use dashmap::DashMap; +use socket2::{SockAddr, Socket}; use crate::{ - chunk::{Chunk, ChunkBufferAllocator}, + chunk::{Chunk, ChunkBuffer, ChunkBufferAllocator}, chunk_socket::{ChunkSocket, ReceivedChunk}, - protocol::ChunkKindData, + protocol::{self, ChunkHeader, GroupDisconnected}, + utils::display_addr, }; -type ChannelId = u16; - pub type ChunkSender = Sender; pub type ChunkReceiver = Receiver; -type ChannelListenerReceiver = Receiver<(ChannelId, ChunkSender)>; -type ChannelListenerSender = Sender<(ChannelId, ChunkSender)>; -struct ChannelConnections { - channel_receiver: ChannelListenerReceiver, - channels: HashMap, +pub enum CallbackReason<'a> { + UnhandledChunk { + chunk: Chunk<'a>, + addr: &'a SockAddr, + }, + Timeout, +} + +pub trait Callback: Fn(&MultiplexSocket, CallbackReason) + Send + 'static {} +impl Callback for F {} + +// #[derive(thiserror::Error, Debug)] +// pub enum SendError { +// #[error("I/O error: {0}")] +// Io(#[from] std::io::Error), + +// #[error("buffer is not a valid chunk")] +// InvalidChunk, +// } + +#[derive(thiserror::Error, Debug)] +enum ForwardChunkError { + #[error("channel is full")] + RecvBufferFull(ReceivedChunk), + + #[error("channel is disconnected")] + Disconnected(ReceivedChunk), +} + +#[derive(thiserror::Error, Debug)] +pub enum ProcessError { + #[error("I/O error: {0}")] + RecvError(#[from] RecvTimeoutError), + + #[error("callback error: {0}")] + Callback(E), } -impl ChannelConnections { - fn new(channel_receiver: ChannelListenerReceiver) -> Self { - Self { - channel_receiver, - channels: HashMap::default(), +pub struct Channel { + socket: Arc, + receiver: ChunkReceiver, + channel_id: protocol::GroupId, +} + +impl Drop for Channel { + fn drop(&mut self) { + self.socket.channels.remove(&self.channel_id); + // debug_assert!(self + // .receiver + // .try_recv() + // .is_err_and(|e| e == TryRecvError::Disconnected)); + } +} + +impl Channel { + #[inline] + pub fn id(&self) -> protocol::GroupId { + self.channel_id + } + + #[inline] + pub fn buffer_allocator(&self) -> &Arc { + self.socket.buffer_allocator() + } + + #[inline] + pub fn send_chunk(&self, kind_data: &H) -> Result<(), std::io::Error> { + self.socket.send_chunk(kind_data) + } + + #[inline] + pub fn send_chunk_buffer_to( + &self, + buffer: &ChunkBuffer, + packet_size: usize, + addr: &SockAddr, + ) -> Result<(), std::io::Error> { + self.socket.send_chunk_buffer_to(buffer, packet_size, addr) + } + + #[inline] + pub fn send_chunk_to( + &self, + kind_data: &H, + addr: &SockAddr, + ) -> Result<(), std::io::Error> { + self.socket.send_chunk_to(kind_data, addr) + } + + #[inline] + pub fn send_chunk_with_payload_to( + &self, + kind_data: &H, + payload: &[u8], + addr: &SockAddr, + ) -> Result<(), std::io::Error> { + self.socket + .send_chunk_with_payload_to(kind_data, payload, addr) + } + + #[inline] + pub fn receiver(&self) -> &ChunkReceiver { + &self.receiver + } + + #[inline] + pub fn recv(&self) -> Result { + self.receiver + .recv() + .map_err(|_| std::io::ErrorKind::NotConnected.into()) + } + + #[inline] + pub fn try_recv(&self) -> Result, std::io::Error> { + match self.receiver.try_recv() { + Ok(chunk) => Ok(Some(chunk)), + Err(TryRecvError::Empty) => Ok(None), + Err(TryRecvError::Disconnected) => Err(std::io::ErrorKind::NotConnected.into()), } } - fn send(&mut self, channel_id: ChannelId, chunk: ReceivedChunk) -> Result<(), RecvError> { - let mut chunk = Some(chunk); + #[inline] + pub fn recv_until( + &self, + deadline: std::time::Instant, + ) -> Result { + self.receiver + .recv_deadline(deadline) + .map_err(|err| match err { + RecvTimeoutError::Timeout => std::io::ErrorKind::TimedOut.into(), + RecvTimeoutError::Disconnected => std::io::ErrorKind::NotConnected.into(), + }) + } - while let Some(c) = chunk.take() { - match self.channels.get(&channel_id) { - Some(sender) => { - if let Err(err) = sender.send(c) { - // This sender is disconnected, remove it from the map and try again. - self.channels.remove(&channel_id); - chunk = Some(err.0); - } - } - None => { - let (id, sender) = self.channel_receiver.recv()?; - self.channels.insert(id, sender); - log::debug!("received channel {}", id); - chunk = Some(c); + // #[inline] + // pub fn recv_timeout(&self, timeout: Duration) -> Result { + // self.receiver.recv_timeout(timeout).map_err(|_| ()) + // } + + /// Processes chunks for a duration + pub fn process_for Result, E>>( + &self, + duration: std::time::Duration, + p: F, + ) -> Result> { + let deadline = std::time::Instant::now() + duration; + self.process_until(deadline, p) + } + + /// Processes chunks until the deadline is reached. + pub fn process_until Result, E>>( + &self, + deadline: std::time::Instant, + p: F, + ) -> Result> { + loop { + let chunk = self.receiver.recv_deadline(deadline)?; + + if let Ok(c) = chunk.validate() { + match p(c, chunk.addr()) { + Ok(Some(v)) => return Ok(v), + Ok(None) => {} + Err(e) => return Err(ProcessError::Callback(e)), } + } else { + unreachable!("should be filtered out by the socket"); } } - - Ok(()) } -} -#[derive(thiserror::Error, Debug)] -pub enum SendError { - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - - #[error("buffer is not a valid chunk")] - InvalidChunk, + // pub fn wait_for_chunk Option>( + // &mut self, + // timeout: std::time::Duration, + // mut p: P, + // ) -> Result { + // let start = std::time::Instant::now(); + // let deadline = start + timeout; + + // loop { + // match self.receiver.recv_deadline(deadline) { + // Ok(chunk) => { + // if let Ok(c) = chunk.validate() { + // if let Some(val) = p(c, chunk.addr()) { + // return Ok(val); + // } + // } + // } + // Err(err) => return Err(err), + // } + // } + // } } pub struct MultiplexSocket { - socket: ChunkSocket, - channel_sender: ChannelListenerSender, + inner: ChunkSocket, + channels: DashMap, } impl MultiplexSocket { - fn receiver_thread( - socket: ChunkSocket, - channel_sender_receiver: ChannelListenerReceiver, - process_unchannelled_chunk: F, - ) { - let mut channel_connections = ChannelConnections::new(channel_sender_receiver); + fn forward_chunk( + &self, + channel_id: protocol::GroupId, + chunk: ReceivedChunk, + cache: &mut HashMap, + ) -> Result<(), ForwardChunkError> { + // First try to send the chunk to the cached channel. + let chunk = if let Some(sender) = cache.get(&channel_id) { + match sender.try_send(chunk) { + Ok(_) => return Ok(()), + Err(TrySendError::Full(chunk)) => { + return Err(ForwardChunkError::RecvBufferFull(chunk)) + } + Err(TrySendError::Disconnected(chunk)) => { + cache.remove(&channel_id); + chunk + } + } + } else { + chunk + }; + + if let Some(sender) = self.channels.get(&channel_id) { + match sender.try_send(chunk) { + Ok(_) => { + cache.insert(channel_id, sender.clone()); + Ok(()) + } + Err(TrySendError::Full(chunk)) => { + return Err(ForwardChunkError::RecvBufferFull(chunk)) + } + Err(TrySendError::Disconnected(_)) => { + unreachable!(); + } + } + } else { + return Err(ForwardChunkError::Disconnected(chunk)); + } + } - loop { - match socket.receive_chunk() { + fn receiver_thread( + receiver_socket: Arc, + sender_socket: Arc, + callback: F, + ) { + // Caches the channels for faster lookup. + let mut channel_cache = HashMap::default(); + + // If we hold the only strong reference to the socket, we can exit the loop and + // drop the socket. + let exit = if Arc::ptr_eq(&receiver_socket, &sender_socket) { + Arc::strong_count(&receiver_socket) <= 2 + } else { + Arc::strong_count(&receiver_socket) <= 1 + }; + + while !exit { + match receiver_socket.inner.receive_chunk() { Ok(chunk) => match chunk.validate() { Ok(c) => { + log::trace!( + "received chunk: {:?} from {}", + c, + display_addr(chunk.addr()) + ); if let Some(channel_id) = c.channel_id() { - if let Err(err) = channel_connections.send(channel_id, chunk) { - log::error!("failed to forward join channel: {}", err); + let ack = c.requires_ack().map(|c| (c, chunk.addr().clone())); + + match receiver_socket.forward_chunk( + channel_id.into(), + chunk, + &mut channel_cache, + ) { + Ok(_) => { + if let Some((ack, addr)) = ack { + log::trace!( + "sending ack({ack}) for channel {channel_id} to {}", + display_addr(&addr) + ); + + if let Err(err) = sender_socket.send_chunk_to( + &protocol::GroupAck { + group_id: channel_id.into(), + seq: ack, + }, + &addr, + ) { + log::error!("failed to send ack: {}", err); + } + } + } + Err(ForwardChunkError::RecvBufferFull(chunk)) => { + log::warn!( + "channel {} is full, dropping chunk: {:?}", + channel_id, + chunk + ); + } + Err(ForwardChunkError::Disconnected(chunk)) => { + receiver_socket + .send_chunk_to( + &GroupDisconnected(channel_id), + chunk.addr(), + ) + .ok(); + log::warn!( + "channel {} is disconnected, dropping chunk", + channel_id + ); + } } } else { - process_unchannelled_chunk(&socket, chunk) + callback( + &receiver_socket, + CallbackReason::UnhandledChunk { + chunk: c, + addr: chunk.addr(), + }, + ); } } Err(err) => log::error!("received invalid chunk: {}", err), }, - Err(err) => { - log::error!("failed to read from socket: {}", err); - } + Err(err) => match err.kind() { + std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut => { + callback(&receiver_socket, CallbackReason::Timeout); + } + _ => { + log::error!("failed to read from socket: {}", err); + } + }, } } } - fn spawn_receiver_thread( - socket: ChunkSocket, - channel_sender_receiver: ChannelListenerReceiver, - process_unchannelled_chunk: F, + fn spawn_receiver_thread( + receiver_socket: Arc, + sender_socket: Arc, + callback: F, ) -> std::thread::JoinHandle<()> { std::thread::spawn(move || { - Self::receiver_thread(socket, channel_sender_receiver, process_unchannelled_chunk); + Self::receiver_thread(receiver_socket, sender_socket, callback); }) } - fn ignore_unchannelled_chunk(_: &ChunkSocket, c: ReceivedChunk) { - log::debug!("ignoring unchannelled chunk: {:?}", c); + pub fn new(socket: Socket, buffer_allocator: Arc) -> Arc { + Self::with_callback(socket, buffer_allocator, |_, _| {}) + } + + pub fn with_sender_socket(socket: Socket, sender_socket: Arc) -> Arc { + let result = Arc::new(Self { + inner: ChunkSocket::new(Arc::new(socket), sender_socket.inner.buffer_allocator().clone()), + channels: DashMap::default(), + }); + + Self::spawn_receiver_thread(result.clone(), sender_socket, |_, _| {}); + + result } - pub fn new( + pub fn with_callback( socket: Socket, buffer_allocator: Arc, - ) -> Result { - Self::with_unchannelled_handler(socket, buffer_allocator, Self::ignore_unchannelled_chunk) + callback: F, + ) -> Arc { + let result = Arc::new(Self { + inner: ChunkSocket::new(Arc::new(socket), buffer_allocator), + channels: DashMap::default(), + }); + + Self::spawn_receiver_thread(result.clone(), result.clone(), callback); + + result } - pub fn with_unchannelled_handler( + pub fn with_callback_and_sender_socket( socket: Socket, + sender_socket: Arc, buffer_allocator: Arc, - process_unchannelled_chunk: F, - ) -> Result { - let (channel_sender, channel_receiver) = crossbeam::channel::unbounded(); - - Self::spawn_receiver_thread( - ChunkSocket::new(socket.try_clone().unwrap(), buffer_allocator.clone()), - channel_receiver, - process_unchannelled_chunk, - ); - - Ok(Self { - socket: ChunkSocket::new(socket, buffer_allocator), - channel_sender, - }) + callback: F, + ) -> Arc { + let result = Arc::new(Self { + inner: ChunkSocket::new(Arc::new(socket), buffer_allocator), + channels: DashMap::default(), + }); + + Self::spawn_receiver_thread(result.clone(), sender_socket, callback); + + result } - pub fn send_chunk(&self, kind_data: &T) -> Result<(), std::io::Error> { - self.socket.send_chunk(kind_data) + #[inline] + pub fn buffer_allocator(&self) -> &Arc { + self.inner.buffer_allocator() + } + + #[inline] + pub fn send_chunk_buffer_to( + &self, + buffer: &ChunkBuffer, + packet_size: usize, + addr: &SockAddr, + ) -> Result<(), std::io::Error> { + self.inner.send_chunk_buffer_to(buffer, packet_size, addr) + } + + #[inline] + pub fn send_chunk(&self, kind_data: &T) -> Result<(), std::io::Error> { + self.inner.send_chunk(kind_data) } - pub fn send_chunk_with_payload( + #[inline] + pub fn send_chunk_to( + &self, + kind_data: &T, + addr: &SockAddr, + ) -> Result<(), std::io::Error> { + self.inner.send_chunk_to(kind_data, addr) + } + + #[inline] + pub fn send_chunk_with_payload( &self, kind_data: &T, payload: &[u8], ) -> Result<(), std::io::Error> { - self.socket.send_chunk_with_payload(kind_data, payload) - } - - /// Registers a channel listener for the given channel id. - /// - /// Any previous listener for the same channel id is replaced. - pub fn listen_to_channel(&self, channel_id: ChannelId) -> ChunkReceiver { - let (sender, receiver) = crossbeam::channel::unbounded(); - if self.channel_sender.send((channel_id, sender)).is_err() { - // The channel cannot be disconnected as the receiver thread which holds the receiver - // only exits when the client is dropped. - unreachable!(); - } - receiver + self.inner.send_chunk_with_payload(kind_data, payload) } - pub fn socket(&self) -> &ChunkSocket { - &self.socket + #[inline] + pub fn send_chunk_with_payload_to( + &self, + kind_data: &T, + payload: &[u8], + addr: &SockAddr, + ) -> Result<(), std::io::Error> { + self.inner + .send_chunk_with_payload_to(kind_data, payload, addr) + } + + pub fn allocate_channel( + self: &Arc, + channel_id: protocol::GroupId, + receive_capacity: Option, + ) -> Option { + let mut result = None; + + self.channels.entry(channel_id).or_insert_with(|| { + let (sender, receiver) = if let Some(capacity) = receive_capacity { + crossbeam::channel::bounded(capacity.get()) + } else { + crossbeam::channel::unbounded() + }; + result = Some(receiver); + sender + }); + + result.map(|receiver| Channel { + receiver, + channel_id, + socket: self.clone(), + }) } } @@ -228,7 +539,7 @@ pub enum TransmitAndWaitError { SendError(#[from] std::io::Error), } -pub fn transmit_and_wait Option>( +pub fn transmit_and_wait Option>( socket: &ChunkSocket, kind_data: &C, retransmit_timeout: std::time::Duration, @@ -253,7 +564,7 @@ pub fn transmit_and_wait Opt Err(TransmitAndWaitError::RecvError(RecvTimeoutError::Timeout)) } -pub fn transmit_to_and_wait bool>( +pub fn transmit_to_and_wait bool>( socket: &ChunkSocket, addr: &SocketAddr, kind_data: &T, diff --git a/src/protocol.rs b/src/protocol.rs index cb651de..a79e878 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,40 +1,107 @@ +//! This module describes the used protocol. +//! +//! These internals are not part of the semver guarantees of this crate and may +//! change at any time. All programs participating in a multicast session must +//! use the same version of this crate. +//! +//! A ['Chunk'] describes the layout of the UDP payload and always has the +//! following form: +//! ```test +//! +----+--------+---------+ +//! | ID | Header | Payload | +//! +----+--------+---------+ +//! ``` +//! The ID is a single byte that identifies the type of the chunk. The header is +//! content and length depends on the chunk type. Some chunks types require and +//! additional payload which is placed after the header. +use std::{fmt::{self, Display}, net::{Ipv6Addr, SocketAddr}}; + use zerocopy::{byteorder::network_endian::*, AsBytes, FromBytes, FromZeroes, Unaligned}; -pub type SequenceNumber = U16; -pub type ChannelId = U16; +#[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned, Clone, Copy, PartialEq, Eq, Default)] +#[repr(transparent)] +pub struct SequenceNumber(U16); + +impl From for SequenceNumber { + fn from(val: u16) -> Self { + Self(val.into()) + } +} + +impl Into for SequenceNumber { + fn into(self) -> u16 { + self.0.into() + } +} -pub type ChunkKind = u8; +impl std::ops::Sub for SequenceNumber { + type Output = usize; -pub mod kind { - use super::ChunkKind; + fn sub(self, rhs: Self) -> Self::Output { + let val: u16 = self.0.into(); + val.wrapping_sub(rhs.0.into()).into() + } +} - pub const CONNECT: ChunkKind = 0; - pub const CONNECTION_INFO: ChunkKind = 1; - pub const JOIN_CHANNEL: ChunkKind = 2; - pub const CONFIRM_JOIN_CHANNEL: ChunkKind = 3; - pub const ACK: ChunkKind = 4; - pub const MESSAGE: ChunkKind = 5; - // pub const MESSAGE_FRAGMENT: ChunkKind = 6; - // pub const FINAL_MESSAGE_FRAGMENT: ChunkKind = 7; - pub const JOIN_BARRIER_GROUP: ChunkKind = 8; - pub const BARRIER_REACHED: ChunkKind = 9; - pub const BARRIER_RELEASED: ChunkKind = 10; - pub const LEAVE_CHANNEL: ChunkKind = 11; - pub const CHANNEL_DISCONNECTED: ChunkKind = 12; +impl Display for SequenceNumber { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + ::fmt(&self.0.into(), f) + } } -pub const MESSAGE_PAYLOAD_OFFSET: usize = 1 + std::mem::size_of::(); +impl SequenceNumber { + pub fn next(self) -> Self { + let val: u16 = self.0.into(); + SequenceNumber(val.wrapping_add(1).into()) + } -pub trait ChunkKindData: AsBytes + FromBytes + FromZeroes + Unaligned { - fn kind() -> ChunkKind; + pub fn prev(self) -> Self { + let val: u16 = self.0.into(); + SequenceNumber(val.wrapping_sub(1).into()) + } } -macro_rules! impl_chunk_data { - ($kind:ident) => { +pub type GroupId = U16; +pub type GroupType = u8; +pub const GROUP_TYPE_BROADCAST: GroupType = 0; +pub const GROUP_TYPE_BARRIER: GroupType = 1; + +pub type ChunkIdentifier = u8; + +// Session management +pub const CHUNK_ID_SESSION_JOIN: ChunkIdentifier = 0; +pub const CHUNK_ID_SESSION_WELCOME: ChunkIdentifier = 1; +pub const CHUNK_ID_SESSION_HEARTBEAT: ChunkIdentifier = 2; + +// General channel management +pub const CHUNK_ID_GROUP_JOIN: ChunkIdentifier = 10; +pub const CHUNK_ID_GROUP_WELCOME: ChunkIdentifier = 11; +pub const CHUNK_ID_GROUP_ACK: ChunkIdentifier = 12; +pub const CHUNK_ID_GROUP_LEAVE: ChunkIdentifier = 13; +pub const CHUNK_ID_GROUP_DISCONNECTED: ChunkIdentifier = 14; + +// Broadcast group +pub const CHUNK_ID_BROADCAST_MESSAGE: ChunkIdentifier = 20; +pub const CHUNK_ID_BROADCAST_FIRST_MESSAGE_FRAGMENT: ChunkIdentifier = 21; +pub const CHUNK_ID_BROADCAST_MESSAGE_FRAGMENT: ChunkIdentifier = 22; +pub const CHUNK_ID_BROADCAST_FINAL_MESSAGE_FRAGMENT: ChunkIdentifier = 23; + +pub const CHUNK_ID_BARRIER_REACHED: ChunkIdentifier = 30; +pub const CHUNK_ID_BARRIER_RELEASED: ChunkIdentifier = 31; + +pub const WELCOME_PACKET_SIZE: usize = std::mem::size_of::() + 1; +pub const MESSAGE_PAYLOAD_OFFSET: usize = 1 + std::mem::size_of::(); + +pub trait ChunkHeader: AsBytes + FromBytes + FromZeroes + Unaligned { + fn id() -> ChunkIdentifier; +} + +macro_rules! impl_chunk_header { + ($header_type:ident) => { paste::paste! { - impl ChunkKindData for $kind { - fn kind() -> ChunkKind { - kind::[< $kind:snake:upper >] + impl ChunkHeader for $header_type { + fn id() -> ChunkIdentifier { + [< CHUNK_ID_ $header_type:snake:upper >] } } } @@ -43,74 +110,145 @@ macro_rules! impl_chunk_data { #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct Connect {} -impl_chunk_data!(Connect); +pub struct UnvalidatedSocketAddr { + pub addr: [u8; 16], + pub port: U16, + pub flow_info: U32, + pub scope_id: U32, +} + +impl From<&UnvalidatedSocketAddr> for SocketAddr { + fn from(value: &UnvalidatedSocketAddr) -> Self { + let addr: Ipv6Addr = value.addr.into(); + match addr.to_canonical() { + std::net::IpAddr::V4(ip) => { + SocketAddr::V4(std::net::SocketAddrV4::new(ip, value.port.into())) + } + std::net::IpAddr::V6(ip) => SocketAddr::V6(std::net::SocketAddrV6::new( + ip, + value.port.into(), + value.flow_info.into(), + value.scope_id.into(), + )), + } + } +} + +impl From for UnvalidatedSocketAddr { + fn from(addr: SocketAddr) -> Self { + match addr { + SocketAddr::V4(addr_v4) => Self { + addr: addr_v4.ip().to_ipv6_mapped().octets(), + port: addr_v4.port().into(), + flow_info: 0.into(), + scope_id: 0.into(), + }, + SocketAddr::V6(addr_v6) => Self { + addr: addr_v6.ip().octets(), + port: addr_v6.port().into(), + flow_info: addr_v6.flowinfo().into(), + scope_id: addr_v6.scope_id().into(), + }, + } + } +} #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct ConnectionInfo { - pub multicast_addr: [u8; 4], - pub multicast_port: U16, +pub struct SessionJoin; +impl_chunk_header!(SessionJoin); + +#[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] +#[repr(C)] +pub struct SessionWelcome { + pub multicast_addr: UnvalidatedSocketAddr, pub chunk_size: U16, } -impl_chunk_data!(ConnectionInfo); +impl_chunk_header!(SessionWelcome); + +#[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] +#[repr(C)] +pub struct SessionHeartbeat; +impl_chunk_header!(SessionHeartbeat); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct JoinChannel { - pub channel_id: ChannelId, +pub struct GroupJoin { + pub group_id: GroupId, + pub group_type: GroupType, } -impl_chunk_data!(JoinChannel); +impl_chunk_header!(GroupJoin); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct ChannelHeader { - pub channel_id: ChannelId, +pub struct GroupWelcome { + pub group_id: GroupId, pub seq: SequenceNumber, } +impl_chunk_header!(GroupWelcome); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct ConfirmJoinChannel { - pub header: ChannelHeader, +pub struct GroupAck { + pub group_id: GroupId, + pub seq: SequenceNumber, } -impl_chunk_data!(ConfirmJoinChannel); +impl_chunk_header!(GroupAck); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct Ack { - pub header: ChannelHeader, -} -impl_chunk_data!(Ack); +pub struct GroupLeave(pub GroupId); +impl_chunk_header!(GroupLeave); + +#[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] +#[repr(C)] +pub struct GroupDisconnected(pub GroupId); +impl_chunk_header!(GroupDisconnected); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct Message { - pub header: ChannelHeader, +pub struct BarrierReached { + pub group_id: GroupId, + pub seq: SequenceNumber, } -impl_chunk_data!(Message); +impl_chunk_header!(BarrierReached); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct JoinBarrierGroup(pub ChannelId); -impl_chunk_data!(JoinBarrierGroup); +pub struct BarrierReleased { + pub group_id: GroupId, + pub seq: SequenceNumber, +} +impl_chunk_header!(BarrierReleased); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct BarrierReached(pub ChannelHeader); -impl_chunk_data!(BarrierReached); +pub struct BroadcastMessage { + pub group_id: GroupId, + pub seq: SequenceNumber, +} +impl_chunk_header!(BroadcastMessage); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct BarrierReleased(pub ChannelHeader); -impl_chunk_data!(BarrierReleased); +pub struct BroadcastFirstMessageFragment { + pub group_id: GroupId, + pub seq: SequenceNumber, +} +impl_chunk_header!(BroadcastFirstMessageFragment); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct LeaveChannel(pub ChannelId); -impl_chunk_data!(LeaveChannel); +pub struct BroadcastMessageFragment { + pub group_id: GroupId, + pub seq: SequenceNumber, +} +impl_chunk_header!(BroadcastMessageFragment); #[derive(Debug, FromBytes, AsBytes, FromZeroes, Unaligned)] #[repr(C)] -pub struct ChannelDisconnected(pub ChannelId); -impl_chunk_data!(ChannelDisconnected); +pub struct BroadcastFinalMessageFragment { + pub group_id: GroupId, + pub seq: SequenceNumber, +} +impl_chunk_header!(BroadcastFinalMessageFragment); diff --git a/src/publisher.rs b/src/publisher.rs deleted file mode 100644 index fd19e6f..0000000 --- a/src/publisher.rs +++ /dev/null @@ -1,718 +0,0 @@ -use std::{collections::VecDeque, io::Write, net::SocketAddr, sync::Arc}; - -use ahash::HashSet; -use crossbeam::channel::RecvTimeoutError; -use dashmap::DashSet; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; - -use crate::{ - chunk::{Chunk, ChunkBuffer, ChunkBufferAllocator}, - chunk_socket::{ChunkSocket, ReceivedChunk}, - multiplex_socket::{ - transmit_to_and_wait, ChunkReceiver, MultiplexSocket, TransmitAndWaitError, - }, - protocol::{ - BarrierReleased, ChannelDisconnected, ChannelHeader, ConfirmJoinChannel, ConnectionInfo, - Message, MESSAGE_PAYLOAD_OFFSET, - }, - ChannelId, SequenceNumber, -}; - -const RETRANSMIT_MILLIS: u64 = 100; - -/// A chunk that has been sent but not yet acknowledged by all subscribers. -struct UnacknowledgedChunk { - sent_time: std::time::Instant, - buffer: ChunkBuffer, - packet_size: usize, - missing_acks: HashSet, - retransmit_count: usize, -} - -pub struct Offer { - socket: ChunkSocket, - offer_id: ChannelId, - - // The sequence number of the last message sent. - seq_sent: SequenceNumber, - - // The sequence number of the last message acknowledged by all subscribers. - seq_ack: SequenceNumber, - - // Chunks, that have been sent but not yet acknowledged by all subscribers. - // - // The first element corresponds to self.seq_ack + 1 and the last element corresponds to - // self.seq_sent. - unacknowledged_chunks: VecDeque>, - - used_offer_ids: Arc>, - receiver: ChunkReceiver, - new_clients: HashSet, - clients: HashSet, - buffer_allocator: Arc, - multicast_addr: SockAddr, -} - -impl Drop for Offer { - fn drop(&mut self) { - self.used_offer_ids.remove(&self.offer_id); - } -} - -impl Offer { - // fn unacknowledged_chunk_mut( - // &mut self, - // seq: SequenceNumber, - // ) -> Option<&mut UnacknowledgedChunk> { - // let index = seq.wrapping_sub(self.seq_ack) as usize - 1; - // self.unacknowledged_chunks - // .get_mut(index) - // .and_then(|c| c.as_mut()) - // } - - fn unacknowledged_chunks_count(&self) -> usize { - self.unacknowledged_chunks - .iter() - .filter(|c| c.is_some()) - .count() - } - - fn process_chunk(&mut self, chunk: ReceivedChunk) { - match chunk.validate() { - Ok(Chunk::JoinChannel(_)) => { - log::debug!("received join channel {}", self.offer_id); - self.new_clients.insert(chunk.addr().as_socket().unwrap()); - } - Ok(Chunk::Ack(ack)) => { - let ack_seq: u16 = ack.header.seq.into(); - let offset = ack_seq.wrapping_sub(self.seq_ack).wrapping_sub(1); - log::debug!("received ack from {}", chunk.addr().as_socket().unwrap()); - log::debug!("received ack: {ack_seq} ({offset}) {}", self.seq_sent); - if offset > u16::MAX / 2 { - // this ack is probably from the past - } else if let Some(Some(c)) = self.unacknowledged_chunks.get_mut(offset as usize) { - log::debug!("removing ack from {:?}", c.missing_acks); - c.missing_acks.remove(&chunk.addr().as_socket().unwrap()); - } - } - Ok(chunk) => { - log::debug!("ignore unexpected chunk: {:?}", chunk); - } - Err(err) => { - log::error!("received invalid chunk: {}", err); - } - } - } - - fn process_pending_chunks(&mut self) { - while let Ok(chunk) = self.receiver.try_recv() { - self.process_chunk(chunk); - } - } - - fn wait_for_chunk(&mut self) { - if let Ok(chunk) = self.receiver.recv() { - self.process_chunk(chunk); - } - } - - fn wait_for_chunk_timeout(&mut self, timeout: std::time::Duration) { - if let Ok(chunk) = self.receiver.recv_timeout(timeout) { - self.process_chunk(chunk); - } - } - - fn process_unacknlowedged_chunks(&mut self) { - // Retransmit chunks, check for disconnects, and remove acknowledged chunks - let mut clients_to_remove = HashSet::default(); - for (offset, chunk) in self.unacknowledged_chunks.iter_mut().enumerate() { - if let Some(c) = chunk { - if c.missing_acks.is_empty() { - chunk.take(); - continue; - } - - let millis = (1 << c.retransmit_count) * RETRANSMIT_MILLIS; - if c.sent_time.elapsed().as_millis() > millis.into() { - if c.retransmit_count < 5 { - log::debug!( - "retransmitting chunk: {}", - self.seq_ack.wrapping_add(offset as u16).wrapping_add(1) - ); - self.socket - .send_chunk_buffer_to(&c.buffer, c.packet_size, &self.multicast_addr) - .unwrap(); - - // TODO: should we reset the sent time? - c.sent_time = std::time::Instant::now(); - c.retransmit_count += 1; - } else { - log::debug!( - "time out for chunk {} and subscribers: {:?}", - offset, - c.missing_acks - ); - clients_to_remove.extend(c.missing_acks.iter().cloned()); - } - } - } - } - if !clients_to_remove.is_empty() { - log::warn!("clients timed out: {:?}", clients_to_remove); - self.clients.retain(|c| !clients_to_remove.contains(c)); - - for c in &mut self.unacknowledged_chunks.iter_mut().flatten() { - c.missing_acks.retain(|a| !clients_to_remove.contains(a)); - } - } - - // Remove acknowledged chunks from the front - while let Some(None) = self.unacknowledged_chunks.front() { - log::debug!("removing acknowledged chunk"); - self.unacknowledged_chunks.pop_front(); - self.seq_ack = self.seq_ack.wrapping_add(1); - } - } - - fn process(&mut self) { - // Process pending acks etc - self.process_pending_chunks(); - - // Retransmit chunks - self.process_unacknlowedged_chunks(); - } - - fn process_blocking(&mut self) { - self.wait_for_chunk(); - self.process(); - } - - pub fn id(&self) -> ChannelId { - self.offer_id - } - - pub fn has_subscribers(&self) -> bool { - !self.clients.is_empty() - } - - pub fn accept(&mut self) -> Option { - self.process(); - - if let Some(client) = self - .new_clients - .iter() - .next() - .cloned() - .and_then(|q| self.new_clients.take(&q)) - { - let mut retries = 0; - - 'outer: while retries < 5 { - self.socket - .send_chunk_to( - &ConfirmJoinChannel { - header: ChannelHeader { - channel_id: self.offer_id.into(), - seq: self.seq_sent.into(), - }, - }, - &client.into(), - ) - .unwrap(); - - let start = std::time::Instant::now(); - - // Wait for ack - while start.elapsed().as_secs() < 1 { - if let Ok(chunk) = self.receiver.try_recv() { - match chunk.validate() { - Ok(Chunk::Ack(ack)) => { - if >::into( - ack.header.seq, - ) == self.seq_sent - && chunk.addr() == &client.into() - { - self.clients.insert(client); - break 'outer; - } - } - Ok(_) => { - self.process_chunk(chunk); - } - Err(err) => { - log::error!("received invalid chunk: {}", err); - } - } - } - } - - retries += 1; - log::debug!("retrying join channel"); - } - - Some(client) - } else { - None - } - } - - fn send_ack_chunk_buffer( - &mut self, - chunk: ChunkBuffer, - packet_size: usize, - ) -> Result<(), std::io::Error> { - self.process(); - - while self.unacknowledged_chunks_count() > 100 { - log::debug!("too many unacknowledged chunks, blockin!"); - self.wait_for_chunk_timeout(std::time::Duration::from_millis(RETRANSMIT_MILLIS)); - self.process(); - } - - self.socket - .send_chunk_buffer_to(&chunk, packet_size, &self.multicast_addr)?; - // TODO: verify that the ack in the chunk is the same as self.seq_sent + 1 - self.unacknowledged_chunks - .push_back(Some(UnacknowledgedChunk { - sent_time: std::time::Instant::now(), - buffer: chunk, - packet_size, - missing_acks: self.clients.clone(), - retransmit_count: 0, - })); - Ok(()) - } - - pub fn write_message(&mut self) -> MessageWriter { - MessageWriter { - offer: self, - buffer: None, - cursor: MESSAGE_PAYLOAD_OFFSET, - } - } - - pub fn flush(&mut self) { - while self.seq_ack < self.seq_sent { - self.process_blocking(); - } - } -} - -pub struct MessageWriter<'a> { - offer: &'a mut Offer, - buffer: Option, - cursor: usize, -} - -impl Drop for MessageWriter<'_> { - fn drop(&mut self) { - // Ignore errors - let _ = self.flush(); - } -} - -impl Write for MessageWriter<'_> { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut src_bytes = buf; - - while !src_bytes.is_empty() { - let mut buffer = self.buffer.take().unwrap_or_else(|| { - let mut buffer = self.offer.buffer_allocator.allocate(); - let seq = self.offer.seq_sent.wrapping_add(1); - buffer.init::(&Message { - header: ChannelHeader { - channel_id: self.offer.offer_id.into(), - seq: seq.into(), - }, - }); - self.offer.seq_sent = seq; - buffer - }); - - let remaining_buffer = &mut buffer[self.cursor..]; - let len = remaining_buffer.len().min(src_bytes.len()); - remaining_buffer[..len].copy_from_slice(&src_bytes[..len]); - self.cursor += len; - - if self.cursor == buffer.len() { - self.offer.send_ack_chunk_buffer(buffer, self.cursor)?; - self.cursor = MESSAGE_PAYLOAD_OFFSET; - } else { - self.buffer = Some(buffer); - } - src_bytes = &src_bytes[len..]; - } - - Ok(buf.len()) - } - - fn flush(&mut self) -> std::io::Result<()> { - if let Some(buffer) = self.buffer.take() { - self.offer.send_ack_chunk_buffer(buffer, self.cursor)?; - self.cursor = MESSAGE_PAYLOAD_OFFSET; - } - Ok(()) - } -} - -pub struct BarrierGroupDesc { - /// The amount of time to wait before retransmitting the barrier release message. - pub retransmit_timeout: std::time::Duration, - - /// The number of times the barrier release message is retransmitted, before closing the - /// connection to clients that have not acknowledged the release. - pub retransmit_count: usize, -} - -#[derive(Debug, Default)] -struct BarrierGroupState { - new_clients: HashSet, - clients: HashSet, - arrived: HashSet, - seq: SequenceNumber, -} - -impl BarrierGroupState { - /// Returns true if the client is connected to the barrier group and false otherwise. - fn client_reached_barrier(&mut self, client: SocketAddr, seq: SequenceNumber) -> bool { - if self.clients.contains(&client) { - if self.seq == seq { - self.arrived.insert(client); - } - true - } else { - false - } - } - - /// Returns if all remotes have arrived at the barrier. - fn all_remotes_arrived(&self) -> bool { - debug_assert!(self.clients.len() != self.arrived.len() || self.clients == self.arrived); - self.clients.len() == self.arrived.len() - } - - /// Processes a single chunk - fn process_chunk(&mut self, chunk: Chunk, addr: SocketAddr) -> bool { - match chunk { - Chunk::JoinBarrierGroup(_) => { - self.new_clients.insert(addr); - true - } - Chunk::BarrierReached(reached) => { - if self.clients.contains(&addr) { - self.client_reached_barrier(addr, reached.0.seq.into()); - true - } else { - log::warn!("received barrier reached from non-client"); - false - } - } - Chunk::LeaveChannel(_) => { - if self.clients.contains(&addr) { - self.clients.remove(&addr); - self.arrived.remove(&addr); - } - false - } - _ => { - log::warn!("received invalid chunk: {chunk:?}"); - self.clients.contains(&addr) - } - } - } -} - -pub struct BarrierGroup { - channel_id: ChannelId, - desc: BarrierGroupDesc, - state: BarrierGroupState, - receiver: ChunkReceiver, - socket: ChunkSocket, - multicast_addr: SocketAddr, -} - -impl BarrierGroup { - fn try_process(&mut self) -> bool { - let mut processed = false; - while let Ok(chunk) = self.receiver.try_recv() { - if let (Ok(chunk), Some(addr)) = (chunk.validate(), chunk.addr().as_socket()) { - if !self.state.process_chunk(chunk, addr) { - let _ = self - .socket - .send_chunk_to(&ChannelDisconnected(self.channel_id.into()), &addr.into()); - } - processed = true; - } - } - processed - } - - fn process(&mut self) { - if let Ok(chunk) = self.receiver.recv() { - if let (Ok(chunk), Some(addr)) = (chunk.validate(), chunk.addr().as_socket()) { - if !self.state.process_chunk(chunk, addr) { - let _ = self - .socket - .send_chunk_to(&ChannelDisconnected(self.channel_id.into()), &addr.into()); - } - } - } - self.try_process(); - } - - pub fn accept_client(&mut self, client: SocketAddr) -> Result<(), TransmitAndWaitError> { - transmit_to_and_wait( - &self.socket, - &client, - &ConfirmJoinChannel { - header: ChannelHeader { - channel_id: self.channel_id.into(), - seq: self.state.seq.into(), - }, - }, - self.desc.retransmit_timeout, - self.desc.retransmit_count, - &self.receiver, - |chunk, addr| { - if let Chunk::Ack(ack) = chunk { - let ack_seq: u16 = ack.header.seq.into(); - if ack_seq == self.state.seq && addr == client { - log::debug!("client {} joined barrier group", client); - self.state.clients.insert(addr); - return true; - } - } else { - self.state.process_chunk(chunk, addr); - } - false - }, - ) - } - - pub fn try_accept(&mut self) -> Result { - self.try_process(); - - if let Some(client) = self - .state - .new_clients - .iter() - .next() - .copied() - .and_then(|q| self.state.new_clients.take(&q)) - { - log::debug!("accepting client {}", client); - self.accept_client(client)?; - Ok(client) - } else { - Err(TransmitAndWaitError::RecvError(RecvTimeoutError::Timeout)) - } - } - - pub fn has_remotes(&self) -> bool { - !self.state.clients.is_empty() - } - - pub fn try_wait(&mut self) -> bool { - self.try_process(); - if self.state.all_remotes_arrived() { - self.wait(); - true - } else { - false - } - } - - pub fn wait(&mut self) { - // Wait until everyone has arrived - if !self.state.all_remotes_arrived() { - self.process(); - } - - // Release the barrier - let release_seq = self.state.seq.into(); - // Alread increment the seq here, to allow remotes to allready confirm the next barrier - // while we are waiting for the acks. - self.state.seq = self.state.seq.wrapping_add(1); - let mut missing_acks = self.state.arrived.clone(); - self.state.arrived.clear(); - - let release_time = std::time::Instant::now(); - - let _ = transmit_to_and_wait( - &self.socket, - &self.multicast_addr, - &BarrierReleased(ChannelHeader { - channel_id: self.channel_id.into(), - seq: release_seq, - }), - self.desc.retransmit_timeout, - self.desc.retransmit_count, - &self.receiver, - |chunk, addr| { - match chunk { - Chunk::Ack(ack) => { - if self.state.clients.contains(&addr) { - if release_seq == ack.header.seq { - missing_acks.remove(&addr); - } - } else { - log::warn!("received ack from non-client"); - let _ = self.socket.send_chunk_to( - &ChannelDisconnected(self.channel_id.into()), - &addr.into(), - ); - } - } - Chunk::BarrierReached(reached) => { - if self.state.clients.contains(&addr) { - let reached_seq: u16 = reached.0.seq.into(); - if reached_seq == self.state.seq { - missing_acks.remove(&addr); - self.state.arrived.insert(addr); - } - } else { - log::warn!("received barrier reached from non-client"); - let _ = self.socket.send_chunk_to( - &ChannelDisconnected(self.channel_id.into()), - &addr.into(), - ); - } - } - _ => { - if !self.state.process_chunk(chunk, addr) { - missing_acks.remove(&addr); - let _ = self.socket.send_chunk_to( - &ChannelDisconnected(self.channel_id.into()), - &addr.into(), - ); - } - } - } - - missing_acks.is_empty() - }, - ); - log::debug!( - "barrier released confirmation time: {:?}", - release_time.elapsed() - ); - - if !missing_acks.is_empty() { - log::warn!("clients timed out: {:?}", missing_acks); - for c in &missing_acks { - let _ = self - .socket - .send_chunk_to(&ChannelDisconnected(self.channel_id.into()), &(*c).into()); - self.state.clients.remove(c); - debug_assert!(!self.state.arrived.contains(c)); - } - } - } -} - -pub struct PublisherConfig { - pub addr: std::net::SocketAddrV4, - pub multicast_addr: std::net::SocketAddrV4, - pub chunk_size: u16, -} - -struct ClientConnection { - addr: std::net::SocketAddr, -} - -pub struct Publisher { - used_channel_ids: Arc>, - socket: MultiplexSocket, - multicast_addr: std::net::SocketAddr, - buffer_allocator: Arc, -} - -#[derive(thiserror::Error, Debug)] -pub enum CreateChannelError { - #[error("offer limit reached")] - ChannelLimitReached, - - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), -} - -impl Publisher { - pub fn new(config: PublisherConfig) -> Self { - let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap(); - socket.bind(&config.addr.into()).unwrap(); - log::debug!("bound to {}", config.addr); - let buffer_allocator = Arc::new(ChunkBufferAllocator::new(config.chunk_size.into())); - // let multicast_addr = SocketAddr::V4(config.multicast_addr); - - let handle_unchannelled = move |socket: &ChunkSocket, chunk: ReceivedChunk| { - if let Ok(Chunk::Connect(_)) = chunk.validate() { - if let Err(err) = socket.send_chunk_to( - &ConnectionInfo { - chunk_size: config.chunk_size.into(), - multicast_addr: config.multicast_addr.ip().octets(), - multicast_port: config.multicast_addr.port().into(), - }, - chunk.addr(), - ) { - log::error!("failed to send connection info: {}", err); - } - } - }; - - Publisher { - used_channel_ids: Arc::new(DashSet::new()), - socket: MultiplexSocket::with_unchannelled_handler( - socket, - buffer_allocator.clone(), - handle_unchannelled, - ) - .unwrap(), - buffer_allocator, - multicast_addr: config.multicast_addr.into(), - } - } - - pub fn create_barrier_group( - &self, - desc: BarrierGroupDesc, - ) -> Result { - for offer_id in 0..=ChannelId::MAX { - if self.used_channel_ids.insert(offer_id) { - let receiver = self.socket.listen_to_channel(offer_id); - return Ok(BarrierGroup { - channel_id: offer_id, - state: BarrierGroupState::default(), - receiver, - desc, - socket: self.socket.socket().try_clone()?, - multicast_addr: self.multicast_addr, - }); - } - } - - Err(CreateChannelError::ChannelLimitReached) - } - - pub fn create_offer(&self) -> Result { - for offer_id in 0..=ChannelId::MAX { - if self.used_channel_ids.insert(offer_id) { - let receiver = self.socket.listen_to_channel(offer_id); - return Ok(Offer { - socket: self.socket.socket().try_clone()?, - offer_id, - seq_sent: 0, - seq_ack: 0, - unacknowledged_chunks: VecDeque::new(), - used_offer_ids: self.used_channel_ids.clone(), - receiver, - new_clients: HashSet::default(), - clients: HashSet::default(), - buffer_allocator: self.buffer_allocator.clone(), - multicast_addr: self.multicast_addr.into(), - }); - } - } - - Err(CreateChannelError::ChannelLimitReached) - } -} diff --git a/src/session.rs b/src/session.rs new file mode 100644 index 0000000..1da281e --- /dev/null +++ b/src/session.rs @@ -0,0 +1,562 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + num::NonZeroUsize, + sync::Arc, +}; + +use ahash::HashSet; +use dashmap::DashMap; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; +use zerocopy::FromBytes; + +use crate::{ + barrier::{ + BarrierGroupCoordinator, BarrierGroupCoordinatorState, BarrierGroupMember, + BarrierGroupMemberState, + }, broadcast::{BroadcastGroupReceiver, BroadcastGroupReceiverState, BroadcastGroupSender, BroadcastGroupSenderState}, chunk::{Chunk, ChunkBufferAllocator}, group::{GroupCoordinator, GroupCoordinatorTypeImpl, GroupMember, GroupMemberTypeImpl}, multiplex_socket::{Callback, CallbackReason, MultiplexSocket}, protocol::{ + self, GroupType, SessionHeartbeat, SessionJoin, SessionWelcome, CHUNK_ID_SESSION_WELCOME, + WELCOME_PACKET_SIZE, + }, utils::ExponentialBackoff +}; + +pub(crate) struct MemberVitals { + timeout: std::time::Duration, + received_heartbeats: DashMap, +} + +impl MemberVitals { + pub fn new(timeout: std::time::Duration) -> Self { + Self { + timeout, + received_heartbeats: DashMap::new(), + } + } + + pub fn update_heartbeat(&self, addr: SockAddr) { + self.received_heartbeats + .insert(addr, std::time::Instant::now()); + } + + pub fn is_alive(&self, addr: &SockAddr) -> bool { + self.received_heartbeats + .get(addr) + .map(|instant| instant.elapsed() < self.timeout) + .unwrap_or(false) + } +} + +/// Indicates an error within the multicast coordinator configuration. +#[derive(thiserror::Error, Debug)] +pub enum ConfigError { + /// The chunk size is too small. + /// + /// The minimum chunk size is 508 bytes. + #[error("Chunk size too small (minimum is 508)")] + ChunkSizeTooSmall, + + /// The heartbeat interval is zero. + #[error("Heartbeat interval must be greater than zero")] + HeartbeatIntervalZero, + + /// The client timeout is zero. + #[error("Client timeout must be greater than zero")] + ClientTimeoutZero, + + /// The client timeout is less than the heartbeat interval. + #[error("Client timeout must be greater than the heartbeat interval")] + ClientTimeoutLessThanHeartbeat, +} + +/// Additional configuration for the multicast coordinator. +pub struct CoordinatorConfig { + /// The chunk size represents the maximum payload size of a socket. + /// + /// To avoid fragmentation, the chunk size plus the size of the protocol + /// header should not exceed the maximum transmission unit (MTU) of the + /// network. So, for a typical MTU of 1500, this should be set to a + /// maximum of 1472. + /// + /// This value must be at least 508 as this defines the maximum size of a + /// packet that is deliverable over the internet. + /// + /// The default is 1472. + pub chunk_size: u16, + + /// The interval at which the clients should send heartbeats. + /// + /// This must be less than the client timeout. + /// + /// The default is 1 second. + pub heartbeat_interval: std::time::Duration, + + /// The time to wait for a client to send a heartbeat before considering it + /// disconnected. + /// + /// The default is 5 seconds. + pub client_timeout: std::time::Duration, +} + +impl CoordinatorConfig { + /// Validates the configuration. + /// + /// See [`MulticastServerConfig`] and [`ConfigError`] for details. + pub fn validate(&self) -> Result<(), ConfigError> { + if self.chunk_size < 508 { + return Err(ConfigError::ChunkSizeTooSmall); + } + + if self.heartbeat_interval.is_zero() { + return Err(ConfigError::HeartbeatIntervalZero); + } + + if self.client_timeout.is_zero() { + return Err(ConfigError::ClientTimeoutZero); + } + + if self.client_timeout <= self.heartbeat_interval { + return Err(ConfigError::ClientTimeoutLessThanHeartbeat); + } + Ok(()) + } +} + +impl Default for CoordinatorConfig { + fn default() -> Self { + Self { + chunk_size: 1472, + heartbeat_interval: std::time::Duration::from_secs(1), + client_timeout: std::time::Duration::from_secs(5), + } + } +} + +/// The identifier of a channel. +pub type GroupId = u16; + +/// Indicates an error during channel creation. +#[derive(thiserror::Error, Debug)] +pub enum GroupCreateError { + /// The channel ID is already in use. + #[error("Group ID {0} is already in use")] + GroupIdInUse(GroupId), + + #[error("Group IDs exhausted")] + GroupIdsExhausted, +} + +/// Indicates an error when initializing the multicast session. +#[derive(thiserror::Error, Debug)] +pub enum StartSessionError { + /// The configuration is invalid. + /// + /// See [`ConfigError`] for details. + #[error("Configuration error: {0}")] + ConfigError(#[from] ConfigError), + + /// The multicast address is not a valid multicast address. + /// + /// See [`MulticastServer::bind`] for details. + #[error("Invalid multicast address: {0}")] + NotAMulticastAddress(SocketAddr), + + /// Failed to create the socket. + #[error("Failed to create socket: {0}")] + SocketCreateError(std::io::Error), + + /// Failed to bind the socket. + #[error("Failed to bind socket: {0}")] + SocketBindError(std::io::Error), +} + +/// The server instance for multicast communication. +pub struct Coordinator { + client_vitals: Arc, + socket: Arc, + multicast_address: SockAddr, +} + +impl Coordinator { + /// Starts a session using the given bind address and multicast address. + /// + /// The bind address is the address that members need to join the session + /// (using [`Member::join_session`]). The multicast address is used for + /// communication. For IPv4, this must be in the range of `224.0.0.0/4`, + /// i.e., the most significant octet must be in the range of 224 to 239 as + /// described in [IETF RFC 5771](https://tools.ietf.org/html/rfc5771). For + /// IPv6, it must be in the range of `ff00::/8` as described in [IETF RFC + /// 4291](https://datatracker.ietf.org/doc/html/rfc4291). + /// + /// See [`Self::start_session_with_config`] for additional configuration + /// options. + pub fn start_session( + bind_address: SocketAddr, + multicast_address: SocketAddr, + ) -> Result { + Self::start_session_with_config( + bind_address, + multicast_address, + CoordinatorConfig::default(), + ) + } + + /// Binds the server to the given address and multicast address with + /// additional configuration. + /// + /// See [`MulticastServerConfig`] for the available configuration options + /// and [`Self::start_session`] for more information regarding the address + /// parameters. + pub fn start_session_with_config( + bind_address: SocketAddr, + multicast_address: SocketAddr, + config: CoordinatorConfig, + ) -> Result { + config.validate()?; + + if !multicast_address.ip().is_multicast() { + return Err(StartSessionError::NotAMulticastAddress(multicast_address)); + } + + let client_vitals = Arc::new(MemberVitals::new(config.client_timeout)); + + let chunk_allocator = Arc::new(ChunkBufferAllocator::new(config.chunk_size.into())); + + let socket = Socket::new( + Domain::for_address(bind_address), + Type::DGRAM, + Some(Protocol::UDP), + ) + .map_err(StartSessionError::SocketCreateError)?; + + socket + .bind(&bind_address.into()) + .map_err(StartSessionError::SocketBindError)?; + + let socket = MultiplexSocket::with_callback( + socket, + chunk_allocator, + Self::create_callback(multicast_address, config.chunk_size, client_vitals.clone()), + ); + + Ok(Self { + client_vitals, + socket, + multicast_address: multicast_address.into(), + }) + } + + /// Create the callback for unhandled chunks + fn create_callback( + multicast_address: SocketAddr, + chunk_size: u16, + member_vitals: Arc, + ) -> impl Callback { + move |socket, reason| match reason { + CallbackReason::UnhandledChunk { + addr, + chunk: Chunk::SessionHeartbeat(SessionHeartbeat), + } => { + member_vitals.update_heartbeat(addr.clone()); + } + CallbackReason::UnhandledChunk { + addr, + chunk: Chunk::SessionJoin(SessionJoin), + } => { + // If the message is not sent, the member will not join the session which is + // probably what we want + if socket + .send_chunk_to( + &SessionWelcome { + multicast_addr: multicast_address.into(), + chunk_size: chunk_size.into(), + }, + addr, + ) + .is_ok() + { + member_vitals.update_heartbeat(addr.clone()); + } + } + _ => {} + } + } + + fn create_group( + &self, + desired_channel_id: Option, + receive_capacity: Option, + inner: I, + ) -> Result, GroupCreateError> { + let channel = if let Some(channel_id) = desired_channel_id { + self.socket + .allocate_channel(channel_id.into(), receive_capacity) + .ok_or(GroupCreateError::GroupIdInUse(channel_id))? + } else { + let mut channel = None; + for channel_id in 1..=GroupId::MAX { + if let Some(c) = self + .socket + .allocate_channel(channel_id.into(), receive_capacity) + { + channel = Some(c); + break; + } + } + channel.ok_or(GroupCreateError::GroupIdsExhausted)? + }; + + Ok(GroupCoordinator { + channel, + state: Default::default(), + inner, + multicast_addr: self.multicast_address.clone(), + session_members: self.client_vitals.clone(), + }) + } + + pub fn create_barrier_group( + &self, + desired_channel_id: Option, + ) -> Result { + let group = self.create_group( + desired_channel_id, + Some(1024.try_into().unwrap()), + BarrierGroupCoordinatorState::default(), + )?; + Ok(BarrierGroupCoordinator { group }) + } + + pub fn create_broadcast_group( + &self, + desired_channel_id: Option, + ) -> Result { + let group = self.create_group( + desired_channel_id, + Some(1024.try_into().unwrap()), + BroadcastGroupSenderState::default(), + )?; + Ok(BroadcastGroupSender::new(group)) + } +} + +/// Indicates an error when joining a multicast session. +#[derive(thiserror::Error, Debug)] +pub enum JoinSessionError { + /// Failed to create the socket. + #[error("Failed to create socket: {0}")] + SocketCreateError(std::io::Error), + + /// Failed to bind multicast socket. + #[error("Failed to bind multicast socket: {0}")] + SocketBindError(std::io::Error), + + /// Failed to connect to the coordinator address. + #[error("Failed to bind socket: {0}")] + SocketConnectError(std::io::Error), + + /// Failed to send join session request. + #[error("Failed to send join session request: {0}")] + SendError(std::io::Error), + + /// Failed to receive welcome packet. + #[error("Failed to receive welcome packet: {0}")] + RecvError(std::io::Error), + + /// Invalid response from coordinator. + #[error("Invalid response from coordinator")] + InvalidResponse, +} + +/// Indicates an error when joining a group. +#[derive(thiserror::Error, Debug)] +pub enum JoinGroupError { + /// THe group was already joined. + #[error("Group ID {0} is already joined")] + AlreadyJoined(GroupId), + + #[error("Group ID {0} is not joined")] + IoError(#[from] std::io::Error), +} + +/// A member of a multicast session. +pub struct Member { + coordinator_socket: Arc, + multicast_socket: Arc, + coordinator_address: SockAddr, +} + +impl Member { + fn receive_welcome( + coordinator_socket: &Socket, + coordinator_address: &SockAddr, + ) -> Result<[u8; WELCOME_PACKET_SIZE], JoinSessionError> { + for deadline in ExponentialBackoff::new() { + coordinator_socket + .send_to(&[0], &coordinator_address) + .map_err(JoinSessionError::SendError)?; + + coordinator_socket + .set_read_timeout(Some(deadline - std::time::Instant::now())) + .map_err(JoinSessionError::RecvError)?; + + // TODO: use [`std::mem::MaybeUninit::uninit_array`] once it is stable + let mut buffer = [std::mem::MaybeUninit::::uninit(); WELCOME_PACKET_SIZE]; + let packet_size = match coordinator_socket.recv(&mut buffer) { + Ok(packet_size) => packet_size, + Err(err) + if err.kind() == std::io::ErrorKind::WouldBlock + || err.kind() == std::io::ErrorKind::TimedOut => + { + continue + } + Err(err) => return Err(JoinSessionError::RecvError(err)), + }; + if packet_size != WELCOME_PACKET_SIZE { + return Err(JoinSessionError::InvalidResponse); + } + // TODO: use [`std::mem::MaybeUninit::array_assume_init`] once it is stable + let packet = unsafe { std::mem::transmute::<_, &[u8; WELCOME_PACKET_SIZE]>(&buffer) }; + if packet[0] != CHUNK_ID_SESSION_WELCOME { + return Err(JoinSessionError::InvalidResponse); + } else { + return Ok(*packet); + } + } + unreachable!() + } + + /// Attempts to join a multicast session. + /// + /// The `addr` parameter corresponds to the `bind_addr` parameter of + /// [`Coordinator::start_session`]. + pub fn join_session(addr: SocketAddr) -> Result { + let coordinator_socket = + Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP)) + .map_err(JoinSessionError::SocketCreateError)?; + + let coordinator_address = addr.into(); + + // Don't connect, otherwise sent_to will not work + // coordinator_socket + // .connect(&addr.into()) + // .map_err(JoinSessionError::SocketConnectError)?; + + let packet = Self::receive_welcome(&coordinator_socket, &coordinator_address)?; + + let welcome = match SessionWelcome::ref_from(&packet[1..]) { + Some(welcome) => welcome, + + // The alignment of `Welcome` is one and the size matches, so this should never happen + None => unreachable!(), + }; + + coordinator_socket + .set_read_timeout(None) + .map_err(JoinSessionError::RecvError)?; + + let multicast_addr: SocketAddr = (&welcome.multicast_addr) + .try_into() + .map_err(|_| JoinSessionError::InvalidResponse)?; + + if !multicast_addr.ip().is_multicast() { + return Err(JoinSessionError::InvalidResponse); + } + + let multicast_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .map_err(JoinSessionError::SocketCreateError)?; + multicast_socket + .set_reuse_address(true) + .map_err(JoinSessionError::SocketCreateError)?; + multicast_socket + .bind(&multicast_addr.into()) + .map_err(JoinSessionError::SocketBindError)?; + + match multicast_addr.ip() { + IpAddr::V4(ip) => { + multicast_socket + .join_multicast_v4(&ip.into(), &Ipv4Addr::UNSPECIFIED) + .map_err(JoinSessionError::SocketBindError)?; + } + IpAddr::V6(ip) => { + multicast_socket + .join_multicast_v6(&ip.into(), 0) + .map_err(JoinSessionError::SocketBindError)?; + } + } + + let buffer_allocator = Arc::new(ChunkBufferAllocator::new(welcome.chunk_size.into())); + + let coordinator_address_copy = coordinator_address.clone(); + let coordinator_socket = MultiplexSocket::with_callback( + coordinator_socket, + buffer_allocator.clone(), + move |socket, reason| match reason { + CallbackReason::Timeout => { + // We ran into a timeout, which is set to the heartbeat interval + // TODO: what to do in case of an error? + let _ = socket.send_chunk_to(&SessionHeartbeat, &coordinator_address_copy); + } + _ => {} + }, + ); + let multicast_socket = + MultiplexSocket::with_sender_socket(multicast_socket, coordinator_socket.clone()); + + Ok(Self { + coordinator_socket, + multicast_socket, + coordinator_address, + }) + } + + fn join_group( + &self, + channel_id: GroupId, + inner: I, + ) -> Result, JoinGroupError> { + let coordinator_channel = self + .coordinator_socket + .allocate_channel(channel_id.into(), Some(1024.try_into().unwrap())) + .ok_or(JoinGroupError::AlreadyJoined(channel_id))?; + + let multicast_channel = self + .multicast_socket + .allocate_channel(channel_id.into(), Some(1024.try_into().unwrap())) + .ok_or(JoinGroupError::AlreadyJoined(channel_id))?; + + GroupMember::join( + self.coordinator_address.clone(), + coordinator_channel, + multicast_channel, + inner, + ) + .map_err(JoinGroupError::IoError) + } + + pub fn join_barrier_group( + &self, + channel_id: GroupId, + ) -> Result { + let group = self.join_group(channel_id, BarrierGroupMemberState::default())?; + Ok(BarrierGroupMember { group }) + } + + pub fn join_broadcast_group( + &self, + channel_id: GroupId, + ) -> Result { + let group = self.join_group(channel_id, BroadcastGroupReceiverState::default())?; + Ok(BroadcastGroupReceiver { group }) + } +} + +#[test] +fn join_session() -> Result<(), Box> { + let port = crate::test::get_port(); + let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port); + let multicast_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)), port); + + let _coordinator = Coordinator::start_session(bind_addr, multicast_addr)?; + let _member = Member::join_session(bind_addr)?; + + Ok(()) +} diff --git a/src/subscriber.rs b/src/subscriber.rs deleted file mode 100644 index c502354..0000000 --- a/src/subscriber.rs +++ /dev/null @@ -1,395 +0,0 @@ -use std::{ - collections::VecDeque, - io::Read, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::Arc, - time::Duration, -}; - -use ahash::HashSet; -use crossbeam::channel::RecvTimeoutError; -use socket2::{Domain, Protocol, Socket, Type}; -use zerocopy::FromBytes; - -use crate::{ - chunk::{Chunk, ChunkBufferAllocator}, - chunk_socket::{ChunkSocket, ReceivedChunk}, - multiplex_socket::{transmit_and_wait, ChunkReceiver, MultiplexSocket}, - protocol::{ - kind, Ack, BarrierReached, ChannelHeader, ChunkKindData, ConnectionInfo, JoinBarrierGroup, - JoinChannel, MESSAGE_PAYLOAD_OFFSET, - }, - ChannelId, SequenceNumber, -}; - -#[derive(thiserror::Error, Debug)] -pub enum ConnectionError { - #[error("failed to connect to server")] - Io(#[from] std::io::Error), - - #[error("failed to receive connection info")] - ConnectionInfo, -} - -#[derive(thiserror::Error, Debug)] -pub enum JoinChannelError { - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - - #[error("operation timed out")] - Timeout, - - #[error("already joined channel")] - AlreadyJoined, -} - -pub enum Message { - SingleChunk(ReceivedChunk), -} - -impl Message { - pub fn read(&self) -> MessageReader { - match self { - Message::SingleChunk(chunk) => MessageReader::Single { - chunk, - cursor: MESSAGE_PAYLOAD_OFFSET, - }, - } - } -} - -pub enum MessageReader<'a> { - Single { - chunk: &'a ReceivedChunk, - cursor: usize, - }, -} - -impl Read for MessageReader<'_> { - fn read(&mut self, buf: &mut [u8]) -> Result { - match self { - MessageReader::Single { chunk, cursor } => { - let packet_size = chunk.packet_size(); - let buffer = chunk.buffer(); - let remaining_buffer = &buffer[*cursor..packet_size]; - let len = remaining_buffer.len().min(buf.len()); - buf[..len].copy_from_slice(&remaining_buffer[..len]); - *cursor += len; - Ok(len) - } - } - } -} - -pub struct Subscription { - control_receiver: ChunkReceiver, - multicast_receiver: ChunkReceiver, - buffer_allocator: Arc, - sequence: SequenceNumber, - control_socket: ChunkSocket, - - /// stores the chunks starting from the last received sequence number. - chunks: VecDeque>, -} - -#[derive(thiserror::Error, Debug)] -pub enum RecvError { - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - - #[error("Disconnected")] - Recv(#[from] crossbeam::channel::RecvError), -} - -impl Subscription { - pub fn recv(&mut self) -> Result { - loop { - let chunk = self.multicast_receiver.recv()?; - match chunk.validate() { - Ok(Chunk::Message(msg, _)) => { - self.control_socket.send_chunk(&Ack { - header: ChannelHeader { - seq: msg.header.seq, - channel_id: msg.header.channel_id, - }, - })?; - - let seq: u16 = msg.header.seq.into(); - let offset = seq.wrapping_sub(self.sequence.wrapping_add(1)); - if offset > u16::MAX / 2 { - // This is most likely an old packet, just ignore it - } else if seq != self.sequence.wrapping_add(1) { - panic!( - "unexpected sequence number: expected {}, got {}", - self.sequence.wrapping_add(1), - seq - ); - } else { - log::debug!("received message: {:?}", msg); - self.sequence = seq; - return Ok(Message::SingleChunk(chunk)); - } - } - Ok(chunk) => { - log::debug!("ignore unexpected chunk: {:?}", chunk); - } - Err(err) => { - log::error!("received invalid chunk: {}", err); - } - } - } - } -} - -#[derive(Debug)] -pub struct Disconnected; - -const CONNECTION_INFO_PACKET_SIZE: usize = std::mem::size_of::() + 1; - -pub enum Barrier { - Connected { - channel_id: ChannelId, - control_receiver: ChunkReceiver, - multicast_receiver: ChunkReceiver, - seq: SequenceNumber, - control_socket: ChunkSocket, - timeout: std::time::Duration, - }, - Disconnected, -} - -impl Barrier { - pub fn wait(&mut self) -> Result<(), Disconnected> { - match self { - Barrier::Connected { - channel_id, - control_receiver, - multicast_receiver, - seq, - control_socket, - timeout, - } => { - #[derive(Debug)] - enum Result { - BarrierReached, - Disconnected, - } - - let result = transmit_and_wait( - control_socket, - &BarrierReached(ChannelHeader { - seq: (*seq).into(), - channel_id: (*channel_id).into(), - }), - *timeout, - 5, - &[multicast_receiver, control_receiver], - |chunk, _| match chunk { - Chunk::BarrierReleased(release) => { - let _ = control_socket.send_chunk(&Ack { - header: ChannelHeader { - seq: release.0.seq, - channel_id: release.0.channel_id, - }, - }); - - let release_seq: u16 = release.0.seq.into(); - if release_seq == *seq { - Some(Result::BarrierReached) - } else { - None - } - } - Chunk::ChannelDisconnected(_) => Some(Result::Disconnected), - _ => None, - }, - ); - - match result { - Ok(Result::BarrierReached) => { - *seq = seq.wrapping_add(1); - Ok(()) - } - _ => { - log::error!("disconnected: {result:?}"); - *self = Barrier::Disconnected; - Err(Disconnected) - } - } - } - Barrier::Disconnected => Err(Disconnected), - } - } -} - -struct Channel { - control_receiver: ChunkReceiver, - multicast_receiver: ChunkReceiver, - seq: SequenceNumber, -} - -pub struct Subscriber { - joined_channels: HashSet, - control_socket: MultiplexSocket, - multicast_socket: MultiplexSocket, - buffer_allocator: Arc, - resend_timeout: Duration, - timeout: Duration, -} - -impl Subscriber { - pub fn connect(addr: SocketAddr) -> Result { - let control_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; - control_socket.connect(&addr.into())?; - - let multicast_socket; - let buffer_allocator; - - let mut buffer = [0; CONNECTION_INFO_PACKET_SIZE]; - let maybe_uninit_buffer = unsafe { - std::mem::transmute::<&mut [u8], &mut [std::mem::MaybeUninit]>(&mut buffer) - }; - - loop { - if let Ok(1) = control_socket.send(&[0]) { - match control_socket.recv(maybe_uninit_buffer) { - Ok(CONNECTION_INFO_PACKET_SIZE) => { - if buffer[0] != kind::CONNECTION_INFO { - log::error!("received invalid chunk: {:?}", buffer); - continue; - } - let conn_info = match ConnectionInfo::ref_from(&buffer[1..]) { - Some(conn_info) => conn_info, - None => unreachable!(), - }; - log::debug!("received connection info: {:?}", conn_info); - multicast_socket = - Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; - multicast_socket.set_reuse_address(true)?; - multicast_socket.bind( - &SocketAddrV4::new( - conn_info.multicast_addr.into(), - conn_info.multicast_port.into(), - ) - .into(), - )?; - multicast_socket.join_multicast_v4( - &conn_info.multicast_addr.into(), - &Ipv4Addr::UNSPECIFIED, - )?; - buffer_allocator = ChunkBufferAllocator::new(conn_info.chunk_size.into()); - break; - } - Ok(c) => { - log::error!("Received invalid chunk: {:?}", c); - } - Err(err) => { - log::error!("Failed to receive connection info: {}", err); - } - } - } - - std::thread::sleep(std::time::Duration::from_secs(1)); - } - - let buffer_allocator = Arc::new(buffer_allocator); - let control_socket = MultiplexSocket::new(control_socket, buffer_allocator.clone())?; - let multicast_socket = MultiplexSocket::new(multicast_socket, buffer_allocator.clone())?; - - Ok(Self { - control_socket, - multicast_socket, - buffer_allocator, - joined_channels: HashSet::default(), - timeout: Duration::from_secs(5), - resend_timeout: Duration::from_secs(1), - }) - } - - fn join_channel( - &mut self, - id: ChannelId, - join_msg: &T, - ) -> Result { - let deadline = std::time::Instant::now() + self.timeout; - - if self.joined_channels.contains(&id) { - return Err(JoinChannelError::AlreadyJoined); - } - - let control_receiver = self.control_socket.listen_to_channel(id); - - while std::time::Instant::now() < deadline { - self.control_socket.send_chunk(join_msg)?; - let receive_deadline = std::time::Instant::now() + self.resend_timeout; - - loop { - match control_receiver.recv_deadline(receive_deadline) { - Ok(chunk) => match chunk.validate() { - Ok(Chunk::ConfirmJoinChannel(confirm)) => { - self.joined_channels.insert(id); - log::info!("joined channel: {:?}", confirm); - self.control_socket.send_chunk(&Ack { - header: ChannelHeader { - channel_id: id.into(), - seq: confirm.header.seq, - }, - })?; - return Ok(Channel { - control_receiver, - multicast_receiver: self.multicast_socket.listen_to_channel(id), - seq: confirm.header.seq.into(), - }); - } - Ok(c) => { - log::debug!("ignore chunk: {:?}", c); - } - Err(err) => { - log::error!("received invalid chunk: {}", err); - } - }, - Err(RecvTimeoutError::Timeout) => { - // Resend the join channel packet, if the timeout has not been reached. - break; - } - Err(RecvTimeoutError::Disconnected) => { - unreachable!(); - } - } - } - } - - Err(JoinChannelError::Timeout) - } - - pub fn join_barrier_group(&mut self, id: ChannelId) -> Result { - let channel = self.join_channel(id, &JoinBarrierGroup(id.into()))?; - - Ok(Barrier::Connected { - control_receiver: channel.control_receiver, - multicast_receiver: channel.multicast_receiver, - seq: channel.seq, - control_socket: self.control_socket.socket().try_clone()?, - channel_id: id, - timeout: std::time::Duration::from_secs(2), - }) - } - - pub fn subscribe(&mut self, offer_id: ChannelId) -> Result { - let channel = self.join_channel( - offer_id, - &JoinChannel { - channel_id: offer_id.into(), - }, - )?; - - Ok(Subscription { - control_receiver: channel.control_receiver, - multicast_receiver: channel.multicast_receiver, - buffer_allocator: self.buffer_allocator.clone(), - sequence: channel.seq, - chunks: VecDeque::new(), - control_socket: self.control_socket.socket().try_clone()?, - }) - } -} diff --git a/src/test.rs b/src/test.rs new file mode 100644 index 0000000..3c29eb9 --- /dev/null +++ b/src/test.rs @@ -0,0 +1,18 @@ +use std::sync::atomic::AtomicU16; + +static NEXT_PORT: AtomicU16 = AtomicU16::new(55555); +pub fn get_port() -> u16 { + let port = NEXT_PORT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if port == 0 { + panic!("No more ports available"); + } + port +} + +pub type Result = std::result::Result>; + +pub fn init_logger() { + let _ = env_logger::Builder::new() + .filter_level(log::LevelFilter::Trace) + .try_init(); +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..07805e5 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,57 @@ +use std::{ + fmt::{self, Display}, + net::SocketAddr, + time::{Duration, Instant}, +}; + +use socket2::SockAddr; + +pub struct DisplaySockAddr<'a>(&'a SockAddr); + +impl Display for DisplaySockAddr<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.0.as_socket() { + Some(addr) => write!(f, "{}", addr), + None => write!(f, "{:?}", self.0), + } + } +} + +pub fn display_addr(addr: &SockAddr) -> DisplaySockAddr { + DisplaySockAddr(addr) +} + +pub struct ExponentialBackoff { + current_wait_time: Duration, + max_wait_time: Duration, +} + +impl ExponentialBackoff { + pub fn new() -> Self { + Self { + current_wait_time: Duration::from_millis(100), + max_wait_time: Duration::from_secs(1), + } + } +} + +impl Iterator for ExponentialBackoff { + type Item = Instant; + + fn next(&mut self) -> Option { + let now = Instant::now(); + let wait_time = self.current_wait_time; + self.current_wait_time *= 2; + if self.current_wait_time > self.max_wait_time { + self.current_wait_time = self.max_wait_time; + } + Some(now + wait_time) + } +} + +pub fn sock_addr_to_socket_addr(addr: SockAddr) -> Result { + match addr.as_socket() { + Some(addr) => Ok(addr), + None => Err(std::io::ErrorKind::AddrNotAvailable.into()), + } +}