Skip to content

Commit

Permalink
Update BatchTUN API for WireGuard
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 15, 2023
1 parent 0e13875 commit 3195f6f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 35 deletions.
10 changes: 4 additions & 6 deletions stack_mixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,13 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := m.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
readBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset)
readBuffers[i] = packetBuffers[i][frontHeadroom:]
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(readBuffers, packetSizes)
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
Expand All @@ -169,13 +167,13 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
if m.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers)
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
Expand Down
10 changes: 4 additions & 6 deletions stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,13 @@ func (s *System) wintunLoop(winTun WinTun) {
func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := s.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
readBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset)
readBuffers[i] = packetBuffers[i][frontHeadroom:]
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(readBuffers, packetSizes)
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
Expand All @@ -222,13 +220,13 @@ func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
if s.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers)
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
Expand Down
4 changes: 2 additions & 2 deletions tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ type WinTun interface {
type BatchTUN interface {
Tun
BatchSize() int
BatchRead(buffers [][]byte, readN []int) (n int, err error)
BatchWrite(buffers [][]byte) error
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
BatchWrite(buffers [][]byte, offset int) error
}

type Options struct {
Expand Down
34 changes: 15 additions & 19 deletions tun_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type NativeTun struct {
ruleIndex6 []int
gsoEnabled bool
gsoBuffer []byte
gsoToWrite []int
gsoReadAccess sync.Mutex
tcpGROAccess sync.Mutex
tcp4GROTable *tcpGROTable
tcp6GROTable *tcpGROTable
Expand Down Expand Up @@ -105,7 +107,7 @@ func (t *NativeTun) Read(p []byte) (n int, err error) {

func (t *NativeTun) Write(p []byte) (n int, err error) {
if t.gsoEnabled {
err = t.BatchWrite([][]byte{p})
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
if err != nil {
return
}
Expand Down Expand Up @@ -140,37 +142,31 @@ func (t *NativeTun) BatchSize() int {
return batchSize
}

func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) {
if t.gsoEnabled {
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
n, err = handleVirtioRead(t.gsoBuffer[:n], buffers, readN, 0)
if err != nil {
return
}

func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
t.gsoReadAccess.Lock()
defer t.gsoReadAccess.Unlock()
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
} else {
return 0, os.ErrInvalid
}
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
}

func (t *NativeTun) BatchWrite(buffers [][]byte) error {
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
t.tcpGROAccess.Lock()
defer func() {
t.tcp4GROTable.reset()
t.tcp6GROTable.reset()
t.tcpGROAccess.Unlock()
}()
var toWrite []int
err := handleGRO(buffers, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite)
t.gsoToWrite = t.gsoToWrite[:0]
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
if err != nil {
return err
}
for _, bufferIndex := range toWrite {
_, err = t.tunFile.Write(buffers[bufferIndex])
offset -= virtioNetHdrLen
for _, bufferIndex := range t.gsoToWrite {
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
if err != nil {
return err
}
Expand Down
8 changes: 6 additions & 2 deletions tun_linux_offload.go
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,12 @@ func checksumNoFold(b []byte, initial uint64) uint64 {
}

func checksumFold(b []byte, initial uint64) uint16 {
r := clashtcpip.Checksum(uint32(initial), b)
return binary.BigEndian.Uint16(r[:])
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}

func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
Expand Down

0 comments on commit 3195f6f

Please sign in to comment.