Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement websocket keepalive #64

Merged
merged 2 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ func (w *WebSocketAcceptor) Spec(subProtocol string) (serializers.Serializer, er
return serializer, nil
}

func (w *WebSocketAcceptor) Accept(conn net.Conn) (BaseSession, error) {
config := DefaultWebSocketServerConfig()
func (w *WebSocketAcceptor) Accept(conn net.Conn, config *WebSocketServerConfig) (BaseSession, error) {
if config == nil {
config = DefaultWebSocketServerConfig()
}
config.SubProtocols = w.protocols()
om26er marked this conversation as resolved.
Show resolved Hide resolved
peer, err := UpgradeWebSocket(conn, config)
if err != nil {
Expand Down Expand Up @@ -161,7 +163,15 @@ func UpgradeWebSocket(conn net.Conn, config *WebSocketServerConfig) (Peer, error
}

isBinary := hs.Protocol != JsonWebsocketProtocol
peer, err := NewWebSocketPeer(conn, hs.Protocol, isBinary, true)

peerConfig := WSPeerConfig{
Protocol: hs.Protocol,
Binary: isBinary,
Server: true,
KeepAliveInterval: config.KeepAliveInterval,
KeepAliveTimeout: config.KeepAliveTimeout,
}
peer, err := NewWebSocketPeer(conn, peerConfig)
if err != nil {
return nil, fmt.Errorf("failed to init reader/writer: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion acceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestAccept(t *testing.T) {
require.NotNil(t, conn)

acceptor := xconn.WebSocketAcceptor{}
session, err := acceptor.Accept(conn)
session, err := acceptor.Accept(conn, nil)
require.NoError(t, err)
require.NotNil(t, session)

Expand Down
21 changes: 16 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ type Client struct {
SerializerSpec WSSerializerSpec
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

DialTimeout time.Duration
DialTimeout time.Duration
KeepAliveInterval time.Duration
KeepAliveTimeout time.Duration
}

func (c *Client) Connect(ctx context.Context, url string, realm string) (*Session, error) {
Expand All @@ -24,11 +26,17 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio
joiner := &WebSocketJoiner{
Authenticator: c.Authenticator,
SerializerSpec: c.SerializerSpec,
DialTimeout: c.DialTimeout,
NetDial: c.NetDial,
}

base, err := joiner.Join(ctx, url, realm)
dialerConfig := &WSDialerConfig{
SubProtocol: c.SerializerSpec.SubProtocol(),
DialTimeout: c.DialTimeout,
NetDial: c.NetDial,
KeepAliveInterval: c.KeepAliveInterval,
KeepAliveTimeout: c.KeepAliveTimeout,
}

base, err := joiner.Join(ctx, url, realm, dialerConfig)
if err != nil {
return nil, err
}
Expand All @@ -38,7 +46,10 @@ func (c *Client) Connect(ctx context.Context, url string, realm string) (*Sessio

func Connect(ctx context.Context, url string, realm string) (*Session, error) {
joiner := &WebSocketJoiner{}
base, err := joiner.Join(ctx, url, realm)
dialerConfig := &WSDialerConfig{
SubProtocol: JsonWebsocketProtocol,
}
base, err := joiner.Join(ctx, url, realm, dialerConfig)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/xconn/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func Run(args []string) error {
throttle = internal.NewThrottle(transport.RateLimit.Rate,
time.Duration(transport.RateLimit.Interval)*time.Second, strategy)
}
server := xconn.NewServer(router, authenticator, throttle)
server := xconn.NewServer(router, authenticator, &xconn.ServerConfig{Throttle: throttle})
if slices.Contains(transport.Serializers, "protobuf") {
if err := server.RegisterSpec(xconn.ProtobufSerializerSpec); err != nil {
return err
Expand Down
18 changes: 0 additions & 18 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,10 @@ package xconn
import (
"fmt"

"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
)

func ClientSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadServerText, wsutil.WriteClientText, nil
}

return wsutil.ReadServerBinary, wsutil.WriteClientBinary, nil
}

func ServerSideWSReaderWriter(binary bool) (ReaderFunc, WriterFunc, error) {
if !binary {
return wsutil.ReadClientText, wsutil.WriteServerText, nil
}

return wsutil.ReadClientBinary, wsutil.WriteServerBinary, nil
}

func ReadMessage(peer Peer, serializer serializers.Serializer) (messages.Message, error) {
payload, err := peer.Read()
if err != nil {
Expand Down
32 changes: 15 additions & 17 deletions joiner.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@ import (
type WebSocketJoiner struct {
SerializerSpec WSSerializerSpec
Authenticator auth.ClientAuthenticator
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

DialTimeout time.Duration
}

func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSession, error) {
func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string, config *WSDialerConfig) (BaseSession, error) {
parsedURL, err := netURL.Parse(url)
if err != nil {
return nil, err
Expand All @@ -36,25 +33,18 @@ func (w *WebSocketJoiner) Join(ctx context.Context, url, realm string) (BaseSess
w.Authenticator = auth.NewAnonymousAuthenticator("", nil)
}

if w.DialTimeout == 0 {
w.DialTimeout = time.Second * 10
}

dialConfig := WSDialerConfig{
SubProtocol: w.SerializerSpec.SubProtocol(),
DialTimeout: w.DialTimeout,
NetDial: w.NetDial,
}

peer, err := DialWebSocket(ctx, parsedURL, dialConfig)
peer, err := DialWebSocket(ctx, parsedURL, config)
if err != nil {
return nil, err
}

return Join(peer, realm, w.SerializerSpec.Serializer(), w.Authenticator)
}

func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig) (Peer, error) {
func DialWebSocket(ctx context.Context, url *netURL.URL, config *WSDialerConfig) (Peer, error) {
if config == nil {
config = &WSDialerConfig{SubProtocol: JsonWebsocketProtocol}
}
wsDialer := ws.Dialer{
Protocols: []string{config.SubProtocol},
}
Expand Down Expand Up @@ -83,7 +73,15 @@ func DialWebSocket(ctx context.Context, url *netURL.URL, config WSDialerConfig)
}

isBinary := config.SubProtocol != JsonWebsocketProtocol
return NewWebSocketPeer(conn, config.SubProtocol, isBinary, false)

peerConfig := WSPeerConfig{
Protocol: config.SubProtocol,
Binary: isBinary,
Server: false,
KeepAliveInterval: config.KeepAliveInterval,
KeepAliveTimeout: config.KeepAliveTimeout,
}
return NewWebSocketPeer(conn, peerConfig)
}

func Join(cl Peer, realm string, serializer serializers.Serializer,
Expand Down
4 changes: 3 additions & 1 deletion joiner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ func TestJoin(t *testing.T) {
address := fmt.Sprintf("ws://%s/ws", listener.Addr().String())

var joiner xconn.WebSocketJoiner
base, err := joiner.Join(context.Background(), address, "realm1")
base, err := joiner.Join(context.Background(), address, "realm1", &xconn.WSDialerConfig{
SubProtocol: xconn.JsonWebsocketProtocol,
})
require.NoError(t, err)
require.NotNil(t, base)

Expand Down
128 changes: 104 additions & 24 deletions peer.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
package xconn

import (
"crypto/rand"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"time"

"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"

"github.com/xconnio/wampproto-go/messages"
"github.com/xconnio/wampproto-go/serializers"
Expand Down Expand Up @@ -88,48 +97,119 @@ func (b *baseSession) Close() error {
return b.client.NetConn().Close()
}

func NewWebSocketPeer(conn net.Conn, protocol string, binary, server bool) (Peer, error) {
var wsReader ReaderFunc
var wsWriter WriterFunc
var err error
if server {
wsReader, wsWriter, err = ServerSideWSReaderWriter(binary)
} else {
wsReader, wsWriter, err = ClientSideWSReaderWriter(binary)
func NewWebSocketPeer(conn net.Conn, peerConfig WSPeerConfig) (Peer, error) {
peer := &WebSocketPeer{
transportType: TransportWebSocket,
protocol: peerConfig.Protocol,
conn: conn,
pingCh: make(chan struct{}, 1),
binary: peerConfig.Binary,
server: peerConfig.Server,
}

if err != nil {
return nil, err
if peerConfig.KeepAliveInterval != 0 {
// Start ping-pong handling
go peer.startPinger(peerConfig.KeepAliveInterval, peerConfig.KeepAliveTimeout)
}

return &WebSocketPeer{
transportType: TransportWebSocket,
protocol: protocol,
conn: conn,
wsReader: wsReader,
wsWriter: wsWriter,
}, nil
return peer, nil
}

type WebSocketPeer struct {
transportType TransportType
protocol string
conn net.Conn
wsReader ReaderFunc
wsWriter WriterFunc

wm sync.Mutex
pingCh chan struct{}
binary bool
server bool

sync.Mutex
}

func (c *WebSocketPeer) startPinger(keepaliveInterval time.Duration, keepaliveTimeout time.Duration) {
ticker := time.NewTicker(keepaliveInterval)
defer ticker.Stop()

if keepaliveTimeout == 0 {
keepaliveTimeout = 10 * time.Second
}
for {
<-ticker.C
// Send a ping
randomBytes := make([]byte, 4)
_, err := rand.Read(randomBytes)
if err != nil {
fmt.Println("failed to generate random bytes:", err)
}
if err := c.writeOpFunc(c.conn, ws.OpPing, randomBytes); err != nil {
log.Printf("failed to send ping: %v\n", err)
_ = c.conn.Close()
return
}

select {
case <-c.pingCh:
case <-time.After(keepaliveTimeout):
log.Println("ping timeout, closing connection")
_ = c.conn.Close()
return
}
}
}

func (c *WebSocketPeer) peerState() ws.State {
if c.server {
return ws.StateServerSide
}
return ws.StateClientSide
}

func (c *WebSocketPeer) writeOpFunc(w io.Writer, op ws.OpCode, p []byte) error {
if c.server {
return wsutil.WriteServerMessage(w, op, p)
}

return wsutil.WriteClientMessage(w, op, p)
}

func (c *WebSocketPeer) Read() ([]byte, error) {
return c.wsReader(c.conn)
header, reader, err := wsutil.NextReader(c.conn, c.peerState())
if err != nil {
return nil, err
}

payload := make([]byte, header.Length)
_, err = reader.Read(payload)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}

switch header.OpCode {
case ws.OpText, ws.OpBinary:
return payload, nil
case ws.OpPing:
if err = c.writeOpFunc(c.conn, ws.OpPong, payload); err != nil {
return nil, fmt.Errorf("failed to send pong: %w", err)
}
case ws.OpPong:
c.pingCh <- struct{}{}
case ws.OpClose:
_ = c.conn.Close()
return nil, fmt.Errorf("connection closed")
}

return c.Read()
}

func (c *WebSocketPeer) Write(bytes []byte) error {
c.wm.Lock()
defer c.wm.Unlock()
c.Lock()
defer c.Unlock()
if c.binary {
return c.writeOpFunc(c.conn, ws.OpBinary, bytes)
}

return c.wsWriter(c.conn, bytes)
return c.writeOpFunc(c.conn, ws.OpText, bytes)
}

func (c *WebSocketPeer) Type() TransportType {
Expand Down
Loading
Loading