Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code cleanup and simplify implementation #11

Merged
merged 15 commits into from
Nov 27, 2023
22 changes: 16 additions & 6 deletions metrics/connection.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package metrics

import (
"strconv"
"github.com/prometheus/client_golang/prometheus"
)

Expand All @@ -25,12 +24,19 @@ var (
Help: "Number of current connections",
})

countSendBlocking = prometheus.NewCounterVec(prometheus.CounterOpts{
countSendBlocking = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: promNamespace,
Subsystem: promSubsystem,
Name: "send_blockings",
Help: "Number of send blocking connections",
}, []string{"sendbuflen"})
})

countMessageFromClosed = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: promNamespace,
Subsystem: promSubsystem,
Name: "sending_on_closed",
Help: "Number of sending on closed connections",
})
)

func IncNewConnection() {
Expand All @@ -41,9 +47,12 @@ func IncClosedConnection() {
countClosedConnections.Inc()
}

func IncSendBlocking(sendbufLen int) {
lenstr := strconv.Itoa(sendbufLen)
countSendBlocking.With(prometheus.Labels{"sendbuflen": lenstr}).Inc()
func IncSendBlocking() {
countSendBlocking.Inc()
}

func IncMessageFromClosed() {
countMessageFromClosed.Inc()
}

func SetCurrentConnections(num int) {
Expand All @@ -54,4 +63,5 @@ func init() {
prometheus.MustRegister(countNewConnections)
prometheus.MustRegister(countClosedConnections)
prometheus.MustRegister(gaugeCurrentConnections)
prometheus.MustRegister(countSendBlocking)
}
1 change: 0 additions & 1 deletion relay/pendingSession.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ type SortedPendingSessions struct {
type pendingSession struct {
expireTime time.Time
topic string
dapp *client
}

func (pq PendingSessions) Len() int { return len(pq) }
Expand Down
64 changes: 13 additions & 51 deletions relay/wsconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"

"github.com/RabbyHub/derelay/log"
"github.com/RabbyHub/derelay/metrics"
Expand All @@ -19,16 +17,13 @@ type client struct {
conn *websocket.Conn
ws *WsServer

id string // randomly generate, just for logging
active bool // heartbeat related
terminated atomic.Bool //
role RoleType // dapp or wallet
session string // session id
pubTopics *TopicSet
subTopics *TopicSet
id string // randomly generate, just for logging
role RoleType // dapp or wallet
session string // session id
pubTopics *TopicSet
subTopics *TopicSet

sendbuf chan SocketMessage // send buffer
ping chan struct{}
quit chan struct{}
}

Expand All @@ -43,25 +38,6 @@ func (c *client) MarshalLogObject(encoder zapcore.ObjectEncoder) error {
return nil
}

func (c *client) heartbeat() {

c.conn.SetPongHandler(func(appData string) error {
c.active = true
return nil
})

for {
if !c.active {
c.terminate(fmt.Errorf("heartbeat fail"))
return
}
c.active = false

_ = c.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(5*time.Second))
<-time.After(10 * time.Second)
}
}

func (c *client) read() {
for {
_, m, err := c.conn.ReadMessage()
Expand All @@ -88,27 +64,16 @@ func (c *client) read() {
func (c *client) write() {
for {
select {
case message, more := <-c.sendbuf:
if !more {
return
}
case message := <-c.sendbuf:
m := new(bytes.Buffer)
if err := json.NewEncoder(m).Encode(message); err != nil {
log.Warn("sending malformed text message", zap.Error(err))
return
continue
}
err := c.conn.WriteMessage(websocket.TextMessage, m.Bytes())
if err != nil {
log.Error("client write error", err, zap.Any("client", c), zap.Any("message", message))
c.terminate(err)
return
}
case _, more := <-c.ping:
if !more {
return
}
if err := c.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(5*time.Second)); err != nil {
log.Error("client ping error", err)
continue
}
case <-c.quit:
return
Expand All @@ -121,16 +86,13 @@ func (c *client) send(message SocketMessage) {
select {
case c.sendbuf <- message:
default:
metrics.IncSendBlocking(len(c.sendbuf))
log.Error("sending to client blocked", fmt.Errorf("sendbuf full"), zap.Any("client", c), zap.Any("len(sendbuf)", len(c.sendbuf)), zap.Any("message", message))
metrics.IncSendBlocking()
log.Error("client sendbuf full", fmt.Errorf(""), zap.Any("client", c), zap.Any("len(sendbuf)", len(c.sendbuf)), zap.Any("message", message))
}
}

func (c *client) terminate(reason error) {
if c.terminated.CompareAndSwap(false, true) {
c.active = false
c.quit <- struct{}{}
c.conn.Close()
c.ws.unregister <- ClientUnregisterEvent{client: c, reason: reason}
}
c.quit <- struct{}{}
c.conn.Close()
c.ws.unregister <- ClientUnregisterEvent{client: c, reason: reason}
}
12 changes: 5 additions & 7 deletions relay/wsconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ func TestSendChanWithNoReceiver(t *testing.T) {
go func() {
defer wg.Done()
for {
select {
case i := <-send:
fmt.Printf("received: %v\n", i)
if i == 3 {
fmt.Printf("receiving routine exit\n")
return
}
i := <-send
fmt.Printf("received: %v\n", i)
if i == 3 {
fmt.Printf("receiving routine exit\n")
return
}
}
}()
Expand Down
8 changes: 0 additions & 8 deletions relay/wshandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@ import (

type WsMessageHandler func(*WsServer, SocketMessage)

var (
messageHandlers = map[MessageType]WsMessageHandler{
Pub: (*WsServer).pubMessage,
Sub: (*WsServer).subMessage,
Ping: (*WsServer).handlePingMessage,
}
)

func (ws *WsServer) pubMessage(message SocketMessage) {
topic := message.Topic
publisher := message.client
Expand Down
22 changes: 9 additions & 13 deletions relay/wsserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ type WsServer struct {
register chan *client
unregister chan ClientUnregisterEvent

// session maintenance
pendingSessions *SortedPendingSessions

redisConn *redis.Client
redisSubConn *redis.PubSub

Expand All @@ -43,15 +40,14 @@ func NewWSServer(config *config.Config) *WsServer {
ws := &WsServer{
config: &config.WsServerConfig, // config

clients: make(map[*client]struct{}),
register: make(chan *client, 1024),
unregister: make(chan ClientUnregisterEvent, 1024),
pendingSessions: NewSortedPendingSessions(),
clients: make(map[*client]struct{}),
register: make(chan *client, 4096),
unregister: make(chan ClientUnregisterEvent, 4096),

publishers: NewTopicClientSet(),
subscribers: NewTopicClientSet(),

localCh: make(chan SocketMessage, 1024),
localCh: make(chan SocketMessage, 2),
}
ws.redisConn = redis.NewClient(&redis.Options{
Addr: config.RedisServerConfig.ServerAddr,
Expand All @@ -73,20 +69,18 @@ func (ws *WsServer) NewClientConn(w http.ResponseWriter, r *http.Request) {
client := &client{
conn: conn,
id: generateRandomBytes16(),
active: true,
ws: ws,
pubTopics: NewTopicSet(),
subTopics: NewTopicSet(),
sendbuf: make(chan SocketMessage, 8),
ping: make(chan struct{}, 8),
sendbuf: make(chan SocketMessage, 256),
quit: make(chan struct{}),
}

ws.register <- client

go client.read()
go client.write()
go client.heartbeat()
//go client.heartbeat()
}

func (ws *WsServer) Run() {
Expand All @@ -97,6 +91,9 @@ func (ws *WsServer) Run() {
for {
select {
case message := <-ws.localCh:
if _, ok := ws.clients[message.client]; !ok {
metrics.IncMessageFromClosed()
}
// local message could be "pub", "sub" or "ack" or "ping"
// pub/sub message handler may contain time-consuming operations(e.g. read/write redis)
// so put them in separate goroutine to avoid blocking wsserver main loop
Expand Down Expand Up @@ -131,7 +128,6 @@ func (ws *WsServer) Run() {
log.Info("forward to subscriber", zap.Any("client", subscriber), zap.Any("message", message))
subscriber.send(message)
}

continue
}

Expand Down