Skip to content

Commit

Permalink
Implement websocket ping pong
Browse files Browse the repository at this point in the history
  • Loading branch information
muzzammilshahid committed Dec 26, 2024
1 parent e43d826 commit 3fe558a
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 65 deletions.
11 changes: 7 additions & 4 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"sync"
"time"

"github.com/gobwas/ws"
"golang.org/x/exp/maps"
Expand Down Expand Up @@ -69,10 +70,11 @@ func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, er
return serializer, nil
}

func (w *WebSocketAcceptor) Accept(conn net.Conn) (BaseSession, error) {
func (w *WebSocketAcceptor) Accept(conn net.Conn, keepaliveInterval time.Duration,
keepaliveTimeout time.Duration) (BaseSession, error) {
config := DefaultWebSocketServerConfig()
config.SubProtocols = w.protocols()
peer, err := UpgradeWebSocket(conn, config)
peer, err := UpgradeWebSocket(conn, config, keepaliveInterval, keepaliveTimeout)
if err != nil {
return nil, fmt.Errorf("failed to init reader/writer: %w", err)
}
Expand Down Expand Up @@ -134,7 +136,8 @@ Welcomed:
return details, nil
}

func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig) (Peer, error) {
func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig, keepaliveInterval time.Duration,
keepaliveTimeout time.Duration) (Peer, error) {
wsUpgrader := ws.Upgrader{
Protocol: func(protoBytes []byte) bool {
if config == nil {
Expand All @@ -161,7 +164,7 @@ func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig) (Peer, error
}

isBinary := hs.Protocol != JsonWebsocketProtocol
peer, err := NewWebSocketPeer(conn, hs.Protocol, isBinary, true)
peer, err := NewWebSocketPeer(conn, hs.Protocol, isBinary, true, keepaliveInterval, keepaliveTimeout)
if err != nil {
return nil, fmt.Errorf("failed to init reader/writer: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion acceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"testing"
"time"

"github.com/gammazero/nexus/v3/client"
"github.com/stretchr/testify/require"
Expand All @@ -25,7 +26,7 @@ func TestAccept(t *testing.T) {
require.NotNil(t, conn)

acceptor := xconn.WebSocketAcceptor{}
session, err := acceptor.Accept(conn)
session, err := acceptor.Accept(conn, time.Duration(0), time.Duration(0))
require.NoError(t, err)
require.NotNil(t, session)

Expand Down
8 changes: 5 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ type Client struct {
SerializerSpec WSSerializerSpec
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

DialTimeout time.Duration
DialTimeout time.Duration
KeepaliveInterval time.Duration
KeepaliveTimeout time.Duration
}

func (c *Client) Connect(ctx context.Context, url string, realm string) (*Session, error) {
Expand All @@ -28,7 +30,7 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio
NetDial: c.NetDial,
}

base, err := joiner.Join(ctx, url, realm)
base, err := joiner.Join(ctx, url, realm, c.KeepaliveInterval, c.KeepaliveTimeout)
if err != nil {
return nil, err
}
Expand All @@ -38,7 +40,7 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio

func Connect(ctx context.Context, url string, realm string) (*Session, error) {
joiner := &WebSocketJoiner{}
base, err := joiner.Join(ctx, url, realm)
base, err := joiner.Join(ctx, url, realm, time.Duration(0), time.Duration(0))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/xconn/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func Run(args []string) error {
throttle = internal.NewThrottle(transport.RateLimit.Rate,
time.Duration(transport.RateLimit.Interval)*time.Second, strategy)
}
server := xconn.NewServer(router, authenticator, throttle)
server := xconn.NewServer(router, authenticator, throttle, time.Duration(0), time.Duration(0))
if slices.Contains(transport.Serializers, "protobuf") {
if err := server.RegisterSpec(xconn.ProtobufSerializerSpec); err != nil {
return err
Expand Down
3 changes: 2 additions & 1 deletion examples/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"
"os"
"os/signal"
"time"

"github.com/xconnio/xconn-go"
)
Expand All @@ -26,7 +27,7 @@ func main() {
r.AddRealm(*realm)
defer r.Close()

server := xconn.NewServer(r, nil, nil)
server := xconn.NewServer(r, nil, nil, time.Duration(0), time.Duration(0))
closer, err := server.Start(*host, *port)
if err != nil {
log.Fatal("Failed to start server:", err)
Expand Down
18 changes: 0 additions & 18 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,10 @@ package xconn
import (
"fmt"

"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
)

func ClientSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadServerText, wsutil.WriteClientText, nil
}

return wsutil.ReadServerBinary, wsutil.WriteClientBinary, nil
}

func ServerSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadClientText, wsutil.WriteServerText, nil
}

return wsutil.ReadClientBinary, wsutil.WriteServerBinary, nil
}

func ReadMessage(peer Peer, serializer serializers.Serializer) (messages.Message, error) {
payload, err := peer.Read()
if err != nil {
Expand Down
10 changes: 6 additions & 4 deletions joiner.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ type WebSocketJoiner struct {
DialTimeout time.Duration
}

func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSession, error) {
func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string, keepaliveInterval time.Duration,
keepaliveTimeout time.Duration) (BaseSession, error) {
parsedURL, err := netURL.Parse(url)
if err != nil {
return nil, err
Expand All @@ -46,15 +47,16 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSess
NetDial: w.NetDial,
}

peer, err := DialWebSocket(ctx, parsedURL, dialConfig)
peer, err := DialWebSocket(ctx, parsedURL, dialConfig, keepaliveInterval, keepaliveTimeout)
if err != nil {
return nil, err
}

return Join(peer, realm, w.SerializerSpec.Serializer(), w.Authenticator)
}

func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig) (Peer, error) {
func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig, keepaliveInterval time.Duration,
keepaliveTimeout time.Duration) (Peer, error) {
wsDialer := ws.Dialer{
Protocols: []string{config.SubProtocol},
}
Expand Down Expand Up @@ -83,7 +85,7 @@ func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig)
}

isBinary := config.SubProtocol != JsonWebsocketProtocol
return NewWebSocketPeer(conn, config.SubProtocol, isBinary, false)
return NewWebSocketPeer(conn, config.SubProtocol, isBinary, false, keepaliveInterval, keepaliveTimeout)
}

func Join(cl Peer, realm string, serializer serializers.Serializer,
Expand Down
3 changes: 2 additions & 1 deletion joiner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"testing"
"time"

"github.com/gammazero/nexus/v3/router"
"github.com/gammazero/nexus/v3/wamp"
Expand Down Expand Up @@ -37,7 +38,7 @@ func TestJoin(t *testing.T) {
address := fmt.Sprintf("ws://%s/ws", listener.Addr().String())

var joiner xconn.WebSocketJoiner
base, err := joiner.Join(context.Background(), address, "realm1")
base, err := joiner.Join(context.Background(), address, "realm1", time.Duration(0), time.Duration(0))
require.NoError(t, err)
require.NotNil(t, base)

Expand Down
125 changes: 101 additions & 24 deletions peer.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
package xconn

import (
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"time"

"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
Expand Down Expand Up @@ -88,48 +96,117 @@ func (b *baseSession) Close() error {
return b.client.NetConn().Close()
}

func NewWebSocketPeer(conn net.Conn, protocol string, binary, server bool) (Peer, error) {
var wsReader ReaderFunc
var wsWriter WriterFunc
var err error
if server {
wsReader, wsWriter, err = ServerSideWSReaderWriter(binary)
} else {
wsReader, wsWriter, err = ClientSideWSReaderWriter(binary)
func NewWebSocketPeer(conn net.Conn, protocol string, binary, server bool, keepaliveInterval time.Duration,
keepaliveTimeout time.Duration) (Peer, error) {
peer := &WebSocketPeer{
transportType: TransportWebSocket,
protocol: protocol,
conn: conn,
pingCh: make(chan struct{}, 1),
binary: binary,
server: server,
}

if err != nil {
return nil, err
if keepaliveInterval != 0*time.Second {
// Start ping-pong handling
go peer.startPinger(keepaliveInterval, keepaliveTimeout)
}

return &WebSocketPeer{
transportType: TransportWebSocket,
protocol: protocol,
conn: conn,
wsReader: wsReader,
wsWriter: wsWriter,
}, nil
return peer, nil
}

type WebSocketPeer struct {
transportType TransportType
protocol string
conn net.Conn
wsReader ReaderFunc
wsWriter WriterFunc

wm sync.Mutex
pingCh chan struct{}
binary bool
server bool

sync.Mutex
}

func (c *WebSocketPeer) startPinger(keepaliveInterval time.Duration, keepaliveTimeout time.Duration) {
ticker := time.NewTicker(keepaliveInterval)
defer ticker.Stop()

if keepaliveTimeout == 0*time.Second {
keepaliveTimeout = 5 * time.Second
}
for {
<-ticker.C
// Send a ping
err := c.writeOpFunc(c.conn, ws.OpPing, []byte("ping"))
if err != nil {
log.Printf("failed to send ping: %v\n", err)
_ = c.conn.Close()
return
}

select {
case <-c.pingCh:
case <-time.After(keepaliveTimeout):
log.Println("pong timeout, closing connection")
_ = c.conn.Close()
return
}
}
}

func (c *WebSocketPeer) peerState() ws.State {
if c.server {
return ws.StateServerSide
}
return ws.StateClientSide
}

func (c *WebSocketPeer) writeOpFunc(w io.Writer, op ws.OpCode, p []byte) error {
if c.server {
return wsutil.WriteServerMessage(w, op, p)
}

return wsutil.WriteClientMessage(w, op, p)
}

func (c *WebSocketPeer) Read() ([]byte, error) {
return c.wsReader(c.conn)
header, reader, err := wsutil.NextReader(c.conn, c.peerState())
if err != nil {
return nil, err
}

payload := make([]byte, header.Length)
_, err = reader.Read(payload)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}

switch header.OpCode {
case ws.OpText, ws.OpBinary:
return payload, nil
case ws.OpPing:
err = c.writeOpFunc(c.conn, ws.OpPong, payload)
if err != nil {
return nil, fmt.Errorf("failed to send pong: %w", err)
}
case ws.OpPong:
c.pingCh <- struct{}{}
case ws.OpClose:
_ = c.conn.Close()
return nil, fmt.Errorf("connection closed")
}

return c.Read()
}

func (c *WebSocketPeer) Write(bytes []byte) error {
c.wm.Lock()
defer c.wm.Unlock()
c.Lock()
defer c.Unlock()
if c.binary {
return c.writeOpFunc(c.conn, ws.OpBinary, bytes)
}

return c.wsWriter(c.conn, bytes)
return c.writeOpFunc(c.conn, ws.OpText, bytes)
}

func (c *WebSocketPeer) Type() TransportType {
Expand Down
Loading

0 comments on commit 3fe558a

Please sign in to comment.