From 69147df696cdfc8c2ca86d434504a136eb1ab52c Mon Sep 17 00:00:00 2001 From: Zhang Jingqiang Date: Fri, 7 Jun 2024 14:28:20 +0800 Subject: [PATCH] g3-io-ext: add LimitedBufCopy --- g3proxy/Cargo.toml | 2 +- lib/g3-io-ext/Cargo.toml | 3 +- lib/g3-io-ext/src/io/buf/copy.rs | 182 ++++++++++++++++++++++++++ lib/g3-io-ext/src/io/buf/mod.rs | 3 + lib/g3-io-ext/src/io/mod.rs | 2 +- lib/g3-io-ext/src/udp/relay/remote.rs | 2 + 6 files changed, 191 insertions(+), 3 deletions(-) create mode 100644 lib/g3-io-ext/src/io/buf/copy.rs diff --git a/g3proxy/Cargo.toml b/g3proxy/Cargo.toml index 2e7251a25..78871a427 100644 --- a/g3proxy/Cargo.toml +++ b/g3proxy/Cargo.toml @@ -70,7 +70,7 @@ g3-slog-types = { workspace = true, features = ["http"] } g3-yaml = { workspace = true, features = ["resolve", "rustls", "openssl", "acl-rule", "http", "ftp-client", "route", "dpi", "audit", "histogram", "geoip"] } g3-json = { workspace = true, features = ["acl-rule", "resolve", "http", "rustls", "openssl", "histogram"] } g3-msgpack.workspace = true -g3-io-ext.workspace = true +g3-io-ext = { workspace = true, features = ["resolver"] } g3-resolver.workspace = true g3-xcrypt.workspace = true g3-ftp-client.workspace = true diff --git a/lib/g3-io-ext/Cargo.toml b/lib/g3-io-ext/Cargo.toml index d6c0b46f7..d3af7d074 100644 --- a/lib/g3-io-ext/Cargo.toml +++ b/lib/g3-io-ext/Cargo.toml @@ -21,7 +21,7 @@ ahash.workspace = true smallvec.workspace = true quinn = { workspace = true, optional = true } g3-types.workspace = true -g3-resolver.workspace = true +g3-resolver = { workspace = true, optional = true } [target.'cfg(unix)'.dependencies] rustix = { workspace = true, features = ["std", "net"] } @@ -39,3 +39,4 @@ governor = { workspace = true, features = ["std", "jitter"] } [features] default = [] quic = ["dep:quinn"] +resolver = ["dep:g3-resolver"] diff --git a/lib/g3-io-ext/src/io/buf/copy.rs b/lib/g3-io-ext/src/io/buf/copy.rs new file mode 100644 index 000000000..cdbacdc21 --- /dev/null +++ b/lib/g3-io-ext/src/io/buf/copy.rs @@ -0,0 +1,182 @@ +/* + * Copyright 2024 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::future::{poll_fn, Future}; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use tokio::io::{AsyncBufRead, AsyncWrite, AsyncWriteExt}; + +use crate::{LimitedCopyConfig, LimitedCopyError}; + +pub struct LimitedBufCopy<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + writer: &'a mut W, + yield_size: usize, + total_write: u64, + buf_size: usize, + read_done: bool, + need_flush: bool, + active: bool, +} + +impl<'a, R, W> LimitedBufCopy<'a, R, W> +where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + pub fn new(reader: &'a mut R, writer: &'a mut W, config: &LimitedCopyConfig) -> Self { + LimitedBufCopy { + reader, + writer, + yield_size: config.yield_size(), + total_write: 0, + buf_size: 0, + read_done: false, + need_flush: false, + active: false, + } + } + + #[inline] + pub fn no_cached_data(&self) -> bool { + self.buf_size == 0 + } + + #[inline] + pub fn finished(&self) -> bool { + self.read_done + } + + #[inline] + pub fn copied_size(&self) -> u64 { + self.total_write + } + + #[inline] + pub fn is_active(&self) -> bool { + self.active + } + + #[inline] + pub fn is_idle(&self) -> bool { + !self.active + } + + #[inline] + pub fn reset_active(&mut self) { + self.active = false; + } + + fn poll_write_cache(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match Pin::new(&mut self.reader).poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => { + self.buf_size = buf.len(); + if buf.is_empty() { + self.read_done = true; + return Poll::Ready(Ok(())); + } + let i = ready!(Pin::new(&mut self.writer).poll_write(cx, buf)) + .map_err(LimitedCopyError::WriteFailed)?; + self.need_flush = true; + self.active = true; + self.buf_size -= i; + self.total_write += i as u64; + Pin::new(&mut *self.reader).consume(i); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(LimitedCopyError::ReadFailed(e))), + Poll::Pending => return Poll::Ready(Ok(())), + } + } + } + + pub async fn write_flush(&mut self) -> Result<(), LimitedCopyError> { + if self.read_done { + return Ok(()); + } + + if self.buf_size > 0 { + poll_fn(|cx| self.poll_write_cache(cx)).await?; + } + + if self.need_flush { + self.writer + .flush() + .await + .map_err(LimitedCopyError::WriteFailed)?; + } + + Ok(()) + } +} + +impl<'a, R, W> Future for LimitedBufCopy<'a, R, W> +where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut copy_this_round = 0; + loop { + let me = &mut *self; + let buffer = match Pin::new(&mut *me.reader).poll_fill_buf(cx) { + Poll::Ready(Ok(buffer)) => { + me.buf_size = buffer.len(); + if buffer.is_empty() { + if self.need_flush { + ready!(Pin::new(&mut self.writer).poll_flush(cx)) + .map_err(LimitedCopyError::WriteFailed)?; + } + self.read_done = true; + return Poll::Ready(Ok(self.total_write)); + } + buffer + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(LimitedCopyError::ReadFailed(e))), + Poll::Pending => { + if self.need_flush { + ready!(Pin::new(&mut self.writer).poll_flush(cx)) + .map_err(LimitedCopyError::WriteFailed)?; + } + return Poll::Pending; + } + }; + + let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer)) + .map_err(LimitedCopyError::WriteFailed)?; + if i == 0 { + return Poll::Ready(Err(LimitedCopyError::WriteFailed( + io::ErrorKind::WriteZero.into(), + ))); + } + self.need_flush = true; + self.active = true; + self.buf_size -= i; + self.total_write += i as u64; + Pin::new(&mut *self.reader).consume(i); + + copy_this_round += i; + if copy_this_round >= self.yield_size { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } +} diff --git a/lib/g3-io-ext/src/io/buf/mod.rs b/lib/g3-io-ext/src/io/buf/mod.rs index 59eecde25..0f1889b3c 100644 --- a/lib/g3-io-ext/src/io/buf/mod.rs +++ b/lib/g3-io-ext/src/io/buf/mod.rs @@ -23,4 +23,7 @@ pub use limited::LimitedBufReader; mod once; pub use once::OnceBufReader; +mod copy; +pub use copy::LimitedBufCopy; + const DEFAULT_BUF_SIZE: usize = 8 * 1024; diff --git a/lib/g3-io-ext/src/io/mod.rs b/lib/g3-io-ext/src/io/mod.rs index 5d3fdf36c..45083abf7 100644 --- a/lib/g3-io-ext/src/io/mod.rs +++ b/lib/g3-io-ext/src/io/mod.rs @@ -29,7 +29,7 @@ pub use limited_write::{ }; mod buf; -pub use buf::{FlexBufReader, LimitedBufReader, OnceBufReader}; +pub use buf::{FlexBufReader, LimitedBufCopy, LimitedBufReader, OnceBufReader}; mod line_recv_buf; pub use line_recv_buf::{LineRecvBuf, RecvLineError}; diff --git a/lib/g3-io-ext/src/udp/relay/remote.rs b/lib/g3-io-ext/src/udp/relay/remote.rs index 900726ee1..ff1bc1b91 100644 --- a/lib/g3-io-ext/src/udp/relay/remote.rs +++ b/lib/g3-io-ext/src/udp/relay/remote.rs @@ -20,6 +20,7 @@ use std::task::{Context, Poll}; use thiserror::Error; +#[cfg(feature = "resolver")] use g3_resolver::ResolveError; use g3_types::net::UpstreamAddr; @@ -46,6 +47,7 @@ pub enum UdpRelayRemoteError { InvalidPacket(SocketAddr, String), #[error("address not supported")] AddressNotSupported, + #[cfg(feature = "resolver")] #[error("domain not resolved: {0}")] DomainNotResolved(#[from] ResolveError), #[error("forbidden target ip address: {0}")]