Skip to content
This repository was archived by the owner on Dec 28, 2024. It is now read-only.

Commit

Permalink
fix: hub registration race conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
palkan committed Mar 4, 2021
1 parent 1eb99c4 commit fd3d32d
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 53 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## master

- Fix race conditions in Hub. ([@palkan][])

Use a single channel for register/unregister and subscribe/unsubscribe to make order of
execution deterministic. Since `select .. case` chooses channels randomly, we may hit the situation when registration is added
after disconnection (_un-registration_).

- Add `sid=xxx` to RPC logs. ([@palkan][])

## 1.0.3 (2021-01-05)
Expand Down
85 changes: 62 additions & 23 deletions node/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@ import (
"github.com/apex/log"
)

// SubscriptionInfo contains information about session-channel(-stream) subscription
type SubscriptionInfo struct {
// HubSubscription contains information about session-channel(-stream) subscription
type HubSubscription struct {
event string
session string
stream string
identifier string
}

// HubRegistration represents registration event ("add" or "remove")
type HubRegistration struct {
event string
session *Session
}

// Hub stores all the sessions and the corresponding subscriptions info
type Hub struct {
// Registered sessions
Expand All @@ -38,16 +45,10 @@ type Hub struct {
disconnect chan *common.RemoteDisconnectMessage

// Register requests from the sessions
register chan *Session

// Unregister requests from sessions
unregister chan *Session
register chan HubRegistration

// Subscribe requests to streams
subscribe chan *SubscriptionInfo

// Unsubscribe requests from streams
unsubscribe chan *SubscriptionInfo
subscribe chan HubSubscription

// Control channel to shutdown hub
shutdown chan struct{}
Expand All @@ -64,10 +65,8 @@ func NewHub() *Hub {
return &Hub{
broadcast: make(chan *common.StreamMessage, 256),
disconnect: make(chan *common.RemoteDisconnectMessage, 128),
register: make(chan *Session, 128),
unregister: make(chan *Session, 2048),
subscribe: make(chan *SubscriptionInfo, 128),
unsubscribe: make(chan *SubscriptionInfo, 128),
register: make(chan HubRegistration, 2048),
subscribe: make(chan HubSubscription, 128),
sessions: make(map[string]*Session),
identifiers: make(map[string]map[string]bool),
streams: make(map[string]map[string]map[string]bool),
Expand All @@ -82,17 +81,17 @@ func (h *Hub) Run() {
h.done.Add(1)
for {
select {
case s := <-h.register:
h.addSession(s)

case s := <-h.unregister:
h.removeSession(s)
case r := <-h.register:
if r.event == "add" {
h.addSession(r.session)
} else {
h.removeSession(r.session)
}

case subinfo := <-h.subscribe:
h.subscribeSession(subinfo.session, subinfo.stream, subinfo.identifier)

case subinfo := <-h.unsubscribe:
if subinfo.stream == "" {
if subinfo.event == "add" {
h.subscribeSession(subinfo.session, subinfo.stream, subinfo.identifier)
} else if subinfo.event == "removeAll" {
h.unsubscribeSessionFromChannel(subinfo.session, subinfo.identifier)
} else {
h.unsubscribeSession(subinfo.session, subinfo.stream, subinfo.identifier)
Expand All @@ -111,6 +110,46 @@ func (h *Hub) Run() {
}
}

// AddSession enqueues sessions registration
func (h *Hub) AddSession(s *Session) {
h.register <- HubRegistration{event: "add", session: s}
}

// RemoveSession enqueues session un-registration
func (h *Hub) RemoveSession(s *Session) {
h.register <- HubRegistration{event: "remove", session: s}
}

// AddSubscription enqueues adding a subscription for session-identifier pair to the hub
func (h *Hub) AddSubscription(sid string, identifier string, stream string) {
h.subscribe <- HubSubscription{event: "add", session: sid, identifier: identifier, stream: stream}
}

// RemoveSubscription enqueues removing a subscription for session-identifier pair from the hub
func (h *Hub) RemoveSubscription(sid string, identifier string, stream string) {
h.subscribe <- HubSubscription{event: "remove", session: sid, identifier: identifier, stream: stream}
}

// RemoveAllSubscriptions enqueues removing all subscription for session-identifier pair from the hub
func (h *Hub) RemoveAllSubscriptions(sid string, identifier string) {
h.subscribe <- HubSubscription{event: "removeAll", session: sid, identifier: identifier}
}

// Broadcast enqueues data broadcasting to a stream
func (h *Hub) Broadcast(stream string, data string) {
h.broadcast <- &common.StreamMessage{Stream: stream, Data: data}
}

// BroadcastMessage enqueues broadcasting a pre-built StreamMessage
func (h *Hub) BroadcastMessage(msg *common.StreamMessage) {
h.broadcast <- msg
}

// RemoteDisconnect enqueues remote disconnect command
func (h *Hub) RemoteDisconnect(msg *common.RemoteDisconnectMessage) {
h.disconnect <- msg
}

// Shutdown sends shutdown command to hub
func (h *Hub) Shutdown() {
h.shutdown <- struct{}{}
Expand Down
23 changes: 11 additions & 12 deletions node/hub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"testing"
"time"

"github.com/anycable/anycable-go/common"
"github.com/stretchr/testify/assert"
)

Expand All @@ -23,7 +22,7 @@ func TestUnsubscribeRaceConditions(t *testing.T) {
hub.addSession(session2)
hub.subscribeSession("321", "test", "test_channel")

hub.broadcast <- &common.StreamMessage{Stream: "test", Data: "hello"}
hub.Broadcast("test", "hello")

done := make(chan bool)
timer := time.After(100 * time.Millisecond)
Expand All @@ -43,13 +42,13 @@ func TestUnsubscribeRaceConditions(t *testing.T) {
assert.Equal(t, 2, hub.Size(), "Connections size must be equal 2")

go func() {
hub.broadcast <- &common.StreamMessage{Stream: "test", Data: "pong"}
hub.unregister <- session
hub.broadcast <- &common.StreamMessage{Stream: "test", Data: "ping"}
hub.Broadcast("test", "pong")
hub.RemoveSession(session)
hub.Broadcast("test", "ping")
}()

go func() {
hub.broadcast <- &common.StreamMessage{Stream: "test", Data: "bye-bye"}
hub.Broadcast("test", "bye-bye")
}()

timer2 := time.After(2500 * time.Millisecond)
Expand Down Expand Up @@ -82,7 +81,7 @@ func TestUnsubscribeSession(t *testing.T) {
hub.subscribeSession("123", "test", "test_channel")
hub.subscribeSession("123", "test2", "test_channel")

hub.broadcast <- &common.StreamMessage{Stream: "test", Data: "\"hello\""}
hub.Broadcast("test", "\"hello\"")

timer := time.After(100 * time.Millisecond)
select {
Expand All @@ -94,7 +93,7 @@ func TestUnsubscribeSession(t *testing.T) {

hub.unsubscribeSession("123", "test", "test_channel")

hub.broadcast <- &common.StreamMessage{Stream: "test", Data: "\"goodbye\""}
hub.Broadcast("test", "\"goodbye\"")

timer = time.After(100 * time.Millisecond)
select {
Expand All @@ -103,7 +102,7 @@ func TestUnsubscribeSession(t *testing.T) {
t.Fatalf("Session shouldn't have received any messages but received: %v", string(msg.payload))
}

hub.broadcast <- &common.StreamMessage{Stream: "test2", Data: "\"bye\""}
hub.Broadcast("test2", "\"bye\"")

timer = time.After(100 * time.Millisecond)
select {
Expand All @@ -115,7 +114,7 @@ func TestUnsubscribeSession(t *testing.T) {

hub.unsubscribeSessionFromAllChannels("123")

hub.broadcast <- &common.StreamMessage{Stream: "test2", Data: "\"goodbye\""}
hub.Broadcast("test2", "\"goodbye\"")

timer = time.After(100 * time.Millisecond)
select {
Expand All @@ -137,7 +136,7 @@ func TestSubscribeSession(t *testing.T) {
t.Run("Subscribe to a single channel", func(t *testing.T) {
hub.subscribeSession("123", "test", "test_channel")

hub.broadcast <- &common.StreamMessage{Stream: "test", Data: "\"hello\""}
hub.Broadcast("test", "\"hello\"")

timer := time.After(100 * time.Millisecond)
select {
Expand All @@ -152,7 +151,7 @@ func TestSubscribeSession(t *testing.T) {
hub.subscribeSession("123", "test", "test_channel")
hub.subscribeSession("123", "test", "test_channel2")

hub.broadcast <- &common.StreamMessage{Stream: "test", Data: "\"hello twice\""}
hub.Broadcast("test", "\"hello twice\"")

done := make(chan bool)
received := []string{}
Expand Down
14 changes: 7 additions & 7 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (n *Node) Authenticate(s *Session) (err error) {
s.Identifiers = res.Identifier
s.connected = true

n.hub.register <- s
n.hub.AddSession(s)
} else {
n.Metrics.Counter(metricsFailedAuths).Inc()
}
Expand Down Expand Up @@ -274,12 +274,12 @@ func (n *Node) Perform(s *Session, msg *common.Message) (err error) {
func (n *Node) Broadcast(msg *common.StreamMessage) {
n.Metrics.Counter(metricsBroadcastMsg).Inc()
n.log.Debugf("Incoming pubsub message: %v", msg)
n.hub.broadcast <- msg
n.hub.BroadcastMessage(msg)
}

// Disconnect adds session to disconnector queue and unregister session from hub
func (n *Node) Disconnect(s *Session) error {
n.hub.unregister <- s
n.hub.RemoveSession(s)
return n.disconnector.Enqueue(s)
}

Expand Down Expand Up @@ -307,7 +307,7 @@ func (n *Node) DisconnectNow(s *Session) error {
func (n *Node) RemoteDisconnect(msg *common.RemoteDisconnectMessage) {
n.Metrics.Counter(metricsBroadcastMsg).Inc()
n.log.Debugf("Incoming pubsub command: %v", msg)
n.hub.disconnect <- msg
n.hub.RemoteDisconnect(msg)
}

func transmit(s *Session, transmissions []string) {
Expand All @@ -322,16 +322,16 @@ func (n *Node) handleCommandReply(s *Session, msg *common.Message, reply *common
}

if reply.StopAllStreams {
n.hub.unsubscribe <- &SubscriptionInfo{session: s.UID, identifier: msg.Identifier}
n.hub.RemoveAllSubscriptions(s.UID, msg.Identifier)
} else if reply.StoppedStreams != nil {
for _, stream := range reply.StoppedStreams {
n.hub.unsubscribe <- &SubscriptionInfo{session: s.UID, stream: stream, identifier: msg.Identifier}
n.hub.RemoveSubscription(s.UID, msg.Identifier, stream)
}
}

if reply.Streams != nil {
for _, stream := range reply.Streams {
n.hub.subscribe <- &SubscriptionInfo{session: s.UID, stream: stream, identifier: msg.Identifier}
n.hub.AddSubscription(s.UID, msg.Identifier, stream)
}
}

Expand Down
19 changes: 8 additions & 11 deletions node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,11 @@ func TestSubscribe(t *testing.T) {
// Adds subscription to session
assert.Contains(t, session.subscriptions, "stream", "Session subsription must be set")

var subscription SubscriptionInfo
var subscription HubSubscription

// Expected to subscribe session to hub
select {
case sub := <-node.hub.subscribe:
subscription = *sub
case subscription = <-node.hub.subscribe:
default:
assert.Fail(t, "Expected hub to receive subscribe message but none was sent")
}
Expand Down Expand Up @@ -150,12 +149,11 @@ func TestUnsubscribe(t *testing.T) {
// Removes subscription from session
assert.NotContains(t, session.subscriptions, "test_channel", "Shouldn't contain test_channel")

var subscription SubscriptionInfo
var subscription HubSubscription

// Expected to subscribe session to hub
select {
case sub := <-node.hub.unsubscribe:
subscription = *sub
case subscription = <-node.hub.subscribe:
default:
assert.Fail(t, "Expected hub to receive unsubscribe message but none was sent")
}
Expand Down Expand Up @@ -223,12 +221,11 @@ func TestPerform(t *testing.T) {
t.Run("With stopped streams", func(t *testing.T) {
assert.Nil(t, node.Perform(session, &common.Message{Identifier: "test_channel", Data: "stop_stream"}))

var subscription SubscriptionInfo
var subscription HubSubscription

// Expected to subscribe session to hub
select {
case sub := <-node.hub.unsubscribe:
subscription = *sub
case subscription = <-node.hub.subscribe:
default:
assert.Fail(t, "Expected hub to receive unsubscribe message but none was sent")
return
Expand Down Expand Up @@ -258,8 +255,8 @@ func TestDisconnect(t *testing.T) {

// Expected to unregister session
select {
case s := <-node.hub.unregister:
assert.Equal(t, session, s, "Expected to disconnect session")
case info := <-node.hub.register:
assert.Equal(t, session, info.session, "Expected to disconnect session")
default:
assert.Fail(t, "Expected hub to receive unregister message but none was sent")
}
Expand Down

0 comments on commit fd3d32d

Please sign in to comment.