diff --git a/connection/connection.go b/connection/connection.go index 470fcb5..238e4db 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -139,7 +139,6 @@ func (c *Connection) Close() error { // ShutdownWrite 关闭可写端,等待读取完接收缓冲区所有数据 func (c *Connection) ShutdownWrite() error { - c.connected.Set(false) return unix.Shutdown(c.fd, unix.SHUT_WR) } diff --git a/example/websocket/wsserver_test.go b/example/websocket/wsserver_test.go index f086587..b9a08c2 100644 --- a/example/websocket/wsserver_test.go +++ b/example/websocket/wsserver_test.go @@ -2,22 +2,30 @@ package main import ( "bytes" + "fmt" "io" "math/rand" "testing" "time" + "github.com/Allenxuxu/gev/log" + "github.com/Allenxuxu/gev" "github.com/Allenxuxu/gev/connection" "github.com/Allenxuxu/gev/plugins/websocket/ws" "github.com/Allenxuxu/gev/plugins/websocket/ws/util" "github.com/Allenxuxu/toolkit/sync" + "github.com/Allenxuxu/toolkit/sync/atomic" + "github.com/stretchr/testify/assert" "golang.org/x/net/websocket" ) -type wsExample struct{} +type wsExample struct { + ClientNum atomic.Int64 +} func (s *wsExample) OnConnect(c *connection.Connection) { + s.ClientNum.Add(1) //log.Println(" OnConnect : ", c.PeerAddr()) } func (s *wsExample) OnMessage(c *connection.Connection, data []byte) (messageType ws.MessageType, out []byte) { @@ -53,6 +61,7 @@ func (s *wsExample) OnMessage(c *connection.Connection, data []byte) (messageTyp } func (s *wsExample) OnClose(c *connection.Connection) { + s.ClientNum.Add(-1) //log.Println("OnClose") } @@ -62,8 +71,7 @@ func TestWebSocketServer_Start(t *testing.T) { s, err := NewWebSocketServer(handler, &ws.Upgrader{}, gev.Address(":1834"), - gev.NumLoops(8), - gev.ReusePort(true)) + gev.NumLoops(8)) if err != nil { t.Fatal(err) } @@ -117,3 +125,49 @@ func startWebSocketClient(addr string) { } } } + +func TestWebSocketServer_CloseConnection(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + handler := new(wsExample) + + s, err := NewWebSocketServer(handler, &ws.Upgrader{}, + gev.Address(":2021"), + gev.NumLoops(8)) + if err != nil { + t.Fatal(err) + } + + go func() { + time.Sleep(time.Second) + + var ( + err error + n = 100 + toClose = 50 + conn = make([]*websocket.Conn, n) + addr = "ws://localhost" + s.Options().Address + ) + + log.SetLevel(log.LevelDebug) + for i := 0; i < n; i++ { + conn[i], err = websocket.Dial(addr, "", addr) + if err != nil { + panic(fmt.Errorf("%d %s", i, err.Error())) + } + + } + assert.Equal(t, n, int(handler.ClientNum.Get())) + + for i := 0; i < toClose; i++ { + if err := conn[i].Close(); err != nil { + panic(err) + } + } + time.Sleep(time.Second) + assert.Equal(t, n-toClose, int(handler.ClientNum.Get())) + + s.Stop() + }() + + s.Start() +}