Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement websocket keepalive #64

Merged
merged 2 commits into from
Dec 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions acceptor.go
Original file line number Diff line number Diff line change
@@ -69,8 +69,10 @@ func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, er
return serializer, nil
}

func (w *WebSocketAcceptor) Accept(conn net.Conn) (BaseSession, error) {
config := DefaultWebSocketServerConfig()
func (w *WebSocketAcceptor) Accept(conn net.Conn, config *WebSocketServerConfig) (BaseSession, error) {
if config == nil {
config = DefaultWebSocketServerConfig()
}
config.SubProtocols = w.protocols()
peer, err := UpgradeWebSocket(conn, config)
if err != nil {
@@ -161,7 +163,15 @@ func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig) (Peer, error
}

isBinary := hs.Protocol != JsonWebsocketProtocol
peer, err := NewWebSocketPeer(conn, hs.Protocol, isBinary, true)

peerConfig := WSPeerConfig{
Protocol: hs.Protocol,
Binary: isBinary,
Server: true,
KeepAliveInterval: config.KeepAliveInterval,
KeepAliveTimeout: config.KeepAliveTimeout,
}
peer, err := NewWebSocketPeer(conn, peerConfig)
if err != nil {
return nil, fmt.Errorf("failed to init reader/writer: %w", err)
}
2 changes: 1 addition & 1 deletion acceptor_test.go
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ func TestAccept(t *testing.T) {
require.NotNil(t, conn)

acceptor := xconn.WebSocketAcceptor{}
session, err := acceptor.Accept(conn)
session, err := acceptor.Accept(conn, nil)
require.NoError(t, err)
require.NotNil(t, session)

21 changes: 16 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
@@ -13,7 +13,9 @@ type Client struct {
SerializerSpec WSSerializerSpec
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

DialTimeout time.Duration
DialTimeout time.Duration
KeepAliveInterval time.Duration
KeepAliveTimeout time.Duration
}

func (c *Client) Connect(ctx context.Context, url string, realm string) (*Session, error) {
@@ -24,11 +26,17 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio
joiner := &WebSocketJoiner{
Authenticator: c.Authenticator,
SerializerSpec: c.SerializerSpec,
DialTimeout: c.DialTimeout,
NetDial: c.NetDial,
}

base, err := joiner.Join(ctx, url, realm)
dialerConfig := &WSDialerConfig{
SubProtocol: c.SerializerSpec.SubProtocol(),
DialTimeout: c.DialTimeout,
NetDial: c.NetDial,
KeepAliveInterval: c.KeepAliveInterval,
KeepAliveTimeout: c.KeepAliveTimeout,
}

base, err := joiner.Join(ctx, url, realm, dialerConfig)
if err != nil {
return nil, err
}
@@ -38,7 +46,10 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio

func Connect(ctx context.Context, url string, realm string) (*Session, error) {
joiner := &WebSocketJoiner{}
base, err := joiner.Join(ctx, url, realm)
dialerConfig := &WSDialerConfig{
SubProtocol: JsonWebsocketProtocol,
}
base, err := joiner.Join(ctx, url, realm, dialerConfig)
if err != nil {
return nil, err
}
2 changes: 1 addition & 1 deletion cmd/xconn/main.go
Original file line number Diff line number Diff line change
@@ -116,7 +116,7 @@ func Run(args []string) error {
throttle = internal.NewThrottle(transport.RateLimit.Rate,
time.Duration(transport.RateLimit.Interval)*time.Second, strategy)
}
server := xconn.NewServer(router, authenticator, throttle)
server := xconn.NewServer(router, authenticator, &xconn.ServerConfig{Throttle: throttle})
if slices.Contains(transport.Serializers, "protobuf") {
if err := server.RegisterSpec(xconn.ProtobufSerializerSpec); err != nil {
return err
18 changes: 0 additions & 18 deletions helpers.go
Original file line number Diff line number Diff line change
@@ -3,28 +3,10 @@ package xconn
import (
"fmt"

"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
)

func ClientSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadServerText, wsutil.WriteClientText, nil
}

return wsutil.ReadServerBinary, wsutil.WriteClientBinary, nil
}

func ServerSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadClientText, wsutil.WriteServerText, nil
}

return wsutil.ReadClientBinary, wsutil.WriteServerBinary, nil
}

func ReadMessage(peer Peer, serializer serializers.Serializer) (messages.Message, error) {
payload, err := peer.Read()
if err != nil {
32 changes: 15 additions & 17 deletions joiner.go
Original file line number Diff line number Diff line change
@@ -17,12 +17,9 @@ import (
type WebSocketJoiner struct {
SerializerSpec WSSerializerSpec
Authenticator auth.ClientAuthenticator
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

DialTimeout time.Duration
}

func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSession, error) {
func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string, config *WSDialerConfig) (BaseSession, error) {
parsedURL, err := netURL.Parse(url)
if err != nil {
return nil, err
@@ -36,25 +33,18 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSess
w.Authenticator = auth.NewAnonymousAuthenticator("", nil)
}

if w.DialTimeout == 0 {
w.DialTimeout = time.Second * 10
}

dialConfig := WSDialerConfig{
SubProtocol: w.SerializerSpec.SubProtocol(),
DialTimeout: w.DialTimeout,
NetDial: w.NetDial,
}

peer, err := DialWebSocket(ctx, parsedURL, dialConfig)
peer, err := DialWebSocket(ctx, parsedURL, config)
if err != nil {
return nil, err
}

return Join(peer, realm, w.SerializerSpec.Serializer(), w.Authenticator)
}

func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig) (Peer, error) {
func DialWebSocket(ctx context.Context, url *netURL.URL, config *WSDialerConfig) (Peer, error) {
if config == nil {
config = &WSDialerConfig{SubProtocol: JsonWebsocketProtocol}
}
wsDialer := ws.Dialer{
Protocols: []string{config.SubProtocol},
}
@@ -83,7 +73,15 @@ func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig)
}

isBinary := config.SubProtocol != JsonWebsocketProtocol
return NewWebSocketPeer(conn, config.SubProtocol, isBinary, false)

peerConfig := WSPeerConfig{
Protocol: config.SubProtocol,
Binary: isBinary,
Server: false,
KeepAliveInterval: config.KeepAliveInterval,
KeepAliveTimeout: config.KeepAliveTimeout,
}
return NewWebSocketPeer(conn, peerConfig)
}

func Join(cl Peer, realm string, serializer serializers.Serializer,
4 changes: 3 additions & 1 deletion joiner_test.go
Original file line number Diff line number Diff line change
@@ -37,7 +37,9 @@ func TestJoin(t *testing.T) {
address := fmt.Sprintf("ws://%s/ws", listener.Addr().String())

var joiner xconn.WebSocketJoiner
base, err := joiner.Join(context.Background(), address, "realm1")
base, err := joiner.Join(context.Background(), address, "realm1", &xconn.WSDialerConfig{
SubProtocol: xconn.JsonWebsocketProtocol,
})
require.NoError(t, err)
require.NotNil(t, base)

128 changes: 104 additions & 24 deletions peer.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
package xconn

import (
"crypto/rand"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"time"

"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
@@ -88,48 +97,119 @@ func (b *baseSession) Close() error {
return b.client.NetConn().Close()
}

func NewWebSocketPeer(conn net.Conn, protocol string, binary, server bool) (Peer, error) {
var wsReader ReaderFunc
var wsWriter WriterFunc
var err error
if server {
wsReader, wsWriter, err = ServerSideWSReaderWriter(binary)
} else {
wsReader, wsWriter, err = ClientSideWSReaderWriter(binary)
func NewWebSocketPeer(conn net.Conn, peerConfig WSPeerConfig) (Peer, error) {
peer := &WebSocketPeer{
transportType: TransportWebSocket,
protocol: peerConfig.Protocol,
conn: conn,
pingCh: make(chan struct{}, 1),
binary: peerConfig.Binary,
server: peerConfig.Server,
}

if err != nil {
return nil, err
if peerConfig.KeepAliveInterval != 0 {
// Start ping-pong handling
go peer.startPinger(peerConfig.KeepAliveInterval, peerConfig.KeepAliveTimeout)
}

return &WebSocketPeer{
transportType: TransportWebSocket,
protocol: protocol,
conn: conn,
wsReader: wsReader,
wsWriter: wsWriter,
}, nil
return peer, nil
}

type WebSocketPeer struct {
transportType TransportType
protocol string
conn net.Conn
wsReader ReaderFunc
wsWriter WriterFunc

wm sync.Mutex
pingCh chan struct{}
binary bool
server bool

sync.Mutex
}

func (c *WebSocketPeer) startPinger(keepaliveInterval time.Duration, keepaliveTimeout time.Duration) {
ticker := time.NewTicker(keepaliveInterval)
defer ticker.Stop()

if keepaliveTimeout == 0 {
keepaliveTimeout = 10 * time.Second
}
for {
<-ticker.C
// Send a ping
randomBytes := make([]byte, 4)
_, err := rand.Read(randomBytes)
if err != nil {
fmt.Println("failed to generate random bytes:", err)
}
if err := c.writeOpFunc(c.conn, ws.OpPing, randomBytes); err != nil {
log.Printf("failed to send ping: %v\n", err)
_ = c.conn.Close()
return
}

select {
case <-c.pingCh:
case <-time.After(keepaliveTimeout):
log.Println("ping timeout, closing connection")
_ = c.conn.Close()
return
}
}
}

func (c *WebSocketPeer) peerState() ws.State {
if c.server {
return ws.StateServerSide
}
return ws.StateClientSide
}

func (c *WebSocketPeer) writeOpFunc(w io.Writer, op ws.OpCode, p []byte) error {
if c.server {
return wsutil.WriteServerMessage(w, op, p)
}

return wsutil.WriteClientMessage(w, op, p)
}

func (c *WebSocketPeer) Read() ([]byte, error) {
return c.wsReader(c.conn)
header, reader, err := wsutil.NextReader(c.conn, c.peerState())
if err != nil {
return nil, err
}

payload := make([]byte, header.Length)
_, err = reader.Read(payload)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}

switch header.OpCode {
case ws.OpText, ws.OpBinary:
return payload, nil
case ws.OpPing:
if err = c.writeOpFunc(c.conn, ws.OpPong, payload); err != nil {
return nil, fmt.Errorf("failed to send pong: %w", err)
}
case ws.OpPong:
c.pingCh <- struct{}{}
case ws.OpClose:
_ = c.conn.Close()
return nil, fmt.Errorf("connection closed")
}

return c.Read()
}

func (c *WebSocketPeer) Write(bytes []byte) error {
c.wm.Lock()
defer c.wm.Unlock()
c.Lock()
defer c.Unlock()
if c.binary {
return c.writeOpFunc(c.conn, ws.OpBinary, bytes)
}

return c.wsWriter(c.conn, bytes)
return c.writeOpFunc(c.conn, ws.OpText, bytes)
}

func (c *WebSocketPeer) Type() TransportType {
Loading
Loading