diff --git a/internal/core/websocket/websocket.go b/internal/core/websocket/websocket.go index d5045db..82d026c 100644 --- a/internal/core/websocket/websocket.go +++ b/internal/core/websocket/websocket.go @@ -23,6 +23,9 @@ const ( // Maximum message size allowed from peer. maxMessageSize = 512 + + // Maximum undelivered messages + maxUndeliveredMsg = 100 ) var ( @@ -46,56 +49,68 @@ type Client struct { hub *Hub // Reference to the Hub conn *websocket.Conn // WebSocket connection send chan []byte // Buffered channel for outbound messages + undeliveredMsg [][]byte // Queue for undelivered messages +} + +func (c *Client) addUndeliveredMsg(message []byte) { + if len(c.undeliveredMsg) >= maxUndeliveredMsg { + // Deleting the oldest message to free up space + c.undeliveredMsg = c.undeliveredMsg[1:] + } + c.undeliveredMsg = append(c.undeliveredMsg, message) } // readPump listens for incoming messages. func (c *Client) readPump() { - defer func() { - c.hub.unregister <- c - c.conn.Close() - }() - c.conn.SetReadLimit(maxMessageSize) - c.conn.SetReadDeadline(time.Now().Add(pongWait)) - c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - - for { - _, message, err := c.conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("error: %v", err) - } - break - } - message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) - - // Handle different events - switch { - case bytes.HasPrefix(message, []byte("join_room:")): - roomName := string(message[len("join_room:"):]) - c.hub.HandleJoinRoom(c, roomName) - c.send <- []byte("join_room_success:" + roomName) - - case bytes.HasPrefix(message, []byte("room_message:")): - roomNameAndMessage := bytes.SplitN(message[len("room_message:"):], []byte(" "), 2) - roomName := string(roomNameAndMessage[0]) - roomMessage := roomNameAndMessage[1] - c.hub.handleRoomBroadcast(roomName, roomMessage) - - case bytes.HasPrefix(message, []byte("leave_room:")): - roomName := string(message[len("leave_room:"):]) - c.hub.handleLeaveRoom(c, roomName) - c.send <- []byte("leave_room_success:" + roomName) - - case bytes.HasPrefix(message, []byte("private_message:")): - receiverAndMessage := bytes.SplitN(message[len("private_message:"):], []byte(" "), 2) - receiver := string(receiverAndMessage[0]) - privateMessage := receiverAndMessage[1] - c.hub.handlePrivateMessage(receiver, privateMessage) - - default: - c.hub.broadcast <- message - } - } + defer func() { + c.hub.unregister <- c + c.conn.Close() + }() + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("error: %v", err) + } + break + } + + message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) + + switch { + case bytes.HasPrefix(message, []byte("join_room:")): + roomName := string(message[len("join_room:"):]) + c.hub.HandleJoinRoom(c, roomName) + c.send <- []byte("join_room_success:" + roomName) + + case bytes.HasPrefix(message, []byte("room_message:")): + roomNameAndMessage := bytes.SplitN(message[len("room_message:"):], []byte(" "), 2) + roomName := string(roomNameAndMessage[0]) + roomMessage := roomNameAndMessage[1] + c.hub.handleRoomBroadcast(roomName, roomMessage) + + case bytes.HasPrefix(message, []byte("leave_room:")): + roomName := string(message[len("leave_room:"):]) + c.hub.handleLeaveRoom(c, roomName) + c.send <- []byte("leave_room_success:" + roomName) + + case bytes.HasPrefix(message, []byte("private_message:")): + receiverAndMessage := bytes.SplitN(message[len("private_message:"):], []byte(" "), 2) + receiver := string(receiverAndMessage[0]) + privateMessage := receiverAndMessage[1] + c.hub.handlePrivateMessage(receiver, privateMessage) + + default: + c.hub.broadcast <- message + } + } } // writePump sends messages to the client. @@ -108,6 +123,7 @@ func (c *Client) writePump() { for { select { case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) @@ -116,6 +132,8 @@ func (c *Client) writePump() { w, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { + // If the connection is broken, add the message to the unread queue + c.undeliveredMsg = c.addUndeliveredMsg(message) return } w.Write(message) @@ -230,6 +248,7 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { // Reconnect existing client client = hub.clients[clientID] client.conn = conn + client.sendUndeliveredMsg() // function that sends unread messages } else { // New client connection clientID = uuid.NewString() @@ -239,6 +258,7 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { send: make(chan []byte, 256), id: clientID, name: "root", + undeliveredMsg: [][]byte{}, } } @@ -247,6 +267,15 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { go client.readPump() } +// Send all unread messages to the client after reconnection. +func (c *Client) sendUndeliveredMsg() { + for _, msg := range c.undeliveredMsg { + c.send <- msg + } + // Clearing the queue of unread messages after sending + c.undeliveredMsg = [][]byte{} +} + // WebSocketServer manages the WebSocket server. type WebSocketServer struct{}