diff --git a/stack_mixed.go b/stack_mixed.go index 17aa3d5..6d729e0 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -13,7 +13,6 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/gvisor/pkg/waiter" "github.com/sagernet/sing-tun/internal/clashtcpip" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" @@ -95,7 +94,6 @@ func (m *Mixed) tunLoop() { m.wintunLoop(winTun) return } - if batchTUN, isBatchTUN := m.tun.(BatchTUN); isBatchTUN { batchSize := batchTUN.BatchSize() if batchSize > 1 { @@ -118,7 +116,12 @@ func (m *Mixed) tunLoop() { } rawPacket := packetBuffer[:frontHeadroom+n] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] - m.processPacket(rawPacket, packet) + if m.processPacket(packet) { + _, err = m.tun.Write(rawPacket) + if err != nil { + m.logger.Trace(E.Cause(err, "write packet")) + } + } } } @@ -132,7 +135,12 @@ func (m *Mixed) wintunLoop(winTun WinTun) { release() continue } - m.processPacket(packet, packet) + if m.processPacket(packet) { + _, err = winTun.Write(packet) + if err != nil { + m.logger.Trace(E.Cause(err, "write packet")) + } + } release() } } @@ -141,6 +149,7 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { frontHeadroom := m.tun.FrontHeadroom() packetBuffers := make([][]byte, batchSize) readBuffers := make([][]byte, batchSize) + writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset) @@ -163,69 +172,83 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - rawPacket := packetBuffer[:frontHeadroom+packetSize] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize] - m.processPacket(rawPacket, packet) + if m.processPacket(packet) { + writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) + } + } + if len(writeBuffers) > 0 { + err = linuxTUN.BatchWrite(writeBuffers) + if err != nil { + m.logger.Trace(E.Cause(err, "batch write packet")) + } + writeBuffers = writeBuffers[:0] } } } -func (m *Mixed) processPacket(rawPacket []byte, packet []byte) { - var err error +func (m *Mixed) processPacket(packet []byte) bool { + var ( + writeBack bool + err error + ) switch ipVersion := packet[0] >> 4; ipVersion { case 4: - err = m.processIPv4(rawPacket, packet) + writeBack, err = m.processIPv4(packet) case 6: - err = m.processIPv6(rawPacket, packet) + writeBack, err = m.processIPv6(packet) default: err = E.New("ip: unknown version: ", ipVersion) } if err != nil { m.logger.Trace(err) + return false } + return writeBack } -func (m *Mixed) processIPv4(rawPacket []byte, packet clashtcpip.IPv4Packet) error { +func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { + writeBack = true destination := packet.DestinationIP() if destination == m.broadcastAddr || !destination.IsGlobalUnicast() { - return common.Error(m.tun.Write(rawPacket)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return m.processIPv4TCP(rawPacket, packet, packet.Payload()) + err = m.processIPv4TCP(packet, packet.Payload()) case clashtcpip.UDP: + writeBack = false pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(packet), }) m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) pkt.DecRef() - return nil + return case clashtcpip.ICMP: - return m.processIPv4ICMP(rawPacket, packet, packet.Payload()) - default: - return common.Error(m.tun.Write(rawPacket)) + err = m.processIPv4ICMP(packet, packet.Payload()) } + return } -func (m *Mixed) processIPv6(rawPacket []byte, packet clashtcpip.IPv6Packet) error { +func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { + writeBack = true if !packet.DestinationIP().IsGlobalUnicast() { - return common.Error(m.tun.Write(rawPacket)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return m.processIPv6TCP(rawPacket, packet, packet.Payload()) + err = m.processIPv6TCP(packet, packet.Payload()) case clashtcpip.UDP: + writeBack = false pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(packet), }) m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) pkt.DecRef() - return nil case clashtcpip.ICMPv6: - return m.processIPv6ICMP(rawPacket, packet, packet.Payload()) - default: - return common.Error(m.tun.Write(rawPacket)) + err = m.processIPv6ICMP(packet, packet.Payload()) } + return } func (m *Mixed) packetLoop() { diff --git a/stack_system.go b/stack_system.go index 5e8240a..c614331 100644 --- a/stack_system.go +++ b/stack_system.go @@ -2,6 +2,7 @@ package tun import ( "context" + F "github.com/sagernet/sing/common/format" "net" "net/netip" "syscall" @@ -148,6 +149,7 @@ func (s *System) tunLoop() { if batchTUN, isBatchTUN := s.tun.(BatchTUN); isBatchTUN { batchSize := batchTUN.BatchSize() if batchSize > 1 { + println(F.ToString("batch size: ", batchSize)) s.batchLoop(batchTUN, batchSize) return } @@ -167,7 +169,12 @@ func (s *System) tunLoop() { } rawPacket := packetBuffer[:frontHeadroom+n] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] - s.processPacket(rawPacket, packet) + if s.processPacket(packet) { + _, err = s.tun.Write(rawPacket) + if err != nil { + s.logger.Trace(E.Cause(err, "write packet")) + } + } } } @@ -181,18 +188,24 @@ func (s *System) wintunLoop(winTun WinTun) { release() continue } - s.processPacket(packet, packet) + if s.processPacket(packet) { + _, err = winTun.Write(packet) + if err != nil { + s.logger.Trace(E.Cause(err, "write packet")) + } + } release() } } -func (m *System) batchLoop(linuxTUN BatchTUN, batchSize int) { - frontHeadroom := m.tun.FrontHeadroom() +func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { + frontHeadroom := s.tun.FrontHeadroom() packetBuffers := make([][]byte, batchSize) readBuffers := make([][]byte, batchSize) + writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset) + packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset) readBuffers[i] = packetBuffers[i][frontHeadroom:] } for { @@ -201,7 +214,7 @@ func (m *System) batchLoop(linuxTUN BatchTUN, batchSize int) { if E.IsClosed(err) { return } - m.logger.Error(E.Cause(err, "batch read packet")) + s.logger.Error(E.Cause(err, "batch read packet")) } if n == 0 { continue @@ -212,26 +225,39 @@ func (m *System) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - rawPacket := packetBuffer[:frontHeadroom+packetSize] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize] - m.processPacket(rawPacket, packet) + if s.processPacket(packet) { + writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) + } + } + if len(writeBuffers) > 0 { + err = linuxTUN.BatchWrite(writeBuffers) + if err != nil { + s.logger.Trace(E.Cause(err, "batch write packet")) + } + writeBuffers = writeBuffers[:0] } } } -func (s *System) processPacket(rawPacket []byte, packet []byte) { - var err error +func (s *System) processPacket(packet []byte) bool { + var ( + writeBack bool + err error + ) switch ipVersion := packet[0] >> 4; ipVersion { case 4: - err = s.processIPv4(rawPacket, packet) + writeBack, err = s.processIPv4(packet) case 6: - err = s.processIPv6(rawPacket, packet) + writeBack, err = s.processIPv6(packet) default: err = E.New("ip: unknown version: ", ipVersion) } if err != nil { s.logger.Trace(err) + return false } + return writeBack } func (s *System) acceptLoop(listener net.Listener) { @@ -275,44 +301,46 @@ func (s *System) acceptLoop(listener net.Listener) { } } -func (s *System) processIPv4(rawPacket []byte, packet clashtcpip.IPv4Packet) error { +func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { + writeBack = true destination := packet.DestinationIP() if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return s.processIPv4TCP(rawPacket, packet, packet.Payload()) + err = s.processIPv4TCP(packet, packet.Payload()) case clashtcpip.UDP: - return s.processIPv4UDP(rawPacket, packet, packet.Payload()) + writeBack = false + err = s.processIPv4UDP(packet, packet.Payload()) case clashtcpip.ICMP: - return s.processIPv4ICMP(rawPacket, packet, packet.Payload()) - default: - return common.Error(s.tun.Write(rawPacket)) + err = s.processIPv4ICMP(packet, packet.Payload()) } + return } -func (s *System) processIPv6(rawPacket []byte, packet clashtcpip.IPv6Packet) error { +func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { + writeBack = true if !packet.DestinationIP().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return } switch packet.Protocol() { case clashtcpip.TCP: - return s.processIPv6TCP(rawPacket, packet, packet.Payload()) + err = s.processIPv6TCP(packet, packet.Payload()) case clashtcpip.UDP: - return s.processIPv6UDP(rawPacket, packet, packet.Payload()) + writeBack = false + err = s.processIPv6UDP(packet, packet.Payload()) case clashtcpip.ICMPv6: - return s.processIPv6ICMP(rawPacket, packet, packet.Payload()) - default: - return common.Error(s.tun.Write(rawPacket)) + err = s.processIPv6ICMP(packet, packet.Payload()) } + return } -func (s *System) processIPv4TCP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { +func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return nil } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -331,14 +359,14 @@ func (s *System) processIPv4TCP(rawPacket []byte, packet clashtcpip.IPv4Packet, } header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(rawPacket)) + return nil } -func (s *System) processIPv6TCP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { +func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return nil } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -357,10 +385,10 @@ func (s *System) processIPv6TCP(rawPacket []byte, packet clashtcpip.IPv6Packet, } header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(rawPacket)) + return nil } -func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { +func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { if packet.Flags()&clashtcpip.FlagMoreFragment != 0 { return E.New("ipv4: fragment dropped") } @@ -373,7 +401,7 @@ func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet, source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return nil } data := buf.As(header.Payload()) if data.Len() == 0 { @@ -392,14 +420,14 @@ func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet, return nil } -func (s *System) processIPv6UDP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { +func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { if !header.Valid() { return E.New("ipv6: udp: invalid packet") } source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(rawPacket)) + return nil } data := buf.As(header.Payload()) if data.Len() == 0 { @@ -418,7 +446,7 @@ func (s *System) processIPv6UDP(rawPacket []byte, packet clashtcpip.IPv6Packet, return nil } -func (s *System) processIPv4ICMP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { +func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 { return nil } @@ -428,10 +456,10 @@ func (s *System) processIPv4ICMP(rawPacket []byte, packet clashtcpip.IPv4Packet, packet.SetDestinationIP(sourceAddress) header.ResetChecksum() packet.ResetChecksum() - return common.Error(s.tun.Write(rawPacket)) + return nil } -func (s *System) processIPv6ICMP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { +func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 { return nil } @@ -441,7 +469,7 @@ func (s *System) processIPv6ICMP(rawPacket []byte, packet clashtcpip.IPv6Packet, packet.SetDestinationIP(sourceAddress) header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(rawPacket)) + return nil } type systemUDPPacketWriter4 struct { diff --git a/tun.go b/tun.go index e6de7f1..abada3a 100644 --- a/tun.go +++ b/tun.go @@ -23,6 +23,7 @@ type Handler interface { type Tun interface { io.ReadWriter + N.VectorisedWriter N.FrontHeadroom Close() error } @@ -33,8 +34,10 @@ type WinTun interface { } type BatchTUN interface { + Tun BatchSize() int BatchRead(buffers [][]byte, readN []int) (n int, err error) + BatchWrite(buffers [][]byte) error } type Options struct { diff --git a/tun_linux.go b/tun_linux.go index 25801cf..37b3f82 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -1,7 +1,7 @@ package tun import ( - "io" + F "github.com/sagernet/sing/common/format" "math/rand" "net" "net/netip" @@ -13,7 +13,10 @@ import ( "github.com/sagernet/netlink" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/x/list" @@ -26,6 +29,7 @@ var _ BatchTUN = (*NativeTun)(nil) type NativeTun struct { tunFd int tunFile *os.File + tunWriter N.VectorisedWriter interfaceCallback *list.Element[DefaultInterfaceUpdateCallback] options Options ruleIndex6 []int @@ -36,6 +40,7 @@ type NativeTun struct { } func New(options Options) (Tun, error) { + var nativeTun *NativeTun if options.FileDescriptor == 0 { tunFd, err := open(options.Name, options.GSO) if err != nil { @@ -45,26 +50,28 @@ func New(options Options) (Tun, error) { if err != nil { return nil, E.Errors(err, unix.Close(tunFd)) } - nativeTun := &NativeTun{ + nativeTun = &NativeTun{ tunFd: tunFd, tunFile: os.NewFile(uintptr(tunFd), "tun"), options: options, } - runtime.SetFinalizer(nativeTun.tunFile, nil) err = nativeTun.configure(tunLink) if err != nil { return nil, E.Errors(err, unix.Close(tunFd)) } - return nativeTun, nil } else { - nativeTun := &NativeTun{ + nativeTun = &NativeTun{ tunFd: options.FileDescriptor, tunFile: os.NewFile(uintptr(options.FileDescriptor), "tun"), options: options, } - runtime.SetFinalizer(nativeTun.tunFile, nil) - return nativeTun, nil } + var ok bool + nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile) + if !ok { + panic("create vectorised writer") + } + return nativeTun, nil } func (t *NativeTun) FrontHeadroom() int { @@ -74,14 +81,6 @@ func (t *NativeTun) FrontHeadroom() int { return 0 } -func (t *NativeTun) UpstreamWriter() io.Writer { - return t.tunFile -} - -func (t *NativeTun) WriterReplaceable() bool { - return !t.gsoEnabled -} - func (t *NativeTun) Read(p []byte) (n int, err error) { if t.gsoEnabled { n, err = t.tunFile.Read(t.gsoBuffer) @@ -105,27 +104,35 @@ func (t *NativeTun) Read(p []byte) (n int, err error) { func (t *NativeTun) Write(p []byte) (n int, err error) { if t.gsoEnabled { - defer func() { - t.tcp4GROTable.reset() - t.tcp6GROTable.reset() - }() - var toWrite []int - err = handleGRO([][]byte{p}, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite) + err = t.BatchWrite([][]byte{p}) if err != nil { return } - if len(toWrite) == 0 { - return - } + n = len(p) + return } return t.tunFile.Write(p) } +func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { + if t.gsoEnabled { + n := buf.LenMulti(buffers) + buffer := buf.NewSize(virtioNetHdrLen + n) + buffer.Truncate(virtioNetHdrLen) + buf.CopyMulti(buffer.Extend(n), buffers) + _, err := t.tunFile.Write(buffer.Bytes()) + buffer.Release() + return err + } else { + return t.WriteVectorised(buffers) + } +} + func (t *NativeTun) BatchSize() int { if !t.gsoEnabled { return 1 } - return idealBatchSize + return idealBatchSize //int(t.options.GSOMaxSize/t.options.MTU) + 1 } func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) { @@ -138,12 +145,35 @@ func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) if err != nil { return } + if n > 0 { + println(F.ToString("batch read ", n)) + } return } else { return 0, os.ErrInvalid } } +func (t *NativeTun) BatchWrite(buffers [][]byte) error { + defer func() { + t.tcp4GROTable.reset() + t.tcp6GROTable.reset() + }() + var toWrite []int + err := handleGRO(buffers, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite) + if err != nil { + return err + } + println(F.ToString("batch write: ", len(buffers)), " to ", len(toWrite)) + for _, bufferIndex := range toWrite { + _, err = t.tunFile.Write(buffers[bufferIndex]) + if err != nil { + return err + } + } + return nil +} + var controlPath string func init() { diff --git a/tun_linux_offload.go b/tun_linux_offload.go index 8d3e4fb..97e2fa6 100644 --- a/tun_linux_offload.go +++ b/tun_linux_offload.go @@ -21,6 +21,8 @@ import ( "golang.org/x/sys/unix" ) +var ErrTooManySegments = E.New("too many segments") + const ( tcpFlagsOffset = 13 idealBatchSize = 128 diff --git a/tun_linux_offload_errors.go b/tun_linux_offload_errors.go deleted file mode 100644 index 8e5db90..0000000 --- a/tun_linux_offload_errors.go +++ /dev/null @@ -1,5 +0,0 @@ -package tun - -import E "github.com/sagernet/sing/common/exceptions" - -var ErrTooManySegments = E.New("too many segments")