From 95328de17ecf4a8f8793bdab0c6007599269c8ab Mon Sep 17 00:00:00 2001 From: zjx158094 Date: Sun, 25 Apr 2021 20:32:24 +0800 Subject: [PATCH] make batchConn more generic --- batchconn.go | 5 + batchconn_generic.go | 28 +++++ batchconn_linux.go | 80 +++++++++++++++ readloop.go | 93 ++++++++++++++++- readloop_generic.go | 11 -- readloop_linux.go | 111 -------------------- sess.go | 13 +-- sess_test.go | 239 ++++++++++++++++++++++++++++++++++++++----- tx.go | 35 +++++++ tx_generic.go | 11 -- tx_linux.go | 51 --------- 11 files changed, 452 insertions(+), 225 deletions(-) create mode 100644 batchconn_generic.go create mode 100644 batchconn_linux.go delete mode 100644 readloop_generic.go delete mode 100644 readloop_linux.go delete mode 100644 tx_generic.go delete mode 100644 tx_linux.go diff --git a/batchconn.go b/batchconn.go index 6c307010..c26aff16 100644 --- a/batchconn.go +++ b/batchconn.go @@ -10,3 +10,8 @@ type batchConn interface { WriteBatch(ms []ipv4.Message, flags int) (int, error) ReadBatch(ms []ipv4.Message, flags int) (int, error) } + +type batchErrDetector interface { + ReadBatchUnavailable(err error) bool + WriteBatchUnavailable(err error) bool +} diff --git a/batchconn_generic.go b/batchconn_generic.go new file mode 100644 index 00000000..42aa6714 --- /dev/null +++ b/batchconn_generic.go @@ -0,0 +1,28 @@ +// +build !linux + +package kcp + +import ( + "net" +) + +func toBatchConn(c net.PacketConn) batchConn { + if xconn, ok := c.(batchConn); ok { + return xconn + } + return nil +} + +func readBatchUnavailable(xconn batchConn, err error) bool { + if detector, ok := xconn.(batchErrDetector); ok { + return detector.ReadBatchUnavailable(err) + } + return false +} + +func writeBatchUnavailable(xconn batchConn, err error) bool { + if detector, ok := xconn.(batchErrDetector); ok { + return detector.WriteBatchUnavailable(err) + } + return false +} diff --git a/batchconn_linux.go b/batchconn_linux.go new file mode 100644 index 00000000..6f0493a3 --- /dev/null +++ b/batchconn_linux.go @@ -0,0 +1,80 @@ +// +build linux + +package kcp + +import ( + "net" + "os" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func toBatchConn(c net.PacketConn) batchConn { + if xconn, ok := c.(batchConn); ok { + return xconn + } + if _, ok := c.(*net.UDPConn); ok { + var xconn batchConn + addr, err := net.ResolveUDPAddr("udp", c.LocalAddr().String()) + if err == nil { + if addr.IP.To4() != nil { + xconn = ipv4.NewPacketConn(c) + } else { + xconn = ipv6.NewPacketConn(c) + } + } + return xconn + } + return nil +} + +func isPacketConn(xconn batchConn) bool { + if _, ok := xconn.(*ipv4.PacketConn); ok { + return true + } + if _, ok := xconn.(*ipv6.PacketConn); ok { + return true + } + return false +} + +func readBatchUnavailable(xconn batchConn, err error) bool { + if isPacketConn(xconn) { + // compatibility issue: + // for linux kernel<=2.6.32, support for sendmmsg is not available + // an error of type os.SyscallError will be returned + if operr, ok := err.(*net.OpError); ok { + if se, ok := operr.Err.(*os.SyscallError); ok { + if se.Syscall == "recvmmsg" { + return true + } + } + } + return false + } + if detector, ok := xconn.(batchErrDetector); ok { + return detector.ReadBatchUnavailable(err) + } + return false +} + +func writeBatchUnavailable(xconn batchConn, err error) bool { + if isPacketConn(xconn) { + // compatibility issue: + // for linux kernel<=2.6.32, support for sendmmsg is not available + // an error of type os.SyscallError will be returned + if operr, ok := err.(*net.OpError); ok { + if se, ok := operr.Err.(*os.SyscallError); ok { + if se.Syscall == "sendmmsg" { + return true + } + } + } + return false + } + if detector, ok := xconn.(batchErrDetector); ok { + return detector.WriteBatchUnavailable(err) + } + return false +} diff --git a/readloop.go b/readloop.go index 697395ab..608cd6af 100644 --- a/readloop.go +++ b/readloop.go @@ -4,19 +4,42 @@ import ( "sync/atomic" "github.com/pkg/errors" + "golang.org/x/net/ipv4" ) +func (s *UDPSession) readLoop() { + // default version + if s.xconn == nil { + s.defaultReadLoop() + return + } + s.batchReadLoop() +} + +func (l *Listener) monitor() { + xconn := toBatchConn(l.conn) + + // default version + if xconn == nil { + l.defaultMonitor() + return + } + l.batchMonitor(xconn) +} + func (s *UDPSession) defaultReadLoop() { buf := make([]byte, mtuLimit) var src string for { if n, addr, err := s.conn.ReadFrom(buf); err == nil { // make sure the packet is from the same source - if src == "" { // set source address - src = addr.String() - } else if addr.String() != src { - atomic.AddUint64(&DefaultSnmp.InErrs, 1) - continue + if addr.String() != src { + if len(src) == 0 { // set source address + src = addr.String() + } else { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + continue + } } s.packetInput(buf[:n]) } else { @@ -37,3 +60,63 @@ func (l *Listener) defaultMonitor() { } } } + +func (s *UDPSession) batchReadLoop() { + // x/net version + var src string + msgs := make([]ipv4.Message, batchSize) + for k := range msgs { + msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)} + } + + for { + if count, err := s.xconn.ReadBatch(msgs, 0); err == nil { + for i := 0; i < count; i++ { + msg := &msgs[i] + // make sure the packet is from the same source + if msg.Addr.String() != src { + if len(src) == 0 { // set source address if nil + src = msg.Addr.String() + } else { + atomic.AddUint64(&DefaultSnmp.InErrs, 1) + continue + } + } + + // source and size has validated + s.packetInput(msg.Buffers[0][:msg.N]) + } + } else { + if readBatchUnavailable(s.xconn, err) { + s.defaultReadLoop() + return + } + s.notifyReadError(errors.WithStack(err)) + return + } + } +} + +func (l *Listener) batchMonitor(xconn batchConn) { + // x/net version + msgs := make([]ipv4.Message, batchSize) + for k := range msgs { + msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)} + } + + for { + if count, err := xconn.ReadBatch(msgs, 0); err == nil { + for i := 0; i < count; i++ { + msg := &msgs[i] + l.packetInput(msg.Buffers[0][:msg.N], msg.Addr) + } + } else { + if readBatchUnavailable(xconn, err) { + l.defaultMonitor() + return + } + l.notifyReadError(errors.WithStack(err)) + return + } + } +} diff --git a/readloop_generic.go b/readloop_generic.go deleted file mode 100644 index 5dbe4f44..00000000 --- a/readloop_generic.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build !linux - -package kcp - -func (s *UDPSession) readLoop() { - s.defaultReadLoop() -} - -func (l *Listener) monitor() { - l.defaultMonitor() -} diff --git a/readloop_linux.go b/readloop_linux.go deleted file mode 100644 index be194afb..00000000 --- a/readloop_linux.go +++ /dev/null @@ -1,111 +0,0 @@ -// +build linux - -package kcp - -import ( - "net" - "os" - "sync/atomic" - - "github.com/pkg/errors" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -// the read loop for a client session -func (s *UDPSession) readLoop() { - // default version - if s.xconn == nil { - s.defaultReadLoop() - return - } - - // x/net version - var src string - msgs := make([]ipv4.Message, batchSize) - for k := range msgs { - msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)} - } - - for { - if count, err := s.xconn.ReadBatch(msgs, 0); err == nil { - for i := 0; i < count; i++ { - msg := &msgs[i] - // make sure the packet is from the same source - if src == "" { // set source address if nil - src = msg.Addr.String() - } else if msg.Addr.String() != src { - atomic.AddUint64(&DefaultSnmp.InErrs, 1) - continue - } - - // source and size has validated - s.packetInput(msg.Buffers[0][:msg.N]) - } - } else { - // compatibility issue: - // for linux kernel<=2.6.32, support for sendmmsg is not available - // an error of type os.SyscallError will be returned - if operr, ok := err.(*net.OpError); ok { - if se, ok := operr.Err.(*os.SyscallError); ok { - if se.Syscall == "recvmmsg" { - s.defaultReadLoop() - return - } - } - } - s.notifyReadError(errors.WithStack(err)) - return - } - } -} - -// monitor incoming data for all connections of server -func (l *Listener) monitor() { - var xconn batchConn - if _, ok := l.conn.(*net.UDPConn); ok { - addr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String()) - if err == nil { - if addr.IP.To4() != nil { - xconn = ipv4.NewPacketConn(l.conn) - } else { - xconn = ipv6.NewPacketConn(l.conn) - } - } - } - - // default version - if xconn == nil { - l.defaultMonitor() - return - } - - // x/net version - msgs := make([]ipv4.Message, batchSize) - for k := range msgs { - msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)} - } - - for { - if count, err := xconn.ReadBatch(msgs, 0); err == nil { - for i := 0; i < count; i++ { - msg := &msgs[i] - l.packetInput(msg.Buffers[0][:msg.N], msg.Addr) - } - } else { - // compatibility issue: - // for linux kernel<=2.6.32, support for sendmmsg is not available - // an error of type os.SyscallError will be returned - if operr, ok := err.(*net.OpError); ok { - if se, ok := operr.Err.(*os.SyscallError); ok { - if se.Syscall == "recvmmsg" { - l.defaultMonitor() - return - } - } - } - l.notifyReadError(errors.WithStack(err)) - return - } - } -} diff --git a/sess.go b/sess.go index 2dedd745..4d0507c3 100644 --- a/sess.go +++ b/sess.go @@ -139,17 +139,8 @@ func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn sess.block = block sess.recvbuf = make([]byte, mtuLimit) - // cast to writebatch conn - if _, ok := conn.(*net.UDPConn); ok { - addr, err := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) - if err == nil { - if addr.IP.To4() != nil { - sess.xconn = ipv4.NewPacketConn(conn) - } else { - sess.xconn = ipv6.NewPacketConn(conn) - } - } - } + // cast to batchConn, can be nil + sess.xconn = toBatchConn(conn) // FEC codec initialization sess.fecDecoder = newFECDecoder(dataShards, parityShards) diff --git a/sess_test.go b/sess_test.go index fbe3ad1a..fe2b7de4 100644 --- a/sess_test.go +++ b/sess_test.go @@ -2,6 +2,7 @@ package kcp import ( "crypto/sha1" + "errors" "fmt" "io" "log" @@ -14,6 +15,7 @@ import ( "time" "golang.org/x/crypto/pbkdf2" + "golang.org/x/net/ipv4" ) var baseport = uint32(10000) @@ -105,12 +107,16 @@ func listenTinyBufferEcho(port int) (net.Listener, error) { return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3) } -func listenSink(port int) (net.Listener, error) { +func listenNoEncryption(port int) (net.Listener, error) { return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0) } -func echoServer(port int) net.Listener { - l, err := listenEcho(port) +func server( + port int, + listen func(port int) (net.Listener, error), + handle func(*UDPSession), +) net.Listener { + l, err := listen(port) if err != nil { panic(err) } @@ -129,35 +135,19 @@ func echoServer(port int) net.Listener { // coverage test s.(*UDPSession).SetReadBuffer(4 * 1024 * 1024) s.(*UDPSession).SetWriteBuffer(4 * 1024 * 1024) - go handleEcho(s.(*UDPSession)) + go handle(s.(*UDPSession)) } }() return l } -func sinkServer(port int) net.Listener { - l, err := listenSink(port) - if err != nil { - panic(err) - } - - go func() { - kcplistener := l.(*Listener) - kcplistener.SetReadBuffer(4 * 1024 * 1024) - kcplistener.SetWriteBuffer(4 * 1024 * 1024) - kcplistener.SetDSCP(46) - for { - s, err := l.Accept() - if err != nil { - return - } - - go handleSink(s.(*UDPSession)) - } - }() +func echoServer(port int) net.Listener { + return server(port, listenEcho, handleEcho) +} - return l +func sinkServer(port int) net.Listener { + return server(port, listenNoEncryption, handleSink) } func tinyBufferEchoServer(port int) net.Listener { @@ -700,3 +690,202 @@ func TestUDPSessionNonOwnedPacketConn(t *testing.T) { t.Fatal("non-owned PacketConn closed after UDPSession.Close()") } } + +type customBatchConn struct { + *net.UDPConn + calledWriteBatch bool + calledReadBatch bool + + disableWriteBatch bool + disableReadBatch bool + simulateWriteBatchErr bool + simulateReadBatchErr bool +} + +func (c *customBatchConn) WriteBatch(ms []ipv4.Message, flags int) (int, error) { + c.calledWriteBatch = true + if c.disableWriteBatch { + return 0, errors.New("unsupported") + } + if c.simulateWriteBatchErr { + return 0, errors.New("unknown err") + } + n := 0 + for k := range ms { + if _, err := c.WriteTo(ms[k].Buffers[0], ms[k].Addr); err == nil { + n++ + } else { + return n, err + } + } + return n, nil +} + +func (c *customBatchConn) ReadBatch(ms []ipv4.Message, flags int) (int, error) { + c.calledReadBatch = true + if c.disableReadBatch { + return 0, errors.New("unsupported") + } + if c.simulateReadBatchErr { + return 0, errors.New("unknown err") + } + succ := 0 + n, addr, err := c.ReadFrom(ms[0].Buffers[0]) + if err != nil { + return succ, err + } + ms[0].N = n + ms[0].Addr = addr + succ++ + return succ, nil +} + +func (c *customBatchConn) ReadBatchUnavailable(err error) bool { + return err.Error() == "unsupported" +} + +func (c *customBatchConn) WriteBatchUnavailable(err error) bool { + return err.Error() == "unsupported" +} + +func TestCustomBatchConn(t *testing.T) { + l := server(0, listenNoEncryption, handleEcho) + defer l.Close() + + // Create a net.PacketConn not owned by the UDPSession. + c, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + panic(err) + } + defer c.Close() + pconn := &customBatchConn{UDPConn: c.(*net.UDPConn)} + defer pconn.Close() + + client, err := NewConn2(l.Addr(), nil, 0, 0, pconn) + if err != nil { + panic(err) + } + defer client.Close() + + wBuf := []byte("hello") + _, err = client.Write(wBuf) + if err != nil { + t.Fatalf("Write() should not fail, err: %v", err) + } + + buf := make([]byte, 100) + n, err := client.Read(buf) + if err != nil { + t.Fatalf("Read() should not fail, err: %v", err) + } + if n != len(wBuf) { + t.Fatalf("should read %d bytes, actual n: %d", len(wBuf), n) + } + if string(wBuf) != string(buf[:n]) { + t.Fatalf("read content should be '%s', actual: '%s'", string(wBuf), string(buf[:n])) + } + + if !pconn.calledWriteBatch { + t.Fatalf("expect to call WriteBatch()") + } + if !pconn.calledReadBatch { + t.Fatalf("expect to call ReadBatch()") + } +} + +func TestCustomBatchConnFallback(t *testing.T) { + l := server(0, listenNoEncryption, handleEcho) + defer l.Close() + + // Create a net.PacketConn not owned by the UDPSession. + c, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + panic(err) + } + defer c.Close() + pconn := &customBatchConn{UDPConn: c.(*net.UDPConn)} + defer pconn.Close() + + // disabled batch ops, it should fallback to normal Read()/Write() + pconn.disableReadBatch = true + pconn.disableWriteBatch = true + + client, err := NewConn2(l.Addr(), nil, 0, 0, pconn) + if err != nil { + panic(err) + } + defer client.Close() + + wBuf := []byte("hello") + _, err = client.Write(wBuf) + if err != nil { + t.Fatalf("Write() should not fail, err: %v", err) + } + + buf := make([]byte, 100) + n, err := client.Read(buf) + if err != nil { + t.Fatalf("Read() should not fail, err: %v", err) + } + if n != len(wBuf) { + t.Fatalf("should read %d bytes, actual n: %d", len(wBuf), n) + } + if string(wBuf) != string(buf[:n]) { + t.Fatalf("read content should be '%s', actual: '%s'", string(wBuf), string(buf[:n])) + } + + if !pconn.calledWriteBatch { + t.Fatalf("expect to call WriteBatch()") + } + if !pconn.calledReadBatch { + t.Fatalf("expect to call ReadBatch()") + } +} + +func TestBatchErrDetectorForRealErr(t *testing.T) { + l := server(0, listenNoEncryption, handleEcho) + defer l.Close() + + // Create a net.PacketConn not owned by the UDPSession. + c, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + panic(err) + } + defer c.Close() + pconn := &customBatchConn{UDPConn: c.(*net.UDPConn)} + defer pconn.Close() + + pconn.simulateReadBatchErr = true + pconn.simulateWriteBatchErr = true + + client, err := NewConn2(l.Addr(), nil, 0, 0, pconn) + if err != nil { + panic(err) + } + defer client.Close() + + client.SetWriteDelay(false) + + wBuf := []byte("hello") + + // no error for the first time + _, err = client.Write(wBuf) + if err != nil { + t.Fatalf("Write() should not fail, err: %v", err) + } + + // wait for the notification + time.Sleep(2 * time.Duration(client.kcp.interval) * time.Millisecond) + + // error for the second time + _, err = client.Write(wBuf) + if err == nil { + t.Fatalf("Write() should fail") + } + + buf := make([]byte, 100) + _, err = client.Read(buf) + if err == nil { + t.Fatalf("Read() should fail") + } +} diff --git a/tx.go b/tx.go index 3397b82e..e0763244 100644 --- a/tx.go +++ b/tx.go @@ -7,6 +7,15 @@ import ( "golang.org/x/net/ipv4" ) +func (s *UDPSession) tx(txqueue []ipv4.Message) { + // default version + if s.xconn == nil || s.xconnWriteError != nil { + s.defaultTx(txqueue) + return + } + s.batchTx(txqueue) +} + func (s *UDPSession) defaultTx(txqueue []ipv4.Message) { nbytes := 0 npkts := 0 @@ -22,3 +31,29 @@ func (s *UDPSession) defaultTx(txqueue []ipv4.Message) { atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) } + +func (s *UDPSession) batchTx(txqueue []ipv4.Message) { + // x/net version + nbytes := 0 + npkts := 0 + for len(txqueue) > 0 { + if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil { + for k := range txqueue[:n] { + nbytes += len(txqueue[k].Buffers[0]) + } + npkts += n + txqueue = txqueue[n:] + } else { + if writeBatchUnavailable(s.xconn, err) { + s.xconnWriteError = err + s.defaultTx(txqueue) + return + } + s.notifyWriteError(errors.WithStack(err)) + break + } + } + + atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) + atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) +} diff --git a/tx_generic.go b/tx_generic.go deleted file mode 100644 index 0b4f3494..00000000 --- a/tx_generic.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build !linux - -package kcp - -import ( - "golang.org/x/net/ipv4" -) - -func (s *UDPSession) tx(txqueue []ipv4.Message) { - s.defaultTx(txqueue) -} diff --git a/tx_linux.go b/tx_linux.go deleted file mode 100644 index 4f19df56..00000000 --- a/tx_linux.go +++ /dev/null @@ -1,51 +0,0 @@ -// +build linux - -package kcp - -import ( - "net" - "os" - "sync/atomic" - - "github.com/pkg/errors" - "golang.org/x/net/ipv4" -) - -func (s *UDPSession) tx(txqueue []ipv4.Message) { - // default version - if s.xconn == nil || s.xconnWriteError != nil { - s.defaultTx(txqueue) - return - } - - // x/net version - nbytes := 0 - npkts := 0 - for len(txqueue) > 0 { - if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil { - for k := range txqueue[:n] { - nbytes += len(txqueue[k].Buffers[0]) - } - npkts += n - txqueue = txqueue[n:] - } else { - // compatibility issue: - // for linux kernel<=2.6.32, support for sendmmsg is not available - // an error of type os.SyscallError will be returned - if operr, ok := err.(*net.OpError); ok { - if se, ok := operr.Err.(*os.SyscallError); ok { - if se.Syscall == "sendmmsg" { - s.xconnWriteError = se - s.defaultTx(txqueue) - return - } - } - } - s.notifyWriteError(errors.WithStack(err)) - break - } - } - - atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) - atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) -}