diff --git a/plugins/websocket/protocol.go b/plugins/websocket/protocol.go index 08a5253..eee56bd 100644 --- a/plugins/websocket/protocol.go +++ b/plugins/websocket/protocol.go @@ -5,9 +5,13 @@ import ( "github.com/Allenxuxu/gev/log" "github.com/Allenxuxu/gev/plugins/websocket/ws" "github.com/Allenxuxu/ringbuffer" + "github.com/gobwas/pool/pbytes" ) -const upgradedKey = "gev_ws_upgraded" +const ( + upgradedKey = "gev_ws_upgraded" + headerbufferKey = "gev_header_buf" +) // Protocol websocket type Protocol struct { @@ -30,8 +34,10 @@ func (p *Protocol) UnPacket(c *connection.Connection, buffer *ringbuffer.RingBuf return } c.Set(upgradedKey, true) + c.Set(headerbufferKey, pbytes.Get(0, ws.MaxHeaderSize-2)) } else { - header, err := ws.VirtualReadHeader(buffer) + bts, _ := c.Get(headerbufferKey) + header, err := ws.VirtualReadHeader(bts.([]byte), buffer) if err != nil { if err != ws.ErrHeaderNotReady { log.Error(err) diff --git a/plugins/websocket/wrap.go b/plugins/websocket/wrap.go index 058da54..8fd04e4 100644 --- a/plugins/websocket/wrap.go +++ b/plugins/websocket/wrap.go @@ -5,6 +5,7 @@ import ( "github.com/Allenxuxu/gev/log" "github.com/Allenxuxu/gev/plugins/websocket/ws" "github.com/Allenxuxu/gev/plugins/websocket/ws/util" + "github.com/gobwas/pool/pbytes" ) // WSHandler WebSocket Server 注册接口 @@ -91,4 +92,8 @@ func (s *HandlerWrap) OnMessage(c *connection.Connection, ctx interface{}, paylo // OnClose wrap func (s *HandlerWrap) OnClose(c *connection.Connection) { s.wsHandler.OnClose(c) + + if bts, ok := c.Get(headerbufferKey); ok { + pbytes.Put(bts.([]byte)) + } } diff --git a/plugins/websocket/ws/read.go b/plugins/websocket/ws/read.go index 78a1292..1556026 100644 --- a/plugins/websocket/ws/read.go +++ b/plugins/websocket/ws/read.go @@ -6,7 +6,6 @@ import ( "github.com/Allenxuxu/ringbuffer" "github.com/Allenxuxu/toolkit/convert" - "github.com/gobwas/pool/pbytes" ) // Errors used by frame reader. @@ -17,14 +16,13 @@ var ( ) // VirtualReadHeader reads a frame header from r. -func VirtualReadHeader(in *ringbuffer.RingBuffer) (h Header, err error) { +func VirtualReadHeader(bts []byte, in *ringbuffer.RingBuffer) (h Header, err error) { if in.Length() < 6 { err = ErrHeaderNotReady return } - bts := pbytes.Get(2, MaxHeaderSize-2) - defer pbytes.Put(bts) + bts = bts[:2] // Prepare to hold first 2 bytes to choose size of next read. _, _ = in.VirtualRead(bts)