Skip to content

Commit

Permalink
g3-io-ext: add LimitedBufCopy
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-jq-b committed Jun 7, 2024
1 parent 9846351 commit 69147df
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 3 deletions.
2 changes: 1 addition & 1 deletion g3proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ g3-slog-types = { workspace = true, features = ["http"] }
g3-yaml = { workspace = true, features = ["resolve", "rustls", "openssl", "acl-rule", "http", "ftp-client", "route", "dpi", "audit", "histogram", "geoip"] }
g3-json = { workspace = true, features = ["acl-rule", "resolve", "http", "rustls", "openssl", "histogram"] }
g3-msgpack.workspace = true
g3-io-ext.workspace = true
g3-io-ext = { workspace = true, features = ["resolver"] }
g3-resolver.workspace = true
g3-xcrypt.workspace = true
g3-ftp-client.workspace = true
Expand Down
3 changes: 2 additions & 1 deletion lib/g3-io-ext/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ahash.workspace = true
smallvec.workspace = true
quinn = { workspace = true, optional = true }
g3-types.workspace = true
g3-resolver.workspace = true
g3-resolver = { workspace = true, optional = true }

[target.'cfg(unix)'.dependencies]
rustix = { workspace = true, features = ["std", "net"] }
Expand All @@ -39,3 +39,4 @@ governor = { workspace = true, features = ["std", "jitter"] }
[features]
default = []
quic = ["dep:quinn"]
resolver = ["dep:g3-resolver"]
182 changes: 182 additions & 0 deletions lib/g3-io-ext/src/io/buf/copy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright 2024 ByteDance and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

use std::future::{poll_fn, Future};
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

use tokio::io::{AsyncBufRead, AsyncWrite, AsyncWriteExt};

use crate::{LimitedCopyConfig, LimitedCopyError};

pub struct LimitedBufCopy<'a, R: ?Sized, W: ?Sized> {
reader: &'a mut R,
writer: &'a mut W,
yield_size: usize,
total_write: u64,
buf_size: usize,
read_done: bool,
need_flush: bool,
active: bool,
}

impl<'a, R, W> LimitedBufCopy<'a, R, W>
where
R: AsyncBufRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
pub fn new(reader: &'a mut R, writer: &'a mut W, config: &LimitedCopyConfig) -> Self {
LimitedBufCopy {
reader,
writer,
yield_size: config.yield_size(),
total_write: 0,
buf_size: 0,
read_done: false,
need_flush: false,
active: false,
}
}

#[inline]
pub fn no_cached_data(&self) -> bool {
self.buf_size == 0
}

#[inline]
pub fn finished(&self) -> bool {
self.read_done
}

#[inline]
pub fn copied_size(&self) -> u64 {
self.total_write
}

#[inline]
pub fn is_active(&self) -> bool {
self.active
}

#[inline]
pub fn is_idle(&self) -> bool {
!self.active
}

#[inline]
pub fn reset_active(&mut self) {
self.active = false;
}

fn poll_write_cache(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), LimitedCopyError>> {
loop {
match Pin::new(&mut self.reader).poll_fill_buf(cx) {
Poll::Ready(Ok(buf)) => {
self.buf_size = buf.len();
if buf.is_empty() {
self.read_done = true;
return Poll::Ready(Ok(()));
}
let i = ready!(Pin::new(&mut self.writer).poll_write(cx, buf))
.map_err(LimitedCopyError::WriteFailed)?;
self.need_flush = true;
self.active = true;
self.buf_size -= i;
self.total_write += i as u64;
Pin::new(&mut *self.reader).consume(i);
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(LimitedCopyError::ReadFailed(e))),
Poll::Pending => return Poll::Ready(Ok(())),
}
}
}

pub async fn write_flush(&mut self) -> Result<(), LimitedCopyError> {
if self.read_done {
return Ok(());
}

if self.buf_size > 0 {
poll_fn(|cx| self.poll_write_cache(cx)).await?;
}

if self.need_flush {
self.writer
.flush()
.await
.map_err(LimitedCopyError::WriteFailed)?;
}

Ok(())
}
}

impl<'a, R, W> Future for LimitedBufCopy<'a, R, W>
where
R: AsyncBufRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
type Output = Result<u64, LimitedCopyError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut copy_this_round = 0;
loop {
let me = &mut *self;
let buffer = match Pin::new(&mut *me.reader).poll_fill_buf(cx) {
Poll::Ready(Ok(buffer)) => {
me.buf_size = buffer.len();
if buffer.is_empty() {
if self.need_flush {
ready!(Pin::new(&mut self.writer).poll_flush(cx))
.map_err(LimitedCopyError::WriteFailed)?;
}
self.read_done = true;
return Poll::Ready(Ok(self.total_write));
}
buffer
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(LimitedCopyError::ReadFailed(e))),
Poll::Pending => {
if self.need_flush {
ready!(Pin::new(&mut self.writer).poll_flush(cx))
.map_err(LimitedCopyError::WriteFailed)?;
}
return Poll::Pending;
}
};

let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer))
.map_err(LimitedCopyError::WriteFailed)?;
if i == 0 {
return Poll::Ready(Err(LimitedCopyError::WriteFailed(
io::ErrorKind::WriteZero.into(),
)));
}
self.need_flush = true;
self.active = true;
self.buf_size -= i;
self.total_write += i as u64;
Pin::new(&mut *self.reader).consume(i);

copy_this_round += i;
if copy_this_round >= self.yield_size {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
}
}
3 changes: 3 additions & 0 deletions lib/g3-io-ext/src/io/buf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ pub use limited::LimitedBufReader;
mod once;
pub use once::OnceBufReader;

mod copy;
pub use copy::LimitedBufCopy;

const DEFAULT_BUF_SIZE: usize = 8 * 1024;
2 changes: 1 addition & 1 deletion lib/g3-io-ext/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub use limited_write::{
};

mod buf;
pub use buf::{FlexBufReader, LimitedBufReader, OnceBufReader};
pub use buf::{FlexBufReader, LimitedBufCopy, LimitedBufReader, OnceBufReader};

mod line_recv_buf;
pub use line_recv_buf::{LineRecvBuf, RecvLineError};
Expand Down
2 changes: 2 additions & 0 deletions lib/g3-io-ext/src/udp/relay/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::task::{Context, Poll};

use thiserror::Error;

#[cfg(feature = "resolver")]
use g3_resolver::ResolveError;
use g3_types::net::UpstreamAddr;

Expand All @@ -46,6 +47,7 @@ pub enum UdpRelayRemoteError {
InvalidPacket(SocketAddr, String),
#[error("address not supported")]
AddressNotSupported,
#[cfg(feature = "resolver")]
#[error("domain not resolved: {0}")]
DomainNotResolved(#[from] ResolveError),
#[error("forbidden target ip address: {0}")]
Expand Down

0 comments on commit 69147df

Please sign in to comment.