Skip to content

Commit

Permalink
udpnat2: New synced udp nat service
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Oct 22, 2024
1 parent b07fb48 commit 50fa1ab
Show file tree
Hide file tree
Showing 12 changed files with 307 additions and 40 deletions.
9 changes: 1 addition & 8 deletions common/bufio/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {

func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
if destination.IsFqdn() {
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
}

func (w *ExtendedUDPConn) Upstream() any {
Expand Down
2 changes: 1 addition & 1 deletion common/network/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ type UDPHandler interface {
}

type UDPHandlerEx interface {
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) error
NewPacketEx(buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr)
}

// Deprecated: Use UDPConnectionHandlerEx instead.
Expand Down
22 changes: 17 additions & 5 deletions common/network/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,27 @@ func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
}

func (o ReadWaitOptions) Copy(buffer *buf.Buffer) *buf.Buffer {
if o.FrontHeadroom > buffer.Start() ||
o.RearHeadroom > buffer.FreeLen() {
newBuffer := o.newBuffer(buf.UDPBufferSize, false)
newBuffer.Write(buffer.Bytes())
buffer.Release()
return newBuffer
} else {
return buffer
}
}

func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
return o.newBuffer(buf.BufferSize)
return o.newBuffer(buf.BufferSize, true)
}

func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
return o.newBuffer(buf.UDPBufferSize)
return o.newBuffer(buf.UDPBufferSize, true)
}

func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buffer {
var bufferSize int
if o.MTU > 0 {
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
Expand All @@ -36,9 +48,9 @@ func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
}
buffer := buf.NewSize(bufferSize)
if o.FrontHeadroom > 0 {
buffer.Resize(o.FrontHeadroom, 0)
buffer.Advance(o.FrontHeadroom)
}
if o.RearHeadroom > 0 {
if o.RearHeadroom > 0 && reserve {
buffer.Reserve(o.RearHeadroom)
}
return buffer
Expand Down
6 changes: 0 additions & 6 deletions common/udpnat/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.
s.nat.Delete(key)
}
}()
} else {
c.localAddr = source
}
if common.Done(c.ctx) {
s.nat.Delete(key)
Expand Down Expand Up @@ -215,10 +213,6 @@ func (c *conn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}

func (c *conn) NeedAdditionalReadDeadline() bool {
return true
}

func (c *conn) Upstream() any {
return c.source
}
90 changes: 90 additions & 0 deletions common/udpnat2/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package udpnat

import (
"io"
"net"
"os"
"time"

"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
)

type natConn struct {
writer N.PacketWriter
localAddr M.Socksaddr
packetChan chan *Packet
doneChan chan struct{}
readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions
}

func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
select {
case p := <-c.packetChan:
_, err = buffer.ReadOnceFrom(p.Buffer)
destination := p.Destination
p.Buffer.Release()
PutPacket(p)
return destination, err
case <-c.doneChan:
return M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}

func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WritePacket(buffer, destination)
}

func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}

func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination
PutPacket(packet)
return
case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
}
}

func (c *natConn) Close() error {
select {
case <-c.doneChan:
default:
close(c.doneChan)
}
return nil
}

func (c *natConn) LocalAddr() net.Addr {
return c.localAddr
}

func (c *natConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}

func (c *natConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}

func (c *natConn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
return nil
}

func (c *natConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
28 changes: 28 additions & 0 deletions common/udpnat2/packet.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package udpnat

import (
M "github.com/sagernet/sing/common/metadata"
"sync"

"github.com/sagernet/sing/common/buf"
)

var packetPool = sync.Pool{
New: func() any {
return new(Packet)
},
}

type Packet struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}

func NewPacket() *Packet {
return packetPool.Get().(*Packet)
}

func PutPacket(packet *Packet) {
*packet = Packet{}
packetPool.Put(packet)
}
93 changes: 93 additions & 0 deletions common/udpnat2/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package udpnat

import (
"context"
"net/netip"
"time"

"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
)

type Service struct {
nat *freelru.LRU[netip.AddrPort, *natConn]
handler N.UDPConnectionHandlerEx
prepare PrepareFunc
metrics Metrics
}

type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc)

type Metrics struct {
Creates uint64
Rejects uint64
Inputs uint64
Drops uint64
}

func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service {
nat := common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
nat.SetLifetime(timeout)
nat.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
select {
case <-conn.doneChan:
return false
default:
return true
}
})
nat.SetOnEvict(func(_ netip.AddrPort, conn *natConn) {
conn.Close()
})
return &Service{
nat: nat,
handler: handler,
prepare: prepare,
}
}

func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) {
conn, loaded := s.nat.Get(source.AddrPort())
if !loaded {
ok, ctx, writer, onClose := s.prepare(source, destination, userData)
if !ok {
s.metrics.Rejects++
return
}
conn = &natConn{
writer: writer,
localAddr: source,
packetChan: make(chan *Packet, 64),
doneChan: make(chan struct{}),
readDeadline: pipe.MakeDeadline(),
}
s.nat.Add(source.AddrPort(), conn)
s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose)
s.metrics.Creates++
}
packet := NewPacket()
buffer := conn.readWaitOptions.NewPacketBuffer()
for _, bufferSlice := range bufferSlices {
buffer.Write(bufferSlice)
}
*packet = Packet{
Buffer: buffer,
Destination: destination,
}
select {
case conn.packetChan <- packet:
s.metrics.Inputs++
default:
packet.Buffer.Release()
PutPacket(packet)
s.metrics.Drops++
}
}

func (s *Service) Metrics() Metrics {
return s.metrics
}
Loading

0 comments on commit 50fa1ab

Please sign in to comment.