diff --git a/internal/core/websocket/websocket.go b/internal/core/websocket/websocket.go index fc83e56..d5045db 100644 --- a/internal/core/websocket/websocket.go +++ b/internal/core/websocket/websocket.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/google/uuid" "github.com/gorilla/websocket" ) @@ -38,21 +39,16 @@ var upgrader = websocket.Upgrader{ }, } -// Client is a middleman between the websocket connection and the hub. +// Client represents a connection. type Client struct { name string - hub *Hub - // The websocket connection. - conn *websocket.Conn - // Buffered channel of outbound messages. - send chan []byte + id string // Unique client ID for reconnection + hub *Hub // Reference to the Hub + conn *websocket.Conn // WebSocket connection + send chan []byte // Buffered channel for outbound messages } -// readPump pumps messages from the websocket connection to the hub. -// -// The application runs readPump in a per-connection goroutine. The application -// ensures that there is at most one reader on a connection by executing all -// reads from this goroutine. +// readPump listens for incoming messages. func (c *Client) readPump() { defer func() { c.hub.unregister <- c @@ -61,6 +57,7 @@ func (c *Client) readPump() { c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { _, message, err := c.conn.ReadMessage() if err != nil { @@ -71,81 +68,37 @@ func (c *Client) readPump() { } message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) - // Parse and handle different events - if bytes.HasPrefix(message, []byte("join_room:")) { + // Handle different events + switch { + case bytes.HasPrefix(message, []byte("join_room:")): roomName := string(message[len("join_room:"):]) c.hub.HandleJoinRoom(c, roomName) - // Send confirmation back to client - confirmationMessage := "join_room_success:" + roomName - c.send <- []byte(confirmationMessage) - } else if bytes.HasPrefix(message, []byte("room_message:")) { - // Handle room message event + c.send <- []byte("join_room_success:" + roomName) + + case bytes.HasPrefix(message, []byte("room_message:")): roomNameAndMessage := bytes.SplitN(message[len("room_message:"):], []byte(" "), 2) roomName := string(roomNameAndMessage[0]) roomMessage := roomNameAndMessage[1] c.hub.handleRoomBroadcast(roomName, roomMessage) - } else if bytes.HasPrefix(message, []byte("leave_room:")) { - // Handle leave room event + + case bytes.HasPrefix(message, []byte("leave_room:")): roomName := string(message[len("leave_room:"):]) c.hub.handleLeaveRoom(c, roomName) - // Send confirmation back to client - confirmationMessage := "leave_room_success:" + roomName - c.send <- []byte(confirmationMessage) - } else if bytes.HasPrefix(message, []byte("private_message:")) { - // Handle private message event + c.send <- []byte("leave_room_success:" + roomName) + + case bytes.HasPrefix(message, []byte("private_message:")): receiverAndMessage := bytes.SplitN(message[len("private_message:"):], []byte(" "), 2) receiver := string(receiverAndMessage[0]) privateMessage := receiverAndMessage[1] c.hub.handlePrivateMessage(receiver, privateMessage) - } else { - // Handle global messages or unhandled events - c.hub.broadcast <- message - } - } -} - -func (h *Hub) handleRoomBroadcast(roomName string, message []byte) { - if room, ok := h.rooms[roomName]; ok { - for client := range room { - select { - case client.send <- message: - default: - close(client.send) - delete(room, client) - } - } - } -} - -func (h *Hub) handleLeaveRoom(client *Client, roomName string) { - if room, ok := h.rooms[roomName]; ok { - if _, exists := room[client]; exists { - delete(room, client) - if len(room) == 0 { - delete(h.rooms, roomName) - } - } - } -} -func (h *Hub) handlePrivateMessage(receiverName string, message []byte) { - for client := range h.clients { - if client.name == receiverName { - select { - case client.send <- message: - default: - close(client.send) - delete(h.clients, client) - } + default: + c.hub.broadcast <- message } } } -// writePump pumps messages from the hub to the websocket connection. -// -// A goroutine running writePump is started for each connection. The -// application ensures that there is at most one writer to a connection by -// executing all writes from this goroutine. +// writePump sends messages to the client. func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { @@ -157,7 +110,6 @@ func (c *Client) writePump() { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { - // The hub closed the channel. c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } @@ -168,7 +120,7 @@ func (c *Client) writePump() { } w.Write(message) - // Add queued chat messages to the current websocket message. + // Add queued messages to current WebSocket message n := len(c.send) for i := 0; i < n; i++ { w.Write(newline) @@ -178,6 +130,7 @@ func (c *Client) writePump() { if err := w.Close(); err != nil { return } + case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { @@ -187,126 +140,139 @@ func (c *Client) writePump() { } } -// serveWs handles websocket requests from the peer. -func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Println(err) - return - } - client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), name: "root"} - client.hub.register <- client - - // Allow collection of memory referenced by the caller by doing all work in - // new goroutines. - go client.writePump() - go client.readPump() -} - -// Hub maintains the set of active clients and broadcasts messages to the -// clients. +// Hub manages clients and rooms. type Hub struct { - // Registered clients. - clients map[*Client]bool - - // Inbound messages from the clients. - broadcast chan []byte - - // Register requests from the clients. - register chan *Client - - // Unregister requests from clients. + clients map[string]*Client // Track clients by ID for reconnection + broadcast chan []byte + register chan *Client unregister chan *Client - - // Maps room names to clients - rooms map[string]map[*Client]bool + rooms map[string]map[*Client]bool } +// Create a new room. func (h *Hub) createRoom(name string) { if _, exists := h.rooms[name]; !exists { h.rooms[name] = make(map[*Client]bool) } } +// Join a room. func (h *Hub) joinRoom(client *Client, room string) { if _, exists := h.rooms[room]; exists { h.rooms[room][client] = true } } -func (h *Hub) LeaveRoom(client *Client, room string) { - if _, exists := h.rooms[room]; exists { - delete(h.rooms[room], client) +// Leave a room. +func (h *Hub) handleLeaveRoom(client *Client, room string) { + if roomClients, ok := h.rooms[room]; ok { + delete(roomClients, client) + if len(roomClients) == 0 { + delete(h.rooms, room) + } } } -func (h *Hub) BroadcastToRoom(room string, message []byte) { - if clients, exists := h.rooms[room]; exists { +// Broadcast message to a room. +func (h *Hub) handleRoomBroadcast(roomName string, message []byte) { + if clients, ok := h.rooms[roomName]; ok { for client := range clients { client.send <- message } } } -func (c *Client) Emit(event string, data []byte) { - message := append([]byte(event+": "), data...) - c.send <- message +// Handle private message. +func (h *Hub) handlePrivateMessage(receiverName string, message []byte) { + for _, client := range h.clients { + if client.name == receiverName { + client.send <- message + } + } } +// Handle join room. func (h *Hub) HandleJoinRoom(client *Client, roomName string) { h.createRoom(roomName) h.joinRoom(client, roomName) } -func newHub() *Hub { - return &Hub{ - broadcast: make(chan []byte), - register: make(chan *Client), - unregister: make(chan *Client), - clients: make(map[*Client]bool), - rooms: make(map[string]map[*Client]bool), - } -} - +// Run starts the Hub. func (h *Hub) Run() { for { select { case client := <-h.register: - h.clients[client] = true + h.clients[client.id] = client case client := <-h.unregister: - if _, ok := h.clients[client]; ok { - delete(h.clients, client) + if _, ok := h.clients[client.id]; ok { + delete(h.clients, client.id) close(client.send) } case message := <-h.broadcast: - for client := range h.clients { - select { - case client.send <- message: - default: - close(client.send) - delete(h.clients, client) - } + for _, client := range h.clients { + client.send <- message } } } } +// Serve WebSocket connection and handle reconnections. +func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + + clientID := r.URL.Query().Get("client_id") + var client *Client + if clientID != "" && hub.clients[clientID] != nil { + // Reconnect existing client + client = hub.clients[clientID] + client.conn = conn + } else { + // New client connection + clientID = uuid.NewString() + client = &Client{ + hub: hub, + conn: conn, + send: make(chan []byte, 256), + id: clientID, + name: "root", + } + } + + client.hub.register <- client + go client.writePump() + go client.readPump() +} + +// WebSocketServer manages the WebSocket server. type WebSocketServer struct{} +// NewWebSocketServer creates a new server. func NewWebSocketServer() *WebSocketServer { return &WebSocketServer{} } +// NewWsServer starts the WebSocket server. func (wss *WebSocketServer) NewWsServer(addr string) { - var _addr = flag.String("addr", ":8080", "http service address") + var _addr = flag.String("addr", addr, "http service address") flag.Parse() - hub := newHub() + hub := &Hub{ + broadcast: make(chan []byte), + register: make(chan *Client), + unregister: make(chan *Client), + clients: make(map[string]*Client), + rooms: make(map[string]map[*Client]bool), + } go hub.Run() + http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { serveWs(hub, w, r) }) err := http.ListenAndServe(*_addr, nil) if err != nil { - log.Fatal("WSS ListenAndServe: ", err) + log.Fatal("WebSocket server error:", err) } }