Skip to content

Commit

Permalink
Implement websocket keepalive
Browse files Browse the repository at this point in the history
  • Loading branch information
muzzammilshahid committed Dec 27, 2024
1 parent e43d826 commit ba42f2f
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 67 deletions.
13 changes: 10 additions & 3 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, er
return serializer, nil
}

func (w *WebSocketAcceptor) Accept(conn net.Conn) (BaseSession, error) {
config := DefaultWebSocketServerConfig()
func (w *WebSocketAcceptor) Accept(conn net.Conn, config *WebSocketServerConfig) (BaseSession, error) {
config.SubProtocols = w.protocols()
peer, err := UpgradeWebSocket(conn, config)
if err != nil {
Expand Down Expand Up @@ -161,7 +160,15 @@ func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig) (Peer, error
}

isBinary := hs.Protocol != JsonWebsocketProtocol
peer, err := NewWebSocketPeer(conn, hs.Protocol, isBinary, true)

peerConfig := WSPeerConfig{
Protocol: hs.Protocol,
Binary: isBinary,
Server: true,
KeepAliveInterval: config.KeepAliveInterval,
KeepAliveTimeout: config.KeepAliveTimeout,
}
peer, err := NewWebSocketPeer(conn, peerConfig)
if err != nil {
return nil, fmt.Errorf("failed to init reader/writer: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion acceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestAccept(t *testing.T) {
require.NotNil(t, conn)

acceptor := xconn.WebSocketAcceptor{}
session, err := acceptor.Accept(conn)
session, err := acceptor.Accept(conn, xconn.DefaultWebSocketServerConfig())
require.NoError(t, err)
require.NotNil(t, session)

Expand Down
8 changes: 5 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ type Client struct {
SerializerSpec WSSerializerSpec
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

DialTimeout time.Duration
DialTimeout time.Duration
KeepAliveInterval time.Duration
KeepAliveTimeout time.Duration
}

func (c *Client) Connect(ctx context.Context, url string, realm string) (*Session, error) {
Expand All @@ -28,7 +30,7 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio
NetDial: c.NetDial,
}

base, err := joiner.Join(ctx, url, realm)
base, err := joiner.Join(ctx, url, realm, c.KeepAliveInterval, c.KeepAliveTimeout)
if err != nil {
return nil, err
}
Expand All @@ -38,7 +40,7 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio

func Connect(ctx context.Context, url string, realm string) (*Session, error) {
joiner := &WebSocketJoiner{}
base, err := joiner.Join(ctx, url, realm)
base, err := joiner.Join(ctx, url, realm, time.Duration(0), time.Duration(0))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/xconn/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func Run(args []string) error {
throttle = internal.NewThrottle(transport.RateLimit.Rate,
time.Duration(transport.RateLimit.Interval)*time.Second, strategy)
}
server := xconn.NewServer(router, authenticator, throttle)
server := xconn.NewServer(router, authenticator, &xconn.ServerConfig{Throttle: throttle})
if slices.Contains(transport.Serializers, "protobuf") {
if err := server.RegisterSpec(xconn.ProtobufSerializerSpec); err != nil {
return err
Expand Down
18 changes: 0 additions & 18 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,10 @@ package xconn
import (
"fmt"

"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
)

func ClientSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadServerText, wsutil.WriteClientText, nil
}

return wsutil.ReadServerBinary, wsutil.WriteClientBinary, nil
}

func ServerSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadClientText, wsutil.WriteServerText, nil
}

return wsutil.ReadClientBinary, wsutil.WriteServerBinary, nil
}

func ReadMessage(peer Peer, serializer serializers.Serializer) (messages.Message, error) {
payload, err := peer.Read()
if err != nil {
Expand Down
21 changes: 16 additions & 5 deletions joiner.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ type WebSocketJoiner struct {
DialTimeout time.Duration
}

func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSession, error) {
func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string, keepaliveInterval time.Duration,
keepaliveTimeout time.Duration) (BaseSession, error) {
parsedURL, err := netURL.Parse(url)
if err != nil {
return nil, err
Expand All @@ -41,9 +42,11 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSess
}

dialConfig := WSDialerConfig{
SubProtocol: w.SerializerSpec.SubProtocol(),
DialTimeout: w.DialTimeout,
NetDial: w.NetDial,
SubProtocol: w.SerializerSpec.SubProtocol(),
DialTimeout: w.DialTimeout,
NetDial: w.NetDial,
KeepAliveInterval: keepaliveInterval,
KeepAliveTimeout: keepaliveTimeout,
}

peer, err := DialWebSocket(ctx, parsedURL, dialConfig)
Expand Down Expand Up @@ -83,7 +86,15 @@ func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig)
}

isBinary := config.SubProtocol != JsonWebsocketProtocol
return NewWebSocketPeer(conn, config.SubProtocol, isBinary, false)

peerConfig := WSPeerConfig{
Protocol: config.SubProtocol,
Binary: isBinary,
Server: false,
KeepAliveInterval: config.KeepAliveInterval,
KeepAliveTimeout: config.KeepAliveTimeout,
}
return NewWebSocketPeer(conn, peerConfig)
}

func Join(cl Peer, realm string, serializer serializers.Serializer,
Expand Down
3 changes: 2 additions & 1 deletion joiner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"testing"
"time"

"github.com/gammazero/nexus/v3/router"
"github.com/gammazero/nexus/v3/wamp"
Expand Down Expand Up @@ -37,7 +38,7 @@ func TestJoin(t *testing.T) {
address := fmt.Sprintf("ws://%s/ws", listener.Addr().String())

var joiner xconn.WebSocketJoiner
base, err := joiner.Join(context.Background(), address, "realm1")
base, err := joiner.Join(context.Background(), address, "realm1", time.Duration(0), time.Duration(0))
require.NoError(t, err)
require.NotNil(t, base)

Expand Down
123 changes: 99 additions & 24 deletions peer.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
package xconn

import (
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"time"

"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
Expand Down Expand Up @@ -88,48 +96,115 @@ func (b *baseSession) Close() error {
return b.client.NetConn().Close()
}

func NewWebSocketPeer(conn net.Conn, protocol string, binary, server bool) (Peer, error) {
var wsReader ReaderFunc
var wsWriter WriterFunc
var err error
if server {
wsReader, wsWriter, err = ServerSideWSReaderWriter(binary)
} else {
wsReader, wsWriter, err = ClientSideWSReaderWriter(binary)
func NewWebSocketPeer(conn net.Conn, peerConfig WSPeerConfig) (Peer, error) {
peer := &WebSocketPeer{
transportType: TransportWebSocket,
protocol: peerConfig.Protocol,
conn: conn,
pingCh: make(chan struct{}, 1),
binary: peerConfig.Binary,
server: peerConfig.Server,
}

if err != nil {
return nil, err
if peerConfig.KeepAliveInterval != 0 {
// Start ping-pong handling
go peer.startPinger(peerConfig.KeepAliveInterval, peerConfig.KeepAliveTimeout)
}

return &WebSocketPeer{
transportType: TransportWebSocket,
protocol: protocol,
conn: conn,
wsReader: wsReader,
wsWriter: wsWriter,
}, nil
return peer, nil
}

type WebSocketPeer struct {
transportType TransportType
protocol string
conn net.Conn
wsReader ReaderFunc
wsWriter WriterFunc

wm sync.Mutex
pingCh chan struct{}
binary bool
server bool

sync.Mutex
}

func (c *WebSocketPeer) startPinger(keepaliveInterval time.Duration, keepaliveTimeout time.Duration) {
ticker := time.NewTicker(keepaliveInterval)
defer ticker.Stop()

if keepaliveTimeout == 0 {
keepaliveTimeout = 5 * time.Second
}
for {
<-ticker.C
// Send a ping
err := c.writeOpFunc(c.conn, ws.OpPing, []byte("ping"))
if err != nil {
log.Printf("failed to send ping: %v\n", err)
_ = c.conn.Close()
return
}

select {
case <-c.pingCh:
case <-time.After(keepaliveTimeout):
log.Println("ping timeout, closing connection")
_ = c.conn.Close()
return
}
}
}

func (c *WebSocketPeer) peerState() ws.State {
if c.server {
return ws.StateServerSide
}
return ws.StateClientSide
}

func (c *WebSocketPeer) writeOpFunc(w io.Writer, op ws.OpCode, p []byte) error {
if c.server {
return wsutil.WriteServerMessage(w, op, p)
}

return wsutil.WriteClientMessage(w, op, p)
}

func (c *WebSocketPeer) Read() ([]byte, error) {
return c.wsReader(c.conn)
header, reader, err := wsutil.NextReader(c.conn, c.peerState())
if err != nil {
return nil, err
}

payload := make([]byte, header.Length)
_, err = reader.Read(payload)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}

switch header.OpCode {
case ws.OpText, ws.OpBinary:
return payload, nil
case ws.OpPing:
if err = c.writeOpFunc(c.conn, ws.OpPong, payload); err != nil {
return nil, fmt.Errorf("failed to send pong: %w", err)
}
case ws.OpPong:
c.pingCh <- struct{}{}
case ws.OpClose:
_ = c.conn.Close()
return nil, fmt.Errorf("connection closed")
}

return c.Read()
}

func (c *WebSocketPeer) Write(bytes []byte) error {
c.wm.Lock()
defer c.wm.Unlock()
c.Lock()
defer c.Unlock()
if c.binary {
return c.writeOpFunc(c.conn, ws.OpBinary, bytes)
}

return c.wsWriter(c.conn, bytes)
return c.writeOpFunc(c.conn, ws.OpText, bytes)
}

func (c *WebSocketPeer) Type() TransportType {
Expand Down
27 changes: 20 additions & 7 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"net"
"os"
"time"

"github.com/projectdiscovery/ratelimit"

Expand All @@ -14,21 +15,30 @@ import (
)

type Server struct {
router *Router
acceptor *WebSocketAcceptor
throttle *internal.Throttle
router *Router
acceptor *WebSocketAcceptor
throttle *internal.Throttle
keepAliveInterval time.Duration
keepAliveTimeout time.Duration
}

func NewServer(router *Router, authenticator auth.ServerAuthenticator, throttle *internal.Throttle) *Server {
func NewServer(router *Router, authenticator auth.ServerAuthenticator, config *ServerConfig) *Server {
acceptor := &WebSocketAcceptor{
Authenticator: authenticator,
}

return &Server{
server := &Server{
router: router,
acceptor: acceptor,
throttle: throttle,
}

if config != nil {
server.throttle = config.Throttle
server.keepAliveInterval = config.KeepAliveInterval
server.keepAliveTimeout = config.KeepAliveTimeout
}

return server
}

func (s *Server) RegisterSpec(spec WSSerializerSpec) error {
Expand All @@ -51,7 +61,10 @@ func (s *Server) Start(host string, port int) (io.Closer, error) {
}

func (s *Server) HandleClient(conn net.Conn) {
base, err := s.acceptor.Accept(conn)
config := DefaultWebSocketServerConfig()
config.KeepAliveInterval = s.keepAliveInterval
config.KeepAliveTimeout = s.keepAliveTimeout
base, err := s.acceptor.Accept(conn, config)
if err != nil {
return
}
Expand Down
Loading

0 comments on commit ba42f2f

Please sign in to comment.