From 9a7e6008cbfff01d52de1d97eed27bf16e82deac Mon Sep 17 00:00:00 2001 From: lengcharles Date: Sat, 4 Feb 2023 14:07:17 +0800 Subject: [PATCH] feat(example): performance optimization and add concurrent lock support (#13) --- examples/chat/client.go | 12 +++++------- examples/chat/hub.go | 41 ++++++++++++++++++++++++++++++++++------- examples/chat/main.go | 3 +-- 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/examples/chat/client.go b/examples/chat/client.go index 34ebbc2..38d4189 100644 --- a/examples/chat/client.go +++ b/examples/chat/client.go @@ -37,8 +37,6 @@ var ( // Client is a middleman between the websocket connection and the hub. type Client struct { - hub *Hub - // The websocket connection. conn *websocket.Conn @@ -53,7 +51,7 @@ type Client struct { // reads from this goroutine. func (c *Client) readPump() { defer func() { - c.hub.unregister <- c + hub.unregister <- c c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) @@ -68,7 +66,7 @@ func (c *Client) readPump() { break } message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) - c.hub.broadcast <- message + hub.broadcast <- message } } @@ -119,10 +117,10 @@ func (c *Client) writePump() { } // serveWs handles websocket requests from the peer. -func serveWs(ctx *app.RequestContext, hub *Hub) { +func serveWs(ctx *app.RequestContext) { err := upgrader.Upgrade(ctx, func(conn *websocket.Conn) { - client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} - client.hub.register <- client + client := &Client{conn: conn, send: make(chan []byte, 256)} + hub.register <- client go client.writePump() client.readPump() diff --git a/examples/chat/hub.go b/examples/chat/hub.go index fb25735..d86f6bd 100644 --- a/examples/chat/hub.go +++ b/examples/chat/hub.go @@ -7,11 +7,16 @@ package main +import ( + "sync" +) + // Hub maintains the set of active clients and broadcasts messages to the // clients. type Hub struct { // Registered clients. - clients map[*Client]bool + clients map[*Client]struct{} + clientsLock sync.RWMutex // Inbound messages from the clients. broadcast chan []byte @@ -23,12 +28,14 @@ type Hub struct { unregister chan *Client } +var hub = newHub() + func newHub() *Hub { return &Hub{ broadcast: make(chan []byte), register: make(chan *Client), unregister: make(chan *Client), - clients: make(map[*Client]bool), + clients: make(map[*Client]struct{}), } } @@ -36,12 +43,9 @@ func (h *Hub) run() { for { select { case client := <-h.register: - h.clients[client] = true + h.Register(client) case client := <-h.unregister: - if _, ok := h.clients[client]; ok { - delete(h.clients, client) - close(client.send) - } + h.Unregister(client) case message := <-h.broadcast: for client := range h.clients { select { @@ -54,3 +58,26 @@ func (h *Hub) run() { } } } + +func (h *Hub) Register(client *Client) { + h.AddClient(client) +} + +func (h *Hub) AddClient(client *Client) { + h.clientsLock.Lock() + defer h.clientsLock.Unlock() + h.clients[client] = struct{}{} +} + +func (h *Hub) Unregister(client *Client) { + h.DelClient(client) +} + +func (h *Hub) DelClient(client *Client) { + h.clientsLock.Lock() + defer h.clientsLock.Unlock() + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + } +} diff --git a/examples/chat/main.go b/examples/chat/main.go index 5d0f32f..e61e889 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -40,7 +40,6 @@ func serveHome(_ context.Context, c *app.RequestContext) { } func main() { - hub := newHub() go hub.run() // server.Default() creates a Hertz with recovery middleware. // If you need a pure hertz, you can use server.New() @@ -49,7 +48,7 @@ func main() { h.GET("/", serveHome) h.GET("/ws", func(c context.Context, ctx *app.RequestContext) { - serveWs(ctx, hub) + serveWs(ctx) }) h.Spin()