From 3a22cc21cab4a058f740c51a5446088bb94fc782 Mon Sep 17 00:00:00 2001 From: Zhang Jingqiang Date: Tue, 4 Jun 2024 15:36:24 +0800 Subject: [PATCH] add write_all_flush method --- g3bench/src/target/h1/opts.rs | 2 +- g3bench/src/target/h2/opts.rs | 2 +- .../keyless/cloudflare/connection/simplex.rs | 6 +- g3bench/src/target/keyless/cloudflare/opts.rs | 2 +- g3bench/src/target/openssl/opts.rs | 2 +- g3bench/src/target/rustls/opts.rs | 2 +- g3keymess/src/serve/task/simplex.rs | 8 +- g3proxy/src/escape/divert_tcp/mod.rs | 9 +-- .../src/escape/proxy_http/tcp_connect/mod.rs | 2 +- .../src/escape/proxy_https/tcp_connect/mod.rs | 2 +- g3proxy/src/inspect/http/v1/connect/mod.rs | 5 ++ g3proxy/src/inspect/http/v1/forward/mod.rs | 13 ++-- g3proxy/src/inspect/http/v1/upgrade/mod.rs | 5 ++ g3proxy/src/inspect/smtp/ending.rs | 8 +- g3proxy/src/inspect/smtp/ext.rs | 7 +- g3proxy/src/inspect/smtp/forward.rs | 16 +--- g3proxy/src/inspect/smtp/greeting.rs | 7 +- g3proxy/src/inspect/smtp/transaction.rs | 14 +--- .../module/http_forward/response/client.rs | 10 +-- .../src/serve/http_proxy/task/forward/task.rs | 13 ++-- g3proxy/src/serve/http_proxy/task/ftp/list.rs | 4 +- .../serve/http_rproxy/task/forward/task.rs | 8 +- .../src/module/keyless/protocol/response.rs | 3 +- lib/g3-daemon/src/control/text.rs | 15 ++-- lib/g3-daemon/src/register/task.rs | 10 +-- lib/g3-fluentd/src/config.rs | 4 + lib/g3-ftp-client/src/control/command.rs | 7 +- lib/g3-io-ext/src/io/ext/limited_write_ext.rs | 8 ++ lib/g3-io-ext/src/io/ext/mod.rs | 1 + lib/g3-io-ext/src/io/ext/write_all_flush.rs | 75 +++++++++++++++++++ lib/g3-socks/src/v4a/reply.rs | 7 +- lib/g3-socks/src/v4a/request.rs | 7 +- lib/g3-socks/src/v5/auth.rs | 16 ++-- lib/g3-socks/src/v5/reply.rs | 7 +- lib/g3-socks/src/v5/request.rs | 6 +- 35 files changed, 188 insertions(+), 125 deletions(-) create mode 100644 lib/g3-io-ext/src/io/ext/write_all_flush.rs diff --git a/g3bench/src/target/h1/opts.rs b/g3bench/src/target/h1/opts.rs index 3e016de60..9db60ef2b 100644 --- a/g3bench/src/target/h1/opts.rs +++ b/g3bench/src/target/h1/opts.rs @@ -140,7 +140,7 @@ impl BenchHttpArgs { if let Some(data) = self.proxy_protocol.data() { stream - .write_all(data) + .write_all(data) // no need to flush data .await .map_err(|e| anyhow!("failed to send proxy protocol data: {e:?}"))?; } diff --git a/g3bench/src/target/h2/opts.rs b/g3bench/src/target/h2/opts.rs index 1c6817e15..a4fd9ff45 100644 --- a/g3bench/src/target/h2/opts.rs +++ b/g3bench/src/target/h2/opts.rs @@ -136,7 +136,7 @@ impl BenchH2Args { if let Some(data) = self.proxy_protocol.data() { stream - .write_all(data) + .write_all(data) // no need to flush data .await .map_err(|e| anyhow!("failed to write proxy protocol data: {e:?}"))?; } diff --git a/g3bench/src/target/keyless/cloudflare/connection/simplex.rs b/g3bench/src/target/keyless/cloudflare/connection/simplex.rs index ce5fa239a..9300b2095 100644 --- a/g3bench/src/target/keyless/cloudflare/connection/simplex.rs +++ b/g3bench/src/target/keyless/cloudflare/connection/simplex.rs @@ -17,7 +17,9 @@ use std::net::SocketAddr; use futures_util::FutureExt; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + +use g3_io_ext::LimitedWriteExt; use super::{KeylessLocalError, KeylessRequest, KeylessResponse, KeylessResponseError}; @@ -62,7 +64,7 @@ impl SimplexTransfer { self.next_req_id = self.next_req_id.wrapping_add(1); self.writer - .write_all(req.as_bytes()) + .write_all_flush(req.as_bytes()) .await .map_err(KeylessLocalError::WriteFailed)?; diff --git a/g3bench/src/target/keyless/cloudflare/opts.rs b/g3bench/src/target/keyless/cloudflare/opts.rs index 71a7ea780..b03934120 100644 --- a/g3bench/src/target/keyless/cloudflare/opts.rs +++ b/g3bench/src/target/keyless/cloudflare/opts.rs @@ -145,7 +145,7 @@ impl KeylessCloudflareArgs { if let Some(data) = self.proxy_protocol.data() { stream - .write_all(data) + .write_all(data) // no need to flush data .await .map_err(|e| anyhow!("failed to write proxy protocol data: {e:?}"))?; } diff --git a/g3bench/src/target/openssl/opts.rs b/g3bench/src/target/openssl/opts.rs index f999c26f9..dfad75a61 100644 --- a/g3bench/src/target/openssl/opts.rs +++ b/g3bench/src/target/openssl/opts.rs @@ -97,7 +97,7 @@ impl BenchOpensslArgs { if let Some(data) = self.proxy_protocol.data() { stream - .write_all(data) + .write_all(data) // no need to flush data .await .map_err(|e| anyhow!("failed to write proxy protocol data: {e:?}"))?; } diff --git a/g3bench/src/target/rustls/opts.rs b/g3bench/src/target/rustls/opts.rs index 8bb9fcff8..dc3c74168 100644 --- a/g3bench/src/target/rustls/opts.rs +++ b/g3bench/src/target/rustls/opts.rs @@ -97,7 +97,7 @@ impl BenchRustlsArgs { if let Some(data) = self.proxy_protocol.data() { stream - .write_all(data) + .write_all(data) // no need to flush data .await .map_err(|e| anyhow!("failed to write proxy protocol data: {e:?}"))?; } diff --git a/g3keymess/src/serve/task/simplex.rs b/g3keymess/src/serve/task/simplex.rs index b20a72c6d..724076c0f 100644 --- a/g3keymess/src/serve/task/simplex.rs +++ b/g3keymess/src/serve/task/simplex.rs @@ -15,10 +15,10 @@ */ use openssl::pkey::{PKey, Private}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncRead, AsyncWrite, BufReader}; use tokio::sync::broadcast; -use g3_io_ext::LimitedBufReadExt; +use g3_io_ext::{LimitedBufReadExt, LimitedWriteExt}; use g3_types::ext::DurationExt; use super::{KeylessTask, WrappedKeylessRequest}; @@ -161,10 +161,10 @@ impl KeylessTask { RequestErrorLogContext { task_id: &self.id }.log(&self.ctx.request_logger, &rsp); writer - .write_all(rsp.message()) + .write_all_flush(rsp.message()) .await .map_err(ServerTaskError::WriteFailed)?; - writer.flush().await.map_err(ServerTaskError::WriteFailed)?; + Ok(()) } } diff --git a/g3proxy/src/escape/divert_tcp/mod.rs b/g3proxy/src/escape/divert_tcp/mod.rs index a339bd081..1e2efccc2 100644 --- a/g3proxy/src/escape/divert_tcp/mod.rs +++ b/g3proxy/src/escape/divert_tcp/mod.rs @@ -20,9 +20,10 @@ use std::sync::Arc; use anyhow::anyhow; use async_trait::async_trait; use slog::Logger; -use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::io::AsyncWrite; use g3_daemon::stat::remote::ArcTcpConnectionTaskRemoteStats; +use g3_io_ext::LimitedWriteExt; use g3_resolver::{ResolveError, ResolveLocalError}; use g3_types::collection::{SelectiveVec, SelectiveVecBuilder}; use g3_types::metrics::MetricsName; @@ -190,11 +191,7 @@ impl DivertTcpEscaper { let pp2_data = pp2_encoder.finalize(); writer - .write_all(pp2_data) - .await - .map_err(TcpConnectError::ProxyProtocolWriteFailed)?; - writer - .flush() + .write_all_flush(pp2_data) .await .map_err(TcpConnectError::ProxyProtocolWriteFailed)?; Ok(pp2_data.len()) diff --git a/g3proxy/src/escape/proxy_http/tcp_connect/mod.rs b/g3proxy/src/escape/proxy_http/tcp_connect/mod.rs index 128c450dc..3a4c8f5d6 100644 --- a/g3proxy/src/escape/proxy_http/tcp_connect/mod.rs +++ b/g3proxy/src/escape/proxy_http/tcp_connect/mod.rs @@ -310,7 +310,7 @@ impl ProxyHttpEscaper { .encode_tcp(task_notes.client_addr(), task_notes.server_addr()) .map_err(TcpConnectError::ProxyProtocolEncodeError)?; stream - .write_all(bytes) + .write_all(bytes) // no need to flush data .await .map_err(TcpConnectError::ProxyProtocolWriteFailed)?; } diff --git a/g3proxy/src/escape/proxy_https/tcp_connect/mod.rs b/g3proxy/src/escape/proxy_https/tcp_connect/mod.rs index ebd446433..7aa766434 100644 --- a/g3proxy/src/escape/proxy_https/tcp_connect/mod.rs +++ b/g3proxy/src/escape/proxy_https/tcp_connect/mod.rs @@ -314,7 +314,7 @@ impl ProxyHttpsEscaper { .encode_tcp(task_notes.client_addr(), task_notes.server_addr()) .map_err(TcpConnectError::ProxyProtocolEncodeError)?; stream - .write_all(bytes) + .write_all(bytes) // no need to flush data .await .map_err(TcpConnectError::ProxyProtocolWriteFailed)?; } diff --git a/g3proxy/src/inspect/http/v1/connect/mod.rs b/g3proxy/src/inspect/http/v1/connect/mod.rs index e65617eab..6b576a392 100644 --- a/g3proxy/src/inspect/http/v1/connect/mod.rs +++ b/g3proxy/src/inspect/http/v1/connect/mod.rs @@ -303,6 +303,11 @@ where LimitedCopyError::WriteFailed(e) => ServerTaskError::ClientTcpWriteFailed(e), })?; recv_body.save_connection().await; + } else { + clt_w + .flush() + .await + .map_err(ServerTaskError::ClientTcpWriteFailed)?; } Ok(()) diff --git a/g3proxy/src/inspect/http/v1/forward/mod.rs b/g3proxy/src/inspect/http/v1/forward/mod.rs index 33030cbde..0989da657 100644 --- a/g3proxy/src/inspect/http/v1/forward/mod.rs +++ b/g3proxy/src/inspect/http/v1/forward/mod.rs @@ -35,7 +35,7 @@ use g3_icap_client::reqmod::IcapReqmodClient; use g3_icap_client::respmod::h1::{ HttpResponseAdapter, RespmodAdaptationEndState, RespmodAdaptationRunState, }; -use g3_io_ext::{LimitedBufReadExt, LimitedCopy, LimitedCopyError}; +use g3_io_ext::{LimitedBufReadExt, LimitedCopy, LimitedCopyError, LimitedWriteExt}; use g3_slog_types::{LtDateTime, LtDuration, LtHttpMethod, LtHttpUri, LtUuid}; use g3_types::net::HttpHeaderMap; @@ -430,6 +430,11 @@ impl<'a, SC: ServerConfig> H1ForwardTask<'a, SC> { LimitedCopyError::WriteFailed(e) => ServerTaskError::ClientTcpWriteFailed(e), })?; recv_body.save_connection().await; + } else { + clt_w + .flush() + .await + .map_err(ServerTaskError::ClientTcpWriteFailed)?; } Ok(()) @@ -769,11 +774,7 @@ impl<'a, SC: ServerConfig> H1ForwardTask<'a, SC> { CW: AsyncWrite + Unpin, { clt_w - .write_all(&head_bytes) - .await - .map_err(ServerTaskError::ClientTcpWriteFailed)?; - clt_w - .flush() + .write_all_flush(&head_bytes) .await .map_err(ServerTaskError::ClientTcpWriteFailed) } diff --git a/g3proxy/src/inspect/http/v1/upgrade/mod.rs b/g3proxy/src/inspect/http/v1/upgrade/mod.rs index 8470ccc5f..0bbd09d44 100644 --- a/g3proxy/src/inspect/http/v1/upgrade/mod.rs +++ b/g3proxy/src/inspect/http/v1/upgrade/mod.rs @@ -310,6 +310,11 @@ where LimitedCopyError::WriteFailed(e) => ServerTaskError::ClientTcpWriteFailed(e), })?; recv_body.save_connection().await; + } else { + clt_w + .flush() + .await + .map_err(ServerTaskError::ClientTcpWriteFailed)?; } Ok(()) diff --git a/g3proxy/src/inspect/smtp/ending.rs b/g3proxy/src/inspect/smtp/ending.rs index d160b92c3..28bae5852 100644 --- a/g3proxy/src/inspect/smtp/ending.rs +++ b/g3proxy/src/inspect/smtp/ending.rs @@ -20,7 +20,7 @@ use std::time::Duration; use anyhow::anyhow; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use g3_io_ext::{LineRecvBuf, RecvLineError}; +use g3_io_ext::{LimitedWriteExt, LineRecvBuf, RecvLineError}; use g3_smtp_proto::command::Command; use g3_smtp_proto::response::{ReplyCode, ResponseEncoder, ResponseParser}; @@ -39,11 +39,7 @@ impl EndQuitServer { W: AsyncWrite + Unpin, { ups_w - .write_all(b"QUIT\r\n") - .await - .map_err(ServerTaskError::UpstreamWriteFailed)?; - ups_w - .flush() + .write_all_flush(b"QUIT\r\n") .await .map_err(ServerTaskError::UpstreamWriteFailed)?; diff --git a/g3proxy/src/inspect/smtp/ext.rs b/g3proxy/src/inspect/smtp/ext.rs index b04986c3f..c056e5699 100644 --- a/g3proxy/src/inspect/smtp/ext.rs +++ b/g3proxy/src/inspect/smtp/ext.rs @@ -18,9 +18,9 @@ use std::io; use std::net::IpAddr; use anyhow::anyhow; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; -use g3_io_ext::{LineRecvBuf, RecvLineError}; +use g3_io_ext::{LimitedWriteExt, LineRecvBuf, RecvLineError}; use g3_smtp_proto::command::Command; use g3_smtp_proto::response::{ResponseEncoder, ResponseParser}; @@ -198,6 +198,5 @@ async fn send_cmd(ups_w: &mut W, line: &[u8]) -> io::Result<()> where W: AsyncWrite + Unpin, { - ups_w.write_all(line).await?; - ups_w.flush().await + ups_w.write_all_flush(line).await } diff --git a/g3proxy/src/inspect/smtp/forward.rs b/g3proxy/src/inspect/smtp/forward.rs index 5137ef3e3..8ded0d95d 100644 --- a/g3proxy/src/inspect/smtp/forward.rs +++ b/g3proxy/src/inspect/smtp/forward.rs @@ -16,9 +16,9 @@ use std::net::IpAddr; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; -use g3_io_ext::LineRecvBuf; +use g3_io_ext::{LimitedWriteExt, LineRecvBuf}; use g3_smtp_proto::command::{Command, MailParam}; use g3_smtp_proto::response::{ReplyCode, ResponseEncoder, ResponseParser}; @@ -169,11 +169,7 @@ impl Forward { .await?; clt_w - .write_all(line) - .await - .map_err(ServerTaskError::ClientTcpWriteFailed)?; - clt_w - .flush() + .write_all_flush(line) .await .map_err(ServerTaskError::ClientTcpWriteFailed)?; @@ -212,11 +208,7 @@ impl Forward { match recv_buf.read_line(clt_r).await { Ok(line) => { ups_w - .write_all(line) - .await - .map_err(ServerTaskError::UpstreamWriteFailed)?; - ups_w - .flush() + .write_all_flush(line) .await .map_err(ServerTaskError::UpstreamWriteFailed)?; recv_buf.consume_line(); diff --git a/g3proxy/src/inspect/smtp/greeting.rs b/g3proxy/src/inspect/smtp/greeting.rs index da9665b9c..01033b261 100644 --- a/g3proxy/src/inspect/smtp/greeting.rs +++ b/g3proxy/src/inspect/smtp/greeting.rs @@ -22,7 +22,7 @@ use anyhow::anyhow; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; -use g3_io_ext::{LineRecvBuf, OnceBufReader, RecvLineError}; +use g3_io_ext::{LimitedWriteExt, LineRecvBuf, OnceBufReader, RecvLineError}; use g3_smtp_proto::response::{ReplyCode, ResponseEncoder, ResponseLineError, ResponseParser}; use g3_types::net::Host; @@ -70,7 +70,7 @@ impl Greeting { let msg = self.rsp.feed_line(line)?; self.total_to_write += line.len(); clt_w - .write_all(line) + .write_all_flush(line) .await .map_err(GreetingError::ClientWriteFailed)?; @@ -150,8 +150,7 @@ impl Greeting { _ => return, }; let rsp = ResponseEncoder::upstream_service_not_ready(self.local_ip, reason); - let _ = clt_w.write_all(rsp.as_bytes()).await; - let _ = clt_w.flush().await; + let _ = clt_w.write_all_flush(rsp.as_bytes()).await; let _ = clt_w.shutdown().await; } } diff --git a/g3proxy/src/inspect/smtp/transaction.rs b/g3proxy/src/inspect/smtp/transaction.rs index dc3e3f40a..bc6d00569 100644 --- a/g3proxy/src/inspect/smtp/transaction.rs +++ b/g3proxy/src/inspect/smtp/transaction.rs @@ -20,7 +20,7 @@ use anyhow::anyhow; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::time::Instant; -use g3_io_ext::{LimitedCopy, LimitedCopyError}; +use g3_io_ext::{LimitedCopy, LimitedCopyError, LimitedWriteExt}; use g3_smtp_proto::command::{Command, MailParam, RecipientParam}; use g3_smtp_proto::io::TextDataReader; use g3_smtp_proto::response::{ReplyCode, ResponseEncoder, ResponseParser}; @@ -205,11 +205,7 @@ impl<'a, SC: ServerConfig> Transaction<'a, SC> { .await?; clt_w - .write_all(line) - .await - .map_err(ServerTaskError::ClientTcpWriteFailed)?; - clt_w - .flush() + .write_all_flush(line) .await .map_err(ServerTaskError::ClientTcpWriteFailed)?; @@ -331,11 +327,7 @@ impl<'a, SC: ServerConfig> Transaction<'a, SC> { UW: AsyncWrite + Unpin, { ups_w - .write_all(cmd_line) - .await - .map_err(ServerTaskError::UpstreamWriteFailed)?; - ups_w - .flush() + .write_all_flush(cmd_line) .await .map_err(ServerTaskError::UpstreamWriteFailed)?; diff --git a/g3proxy/src/module/http_forward/response/client.rs b/g3proxy/src/module/http_forward/response/client.rs index c3732e2b7..a15653675 100644 --- a/g3proxy/src/module/http_forward/response/client.rs +++ b/g3proxy/src/module/http_forward/response/client.rs @@ -24,6 +24,7 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; use g3_ftp_client::FtpConnectError; use g3_http::server::HttpRequestParseError; +use g3_io_ext::LimitedWriteExt; use g3_types::net::ConnectError; use crate::module::http_header; @@ -477,8 +478,7 @@ impl HttpProxyClientResponse { header.extend_from_slice(line.as_bytes()); } header.extend_from_slice(b"\r\n"); - writer.write_all(header.as_ref()).await?; - writer.flush().await?; + writer.write_all_flush(header.as_ref()).await?; Ok(()) } @@ -509,8 +509,7 @@ impl HttpProxyClientResponse { W: AsyncWrite + Unpin, { let s = format!("{version:?} 100 Continue\r\n\r\n"); - writer.write_all(s.as_bytes()).await?; - writer.flush().await?; + writer.write_all_flush(s.as_bytes()).await?; Ok(()) } @@ -546,8 +545,7 @@ impl HttpProxyClientResponse { // append body header.extend_from_slice(body.as_bytes()); - writer.write_all(header.as_ref()).await?; - writer.flush().await?; + writer.write_all_flush(header.as_ref()).await?; Ok(()) } diff --git a/g3proxy/src/serve/http_proxy/task/forward/task.rs b/g3proxy/src/serve/http_proxy/task/forward/task.rs index 5471baac7..bdca1c6fa 100644 --- a/g3proxy/src/serve/http_proxy/task/forward/task.rs +++ b/g3proxy/src/serve/http_proxy/task/forward/task.rs @@ -33,7 +33,7 @@ use g3_icap_client::reqmod::h1::{ use g3_icap_client::respmod::h1::{ HttpResponseAdapter, RespmodAdaptationEndState, RespmodAdaptationRunState, }; -use g3_io_ext::{LimitedBufReadExt, LimitedCopy, LimitedCopyError}; +use g3_io_ext::{LimitedBufReadExt, LimitedCopy, LimitedCopyError, LimitedWriteExt}; use g3_types::acl::AclAction; use g3_types::net::{HttpHeaderMap, ProxyRequestType}; @@ -991,6 +991,11 @@ impl<'a> HttpProxyForwardTask<'a> { LimitedCopyError::WriteFailed(e) => ServerTaskError::ClientTcpWriteFailed(e), })?; recv_body.save_connection().await; + } else { + clt_w + .flush() + .await + .map_err(ServerTaskError::ClientTcpWriteFailed)?; } Ok(()) @@ -1527,11 +1532,7 @@ impl<'a> HttpProxyForwardTask<'a> { { let buf = rsp.serialize(); clt_w - .write_all(buf.as_ref()) - .await - .map_err(ServerTaskError::ClientTcpWriteFailed)?; - clt_w - .flush() + .write_all_flush(buf.as_ref()) .await .map_err(ServerTaskError::ClientTcpWriteFailed) } diff --git a/g3proxy/src/serve/http_proxy/task/ftp/list.rs b/g3proxy/src/serve/http_proxy/task/ftp/list.rs index e9fdf1ae9..9a4d227ba 100644 --- a/g3proxy/src/serve/http_proxy/task/ftp/list.rs +++ b/g3proxy/src/serve/http_proxy/task/ftp/list.rs @@ -19,6 +19,7 @@ use std::io::{self, Write}; use tokio::io::{AsyncWrite, AsyncWriteExt, BufWriter}; use g3_ftp_client::FtpLineDataReceiver; +use g3_io_ext::LimitedWriteExt; const CHUNKED_BUF_HEAD_RESERVED: usize = (usize::BITS as usize >> 2) + 2; const CHUNKED_BUF_TAIL_RESERVED: usize = 2; @@ -113,8 +114,7 @@ where if self.buf_len > CHUNKED_BUF_HEAD_RESERVED { self.send_buf().await?; } - self.writer.write_all(b"0\r\n\r\n").await?; - self.writer.flush().await + self.writer.write_all_flush(b"0\r\n\r\n").await } #[inline] diff --git a/g3proxy/src/serve/http_rproxy/task/forward/task.rs b/g3proxy/src/serve/http_rproxy/task/forward/task.rs index 593f11d05..861a2a12d 100644 --- a/g3proxy/src/serve/http_rproxy/task/forward/task.rs +++ b/g3proxy/src/serve/http_rproxy/task/forward/task.rs @@ -24,7 +24,7 @@ use tokio::time::Instant; use g3_http::client::HttpForwardRemoteResponse; use g3_http::server::HttpProxyClientRequest; use g3_http::{HttpBodyReader, HttpBodyType}; -use g3_io_ext::{LimitedBufReadExt, LimitedCopy, LimitedCopyError}; +use g3_io_ext::{LimitedBufReadExt, LimitedCopy, LimitedCopyError, LimitedWriteExt}; use g3_types::acl::AclAction; use super::protocol::{HttpClientReader, HttpClientWriter, HttpRProxyRequest}; @@ -1049,11 +1049,7 @@ impl<'a> HttpRProxyForwardTask<'a> { { let buf = rsp.serialize(); clt_w - .write_all(buf.as_ref()) - .await - .map_err(ServerTaskError::ClientTcpWriteFailed)?; - clt_w - .flush() + .write_all_flush(buf.as_ref()) .await .map_err(ServerTaskError::ClientTcpWriteFailed) } diff --git a/g3tiles/src/module/keyless/protocol/response.rs b/g3tiles/src/module/keyless/protocol/response.rs index e40ffdb03..73ca4955f 100644 --- a/g3tiles/src/module/keyless/protocol/response.rs +++ b/g3tiles/src/module/keyless/protocol/response.rs @@ -100,8 +100,7 @@ impl KeylessInternalErrorResponse { where W: AsyncWrite + Unpin, { - writer.write_all(&self.buf).await?; - writer.flush().await + writer.write_all_flush(&self.buf).await } } diff --git a/lib/g3-daemon/src/control/text.rs b/lib/g3-daemon/src/control/text.rs index 9a199c6ef..acaf36914 100644 --- a/lib/g3-daemon/src/control/text.rs +++ b/lib/g3-daemon/src/control/text.rs @@ -17,11 +17,11 @@ use std::str::SplitWhitespace; use anyhow::anyhow; -use tokio::io::{AsyncBufRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncBufRead, AsyncWrite}; use tokio::time::Duration; use yaml_rust::Yaml; -use g3_io_ext::LimitedBufReadExt; +use g3_io_ext::{LimitedBufReadExt, LimitedWriteExt}; use super::{CtlProtoType, GeneralControllerConfig}; @@ -111,11 +111,12 @@ where async fn send_response(&mut self, response: &str) -> anyhow::Result<()> { let send_timeout = self.config.send_timeout; - let fut = async { - self.writer.write_all(response.as_bytes()).await?; - self.writer.flush().await - }; - match tokio::time::timeout(Duration::from_secs(send_timeout), fut).await? { + match tokio::time::timeout( + Duration::from_secs(send_timeout), + self.writer.write_all_flush(response.as_bytes()), + ) + .await? + { Ok(_) => Ok(()), Err(e) => Err(anyhow!("write: {e}")), } diff --git a/lib/g3-daemon/src/register/task.rs b/lib/g3-daemon/src/register/task.rs index 2e4f74d06..0764ce4d4 100644 --- a/lib/g3-daemon/src/register/task.rs +++ b/lib/g3-daemon/src/register/task.rs @@ -19,12 +19,12 @@ use std::sync::Arc; use anyhow::anyhow; use http::{Method, StatusCode}; use serde_json::{Map, Value}; -use tokio::io::{AsyncWriteExt, BufStream}; +use tokio::io::BufStream; use tokio::net::TcpStream; use g3_http::client::HttpForwardRemoteResponse; use g3_http::HttpBodyReader; -use g3_io_ext::LimitedBufReadExt; +use g3_io_ext::{LimitedBufReadExt, LimitedWriteExt}; use super::RegisterConfig; @@ -97,11 +97,7 @@ impl RegisterTask { async fn write_request(&mut self, data: &[u8]) -> anyhow::Result<()> { self.stream - .write_all(data) - .await - .map_err(|e| anyhow!("failed to write data: {e:?}"))?; - self.stream - .flush() + .write_all_flush(data) .await .map_err(|e| anyhow!("failed to write data: {e:?}")) } diff --git a/lib/g3-fluentd/src/config.rs b/lib/g3-fluentd/src/config.rs index bed939ab6..886b6c1f8 100644 --- a/lib/g3-fluentd/src/config.rs +++ b/lib/g3-fluentd/src/config.rs @@ -202,6 +202,10 @@ impl FluentdClientConfig { .write_all(ping_msg.as_slice()) .await .map_err(|e| anyhow!("failed to write ping msg: {e:?}"))?; + connection + .flush() + .await + .map_err(|e| anyhow!("failed to flush ping msg: {e:?}"))?; let mut pong_buf = Vec::with_capacity(1024); // TODO config let pong_len = connection diff --git a/lib/g3-ftp-client/src/control/command.rs b/lib/g3-ftp-client/src/control/command.rs index 070e516d7..c45ef6c0c 100644 --- a/lib/g3-ftp-client/src/control/command.rs +++ b/lib/g3-ftp-client/src/control/command.rs @@ -17,7 +17,9 @@ use std::fmt; use std::io; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use g3_io_ext::LimitedWriteExt; use super::FtpControlChannel; @@ -81,8 +83,7 @@ where #[cfg(feature = "log-raw-io")] crate::debug::log_cmd(unsafe { std::str::from_utf8_unchecked(buf).trim_end() }); - self.stream.write_all(buf).await?; - self.stream.flush().await?; + self.stream.write_all_flush(buf).await?; Ok(()) } diff --git a/lib/g3-io-ext/src/io/ext/limited_write_ext.rs b/lib/g3-io-ext/src/io/ext/limited_write_ext.rs index cb59c3f00..4dfcab621 100644 --- a/lib/g3-io-ext/src/io/ext/limited_write_ext.rs +++ b/lib/g3-io-ext/src/io/ext/limited_write_ext.rs @@ -18,6 +18,7 @@ use std::io::IoSlice; use tokio::io::AsyncWrite; +use super::write_all_flush::WriteAllFlush; use super::write_all_vectored::WriteAllVectored; pub trait LimitedWriteExt: AsyncWrite { @@ -30,6 +31,13 @@ pub trait LimitedWriteExt: AsyncWrite { { WriteAllVectored::new(self, bufs) } + + fn write_all_flush<'a>(&'a mut self, buf: &'a [u8]) -> WriteAllFlush<'a, Self> + where + Self: Unpin, + { + WriteAllFlush::new(self, buf) + } } impl LimitedWriteExt for W {} diff --git a/lib/g3-io-ext/src/io/ext/mod.rs b/lib/g3-io-ext/src/io/ext/mod.rs index 072212cf7..114935004 100644 --- a/lib/g3-io-ext/src/io/ext/mod.rs +++ b/lib/g3-io-ext/src/io/ext/mod.rs @@ -19,6 +19,7 @@ mod fill_wait_eof; mod limited_read_buf_until; mod limited_read_until; mod limited_skip_until; +mod write_all_flush; mod write_all_vectored; mod limited_buf_read_ext; diff --git a/lib/g3-io-ext/src/io/ext/write_all_flush.rs b/lib/g3-io-ext/src/io/ext/write_all_flush.rs new file mode 100644 index 000000000..4079f3f7b --- /dev/null +++ b/lib/g3-io-ext/src/io/ext/write_all_flush.rs @@ -0,0 +1,75 @@ +/* + * Copyright 2024 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::future::Future; +use std::io; +use std::mem; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use pin_project_lite::pin_project; +use tokio::io::AsyncWrite; + +pin_project! { + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct WriteAllFlush<'a, W: ?Sized> { + writer: &'a mut W, + buf: &'a [u8], + flush_done: bool, + } +} + +impl<'a, W> WriteAllFlush<'a, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + pub(crate) fn new(writer: &'a mut W, buf: &'a [u8]) -> Self { + WriteAllFlush { + writer, + buf, + flush_done: false, + } + } +} + +impl Future for WriteAllFlush<'_, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + while !me.buf.is_empty() { + let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf))?; + { + let (_, rest) = mem::take(&mut *me.buf).split_at(n); + *me.buf = rest; + } + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + + if !*me.flush_done { + ready!(Pin::new(&mut *me.writer).poll_flush(cx))?; + *me.flush_done = true; + } + + Poll::Ready(Ok(())) + } +} diff --git a/lib/g3-socks/src/v4a/reply.rs b/lib/g3-socks/src/v4a/reply.rs index 47a38cbf0..626a560a2 100644 --- a/lib/g3-socks/src/v4a/reply.rs +++ b/lib/g3-socks/src/v4a/reply.rs @@ -17,7 +17,9 @@ use std::io; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + +use g3_io_ext::LimitedWriteExt; use super::{SocksNegotiationError, SocksReplyParseError}; @@ -91,8 +93,7 @@ impl SocksV4Reply { W: AsyncWrite + Unpin, { let buf: [u8; 8] = [0, self.code(), 0, 0, 0, 0, 0, 0]; - clt_w.write_all(&buf).await?; - clt_w.flush().await?; + clt_w.write_all_flush(&buf).await?; Ok(()) } diff --git a/lib/g3-socks/src/v4a/request.rs b/lib/g3-socks/src/v4a/request.rs index bbed48a0c..f8184c9b1 100644 --- a/lib/g3-socks/src/v4a/request.rs +++ b/lib/g3-socks/src/v4a/request.rs @@ -18,9 +18,9 @@ use std::io; use std::net::{IpAddr, Ipv4Addr}; use bytes::{BufMut, BytesMut}; -use tokio::io::{AsyncBufRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncBufRead, AsyncReadExt, AsyncWrite}; -use g3_io_ext::LimitedBufReadExt; +use g3_io_ext::{LimitedBufReadExt, LimitedWriteExt}; use g3_types::net::{Host, UpstreamAddr}; use super::{SocksCommand, SocksNegotiationError, SocksRequestParseError}; @@ -140,7 +140,6 @@ impl SocksV4aRequest { buf } }; - writer.write_all(buf.as_ref()).await?; - writer.flush().await + writer.write_all_flush(buf.as_ref()).await } } diff --git a/lib/g3-socks/src/v5/auth.rs b/lib/g3-socks/src/v5/auth.rs index cbe617aab..3f910bdcd 100644 --- a/lib/g3-socks/src/v5/auth.rs +++ b/lib/g3-socks/src/v5/auth.rs @@ -20,6 +20,7 @@ use std::io; use bytes::{BufMut, BytesMut}; use tokio::io::{AsyncBufRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use g3_io_ext::LimitedWriteExt; use g3_types::auth::{Password, Username}; use g3_types::net::SocksAuth; @@ -71,8 +72,7 @@ where W: AsyncWrite + Unpin, { let msg = [0x05, method.code()]; - clt_w.write_all(&msg).await?; - clt_w.flush().await + clt_w.write_all_flush(&msg).await } async fn send_methods_to_remote(writer: &mut W, auth: &SocksAuth) -> io::Result<()> @@ -118,11 +118,7 @@ where buf.put_slice(password.as_original().as_bytes()); buf_stream - .write_all(buf.as_ref()) - .await - .map_err(SocksConnectError::WriteFailed)?; - buf_stream - .flush() + .write_all_flush(buf.as_ref()) .await .map_err(SocksConnectError::WriteFailed)?; @@ -180,8 +176,7 @@ where W: AsyncWrite + Unpin, { let buf = [0x01, 0x00]; - clt_w.write_all(&buf).await?; - clt_w.flush().await + clt_w.write_all_flush(&buf).await } pub async fn send_user_auth_failure(clt_w: &mut W) -> io::Result<()> @@ -189,6 +184,5 @@ where W: AsyncWrite + Unpin, { let buf = [0x01, 0x01]; - clt_w.write_all(&buf).await?; - clt_w.flush().await + clt_w.write_all_flush(&buf).await } diff --git a/lib/g3-socks/src/v5/reply.rs b/lib/g3-socks/src/v5/reply.rs index ee9a7f83b..362c43de4 100644 --- a/lib/g3-socks/src/v5/reply.rs +++ b/lib/g3-socks/src/v5/reply.rs @@ -18,7 +18,9 @@ use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use bytes::{BufMut, BytesMut}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + +use g3_io_ext::LimitedWriteExt; use super::{SocksNegotiationError, SocksReplyParseError}; @@ -164,7 +166,6 @@ impl Socks5Reply { buf.put_slice(&[0x00, 0x00]); } } - clt_w.write_all(buf.as_ref()).await?; - clt_w.flush().await + clt_w.write_all_flush(buf.as_ref()).await } } diff --git a/lib/g3-socks/src/v5/request.rs b/lib/g3-socks/src/v5/request.rs index 82fccd894..36c7b0cf0 100644 --- a/lib/g3-socks/src/v5/request.rs +++ b/lib/g3-socks/src/v5/request.rs @@ -18,8 +18,9 @@ use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use bytes::{BufMut, BytesMut}; -use tokio::io::{AsyncBufRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncBufRead, AsyncReadExt, AsyncWrite}; +use g3_io_ext::LimitedWriteExt; use g3_types::net::{Host, UpstreamAddr}; use super::{SocksCommand, SocksNegotiationError, SocksRequestParseError}; @@ -121,7 +122,6 @@ impl Socks5Request { buf.put_u16(addr.port()); } } - writer.write_all(buf.as_ref()).await?; - writer.flush().await + writer.write_all_flush(buf.as_ref()).await } }