Skip to content

Commit

Permalink
Merge pull request #64 from muzzammilshahid/ping-pong
Browse files Browse the repository at this point in the history
Implement websocket keepalive
  • Loading branch information
muzzammilshahid authored Dec 27, 2024
2 parents e43d826 + 4b7ec50 commit c2b4b5e
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 81 deletions.
16 changes: 13 additions & 3 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, er
return serializer, nil
}

func (w *WebSocketAcceptor) Accept(conn net.Conn) (BaseSession, error) {
config := DefaultWebSocketServerConfig()
func (w *WebSocketAcceptor) Accept(conn net.Conn, config *WebSocketServerConfig) (BaseSession, error) {
if config == nil {
config = DefaultWebSocketServerConfig()
}
config.SubProtocols = w.protocols()
peer, err := UpgradeWebSocket(conn, config)
if err != nil {
Expand Down Expand Up @@ -161,7 +163,15 @@ func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig) (Peer, error
}

isBinary := hs.Protocol != JsonWebsocketProtocol
peer, err := NewWebSocketPeer(conn, hs.Protocol, isBinary, true)

peerConfig := WSPeerConfig{
Protocol: hs.Protocol,
Binary: isBinary,
Server: true,
KeepAliveInterval: config.KeepAliveInterval,
KeepAliveTimeout: config.KeepAliveTimeout,
}
peer, err := NewWebSocketPeer(conn, peerConfig)
if err != nil {
return nil, fmt.Errorf("failed to init reader/writer: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion acceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestAccept(t *testing.T) {
require.NotNil(t, conn)

acceptor := xconn.WebSocketAcceptor{}
session, err := acceptor.Accept(conn)
session, err := acceptor.Accept(conn, nil)
require.NoError(t, err)
require.NotNil(t, session)

Expand Down
21 changes: 16 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ type Client struct {
SerializerSpec WSSerializerSpec
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

DialTimeout time.Duration
DialTimeout time.Duration
KeepAliveInterval time.Duration
KeepAliveTimeout time.Duration
}

func (c *Client) Connect(ctx context.Context, url string, realm string) (*Session, error) {
Expand All @@ -24,11 +26,17 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio
joiner := &WebSocketJoiner{
Authenticator: c.Authenticator,
SerializerSpec: c.SerializerSpec,
DialTimeout: c.DialTimeout,
NetDial: c.NetDial,
}

base, err := joiner.Join(ctx, url, realm)
dialerConfig := &WSDialerConfig{
SubProtocol: c.SerializerSpec.SubProtocol(),
DialTimeout: c.DialTimeout,
NetDial: c.NetDial,
KeepAliveInterval: c.KeepAliveInterval,
KeepAliveTimeout: c.KeepAliveTimeout,
}

base, err := joiner.Join(ctx, url, realm, dialerConfig)
if err != nil {
return nil, err
}
Expand All @@ -38,7 +46,10 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio

func Connect(ctx context.Context, url string, realm string) (*Session, error) {
joiner := &WebSocketJoiner{}
base, err := joiner.Join(ctx, url, realm)
dialerConfig := &WSDialerConfig{
SubProtocol: JsonWebsocketProtocol,
}
base, err := joiner.Join(ctx, url, realm, dialerConfig)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/xconn/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func Run(args []string) error {
throttle = internal.NewThrottle(transport.RateLimit.Rate,
time.Duration(transport.RateLimit.Interval)*time.Second, strategy)
}
server := xconn.NewServer(router, authenticator, throttle)
server := xconn.NewServer(router, authenticator, &xconn.ServerConfig{Throttle: throttle})
if slices.Contains(transport.Serializers, "protobuf") {
if err := server.RegisterSpec(xconn.ProtobufSerializerSpec); err != nil {
return err
Expand Down
18 changes: 0 additions & 18 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,10 @@ package xconn
import (
"fmt"

"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
)

func ClientSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadServerText, wsutil.WriteClientText, nil
}

return wsutil.ReadServerBinary, wsutil.WriteClientBinary, nil
}

func ServerSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadClientText, wsutil.WriteServerText, nil
}

return wsutil.ReadClientBinary, wsutil.WriteServerBinary, nil
}

func ReadMessage(peer Peer, serializer serializers.Serializer) (messages.Message, error) {
payload, err := peer.Read()
if err != nil {
Expand Down
32 changes: 15 additions & 17 deletions joiner.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@ import (
type WebSocketJoiner struct {
SerializerSpec WSSerializerSpec
Authenticator auth.ClientAuthenticator
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

DialTimeout time.Duration
}

func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSession, error) {
func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string, config *WSDialerConfig) (BaseSession, error) {
parsedURL, err := netURL.Parse(url)
if err != nil {
return nil, err
Expand All @@ -36,25 +33,18 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSess
w.Authenticator = auth.NewAnonymousAuthenticator("", nil)
}

if w.DialTimeout == 0 {
w.DialTimeout = time.Second * 10
}

dialConfig := WSDialerConfig{
SubProtocol: w.SerializerSpec.SubProtocol(),
DialTimeout: w.DialTimeout,
NetDial: w.NetDial,
}

peer, err := DialWebSocket(ctx, parsedURL, dialConfig)
peer, err := DialWebSocket(ctx, parsedURL, config)
if err != nil {
return nil, err
}

return Join(peer, realm, w.SerializerSpec.Serializer(), w.Authenticator)
}

func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig) (Peer, error) {
func DialWebSocket(ctx context.Context, url *netURL.URL, config *WSDialerConfig) (Peer, error) {
if config == nil {
config = &WSDialerConfig{SubProtocol: JsonWebsocketProtocol}
}
wsDialer := ws.Dialer{
Protocols: []string{config.SubProtocol},
}
Expand Down Expand Up @@ -83,7 +73,15 @@ func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig)
}

isBinary := config.SubProtocol != JsonWebsocketProtocol
return NewWebSocketPeer(conn, config.SubProtocol, isBinary, false)

peerConfig := WSPeerConfig{
Protocol: config.SubProtocol,
Binary: isBinary,
Server: false,
KeepAliveInterval: config.KeepAliveInterval,
KeepAliveTimeout: config.KeepAliveTimeout,
}
return NewWebSocketPeer(conn, peerConfig)
}

func Join(cl Peer, realm string, serializer serializers.Serializer,
Expand Down
4 changes: 3 additions & 1 deletion joiner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ func TestJoin(t *testing.T) {
address := fmt.Sprintf("ws://%s/ws", listener.Addr().String())

var joiner xconn.WebSocketJoiner
base, err := joiner.Join(context.Background(), address, "realm1")
base, err := joiner.Join(context.Background(), address, "realm1", &xconn.WSDialerConfig{
SubProtocol: xconn.JsonWebsocketProtocol,
})
require.NoError(t, err)
require.NotNil(t, base)

Expand Down
128 changes: 104 additions & 24 deletions peer.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
package xconn

import (
"crypto/rand"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"time"

"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
Expand Down Expand Up @@ -88,48 +97,119 @@ func (b *baseSession) Close() error {
return b.client.NetConn().Close()
}

func NewWebSocketPeer(conn net.Conn, protocol string, binary, server bool) (Peer, error) {
var wsReader ReaderFunc
var wsWriter WriterFunc
var err error
if server {
wsReader, wsWriter, err = ServerSideWSReaderWriter(binary)
} else {
wsReader, wsWriter, err = ClientSideWSReaderWriter(binary)
func NewWebSocketPeer(conn net.Conn, peerConfig WSPeerConfig) (Peer, error) {
peer := &WebSocketPeer{
transportType: TransportWebSocket,
protocol: peerConfig.Protocol,
conn: conn,
pingCh: make(chan struct{}, 1),
binary: peerConfig.Binary,
server: peerConfig.Server,
}

if err != nil {
return nil, err
if peerConfig.KeepAliveInterval != 0 {
// Start ping-pong handling
go peer.startPinger(peerConfig.KeepAliveInterval, peerConfig.KeepAliveTimeout)
}

return &WebSocketPeer{
transportType: TransportWebSocket,
protocol: protocol,
conn: conn,
wsReader: wsReader,
wsWriter: wsWriter,
}, nil
return peer, nil
}

type WebSocketPeer struct {
transportType TransportType
protocol string
conn net.Conn
wsReader ReaderFunc
wsWriter WriterFunc

wm sync.Mutex
pingCh chan struct{}
binary bool
server bool

sync.Mutex
}

func (c *WebSocketPeer) startPinger(keepaliveInterval time.Duration, keepaliveTimeout time.Duration) {
ticker := time.NewTicker(keepaliveInterval)
defer ticker.Stop()

if keepaliveTimeout == 0 {
keepaliveTimeout = 10 * time.Second
}
for {
<-ticker.C
// Send a ping
randomBytes := make([]byte, 4)
_, err := rand.Read(randomBytes)
if err != nil {
fmt.Println("failed to generate random bytes:", err)
}
if err := c.writeOpFunc(c.conn, ws.OpPing, randomBytes); err != nil {
log.Printf("failed to send ping: %v\n", err)
_ = c.conn.Close()
return
}

select {
case <-c.pingCh:
case <-time.After(keepaliveTimeout):
log.Println("ping timeout, closing connection")
_ = c.conn.Close()
return
}
}
}

func (c *WebSocketPeer) peerState() ws.State {
if c.server {
return ws.StateServerSide
}
return ws.StateClientSide
}

func (c *WebSocketPeer) writeOpFunc(w io.Writer, op ws.OpCode, p []byte) error {
if c.server {
return wsutil.WriteServerMessage(w, op, p)
}

return wsutil.WriteClientMessage(w, op, p)
}

func (c *WebSocketPeer) Read() ([]byte, error) {
return c.wsReader(c.conn)
header, reader, err := wsutil.NextReader(c.conn, c.peerState())
if err != nil {
return nil, err
}

payload := make([]byte, header.Length)
_, err = reader.Read(payload)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}

switch header.OpCode {
case ws.OpText, ws.OpBinary:
return payload, nil
case ws.OpPing:
if err = c.writeOpFunc(c.conn, ws.OpPong, payload); err != nil {
return nil, fmt.Errorf("failed to send pong: %w", err)
}
case ws.OpPong:
c.pingCh <- struct{}{}
case ws.OpClose:
_ = c.conn.Close()
return nil, fmt.Errorf("connection closed")
}

return c.Read()
}

func (c *WebSocketPeer) Write(bytes []byte) error {
c.wm.Lock()
defer c.wm.Unlock()
c.Lock()
defer c.Unlock()
if c.binary {
return c.writeOpFunc(c.conn, ws.OpBinary, bytes)
}

return c.wsWriter(c.conn, bytes)
return c.writeOpFunc(c.conn, ws.OpText, bytes)
}

func (c *WebSocketPeer) Type() TransportType {
Expand Down
Loading

0 comments on commit c2b4b5e

Please sign in to comment.