diff --git a/transaction-relayer/src/main.rs b/transaction-relayer/src/main.rs index c430e502..bb87e139 100644 --- a/transaction-relayer/src/main.rs +++ b/transaction-relayer/src/main.rs @@ -2,6 +2,7 @@ use std::{ collections::HashSet, fs, net::{IpAddr, Ipv4Addr, SocketAddr}, + ops::Range, path::PathBuf, str::FromStr, sync::{ @@ -206,35 +207,59 @@ struct Sockets { } fn get_sockets(args: &Args) -> Sockets { - let tpu_quic_sockets = (0..args.num_tpu_quic_servers) - .flat_map(|i| { - multi_bind_in_range( + assert!(args.num_tpu_quic_servers < u16::MAX as usize); + assert!(args.num_tpu_fwd_quic_servers < u16::MAX as usize); + + let tpu_ports = Range { + start: args.tpu_quic_port, + end: args + .tpu_quic_port + .checked_add(args.num_tpu_quic_servers as u16) + .unwrap(), + }; + let tpu_fwd_ports = Range { + start: args.tpu_quic_fwd_port, + end: args + .tpu_quic_fwd_port + .checked_add(args.num_tpu_fwd_quic_servers as u16) + .unwrap(), + }; + + for tpu_port in tpu_ports.start..tpu_ports.end { + assert!(!tpu_fwd_ports.contains(&tpu_port)); + } + + let (tpu_p, tpu_quic_sockets): (Vec<_>, Vec<_>) = (0..args.num_tpu_quic_servers) + .map(|i| { + let (port, mut sock) = multi_bind_in_range( IpAddr::V4(Ipv4Addr::from([0, 0, 0, 0])), - ( - args.tpu_quic_port + i as u16, - args.tpu_quic_port + 1 + i as u16, - ), + (tpu_ports.start + i as u16, tpu_ports.start + 1 + i as u16), 1, ) - .unwrap() - .1 + .unwrap(); + + (port, sock.pop().unwrap()) }) - .collect::>(); + .unzip(); - let tpu_fwd_quic_sockets = (0..args.num_tpu_fwd_quic_servers) - .flat_map(|i| { - multi_bind_in_range( + let (tpu_fwd_p, tpu_fwd_quic_sockets): (Vec<_>, Vec<_>) = (0..args.num_tpu_fwd_quic_servers) + .map(|i| { + let (port, mut sock) = multi_bind_in_range( IpAddr::V4(Ipv4Addr::from([0, 0, 0, 0])), ( - args.tpu_quic_fwd_port + i as u16, - args.tpu_quic_fwd_port + 1 + i as u16, + tpu_fwd_ports.start + i as u16, + tpu_fwd_ports.start + 1 + i as u16, ), 1, ) - .unwrap() - .1 + .unwrap(); + + (port, sock.pop().unwrap()) }) - .collect::>(); + .unzip(); + + assert_eq!(tpu_ports.collect::>(), tpu_p); + assert_eq!(tpu_fwd_ports.collect::>(), tpu_fwd_p); Sockets { tpu_sockets: TpuSockets { @@ -418,8 +443,7 @@ fn main() { delay_packet_receiver, leader_cache.handle(), public_ip, - (args.tpu_quic_port..args.tpu_quic_port + args.num_tpu_quic_servers as u16) - .collect(), + (args.tpu_quic_port..args.tpu_quic_port + args.num_tpu_quic_servers as u16).collect(), (args.tpu_quic_fwd_port..args.tpu_quic_fwd_port + args.num_tpu_fwd_quic_servers as u16) .collect(), health_manager.handle(),