Skip to content

Commit

Permalink
fix: properly close UDP associations
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Feb 10, 2025
1 parent 0387dfb commit 3c158cf
Showing 1 changed file with 101 additions and 28 deletions.
129 changes: 101 additions & 28 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"log/slog"
"net"
"net/netip"
"os"
"runtime/debug"
"sync"
"time"
Expand Down Expand Up @@ -160,9 +161,16 @@ func (h *associationHandler) HandleAssociation(ctx context.Context, clientConn n
default:
}
clientProxyBytes, err := clientConn.Read(readBuf)
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
debugUDP(l, "Client closed connection")
break
if err != nil {
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
debugUDP(l, "Client closed connection")
break
}
if errors.Is(err, os.ErrDeadlineExceeded) {
debugUDP(l, "Read deadline exceeded")
break
}
l.Warn("Failed to read from client. Continuing to listen.", "err", err)
}
pkt := readBuf[:clientProxyBytes]
debugUDP(l, "Outbound packet.", slog.Int("bytes", clientProxyBytes))
Expand Down Expand Up @@ -199,7 +207,10 @@ func (h *associationHandler) HandleAssociation(ctx context.Context, clientConn n
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create a `PacketConn`", err)
}
l = l.With(slog.Any("tgtListener", targetConn.LocalAddr()))
go relayTargetToClient(targetConn, clientConn, cryptoKey, assocMetrics, l)
go func() {
relayTargetToClient(targetConn, clientConn, cryptoKey, assocMetrics, l)
clientConn.SetReadDeadline(time.Now())
}()
} else {
unpackStart := time.Now()
textData, err := shadowsocks.Unpack(nil, pkt, cryptoKey)
Expand Down Expand Up @@ -297,6 +308,7 @@ func PacketServe(clientConn net.PacketConn, assocHandle AssociationHandleFunc, m
clientAddr: clientAddr,
readCh: make(chan *packet, 5),
doneCh: make(chan struct{}),
timeoutCh: make(chan struct{}),
}
if err != nil {
slog.Error("Failed to handle association", slog.Any("err", err))
Expand Down Expand Up @@ -344,49 +356,110 @@ type association struct {
clientAddr net.Addr
readCh chan *packet
doneCh chan struct{}

readDeadline time.Time
readTimeout *time.Timer
timeoutCh chan struct{} // Channel to signal timeout changes.
}

var _ net.Conn = (*association)(nil)

func (c *association) Read(p []byte) (int, error) {
pkt, ok := <-c.readCh
if !ok {
return 0, net.ErrClosed
}
n := copy(p, pkt.payload)
pkt.done()
if n < len(pkt.payload) {
return n, io.ErrShortBuffer
func (a *association) Read(p []byte) (int, error) {
for {
now := time.Now()
if !a.readDeadline.IsZero() && now.After(a.readDeadline) {
if a.readTimeout != nil {
if !a.readTimeout.Stop() {
<-a.readTimeout.C
}
}
return 0, os.ErrDeadlineExceeded
}

select {
case pkt, ok := <-a.readCh:
if !ok {
return 0, net.ErrClosed
}
if a.readTimeout != nil {
if !a.readTimeout.Stop() {
<-a.readTimeout.C
}
}
n := copy(p, pkt.payload)
pkt.done()
if n < len(pkt.payload) {
return n, io.ErrShortBuffer
}
return n, nil

case <-a.timeoutCh:
// The timeout has changed. The next loop iteration will handle it.
case <-func() <-chan time.Time {
if a.readTimeout == nil {
return nil
}
return a.readTimeout.C
}():
if !a.readDeadline.IsZero() && now.After(a.readDeadline) {
return 0, os.ErrDeadlineExceeded
}
}
}
return n, nil
}

func (c *association) Write(b []byte) (n int, err error) {
return c.pc.WriteTo(b, c.clientAddr)
func (a *association) Write(b []byte) (n int, err error) {
return a.pc.WriteTo(b, a.clientAddr)
}

func (c *association) Close() error {
close(c.readCh)
return c.pc.Close()
func (a *association) Close() error {
close(a.readCh)
return nil
}

func (c *association) LocalAddr() net.Addr {
return c.pc.LocalAddr()
func (a *association) LocalAddr() net.Addr {
return a.pc.LocalAddr()
}

func (c *association) RemoteAddr() net.Addr {
return c.clientAddr
func (a *association) RemoteAddr() net.Addr {
return a.clientAddr
}

func (c *association) SetDeadline(t time.Time) error {
return errors.ErrUnsupported
func (a *association) SetDeadline(t time.Time) error {
e1 := a.SetReadDeadline(t)
e2 := a.SetWriteDeadline(t)
if e1 != nil {
return e1
}
return e2
}

func (c *association) SetReadDeadline(t time.Time) error {
return errors.ErrUnsupported
func (a *association) SetReadDeadline(t time.Time) error {
a.readDeadline = t

if a.readTimeout != nil {
if !a.readTimeout.Stop() {
<-a.readTimeout.C
}
}

if !t.IsZero() {
if a.readTimeout == nil {
a.readTimeout = time.NewTimer(time.Until(t))
} else {
a.readTimeout.Reset(time.Until(t))
}
}

select {
case a.timeoutCh <- struct{}{}:
default:
}

return nil
}

func (c *association) SetWriteDeadline(t time.Time) error {
func (a *association) SetWriteDeadline(t time.Time) error {
return errors.ErrUnsupported
}

Expand Down

0 comments on commit 3c158cf

Please sign in to comment.