From ac77bdcfade06fdac0b6067a9474d9ce62d1599b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 19 Oct 2024 15:02:25 +0800 Subject: [PATCH] Add lazy conn support for gVisor --- stack_gvisor.go | 51 +++------- stack_gvisor_err.go | 50 ---------- stack_gvisor_lazy.go | 233 +++++++++++++++++++++++++++++++++++++++++++ stack_gvisor_udp.go | 71 +------------ stack_mixed.go | 3 +- 5 files changed, 248 insertions(+), 160 deletions(-) delete mode 100644 stack_gvisor_err.go create mode 100644 stack_gvisor_lazy.go diff --git a/stack_gvisor.go b/stack_gvisor.go index 795d963..fcdeeba 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -76,43 +76,17 @@ func (t *GVisor) Start() error { return err } tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - handshakeCtx, cancel := context.WithCancel(context.Background()) - go func() { - select { - case <-t.ctx.Done(): - wq.Notify(wq.Events()) - case <-handshakeCtx.Done(): - } - }() - endpoint, err := r.CreateEndpoint(&wq) - cancel() - if err != nil { - r.Complete(true) - return - } - r.Complete(false) - endpoint.SocketOptions().SetKeepAlive(true) - keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second) - endpoint.SetSockOpt(&keepAliveIdle) - keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second) - endpoint.SetSockOpt(&keepAliveInterval) - tcpConn := gonet.NewTCPConn(&wq, endpoint) - lAddr := tcpConn.RemoteAddr() - rAddr := tcpConn.LocalAddr() - if lAddr == nil || rAddr == nil { - tcpConn.Close() - return + var metadata M.Metadata + metadata.Source = M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) + metadata.Destination = M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) + conn := &gLazyConn{ + parentCtx: t.ctx, + stack: t.stack, + request: r, + localAddr: metadata.Source.TCPAddr(), + remoteAddr: metadata.Destination.TCPAddr(), } - go func() { - var metadata M.Metadata - metadata.Source = M.SocksaddrFromNet(lAddr) - metadata.Destination = M.SocksaddrFromNet(rAddr) - hErr := t.handler.NewConnection(t.ctx, &gTCPConn{tcpConn}, metadata) - if hErr != nil { - endpoint.Abort() - } - }() + _ = t.handler.NewConnection(t.ctx, conn, metadata) }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) if !t.endpointIndependentNat { @@ -129,12 +103,11 @@ func (t *GVisor) Start() error { endpoint.Abort() return } - gConn := &gUDPConn{UDPConn: udpConn} go func() { var metadata M.Metadata metadata.Source = M.SocksaddrFromNet(lAddr) metadata.Destination = M.SocksaddrFromNet(rAddr) - ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second) + ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second) hErr := t.handler.NewPacketConnection(ctx, conn, metadata) if hErr != nil { endpoint.Abort() @@ -191,7 +164,7 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { }) tErr := ipStack.CreateNIC(defaultNIC, ep) if tErr != nil { - return nil, E.New("create nic: ", wrapStackError(tErr)) + return nil, E.New("create nic: ", gonet.TranslateNetstackError(tErr)) } ipStack.SetRouteTable([]tcpip.Route{ {Destination: header.IPv4EmptySubnet, NIC: defaultNIC}, diff --git a/stack_gvisor_err.go b/stack_gvisor_err.go deleted file mode 100644 index 51ce23c..0000000 --- a/stack_gvisor_err.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build with_gvisor - -package tun - -import ( - "net" - - "github.com/sagernet/gvisor/pkg/tcpip" - "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" - E "github.com/sagernet/sing/common/exceptions" -) - -type gTCPConn struct { - *gonet.TCPConn -} - -func (c *gTCPConn) Upstream() any { - return c.TCPConn -} - -func (c *gTCPConn) Write(b []byte) (n int, err error) { - n, err = c.TCPConn.Write(b) - if err == nil { - return - } - err = wrapError(err) - return -} - -func wrapStackError(err tcpip.Error) error { - switch err.(type) { - case *tcpip.ErrClosedForSend, - *tcpip.ErrClosedForReceive, - *tcpip.ErrAborted: - return net.ErrClosed - } - return E.New(err.String()) -} - -func wrapError(err error) error { - if opErr, isOpErr := err.(*net.OpError); isOpErr { - switch opErr.Err.Error() { - case "endpoint is closed for send", - "endpoint is closed for receive", - "operation aborted": - return net.ErrClosed - } - } - return err -} diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go new file mode 100644 index 0000000..44e3010 --- /dev/null +++ b/stack_gvisor_lazy.go @@ -0,0 +1,233 @@ +//go:build with_gvisor + +package tun + +import ( + "context" + "errors" + "net" + "os" + "sync" + "syscall" + "time" + + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" + "github.com/sagernet/gvisor/pkg/waiter" +) + +type gLazyConn struct { + tcpConn *gonet.TCPConn + parentCtx context.Context + stack *stack.Stack + request *tcp.ForwarderRequest + localAddr net.Addr + remoteAddr net.Addr + handshakeAccess sync.Mutex + handshakeDone bool + handshakeErr error +} + +func (c *gLazyConn) HandshakeContext(ctx context.Context) error { + if c.handshakeDone { + return nil + } + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + if c.handshakeDone { + return nil + } + defer func() { + c.handshakeDone = true + }() + var ( + wq waiter.Queue + endpoint tcpip.Endpoint + ) + handshakeCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-c.parentCtx.Done(): + wq.Notify(wq.Events()) + case <-handshakeCtx.Done(): + } + }() + endpoint, err := c.request.CreateEndpoint(&wq) + cancel() + if err != nil { + gErr := gonet.TranslateNetstackError(err) + c.handshakeErr = gErr + c.request.Complete(true) + return gErr + } + c.request.Complete(false) + endpoint.SocketOptions().SetKeepAlive(true) + keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second) + endpoint.SetSockOpt(&keepAliveIdle) + keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second) + endpoint.SetSockOpt(&keepAliveInterval) + tcpConn := gonet.NewTCPConn(&wq, endpoint) + c.tcpConn = tcpConn + return nil +} + +func (c *gLazyConn) HandshakeFailure(err error) error { + if c.handshakeDone { + return nil + } + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + if c.handshakeDone { + return nil + } + wErr := gWriteUnreachable(c.stack, c.request.Packet(), err) + c.request.Complete(wErr == os.ErrInvalid) + c.handshakeDone = true + c.handshakeErr = err + return nil +} + +func (c *gLazyConn) HandshakeSuccess() error { + return c.HandshakeContext(context.Background()) +} + +func (c *gLazyConn) Read(b []byte) (n int, err error) { + if !c.handshakeDone { + err = c.HandshakeContext(context.Background()) + if err != nil { + return + } + } else if c.handshakeErr != nil { + return 0, c.handshakeErr + } + return c.tcpConn.Read(b) +} + +func (c *gLazyConn) Write(b []byte) (n int, err error) { + if !c.handshakeDone { + err = c.HandshakeContext(context.Background()) + if err != nil { + return + } + } else if c.handshakeErr != nil { + return 0, c.handshakeErr + } + return c.tcpConn.Write(b) +} + +func (c *gLazyConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *gLazyConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *gLazyConn) SetDeadline(t time.Time) error { + if !c.handshakeDone { + err := c.HandshakeContext(context.Background()) + if err != nil { + return err + } + } else if c.handshakeErr != nil { + return c.handshakeErr + } + return c.tcpConn.SetDeadline(t) +} + +func (c *gLazyConn) SetReadDeadline(t time.Time) error { + if !c.handshakeDone { + err := c.HandshakeContext(context.Background()) + if err != nil { + return err + } + } else if c.handshakeErr != nil { + return c.handshakeErr + } + return c.tcpConn.SetReadDeadline(t) +} + +func (c *gLazyConn) SetWriteDeadline(t time.Time) error { + if !c.handshakeDone { + err := c.HandshakeContext(context.Background()) + if err != nil { + return err + } + } else if c.handshakeErr != nil { + return c.handshakeErr + } + return c.tcpConn.SetWriteDeadline(t) +} + +func (c *gLazyConn) Close() error { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + if !c.handshakeDone { + c.request.Complete(true) + c.handshakeErr = net.ErrClosed + return nil + } else if c.handshakeErr != nil { + return nil + } + return c.tcpConn.Close() +} + +func (c *gLazyConn) CloseRead() error { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + if !c.handshakeDone { + c.request.Complete(true) + c.handshakeErr = net.ErrClosed + return nil + } else if c.handshakeErr != nil { + return nil + } + return c.tcpConn.CloseRead() +} + +func (c *gLazyConn) CloseWrite() error { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + if !c.handshakeDone { + c.request.Complete(true) + c.handshakeErr = net.ErrClosed + return nil + } else if c.handshakeErr != nil { + return nil + } + return c.tcpConn.CloseRead() +} + +func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error { + if errors.Is(err, syscall.ENETUNREACH) { + if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) + } else { + return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) + } + } else if errors.Is(err, syscall.EHOSTUNREACH) { + if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostProhibited) + } else { + return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) + } + } else if errors.Is(err, syscall.ECONNREFUSED) { + if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) + } else { + return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable) + } + } + return os.ErrInvalid +} + +func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error { + return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)) +} + +func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error { + return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)) +} diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 7d85185..a97eff4 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -4,12 +4,10 @@ package tun import ( "context" - "errors" "math" "net/netip" "os" "sync" - "syscall" "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" @@ -103,7 +101,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock false, ) if err != nil { - return wrapStackError(err) + return gonet.TranslateNetstackError(err) } defer route.Release() @@ -140,74 +138,9 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock }, packet) if err != nil { route.Stats().UDP.PacketSendErrors.Increment() - return wrapStackError(err) + return gonet.TranslateNetstackError(err) } route.Stats().UDP.PacketsSent.Increment() return nil } - -type gUDPConn struct { - *gonet.UDPConn -} - -func (c *gUDPConn) Read(b []byte) (n int, err error) { - n, err = c.UDPConn.Read(b) - if err == nil { - return - } - err = wrapError(err) - return -} - -func (c *gUDPConn) Write(b []byte) (n int, err error) { - n, err = c.UDPConn.Write(b) - if err == nil { - return - } - err = wrapError(err) - return -} - -func (c *gUDPConn) Close() error { - return c.UDPConn.Close() -} - -func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) (retErr error) { - if errors.Is(err, syscall.ENETUNREACH) { - if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable) - } else { - return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) - } - } else if errors.Is(err, syscall.EHOSTUNREACH) { - if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable) - } else { - return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) - } - } else if errors.Is(err, syscall.ECONNREFUSED) { - if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) - } else { - return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable) - } - } - return nil -} - -func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error { - err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true) - if err != nil { - return wrapStackError(err) - } - return nil -} - -func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error { - err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true) - if err != nil { - return wrapStackError(err) - } - return nil -} diff --git a/stack_mixed.go b/stack_mixed.go index c1abbb7..4872c1f 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -63,12 +63,11 @@ func (m *Mixed) Start() error { endpoint.Abort() return } - gConn := &gUDPConn{UDPConn: udpConn} go func() { var metadata M.Metadata metadata.Source = M.SocksaddrFromNet(lAddr) metadata.Destination = M.SocksaddrFromNet(rAddr) - ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(m.udpTimeout)*time.Second) + ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, metadata.Destination), time.Duration(m.udpTimeout)*time.Second) hErr := m.handler.NewPacketConnection(ctx, conn, metadata) if hErr != nil { endpoint.Abort()