Skip to content

Commit

Permalink
fix(ws): deadlock on unsubscribe when epoll disabled (#982)
Browse files Browse the repository at this point in the history
We have used the unsubscription channel of the epoll implementation for
both paths, although epoll can be disabled. This led to a deadlock.

Related Cosmo PR: wundergraph/cosmo#1380
  • Loading branch information
StarpTech authored Nov 16, 2024
1 parent aff0506 commit 2fad683
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,26 @@ import (

const ackWaitTimeout = 30 * time.Second

type epollState struct {
// connections is a map of fd -> connection to keep track of all active connections
connections map[int]*connection
hasConnections atomic.Bool
// triggers is a map of subscription id -> fd to easily look up the connection for a subscription id
triggers map[uint64]int

// clientUnsubscribe is a channel to signal to the epoll run loop that a client needs to be unsubscribed
clientUnsubscribe chan uint64
// addConn is a channel to signal to the epoll run loop that a new connection needs to be added
addConn chan *connection
// waitForEventsTicker is the ticker for the epoll run loop
// it is used to prevent busy waiting and to limit the CPU usage
// instead of polling the epoll instance all the time, we wait until the next tick to throttle the epoll loop
waitForEventsTicker *time.Ticker

// waitForEventsTick is the channel to receive the tick from the waitForEventsTicker
waitForEventsTick <-chan time.Time
}

// subscriptionClient allows running multiple subscriptions via the same WebSocket either SSE connection
// It takes care of de-duplicating connections to the same origin under certain circumstances
// If Hash(URL,Body,Headers) result in the same result, an existing connection is re-used
Expand All @@ -47,23 +67,7 @@ type subscriptionClient struct {

epoll epoller.Poller
epollConfig EpollConfiguration

// connections is a map of fd -> connection to keep track of all active connections
connections map[int]*connection
hasConnections atomic.Bool
// triggers is a map of subscription id -> fd to easily look up the connection for a subscription id
triggers map[uint64]int

// clientUnsubscribe is a channel to signal to the epoll run loop that a client needs to be unsubscribed
clientUnsubscribe chan uint64
// addConn is a channel to signal to the epoll run loop that a new connection needs to be added
addConn chan *connection
// waitForEventsTicker is the ticker for the epoll run loop
// it is used to prevent busy waiting and to limit the CPU usage
// instead of polling the epoll instance all the time, we wait until the next tick to throttle the epoll loop
waitForEventsTicker *time.Ticker
// waitForEventsTick is the channel to receive the tick from the waitForEventsTicker
waitForEventsTick <-chan time.Time
epollState *epollState
}

func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error {
Expand All @@ -81,7 +85,12 @@ func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, opt
}

func (c *subscriptionClient) Unsubscribe(id uint64) {
c.clientUnsubscribe <- id
// if we don't have epoll, we don't have a channel consumer of the clientUnsubscribe channel
// we have to return to prevent a deadlock
if c.epoll == nil {
return
}
c.epollState.clientUnsubscribe <- id
}

type InvalidWsSubprotocolError struct {
Expand Down Expand Up @@ -195,16 +204,19 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi
epollConfig: op.epollConfiguration,
}
if !op.epollConfiguration.Disable {
client.connections = make(map[int]*connection)
client.triggers = make(map[uint64]int)
client.clientUnsubscribe = make(chan uint64, op.epollConfiguration.BufferSize)
client.addConn = make(chan *connection, op.epollConfiguration.BufferSize)
// this is not needed, but we want to make it explicit that we're starting with nil as the tick channel
// reading from nil channels blocks forever, which allows us to prevent the epoll loop from starting
// once we add the first connection, we start the ticker and set the tick channel
// after the last connection is removed, we set the tick channel to nil again
// this way we can start and stop the epoll loop dynamically
client.waitForEventsTick = nil
client.epollState = &epollState{
connections: make(map[int]*connection),
triggers: make(map[uint64]int),
clientUnsubscribe: make(chan uint64, op.epollConfiguration.BufferSize),
addConn: make(chan *connection, op.epollConfiguration.BufferSize),
// this is not needed, but we want to make it explicit that we're starting with nil as the tick channel
// reading from nil channels blocks forever, which allows us to prevent the epoll loop from starting
// once we add the first connection, we start the ticker and set the tick channel
// after the last connection is removed, we set the tick channel to nil again
// this way we can start and stop the epoll loop dynamically
waitForEventsTick: nil,
}

// ignore error is ok, it means that epoll is not supported, which is handled gracefully by the client
epoll, _ := epoller.NewPoller(op.epollConfiguration.BufferSize, op.epollConfiguration.TickInterval)
if epoll != nil {
Expand Down Expand Up @@ -323,7 +335,7 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont
fd := epoller.SocketFD(conn.conn)
conn.id, conn.fd = id, fd
// submit the connection to the epoll run loop
c.addConn <- conn
c.epollState.addConn <- conn
return nil
}

Expand Down Expand Up @@ -636,16 +648,16 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) {
// if the engine context is done, we close the epoll loop
case <-done:
return
case conn := <-c.addConn:
case conn := <-c.epollState.addConn:
c.handleAddConn(conn)
case id := <-c.clientUnsubscribe:
case id := <-c.epollState.clientUnsubscribe:
c.handleClientUnsubscribe(id)
// while len(c.connections) == 0, this channel is nil, so we will never try to wait for epoll events
// this is important to prevent busy waiting
// once we add the first connection, we start the ticker and set the tick channel
// the ticker ensures that we don't poll the epoll instance all the time,
// but at most every TickInterval
case <-c.waitForEventsTick:
case <-c.epollState.waitForEventsTick:
events, err := c.epoll.Wait(c.epollConfig.WaitForNumEvents)
if err != nil {
c.log.Error("epoll.Wait", abstractlogger.Error(err))
Expand All @@ -656,7 +668,7 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) {

for i := range events {
fd := epoller.SocketFD(events[i])
conn, ok := c.connections[fd]
conn, ok := c.epollState.connections[fd]
if !ok {
// Should never happen
panic(fmt.Sprintf("connection with fd %d not found", fd))
Expand Down Expand Up @@ -684,9 +696,9 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) {
}
// we decrease the number of events we're waiting for to eventually break the loop
waitForEvents--
case conn := <-c.addConn:
case conn := <-c.epollState.addConn:
c.handleAddConn(conn)
case id := <-c.clientUnsubscribe:
case id := <-c.epollState.clientUnsubscribe:
c.handleClientUnsubscribe(id)
case <-done:
return
Expand All @@ -698,10 +710,10 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) {

func (c *subscriptionClient) close() {
defer c.log.Debug("subscriptionClient.close", abstractlogger.String("reason", "epoll closed by context"))
if c.waitForEventsTicker != nil {
c.waitForEventsTicker.Stop()
if c.epollState.waitForEventsTicker != nil {
c.epollState.waitForEventsTicker.Stop()
}
for _, conn := range c.connections {
for _, conn := range c.epollState.connections {
_ = c.epoll.Remove(conn.conn)
conn.handler.ServerClose()
}
Expand All @@ -719,52 +731,52 @@ func (c *subscriptionClient) handleAddConn(conn *connection) {
conn.handler.ServerClose()
return
}
c.connections[conn.fd] = conn
c.triggers[conn.id] = conn.fd
c.epollState.connections[conn.fd] = conn
c.epollState.triggers[conn.id] = conn.fd
// when we previously had 0 connections, we will have 1 connection now
// this means we need to start the ticker so that we get epoll events
if len(c.connections) == 1 {
c.waitForEventsTicker = time.NewTicker(c.epollConfig.TickInterval)
c.waitForEventsTick = c.waitForEventsTicker.C
c.hasConnections.Store(true)
if len(c.epollState.connections) == 1 {
c.epollState.waitForEventsTicker = time.NewTicker(c.epollConfig.TickInterval)
c.epollState.waitForEventsTick = c.epollState.waitForEventsTicker.C
c.epollState.hasConnections.Store(true)
}
}

func (c *subscriptionClient) handleClientUnsubscribe(id uint64) {
fd, ok := c.triggers[id]
fd, ok := c.epollState.triggers[id]
if !ok {
return
}
delete(c.triggers, id)
conn, ok := c.connections[fd]
delete(c.epollState.triggers, id)
conn, ok := c.epollState.connections[fd]
if !ok {
return
}
delete(c.connections, fd)
delete(c.epollState.connections, fd)
_ = c.epoll.Remove(conn.conn)
conn.handler.ClientClose()
// if we have no connections left, we stop the ticker
if len(c.connections) == 0 {
c.waitForEventsTicker.Stop()
c.waitForEventsTick = nil
c.hasConnections.Store(false)
if len(c.epollState.connections) == 0 {
c.epollState.waitForEventsTicker.Stop()
c.epollState.waitForEventsTick = nil
c.epollState.hasConnections.Store(false)
}
}

func (c *subscriptionClient) handleServerUnsubscribe(fd int) {
conn, ok := c.connections[fd]
conn, ok := c.epollState.connections[fd]
if !ok {
return
}
delete(c.connections, fd)
delete(c.triggers, conn.id)
delete(c.epollState.connections, fd)
delete(c.epollState.triggers, conn.id)
_ = c.epoll.Remove(conn.conn)
conn.handler.ServerClose()
// if we have no connections left, we stop the ticker
if len(c.connections) == 0 {
c.waitForEventsTicker.Stop()
c.waitForEventsTick = nil
c.hasConnections.Store(false)
if len(c.epollState.connections) == 0 {
c.epollState.waitForEventsTicker.Stop()
c.epollState.waitForEventsTick = nil
c.epollState.hasConnections.Store(false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ func TestAsyncSubscribe(t *testing.T) {
return true
}, time.Second*5, time.Millisecond*10, "server did not close")
time.Sleep(time.Second)
assert.Equal(t, false, client.hasConnections.Load())
assert.Equal(t, false, client.epollState.hasConnections.Load())
})
t.Run("forever timeout", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -1103,7 +1103,7 @@ func TestAsyncSubscribe(t *testing.T) {
return true
}, time.Second, time.Millisecond*10, "server did not close")
serverCancel()
assert.Equal(t, false, client.hasConnections.Load())
assert.Equal(t, false, client.epollState.hasConnections.Load())
})
t.Run("error object", func(t *testing.T) {
t.Parallel()
Expand Down

0 comments on commit 2fad683

Please sign in to comment.