diff --git a/stack_gvisor.go b/stack_gvisor.go index 6a1d0f3..ca41a37 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -70,7 +70,7 @@ func (t *GVisor) Start() error { if err != nil { return err } - linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, bufio.NewVectorisedWriter(t.tun)} + linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun} ipStack, err := newGVisorStack(linkEndpoint) if err != nil { return err diff --git a/stack_mixed.go b/stack_mixed.go index 17aa3d5..b52c996 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -13,17 +13,14 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/gvisor/pkg/waiter" "github.com/sagernet/sing-tun/internal/clashtcpip" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" ) type Mixed struct { *System - writer N.VectorisedWriter endpointIndependentNat bool stack *stack.Stack endpoint *channel.Endpoint @@ -38,7 +35,6 @@ func NewMixed( } return &Mixed{ System: system.(*System), - writer: bufio.NewVectorisedWriter(options.Tun), endpointIndependentNat: options.EndpointIndependentNat, }, nil } @@ -95,7 +91,6 @@ func (m *Mixed) tunLoop() { m.wintunLoop(winTun) return } - if batchTUN, isBatchTUN := m.tun.(BatchTUN); isBatchTUN { batchSize := batchTUN.BatchSize() if batchSize > 1 { @@ -118,7 +113,12 @@ func (m *Mixed) tunLoop() { } rawPacket := packetBuffer[:frontHeadroom+n] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] - m.processPacket(rawPacket, packet) + if m.processPacket(packet) { + _, err = m.tun.Write(rawPacket) + if err != nil { + m.logger.Trace(E.Cause(err, "write packet")) + } + } } } @@ -132,7 +132,12 @@ func (m *Mixed) wintunLoop(winTun WinTun) { release() continue } - m.processPacket(packet, packet) + if m.processPacket(packet) { + _, err = winTun.Write(packet) + if err != nil { + m.logger.Trace(E.Cause(err, "write packet")) + } + } release() } } @@ -141,6 +146,7 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { frontHeadroom := m.tun.FrontHeadroom() packetBuffers := make([][]byte, batchSize) readBuffers := make([][]byte, batchSize) + writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset) @@ -163,69 +169,85 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - rawPacket := packetBuffer[:frontHeadroom+packetSize] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize] - m.processPacket(rawPacket, packet) + if m.processPacket(packet) { + writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) + } + } + if len(writeBuffers) > 0 { + err = linuxTUN.BatchWrite(writeBuffers) + if err != nil { + m.logger.Trace(E.Cause(err, "batch write packet")) + } + writeBuffers = writeBuffers[:0] } } } -func (m *Mixed) processPacket(rawPacket []byte, packet []byte) { - var err error +func (m *Mixed) processPacket(packet []byte) bool { + var ( + writeBack bool + err error + ) switch ipVersion := packet[0] >> 4; ipVersion { case 4: - err = m.processIPv4(rawPacket, packet) + writeBack, err = m.processIPv4(packet) case 6: - err = m.processIPv6(rawPacket, packet) + writeBack, err = m.processIPv6(packet) default: err = E.New("ip: unknown version: ", ipVersion) } if err != nil { m.logger.Trace(err) + return false } + return writeBack } -func (m *Mixed) processIPv4(rawPacket []byte, packet clashtcpip.IPv4Packet) error { +func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { + writeBack = true destination := packet.DestinationIP() if destination == m.broadcastAddr || !destination.IsGlobalUnicast() { - return common.Error(m.tun.Write(rawPacket)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return m.processIPv4TCP(rawPacket, packet, packet.Payload()) + err = m.processIPv4TCP(packet, packet.Payload()) case clashtcpip.UDP: + writeBack = false pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packet), + Payload: buffer.MakeWithData(packet), + IsForwardedPacket: true, }) m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) pkt.DecRef() - return nil + return case clashtcpip.ICMP: - return m.processIPv4ICMP(rawPacket, packet, packet.Payload()) - default: - return common.Error(m.tun.Write(rawPacket)) + err = m.processIPv4ICMP(packet, packet.Payload()) } + return } -func (m *Mixed) processIPv6(rawPacket []byte, packet clashtcpip.IPv6Packet) error { +func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { + writeBack = true if !packet.DestinationIP().IsGlobalUnicast() { - return common.Error(m.tun.Write(rawPacket)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return m.processIPv6TCP(rawPacket, packet, packet.Payload()) + err = m.processIPv6TCP(packet, packet.Payload()) case clashtcpip.UDP: + writeBack = false pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packet), + Payload: buffer.MakeWithData(packet), + IsForwardedPacket: true, }) m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) pkt.DecRef() - return nil case clashtcpip.ICMPv6: - return m.processIPv6ICMP(rawPacket, packet, packet.Payload()) - default: - return common.Error(m.tun.Write(rawPacket)) + err = m.processIPv6ICMP(packet, packet.Payload()) } + return } func (m *Mixed) packetLoop() { @@ -234,7 +256,7 @@ func (m *Mixed) packetLoop() { if packet == nil { break } - bufio.WriteVectorised(m.writer, packet.AsSlices()) + bufio.WriteVectorised(m.tun, packet.AsSlices()) packet.DecRef() } } diff --git a/stack_system.go b/stack_system.go index 5e8240a..73a83ae 100644 --- a/stack_system.go +++ b/stack_system.go @@ -41,7 +41,6 @@ type System struct { udpNat *udpnat.Service[netip.AddrPort] bindInterface bool interfaceFinder control.InterfaceFinder - offload bool } type Session struct { @@ -167,7 +166,12 @@ func (s *System) tunLoop() { } rawPacket := packetBuffer[:frontHeadroom+n] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] - s.processPacket(rawPacket, packet) + if s.processPacket(packet) { + _, err = s.tun.Write(rawPacket) + if err != nil { + s.logger.Trace(E.Cause(err, "write packet")) + } + } } } @@ -181,18 +185,24 @@ func (s *System) wintunLoop(winTun WinTun) { release() continue } - s.processPacket(packet, packet) + if s.processPacket(packet) { + _, err = winTun.Write(packet) + if err != nil { + s.logger.Trace(E.Cause(err, "write packet")) + } + } release() } } -func (m *System) batchLoop(linuxTUN BatchTUN, batchSize int) { - frontHeadroom := m.tun.FrontHeadroom() +func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { + frontHeadroom := s.tun.FrontHeadroom() packetBuffers := make([][]byte, batchSize) readBuffers := make([][]byte, batchSize) + writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset) + packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset) readBuffers[i] = packetBuffers[i][frontHeadroom:] } for { @@ -201,7 +211,7 @@ func (m *System) batchLoop(linuxTUN BatchTUN, batchSize int) { if E.IsClosed(err) { return } - m.logger.Error(E.Cause(err, "batch read packet")) + s.logger.Error(E.Cause(err, "batch read packet")) } if n == 0 { continue @@ -212,26 +222,39 @@ func (m *System) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - rawPacket := packetBuffer[:frontHeadroom+packetSize] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize] - m.processPacket(rawPacket, packet) + if s.processPacket(packet) { + writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) + } + } + if len(writeBuffers) > 0 { + err = linuxTUN.BatchWrite(writeBuffers) + if err != nil { + s.logger.Trace(E.Cause(err, "batch write packet")) + } + writeBuffers = writeBuffers[:0] } } } -func (s *System) processPacket(rawPacket []byte, packet []byte) { - var err error +func (s *System) processPacket(packet []byte) bool { + var ( + writeBack bool + err error + ) switch ipVersion := packet[0] >> 4; ipVersion { case 4: - err = s.processIPv4(rawPacket, packet) + writeBack, err = s.processIPv4(packet) case 6: - err = s.processIPv6(rawPacket, packet) + writeBack, err = s.processIPv6(packet) default: err = E.New("ip: unknown version: ", ipVersion) } if err != nil { s.logger.Trace(err) + return false } + return writeBack } func (s *System) acceptLoop(listener net.Listener) { @@ -275,44 +298,46 @@ func (s *System) acceptLoop(listener net.Listener) { } } -func (s *System) processIPv4(rawPacket []byte, packet clashtcpip.IPv4Packet) error { +func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { + writeBack = true destination := packet.DestinationIP() if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return s.processIPv4TCP(rawPacket, packet, packet.Payload()) + err = s.processIPv4TCP(packet, packet.Payload()) case clashtcpip.UDP: - return s.processIPv4UDP(rawPacket, packet, packet.Payload()) + writeBack = false + err = s.processIPv4UDP(packet, packet.Payload()) case clashtcpip.ICMP: - return s.processIPv4ICMP(rawPacket, packet, packet.Payload()) - default: - return common.Error(s.tun.Write(rawPacket)) + err = s.processIPv4ICMP(packet, packet.Payload()) } + return } -func (s *System) processIPv6(rawPacket []byte, packet clashtcpip.IPv6Packet) error { +func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { + writeBack = true if !packet.DestinationIP().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return s.processIPv6TCP(rawPacket, packet, packet.Payload()) + err = s.processIPv6TCP(packet, packet.Payload()) case clashtcpip.UDP: - return s.processIPv6UDP(rawPacket, packet, packet.Payload()) + writeBack = false + err = s.processIPv6UDP(packet, packet.Payload()) case clashtcpip.ICMPv6: - return s.processIPv6ICMP(rawPacket, packet, packet.Payload()) - default: - return common.Error(s.tun.Write(rawPacket)) + err = s.processIPv6ICMP(packet, packet.Payload()) } + return } -func (s *System) processIPv4TCP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { +func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return nil } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -331,14 +356,14 @@ func (s *System) processIPv4TCP(rawPacket []byte, packet clashtcpip.IPv4Packet, } header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(rawPacket)) + return nil } -func (s *System) processIPv6TCP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { +func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return nil } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -357,10 +382,10 @@ func (s *System) processIPv6TCP(rawPacket []byte, packet clashtcpip.IPv6Packet, } header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(rawPacket)) + return nil } -func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { +func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { if packet.Flags()&clashtcpip.FlagMoreFragment != 0 { return E.New("ipv4: fragment dropped") } @@ -373,7 +398,7 @@ func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet, source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return nil } data := buf.As(header.Payload()) if data.Len() == 0 { @@ -392,14 +417,14 @@ func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet, return nil } -func (s *System) processIPv6UDP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { +func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { if !header.Valid() { return E.New("ipv6: udp: invalid packet") } source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return nil } data := buf.As(header.Payload()) if data.Len() == 0 { @@ -418,7 +443,7 @@ func (s *System) processIPv6UDP(rawPacket []byte, packet clashtcpip.IPv6Packet, return nil } -func (s *System) processIPv4ICMP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { +func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 { return nil } @@ -428,10 +453,10 @@ func (s *System) processIPv4ICMP(rawPacket []byte, packet clashtcpip.IPv4Packet, packet.SetDestinationIP(sourceAddress) header.ResetChecksum() packet.ResetChecksum() - return common.Error(s.tun.Write(rawPacket)) + return nil } -func (s *System) processIPv6ICMP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { +func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 { return nil } @@ -441,7 +466,7 @@ func (s *System) processIPv6ICMP(rawPacket []byte, packet clashtcpip.IPv6Packet, packet.SetDestinationIP(sourceAddress) header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(rawPacket)) + return nil } type systemUDPPacketWriter4 struct { diff --git a/tun.go b/tun.go index e6de7f1..abada3a 100644 --- a/tun.go +++ b/tun.go @@ -23,6 +23,7 @@ type Handler interface { type Tun interface { io.ReadWriter + N.VectorisedWriter N.FrontHeadroom Close() error } @@ -33,8 +34,10 @@ type WinTun interface { } type BatchTUN interface { + Tun BatchSize() int BatchRead(buffers [][]byte, readN []int) (n int, err error) + BatchWrite(buffers [][]byte) error } type Options struct { diff --git a/tun_darwin_gvisor.go b/tun_darwin_gvisor.go index b8431fb..a1f13ae 100644 --- a/tun_darwin_gvisor.go +++ b/tun_darwin_gvisor.go @@ -51,13 +51,13 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } func (e *DarwinEndpoint) dispatchLoop() { - packetBuffer := make([]byte, e.tun.mtu+4) + packetBuffer := make([]byte, e.tun.mtu+PacketOffset) for { n, err := e.tun.tunFile.Read(packetBuffer) if err != nil { break } - packet := packetBuffer[4:n] + packet := packetBuffer[PacketOffset:n] var networkProtocol tcpip.NetworkProtocolNumber switch header.IPVersion(packet) { case header.IPv4Version: diff --git a/tun_linux.go b/tun_linux.go index 25801cf..ac237c1 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -1,19 +1,22 @@ package tun import ( - "io" "math/rand" "net" "net/netip" "os" "os/exec" "runtime" + "sync" "syscall" "unsafe" "github.com/sagernet/netlink" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/x/list" @@ -26,16 +29,19 @@ var _ BatchTUN = (*NativeTun)(nil) type NativeTun struct { tunFd int tunFile *os.File + tunWriter N.VectorisedWriter interfaceCallback *list.Element[DefaultInterfaceUpdateCallback] options Options ruleIndex6 []int gsoEnabled bool gsoBuffer []byte + tcpGROAccess sync.Mutex tcp4GROTable *tcpGROTable tcp6GROTable *tcpGROTable } func New(options Options) (Tun, error) { + var nativeTun *NativeTun if options.FileDescriptor == 0 { tunFd, err := open(options.Name, options.GSO) if err != nil { @@ -45,26 +51,28 @@ func New(options Options) (Tun, error) { if err != nil { return nil, E.Errors(err, unix.Close(tunFd)) } - nativeTun := &NativeTun{ + nativeTun = &NativeTun{ tunFd: tunFd, tunFile: os.NewFile(uintptr(tunFd), "tun"), options: options, } - runtime.SetFinalizer(nativeTun.tunFile, nil) err = nativeTun.configure(tunLink) if err != nil { return nil, E.Errors(err, unix.Close(tunFd)) } - return nativeTun, nil } else { - nativeTun := &NativeTun{ + nativeTun = &NativeTun{ tunFd: options.FileDescriptor, tunFile: os.NewFile(uintptr(options.FileDescriptor), "tun"), options: options, } - runtime.SetFinalizer(nativeTun.tunFile, nil) - return nativeTun, nil } + var ok bool + nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile) + if !ok { + panic("create vectorised writer") + } + return nativeTun, nil } func (t *NativeTun) FrontHeadroom() int { @@ -74,14 +82,6 @@ func (t *NativeTun) FrontHeadroom() int { return 0 } -func (t *NativeTun) UpstreamWriter() io.Writer { - return t.tunFile -} - -func (t *NativeTun) WriterReplaceable() bool { - return !t.gsoEnabled -} - func (t *NativeTun) Read(p []byte) (n int, err error) { if t.gsoEnabled { n, err = t.tunFile.Read(t.gsoBuffer) @@ -105,27 +105,39 @@ func (t *NativeTun) Read(p []byte) (n int, err error) { func (t *NativeTun) Write(p []byte) (n int, err error) { if t.gsoEnabled { - defer func() { - t.tcp4GROTable.reset() - t.tcp6GROTable.reset() - }() - var toWrite []int - err = handleGRO([][]byte{p}, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite) + err = t.BatchWrite([][]byte{p}) if err != nil { return } - if len(toWrite) == 0 { - return - } + n = len(p) + return } return t.tunFile.Write(p) } +func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { + if t.gsoEnabled { + n := buf.LenMulti(buffers) + buffer := buf.NewSize(virtioNetHdrLen + n) + buffer.Truncate(virtioNetHdrLen) + buf.CopyMulti(buffer.Extend(n), buffers) + _, err := t.tunFile.Write(buffer.Bytes()) + buffer.Release() + return err + } else { + return t.tunWriter.WriteVectorised(buffers) + } +} + func (t *NativeTun) BatchSize() int { if !t.gsoEnabled { return 1 } - return idealBatchSize + batchSize := int(t.options.GSOMaxSize/t.options.MTU) * 2 + if batchSize > idealBatchSize { + batchSize = idealBatchSize + } + return batchSize } func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) { @@ -138,12 +150,34 @@ func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) if err != nil { return } + return } else { return 0, os.ErrInvalid } } +func (t *NativeTun) BatchWrite(buffers [][]byte) error { + t.tcpGROAccess.Lock() + defer func() { + t.tcp4GROTable.reset() + t.tcp6GROTable.reset() + t.tcpGROAccess.Unlock() + }() + var toWrite []int + err := handleGRO(buffers, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite) + if err != nil { + return err + } + for _, bufferIndex := range toWrite { + _, err = t.tunFile.Write(buffers[bufferIndex]) + if err != nil { + return err + } + } + return nil +} + var controlPath string func init() {