Skip to content

Commit

Permalink
connbinder add hook
Browse files Browse the repository at this point in the history
  • Loading branch information
snail007 committed Feb 17, 2025
1 parent 64d62d5 commit d6af061
Showing 1 changed file with 48 additions and 20 deletions.
68 changes: 48 additions & 20 deletions util/net/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,17 +588,27 @@ func (b *defaultBufferedConn) PeekMax(n int) (d []byte, err error) {
}

type ConnBinder struct {
src net.Conn
dst net.Conn
onSrcClose CloseHandler
onDstClose CloseHandler
onClose func()
readBufSize int
trafficBytes *int64
ctx Context
started bool
autoClose bool
err error
src net.Conn
dst net.Conn
onSrcClose CloseHandler
onDstClose CloseHandler
onClose func()
readBufSize int
trafficBytes *int64
ctx Context
started bool
autoClose bool
err error
afterSrcFirstRead func(b []byte) []byte
afterDstFirstRead func(b []byte) []byte
}

func (s *ConnBinder) SetAfterSrcFirstRead(afterSrcFirstRead func(b []byte) []byte) {
s.afterSrcFirstRead = afterSrcFirstRead
}

func (s *ConnBinder) SetAfterDstFirstRead(afterDstFirstRead func(b []byte) []byte) {
s.afterDstFirstRead = afterDstFirstRead
}

func (s *ConnBinder) setError(err error) {
Expand Down Expand Up @@ -640,24 +650,42 @@ func (s *ConnBinder) OnClose(onClose func()) *ConnBinder {
return s
}

func (s *ConnBinder) copy(src, dst net.Conn) error {
func (s *ConnBinder) copy(a, b net.Conn, aIsSrc bool) error {
buf := gbytes.GetPool(s.readBufSize).Get().([]byte)
defer func() {
if s.autoClose {
src.Close()
dst.Close()
a.Close()
b.Close()
}
gbytes.GetPool(s.readBufSize).Put(buf)
}()
isFirst := true
for {
n, err := src.Read(buf)
n, err := a.Read(buf)
if err != nil {
err = errors.Wrap(err, "failed to read from src: "+src.RemoteAddr().String())
err = errors.Wrap(err, "failed to read from src: "+a.RemoteAddr().String())
}
if n > 0 {
_, err = dst.Write(buf[:n])
if isFirst {
isFirst = false
if aIsSrc {
if s.afterSrcFirstRead != nil {
_, err = b.Write(s.afterSrcFirstRead(buf[:n]))
} else {
_, err = b.Write(buf[:n])
}
} else {
if s.afterDstFirstRead != nil {
_, err = b.Write(s.afterDstFirstRead(buf[:n]))
} else {
_, err = b.Write(buf[:n])
}
}
} else {
_, err = b.Write(buf[:n])
}
if err != nil {
err = errors.Wrap(err, "failed to write to dst: "+dst.RemoteAddr().String())
err = errors.Wrap(err, "failed to write to dst: "+b.RemoteAddr().String())
}
atomic.AddInt64(s.trafficBytes, int64(n))
}
Expand All @@ -675,12 +703,12 @@ func (s *ConnBinder) StartAndWait() {
g.Add(2)
go func() {
defer g.Done()
s.setError(s.copy(s.src, s.dst))
s.setError(s.copy(s.src, s.dst, true))
s.onSrcClose(s.ctx)
}()
go func() {
defer g.Done()
s.setError(s.copy(s.dst, s.src))
s.setError(s.copy(s.dst, s.src, false))
s.onDstClose(s.ctx)
}()
g.Wait()
Expand Down

0 comments on commit d6af061

Please sign in to comment.