From 3c8a51819be6b262b6411c7f23445132dafc02d1 Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Fri, 27 Dec 2024 18:46:45 +0500 Subject: [PATCH 1/2] Implement websocket keepalive --- acceptor.go | 13 +++-- acceptor_test.go | 2 +- client.go | 21 ++++++-- cmd/xconn/main.go | 2 +- helpers.go | 18 ------- joiner.go | 29 ++++++----- joiner_test.go | 4 +- peer.go | 123 +++++++++++++++++++++++++++++++++++++--------- server.go | 27 +++++++--- types.go | 27 ++++++++-- 10 files changed, 187 insertions(+), 79 deletions(-) diff --git a/acceptor.go b/acceptor.go index 3c13547..1cdaba4 100644 --- a/acceptor.go +++ b/acceptor.go @@ -69,8 +69,7 @@ func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, er return serializer, nil } -func (w *WebSocketAcceptor) Accept(conn net.Conn) (BaseSession, error) { - config := DefaultWebSocketServerConfig() +func (w *WebSocketAcceptor) Accept(conn net.Conn, config *WebSocketServerConfig) (BaseSession, error) { config.SubProtocols = w.protocols() peer, err := UpgradeWebSocket(conn, config) if err != nil { @@ -161,7 +160,15 @@ func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig) (Peer, error } isBinary := hs.Protocol != JsonWebsocketProtocol - peer, err := NewWebSocketPeer(conn, hs.Protocol, isBinary, true) + + peerConfig := WSPeerConfig{ + Protocol: hs.Protocol, + Binary: isBinary, + Server: true, + KeepAliveInterval: config.KeepAliveInterval, + KeepAliveTimeout: config.KeepAliveTimeout, + } + peer, err := NewWebSocketPeer(conn, peerConfig) if err != nil { return nil, fmt.Errorf("failed to init reader/writer: %w", err) } diff --git a/acceptor_test.go b/acceptor_test.go index bf7b556..ea6da98 100644 --- a/acceptor_test.go +++ b/acceptor_test.go @@ -25,7 +25,7 @@ func TestAccept(t *testing.T) { require.NotNil(t, conn) acceptor := xconn.WebSocketAcceptor{} - session, err := acceptor.Accept(conn) + session, err := acceptor.Accept(conn, xconn.DefaultWebSocketServerConfig()) require.NoError(t, err) require.NotNil(t, session) diff --git a/client.go b/client.go index 5970054..d3d27f2 100644 --- a/client.go +++ b/client.go @@ -13,7 +13,9 @@ type Client struct { SerializerSpec WSSerializerSpec NetDial func(ctx context.Context, network, addr string) (net.Conn, error) - DialTimeout time.Duration + DialTimeout time.Duration + KeepAliveInterval time.Duration + KeepAliveTimeout time.Duration } func (c *Client) Connect(ctx context.Context, url string, realm string) (*Session, error) { @@ -24,11 +26,17 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio joiner := &WebSocketJoiner{ Authenticator: c.Authenticator, SerializerSpec: c.SerializerSpec, - DialTimeout: c.DialTimeout, - NetDial: c.NetDial, } - base, err := joiner.Join(ctx, url, realm) + dialerConfig := &WSDialerConfig{ + SubProtocol: c.SerializerSpec.SubProtocol(), + DialTimeout: c.DialTimeout, + NetDial: c.NetDial, + KeepAliveInterval: c.KeepAliveInterval, + KeepAliveTimeout: c.KeepAliveTimeout, + } + + base, err := joiner.Join(ctx, url, realm, dialerConfig) if err != nil { return nil, err } @@ -38,7 +46,10 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio func Connect(ctx context.Context, url string, realm string) (*Session, error) { joiner := &WebSocketJoiner{} - base, err := joiner.Join(ctx, url, realm) + dialerConfig := &WSDialerConfig{ + SubProtocol: JsonWebsocketProtocol, + } + base, err := joiner.Join(ctx, url, realm, dialerConfig) if err != nil { return nil, err } diff --git a/cmd/xconn/main.go b/cmd/xconn/main.go index 1f41dd0..5ba5eba 100644 --- a/cmd/xconn/main.go +++ b/cmd/xconn/main.go @@ -116,7 +116,7 @@ func Run(args []string) error { throttle = internal.NewThrottle(transport.RateLimit.Rate, time.Duration(transport.RateLimit.Interval)*time.Second, strategy) } - server := xconn.NewServer(router, authenticator, throttle) + server := xconn.NewServer(router, authenticator, &xconn.ServerConfig{Throttle: throttle}) if slices.Contains(transport.Serializers, "protobuf") { if err := server.RegisterSpec(xconn.ProtobufSerializerSpec); err != nil { return err diff --git a/helpers.go b/helpers.go index 4911adc..b524db1 100644 --- a/helpers.go +++ b/helpers.go @@ -3,28 +3,10 @@ package xconn import ( "fmt" - "github.com/gobwas/ws/wsutil" - "github.com/xconnio/wampproto-go/messages" "github.com/xconnio/wampproto-go/serializers" ) -func ClientSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) { - if !binary { - return wsutil.ReadServerText, wsutil.WriteClientText, nil - } - - return wsutil.ReadServerBinary, wsutil.WriteClientBinary, nil -} - -func ServerSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) { - if !binary { - return wsutil.ReadClientText, wsutil.WriteServerText, nil - } - - return wsutil.ReadClientBinary, wsutil.WriteServerBinary, nil -} - func ReadMessage(peer Peer, serializer serializers.Serializer) (messages.Message, error) { payload, err := peer.Read() if err != nil { diff --git a/joiner.go b/joiner.go index 8b57771..6a443e8 100644 --- a/joiner.go +++ b/joiner.go @@ -17,12 +17,9 @@ import ( type WebSocketJoiner struct { SerializerSpec WSSerializerSpec Authenticator auth.ClientAuthenticator - NetDial func(ctx context.Context, network, addr string) (net.Conn, error) - - DialTimeout time.Duration } -func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSession, error) { +func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string, config *WSDialerConfig) (BaseSession, error) { parsedURL, err := netURL.Parse(url) if err != nil { return nil, err @@ -36,17 +33,11 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSess w.Authenticator = auth.NewAnonymousAuthenticator("", nil) } - if w.DialTimeout == 0 { - w.DialTimeout = time.Second * 10 - } - - dialConfig := WSDialerConfig{ - SubProtocol: w.SerializerSpec.SubProtocol(), - DialTimeout: w.DialTimeout, - NetDial: w.NetDial, + if config.DialTimeout == 0 { + config.DialTimeout = time.Second * 10 } - peer, err := DialWebSocket(ctx, parsedURL, dialConfig) + peer, err := DialWebSocket(ctx, parsedURL, config) if err != nil { return nil, err } @@ -54,7 +45,7 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSess return Join(peer, realm, w.SerializerSpec.Serializer(), w.Authenticator) } -func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig) (Peer, error) { +func DialWebSocket(ctx context.Context, url *netURL.URL, config *WSDialerConfig) (Peer, error) { wsDialer := ws.Dialer{ Protocols: []string{config.SubProtocol}, } @@ -83,7 +74,15 @@ func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig) } isBinary := config.SubProtocol != JsonWebsocketProtocol - return NewWebSocketPeer(conn, config.SubProtocol, isBinary, false) + + peerConfig := WSPeerConfig{ + Protocol: config.SubProtocol, + Binary: isBinary, + Server: false, + KeepAliveInterval: config.KeepAliveInterval, + KeepAliveTimeout: config.KeepAliveTimeout, + } + return NewWebSocketPeer(conn, peerConfig) } func Join(cl Peer, realm string, serializer serializers.Serializer, diff --git a/joiner_test.go b/joiner_test.go index a6abc59..51c3468 100644 --- a/joiner_test.go +++ b/joiner_test.go @@ -37,7 +37,9 @@ func TestJoin(t *testing.T) { address := fmt.Sprintf("ws://%s/ws", listener.Addr().String()) var joiner xconn.WebSocketJoiner - base, err := joiner.Join(context.Background(), address, "realm1") + base, err := joiner.Join(context.Background(), address, "realm1", &xconn.WSDialerConfig{ + SubProtocol: xconn.JsonWebsocketProtocol, + }) require.NoError(t, err) require.NotNil(t, base) diff --git a/peer.go b/peer.go index b63ee9e..379dcf6 100644 --- a/peer.go +++ b/peer.go @@ -1,8 +1,16 @@ package xconn import ( + "errors" + "fmt" + "io" + "log" "net" "sync" + "time" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" "github.com/xconnio/wampproto-go/messages" "github.com/xconnio/wampproto-go/serializers" @@ -88,48 +96,115 @@ func (b *baseSession) Close() error { return b.client.NetConn().Close() } -func NewWebSocketPeer(conn net.Conn, protocol string, binary, server bool) (Peer, error) { - var wsReader ReaderFunc - var wsWriter WriterFunc - var err error - if server { - wsReader, wsWriter, err = ServerSideWSReaderWriter(binary) - } else { - wsReader, wsWriter, err = ClientSideWSReaderWriter(binary) +func NewWebSocketPeer(conn net.Conn, peerConfig WSPeerConfig) (Peer, error) { + peer := &WebSocketPeer{ + transportType: TransportWebSocket, + protocol: peerConfig.Protocol, + conn: conn, + pingCh: make(chan struct{}, 1), + binary: peerConfig.Binary, + server: peerConfig.Server, } - if err != nil { - return nil, err + if peerConfig.KeepAliveInterval != 0 { + // Start ping-pong handling + go peer.startPinger(peerConfig.KeepAliveInterval, peerConfig.KeepAliveTimeout) } - return &WebSocketPeer{ - transportType: TransportWebSocket, - protocol: protocol, - conn: conn, - wsReader: wsReader, - wsWriter: wsWriter, - }, nil + return peer, nil } type WebSocketPeer struct { transportType TransportType protocol string conn net.Conn - wsReader ReaderFunc - wsWriter WriterFunc - wm sync.Mutex + pingCh chan struct{} + binary bool + server bool + + sync.Mutex +} + +func (c *WebSocketPeer) startPinger(keepaliveInterval time.Duration, keepaliveTimeout time.Duration) { + ticker := time.NewTicker(keepaliveInterval) + defer ticker.Stop() + + if keepaliveTimeout == 0 { + keepaliveTimeout = 10 * time.Second + } + for { + <-ticker.C + // Send a ping + err := c.writeOpFunc(c.conn, ws.OpPing, []byte("ping")) + if err != nil { + log.Printf("failed to send ping: %v\n", err) + _ = c.conn.Close() + return + } + + select { + case <-c.pingCh: + case <-time.After(keepaliveTimeout): + log.Println("ping timeout, closing connection") + _ = c.conn.Close() + return + } + } +} + +func (c *WebSocketPeer) peerState() ws.State { + if c.server { + return ws.StateServerSide + } + return ws.StateClientSide +} + +func (c *WebSocketPeer) writeOpFunc(w io.Writer, op ws.OpCode, p []byte) error { + if c.server { + return wsutil.WriteServerMessage(w, op, p) + } + + return wsutil.WriteClientMessage(w, op, p) } func (c *WebSocketPeer) Read() ([]byte, error) { - return c.wsReader(c.conn) + header, reader, err := wsutil.NextReader(c.conn, c.peerState()) + if err != nil { + return nil, err + } + + payload := make([]byte, header.Length) + _, err = reader.Read(payload) + if err != nil && !errors.Is(err, io.EOF) { + return nil, err + } + + switch header.OpCode { + case ws.OpText, ws.OpBinary: + return payload, nil + case ws.OpPing: + if err = c.writeOpFunc(c.conn, ws.OpPong, payload); err != nil { + return nil, fmt.Errorf("failed to send pong: %w", err) + } + case ws.OpPong: + c.pingCh <- struct{}{} + case ws.OpClose: + _ = c.conn.Close() + return nil, fmt.Errorf("connection closed") + } + + return c.Read() } func (c *WebSocketPeer) Write(bytes []byte) error { - c.wm.Lock() - defer c.wm.Unlock() + c.Lock() + defer c.Unlock() + if c.binary { + return c.writeOpFunc(c.conn, ws.OpBinary, bytes) + } - return c.wsWriter(c.conn, bytes) + return c.writeOpFunc(c.conn, ws.OpText, bytes) } func (c *WebSocketPeer) Type() TransportType { diff --git a/server.go b/server.go index c441fe8..c848cd3 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "log" "net" "os" + "time" "github.com/projectdiscovery/ratelimit" @@ -14,21 +15,30 @@ import ( ) type Server struct { - router *Router - acceptor *WebSocketAcceptor - throttle *internal.Throttle + router *Router + acceptor *WebSocketAcceptor + throttle *internal.Throttle + keepAliveInterval time.Duration + keepAliveTimeout time.Duration } -func NewServer(router *Router, authenticator auth.ServerAuthenticator, throttle *internal.Throttle) *Server { +func NewServer(router *Router, authenticator auth.ServerAuthenticator, config *ServerConfig) *Server { acceptor := &WebSocketAcceptor{ Authenticator: authenticator, } - return &Server{ + server := &Server{ router: router, acceptor: acceptor, - throttle: throttle, } + + if config != nil { + server.throttle = config.Throttle + server.keepAliveInterval = config.KeepAliveInterval + server.keepAliveTimeout = config.KeepAliveTimeout + } + + return server } func (s *Server) RegisterSpec(spec WSSerializerSpec) error { @@ -51,7 +61,10 @@ func (s *Server) Start(host string, port int) (io.Closer, error) { } func (s *Server) HandleClient(conn net.Conn) { - base, err := s.acceptor.Accept(conn) + config := DefaultWebSocketServerConfig() + config.KeepAliveInterval = s.keepAliveInterval + config.KeepAliveTimeout = s.keepAliveTimeout + base, err := s.acceptor.Accept(conn, config) if err != nil { return } diff --git a/types.go b/types.go index 2e083c0..e4de94f 100644 --- a/types.go +++ b/types.go @@ -11,6 +11,7 @@ import ( "github.com/xconnio/wampproto-go/messages" "github.com/xconnio/wampproto-go/serializers" wampprotobuf "github.com/xconnio/wampproto-protobuf/go" + "github.com/xconnio/xconn-go/internal" ) type ( @@ -58,13 +59,31 @@ type Peer interface { } type WSDialerConfig struct { - SubProtocol string - DialTimeout time.Duration - NetDial func(ctx context.Context, network, addr string) (net.Conn, error) + SubProtocol string + DialTimeout time.Duration + NetDial func(ctx context.Context, network, addr string) (net.Conn, error) + KeepAliveInterval time.Duration + KeepAliveTimeout time.Duration } type WebSocketServerConfig struct { - SubProtocols []string + SubProtocols []string + KeepAliveInterval time.Duration + KeepAliveTimeout time.Duration +} + +type WSPeerConfig struct { + Protocol string + Binary bool + Server bool + KeepAliveInterval time.Duration + KeepAliveTimeout time.Duration +} + +type ServerConfig struct { + Throttle *internal.Throttle + KeepAliveInterval time.Duration + KeepAliveTimeout time.Duration } func DefaultWebSocketServerConfig() *WebSocketServerConfig { From 4b7ec503bce9d28d246e3f1becc3b269688e535b Mon Sep 17 00:00:00 2001 From: Muzzammil Shahid Date: Fri, 27 Dec 2024 19:01:12 +0500 Subject: [PATCH 2/2] Implement websocket keepalive --- acceptor.go | 3 +++ acceptor_test.go | 2 +- joiner.go | 7 +++---- peer.go | 7 ++++++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/acceptor.go b/acceptor.go index 1cdaba4..98db720 100644 --- a/acceptor.go +++ b/acceptor.go @@ -70,6 +70,9 @@ func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, er } func (w *WebSocketAcceptor) Accept(conn net.Conn, config *WebSocketServerConfig) (BaseSession, error) { + if config == nil { + config = DefaultWebSocketServerConfig() + } config.SubProtocols = w.protocols() peer, err := UpgradeWebSocket(conn, config) if err != nil { diff --git a/acceptor_test.go b/acceptor_test.go index ea6da98..68a7b52 100644 --- a/acceptor_test.go +++ b/acceptor_test.go @@ -25,7 +25,7 @@ func TestAccept(t *testing.T) { require.NotNil(t, conn) acceptor := xconn.WebSocketAcceptor{} - session, err := acceptor.Accept(conn, xconn.DefaultWebSocketServerConfig()) + session, err := acceptor.Accept(conn, nil) require.NoError(t, err) require.NotNil(t, session) diff --git a/joiner.go b/joiner.go index 6a443e8..b97b3a9 100644 --- a/joiner.go +++ b/joiner.go @@ -33,10 +33,6 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string, config *W w.Authenticator = auth.NewAnonymousAuthenticator("", nil) } - if config.DialTimeout == 0 { - config.DialTimeout = time.Second * 10 - } - peer, err := DialWebSocket(ctx, parsedURL, config) if err != nil { return nil, err @@ -46,6 +42,9 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string, config *W } func DialWebSocket(ctx context.Context, url *netURL.URL, config *WSDialerConfig) (Peer, error) { + if config == nil { + config = &WSDialerConfig{SubProtocol: JsonWebsocketProtocol} + } wsDialer := ws.Dialer{ Protocols: []string{config.SubProtocol}, } diff --git a/peer.go b/peer.go index 379dcf6..53e796d 100644 --- a/peer.go +++ b/peer.go @@ -1,6 +1,7 @@ package xconn import ( + "crypto/rand" "errors" "fmt" "io" @@ -136,8 +137,12 @@ func (c *WebSocketPeer) startPinger(keepaliveInterval time.Duration, keepaliveTi for { <-ticker.C // Send a ping - err := c.writeOpFunc(c.conn, ws.OpPing, []byte("ping")) + randomBytes := make([]byte, 4) + _, err := rand.Read(randomBytes) if err != nil { + fmt.Println("failed to generate random bytes:", err) + } + if err := c.writeOpFunc(c.conn, ws.OpPing, randomBytes); err != nil { log.Printf("failed to send ping: %v\n", err) _ = c.conn.Close() return