Skip to content

Commit

Permalink
g3proxy: use less io split
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-jq-b committed Jun 4, 2024
1 parent 0646fb6 commit 170436e
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 134 deletions.
43 changes: 16 additions & 27 deletions g3proxy/src/escape/proxy_float/peer/http/http_connect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ use std::sync::Arc;

use anyhow::anyhow;
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
use tokio::net::tcp;
use tokio::net::TcpStream;

use g3_daemon::stat::remote::{
ArcTcpConnectionTaskRemoteStats, TcpConnectionTaskRemoteStatsWrapper,
};
use g3_http::connect::{HttpConnectRequest, HttpConnectResponse};
use g3_io_ext::{LimitedReader, LimitedWriter};
use g3_io_ext::{LimitedReader, LimitedStream, LimitedWriter};
use g3_openssl::SslConnector;
use g3_types::net::{Host, OpensslClientConfig};

Expand All @@ -40,41 +40,30 @@ impl ProxyFloatHttpPeer {
&'a self,
tcp_notes: &'a mut TcpConnectTaskNotes,
task_notes: &'a ServerTaskNotes,
) -> Result<
(
BufReader<LimitedReader<tcp::OwnedReadHalf>>,
LimitedWriter<tcp::OwnedWriteHalf>,
),
TcpConnectError,
> {
let (r, mut w) = self.tcp_new_connection(tcp_notes, task_notes).await?;
) -> Result<BufReader<LimitedStream<TcpStream>>, TcpConnectError> {
let mut stream = self.tcp_new_connection(tcp_notes, task_notes).await?;

let req =
HttpConnectRequest::new(&tcp_notes.upstream, &self.shared_config.append_http_headers);
req.send(&mut w)
req.send(&mut stream)
.await
.map_err(TcpConnectError::NegotiationWriteFailed)?;

let mut r = BufReader::new(r);
let _ = HttpConnectResponse::recv(&mut r, self.http_connect_rsp_hdr_max_size).await?;
let mut buf_stream = BufReader::new(stream);
let _ =
HttpConnectResponse::recv(&mut buf_stream, self.http_connect_rsp_hdr_max_size).await?;

// TODO detect and set outgoing_addr and target_addr for supported remote proxies
// set with the registered public ip by default

Ok((r, w))
Ok(buf_stream)
}

pub(super) async fn timed_http_connect_tcp_connect_to<'a>(
&'a self,
tcp_notes: &'a mut TcpConnectTaskNotes,
task_notes: &'a ServerTaskNotes,
) -> Result<
(
BufReader<LimitedReader<tcp::OwnedReadHalf>>,
LimitedWriter<tcp::OwnedWriteHalf>,
),
TcpConnectError,
> {
) -> Result<BufReader<LimitedStream<TcpStream>>, TcpConnectError> {
tokio::time::timeout(
self.escaper_config.peer_negotiation_timeout,
self.http_connect_tcp_connect_to(tcp_notes, task_notes),
Expand All @@ -89,12 +78,12 @@ impl ProxyFloatHttpPeer {
task_notes: &'a ServerTaskNotes,
task_stats: ArcTcpConnectionTaskRemoteStats,
) -> TcpConnectResult {
let (mut r, mut w) = self
let mut buf_stream = self
.timed_http_connect_tcp_connect_to(tcp_notes, task_notes)
.await?;

// add in read buffered data
let r_buffer_size = r.buffer().len() as u64;
let r_buffer_size = buf_stream.buffer().len() as u64;
task_stats.add_read_bytes(r_buffer_size);
let mut wrapper_stats = TcpConnectRemoteWrapperStats::new(&self.escaper_stats, task_stats);
let user_stats = self.fetch_user_upstream_io_stats(task_notes);
Expand All @@ -105,9 +94,9 @@ impl ProxyFloatHttpPeer {
let wrapper_stats = Arc::new(wrapper_stats);

// reset underlying io stats
r.get_mut().reset_stats(wrapper_stats.clone() as _);
w.reset_stats(wrapper_stats as _);
buf_stream.get_mut().reset_stats(wrapper_stats.clone());

let (r, w) = tokio::io::split(buf_stream);
Ok((Box::new(r), Box::new(w)))
}

Expand All @@ -119,14 +108,14 @@ impl ProxyFloatHttpPeer {
tls_name: &'a Host,
tls_application: TlsApplication,
) -> Result<impl AsyncRead + AsyncWrite, TcpConnectError> {
let (ups_r, ups_w) = self
let buf_stream = self
.timed_http_connect_tcp_connect_to(tcp_notes, task_notes)
.await?;

let ssl = tls_config
.build_ssl(tls_name, tcp_notes.upstream.port())
.map_err(TcpConnectError::InternalTlsClientError)?;
let connector = SslConnector::new(ssl, tokio::io::join(ups_r, ups_w))
let connector = SslConnector::new(ssl, buf_stream.into_inner())
.map_err(|e| TcpConnectError::InternalTlsClientError(anyhow::Error::new(e)))?;

match tokio::time::timeout(tls_config.handshake_timeout, connector.connect()).await {
Expand Down
3 changes: 2 additions & 1 deletion g3proxy/src/escape/proxy_float/peer/http/http_forward/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ impl ProxyFloatHttpPeer {
task_notes: &'a ServerTaskNotes,
task_stats: ArcHttpForwardTaskRemoteStats,
) -> Result<BoxHttpForwardConnection, TcpConnectError> {
let (ups_r, mut ups_w) = self.tcp_new_connection(tcp_notes, task_notes).await?;
let stream = self.tcp_new_connection(tcp_notes, task_notes).await?;
let (ups_r, mut ups_w) = stream.into_split_tcp();

let mut w_wrapper_stats =
HttpForwardRemoteWrapperStats::new(&self.escaper_stats, &task_stats);
Expand Down
26 changes: 7 additions & 19 deletions g3proxy/src/escape/proxy_float/peer/http/tcp_connect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

use std::net::{IpAddr, SocketAddr};

use tokio::net::{tcp, TcpStream};
use tokio::net::TcpStream;
use tokio::time::Instant;

use g3_io_ext::{LimitedReader, LimitedWriter};
use g3_io_ext::LimitedStream;
use g3_types::net::ConnectError;

use super::ProxyFloatHttpPeer;
Expand Down Expand Up @@ -105,30 +105,18 @@ impl ProxyFloatHttpPeer {
&'a self,
tcp_notes: &'a mut TcpConnectTaskNotes,
task_notes: &'a ServerTaskNotes,
) -> Result<
(
LimitedReader<tcp::OwnedReadHalf>,
LimitedWriter<tcp::OwnedWriteHalf>,
),
TcpConnectError,
> {
) -> Result<LimitedStream<TcpStream>, TcpConnectError> {
let stream = self.tcp_connect_to(tcp_notes, task_notes).await?;
let (r, w) = stream.into_split();

let limit_config = &self.shared_config.tcp_conn_speed_limit;
let r = LimitedReader::new(
r,
let stream = LimitedStream::new(
stream,
limit_config.shift_millis,
limit_config.max_south,
self.escaper_stats.clone() as _,
);
let w = LimitedWriter::new(
w,
limit_config.shift_millis,
limit_config.max_north,
self.escaper_stats.clone() as _,
self.escaper_stats.clone(),
);

Ok((r, w))
Ok(stream)
}
}
25 changes: 13 additions & 12 deletions g3proxy/src/escape/proxy_float/peer/https/http_connect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,30 @@ impl ProxyFloatHttpsPeer {
&'a self,
tcp_notes: &'a mut TcpConnectTaskNotes,
task_notes: &'a ServerTaskNotes,
) -> Result<(BufReader<impl AsyncRead>, impl AsyncWrite), TcpConnectError> {
let (r, mut w) = self.tls_handshake_with(tcp_notes, task_notes).await?;
) -> Result<BufReader<impl AsyncRead + AsyncWrite>, TcpConnectError> {
let mut stream = self.tls_handshake_with(tcp_notes, task_notes).await?;

let req =
HttpConnectRequest::new(&tcp_notes.upstream, &self.shared_config.append_http_headers);
req.send(&mut w)
req.send(&mut stream)
.await
.map_err(TcpConnectError::NegotiationWriteFailed)?;

let mut r = BufReader::new(r);
let _ = HttpConnectResponse::recv(&mut r, self.http_connect_rsp_hdr_max_size).await?;
let mut buf_stream = BufReader::new(stream);
let _ =
HttpConnectResponse::recv(&mut buf_stream, self.http_connect_rsp_hdr_max_size).await?;

// TODO detect and set outgoing_addr and target_addr for supported remote proxies
// set with the registered public ip by default

Ok((r, w))
Ok(buf_stream)
}

pub(super) async fn timed_http_connect_tcp_connect_to<'a>(
&'a self,
tcp_notes: &'a mut TcpConnectTaskNotes,
task_notes: &'a ServerTaskNotes,
) -> Result<(BufReader<impl AsyncRead>, impl AsyncWrite), TcpConnectError> {
) -> Result<BufReader<impl AsyncRead + AsyncWrite>, TcpConnectError> {
tokio::time::timeout(
self.escaper_config.peer_negotiation_timeout,
self.http_connect_tcp_connect_to(tcp_notes, task_notes),
Expand All @@ -74,13 +75,13 @@ impl ProxyFloatHttpsPeer {
task_notes: &'a ServerTaskNotes,
task_stats: ArcTcpConnectionTaskRemoteStats,
) -> TcpConnectResult {
let (r, w) = self
let buf_stream = self
.timed_http_connect_tcp_connect_to(tcp_notes, task_notes)
.await?;

// add task and user stats
// add in read buffered data
let r_buffer_size = r.buffer().len() as u64;
let r_buffer_size = buf_stream.buffer().len() as u64;
task_stats.add_read_bytes(r_buffer_size);
let mut wrapper_stats = TcpConnectionTaskRemoteStatsWrapper::new(task_stats);
let user_stats = self.fetch_user_upstream_io_stats(task_notes);
Expand All @@ -90,6 +91,7 @@ impl ProxyFloatHttpsPeer {
wrapper_stats.push_other_stats(user_stats);
let wrapper_stats = Arc::new(wrapper_stats);

let (r, w) = tokio::io::split(buf_stream);
let r = LimitedReader::new_unlimited(r, wrapper_stats.clone() as _);
let w = LimitedWriter::new_unlimited(w, wrapper_stats as _);

Expand All @@ -104,14 +106,14 @@ impl ProxyFloatHttpsPeer {
tls_name: &'a Host,
tls_application: TlsApplication,
) -> Result<impl AsyncRead + AsyncWrite, TcpConnectError> {
let (ups_r, ups_w) = self
let buf_stream = self
.timed_http_connect_tcp_connect_to(tcp_notes, task_notes)
.await?;

let ssl = tls_config
.build_ssl(tls_name, tcp_notes.upstream.port())
.map_err(TcpConnectError::InternalTlsClientError)?;
let connector = SslConnector::new(ssl, tokio::io::join(ups_r, ups_w))
let connector = SslConnector::new(ssl, buf_stream.into_inner())
.map_err(|e| TcpConnectError::InternalTlsClientError(anyhow::Error::new(e)))?;

match tokio::time::timeout(tls_config.handshake_timeout, connector.connect()).await {
Expand Down Expand Up @@ -160,7 +162,6 @@ impl ProxyFloatHttpsPeer {
TlsApplication::TcpStream,
)
.await?;

let (ups_r, ups_w) = tokio::io::split(tls_stream);

// add task and user stats
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ impl ProxyFloatHttpsPeer {
task_notes: &'a ServerTaskNotes,
task_stats: ArcHttpForwardTaskRemoteStats,
) -> Result<BoxHttpForwardConnection, TcpConnectError> {
let (ups_r, ups_w) = self.tls_handshake_with(tcp_notes, task_notes).await?;
let stream = self.tls_handshake_with(tcp_notes, task_notes).await?;
let (ups_r, ups_w) = tokio::io::split(stream);

// add task and user stats
let mut wrapper_stats = HttpForwardTaskRemoteWrapperStats::new(task_stats);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl ProxyFloatHttpsPeer {
&'a self,
tcp_notes: &'a mut TcpConnectTaskNotes,
task_notes: &'a ServerTaskNotes,
) -> Result<(impl AsyncRead, impl AsyncWrite), TcpConnectError> {
) -> Result<impl AsyncRead + AsyncWrite, TcpConnectError> {
let stream = self.tcp_new_connection(tcp_notes, task_notes).await?;

let ssl = self
Expand All @@ -42,10 +42,7 @@ impl ProxyFloatHttpsPeer {
.map_err(|e| TcpConnectError::InternalTlsClientError(anyhow::Error::new(e)))?;

match tokio::time::timeout(self.tls_config.handshake_timeout, connector.connect()).await {
Ok(Ok(stream)) => {
let (r, w) = tokio::io::split(stream);
Ok((r, w))
}
Ok(Ok(stream)) => Ok(stream),
Ok(Err(e)) => {
let e = anyhow::Error::new(e);
let tls_peer = UpstreamAddr::from_ip_and_port(self.addr.ip(), self.addr.port());
Expand Down
Loading

0 comments on commit 170436e

Please sign in to comment.