diff --git a/internal/core/websocket/websocket.go b/internal/core/websocket/websocket.go index 43a215c..53bbff6 100644 --- a/internal/core/websocket/websocket.go +++ b/internal/core/websocket/websocket.go @@ -2,7 +2,9 @@ package websocket import ( "bytes" + "compress/gzip" "flag" + "io" "log" "net/http" "time" @@ -36,12 +38,40 @@ var ( var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, + EnableCompression: true, // Enable WebSocket compression CheckOrigin: func(r *http.Request) bool { // Allow connections from any origin return true }, } +// compressGzip compresses data using gzip. +func compressGzip(data []byte) ([]byte, error) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + if _, err := gz.Write(data); err != nil { + return nil, err + } + if err := gz.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// decompressGzip decompresses gzip data. +func decompressGzip(data []byte) ([]byte, error) { + reader, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, err + } + defer reader.Close() + var buf bytes.Buffer + if _, err := io.Copy(&buf, reader); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + // Client represents a connection. type Client struct { name string @@ -60,7 +90,16 @@ func (c *Client) addUndeliveredMsg(message []byte) { c.undeliveredMsg = append(c.undeliveredMsg, message) } -// readPump listens for incoming messages. +// Send all unread messages to the client after reconnection. +func (c *Client) sendUndeliveredMsg() { + for _, msg := range c.undeliveredMsg { + c.send <- msg + } + // Clearing the queue of unread messages after sending + c.undeliveredMsg = [][]byte{} +} + +// readPump listens for incoming messages and decompresses them. func (c *Client) readPump() { defer func() { c.hub.unregister <- c @@ -82,7 +121,14 @@ func (c *Client) readPump() { break } - message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) + // Attempt to decompress message + decompressedMessage, err := decompressGzip(message) + if err != nil { + log.Println("Failed to decompress message:", err) + continue + } + + message = bytes.TrimSpace(bytes.Replace(decompressedMessage, newline, space, -1)) switch { case bytes.HasPrefix(message, []byte("join_room:")): @@ -113,7 +159,7 @@ func (c *Client) readPump() { } } -// writePump sends messages to the client. +// writePump sends compressed messages to the client. func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { @@ -123,20 +169,26 @@ func (c *Client) writePump() { for { select { case message, ok := <-c.send: - c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } - w, err := c.conn.NextWriter(websocket.TextMessage) + // Compress the message before sending + compressedMessage, err := compressGzip(message) + if err != nil { + log.Println("Failed to compress message:", err) + continue + } + + w, err := c.conn.NextWriter(websocket.BinaryMessage) if err != nil { // If the connection is broken, add the message to the unread queue c.addUndeliveredMsg(message) return } - w.Write(message) + w.Write(compressedMessage) // Add queued messages to current WebSocket message n := len(c.send) @@ -181,6 +233,12 @@ func (h *Hub) joinRoom(client *Client, room string) { } } +// Handle join room. +func (h *Hub) HandleJoinRoom(client *Client, roomName string) { + h.createRoom(roomName) + h.joinRoom(client, roomName) +} + // Leave a room. func (h *Hub) handleLeaveRoom(client *Client, room string) { if roomClients, ok := h.rooms[room]; ok { @@ -209,12 +267,6 @@ func (h *Hub) handlePrivateMessage(receiverName string, message []byte) { } } -// Handle join room. -func (h *Hub) HandleJoinRoom(client *Client, roomName string) { - h.createRoom(roomName) - h.joinRoom(client, roomName) -} - // Run starts the Hub. func (h *Hub) Run() { for { @@ -267,15 +319,6 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { go client.readPump() } -// Send all unread messages to the client after reconnection. -func (c *Client) sendUndeliveredMsg() { - for _, msg := range c.undeliveredMsg { - c.send <- msg - } - // Clearing the queue of unread messages after sending - c.undeliveredMsg = [][]byte{} -} - // WebSocketServer manages the WebSocket server. type WebSocketServer struct{}