Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix message acks on wrong rabbitmq channels #4358

Merged
merged 15 commits into from
Dec 11, 2024
Merged
1 change: 1 addition & 0 deletions changelog.d/3-bug-fixes/rabbitmq-acks
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cannon does not attempt to restore a rabbitmq channel after it disconnects. This fixes a potential issue where a client would be able to ack a message on the wrong channel.
97 changes: 81 additions & 16 deletions integration/test/Test/Events.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@ import API.Galley
import API.Gundeck
import qualified Control.Concurrent.Timeout as Timeout
import Control.Monad.Codensity
import Control.Monad.RWS (asks)
import Control.Monad.Trans.Class
import Control.Retry
import Data.ByteString.Conversion (toByteString')
import qualified Data.Text as Text
import Data.Timeout
import Network.AMQP.Extended
import Network.RabbitMqAdmin
import qualified Network.WebSockets as WS
import Notifications
import SetupHelpers
import Testlib.Prelude hiding (assertNoEvent)
import Testlib.ResourcePool (acquireResources)
import UnliftIO hiding (handle)

testConsumeEventsOneWebSocket :: (HasCallStack) => App ()
Expand All @@ -38,10 +42,10 @@ testConsumeEventsOneWebSocket = do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` clientId
e %. "data.delivery_tag"
assertNoEvent ws
assertNoEvent_ ws

sendAck ws deliveryTag False
assertNoEvent ws
assertNoEvent_ ws

handle <- randomHandle
putHandle alice handle >>= assertSuccess
Expand Down Expand Up @@ -80,7 +84,7 @@ testConsumeEventsForDifferentUsers = do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` clientId
e %. "data.delivery_tag"
assertNoEvent ws
assertNoEvent_ ws
sendAck ws deliveryTag False

testConsumeEventsWhileHavingLegacyClients :: (HasCallStack) => App ()
Expand Down Expand Up @@ -137,7 +141,7 @@ testConsumeEventsAcks = do
sendAck ws deliveryTag False

runCodensity (createEventsWebSocket alice clientId) $ \ws -> do
assertNoEvent ws
assertNoEvent_ ws

testConsumeEventsMultipleAcks :: (HasCallStack) => App ()
testConsumeEventsMultipleAcks = do
Expand All @@ -161,7 +165,7 @@ testConsumeEventsMultipleAcks = do
sendAck ws deliveryTag True

runCodensity (createEventsWebSocket alice clientId) $ \ws -> do
assertNoEvent ws
assertNoEvent_ ws

testConsumeEventsAckNewEventWithoutAckingOldOne :: (HasCallStack) => App ()
testConsumeEventsAckNewEventWithoutAckingOldOne = do
Expand Down Expand Up @@ -195,7 +199,7 @@ testConsumeEventsAckNewEventWithoutAckingOldOne = do
sendAck ws deliveryTagClientAdd False

runCodensity (createEventsWebSocket alice clientId) $ \ws -> do
assertNoEvent ws
assertNoEvent_ ws

testEventsDeadLettered :: (HasCallStack) => App ()
testEventsDeadLettered = do
Expand Down Expand Up @@ -229,7 +233,7 @@ testEventsDeadLettered = do
ackEvent ws e

-- We've consumed the whole queue.
assertNoEvent ws
assertNoEvent_ ws

testTransientEventsDoNotTriggerDeadLetters :: (HasCallStack) => App ()
testTransientEventsDoNotTriggerDeadLetters = do
Expand Down Expand Up @@ -257,7 +261,7 @@ testTransientEventsDoNotTriggerDeadLetters = do
sendTypingStatus alice selfConvId "started" >>= assertSuccess

runCodensity (createEventsWebSocket alice clientId) $ \ws -> do
assertNoEvent ws
assertNoEvent_ ws

testTransientEvents :: (HasCallStack) => App ()
testTransientEvents = do
Expand Down Expand Up @@ -296,7 +300,7 @@ testTransientEvents = do
e %. "data.event.payload.0.user.handle" `shouldMatch` handle
ackEvent ws e

assertNoEvent ws
assertNoEvent_ ws

testChannelLimit :: (HasCallStack) => App ()
testChannelLimit = withModifiedBackend
Expand All @@ -318,16 +322,46 @@ testChannelLimit = withModifiedBackend
lowerCodensity $ do
for_ clients $ \c -> do
ws <- createEventsWebSocket alice c
e <- Codensity $ \k -> assertEvent ws k
lift $ do
lift $ assertEvent ws $ \e -> do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` c
e %. "data.delivery_tag"

-- the first client fails to connect because the server runs out of channels
do
ws <- createEventsWebSocket alice client0
lift $ assertNoEvent ws
lift $ assertNoEvent_ ws

testChannelKilled :: (HasCallStack) => App ()
testChannelKilled = lowerCodensity $ do
pool <- lift $ asks (.resourcePool)
[backend] <- acquireResources 1 pool
domain <- startDynamicBackend backend mempty
alice <- lift $ randomUser domain def
[c1, c2] <-
lift
$ replicateM 2
$ addClient alice def {acapabilities = Just ["consumable-notifications"]}
>>= getJSON 201
>>= (%. "id")
>>= asString

ws <- createEventsWebSocket alice c1
lift $ do
assertEvent ws $ \e -> do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` c1
ackEvent ws e

assertEvent ws $ \e -> do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` c2

recoverAll
(constantDelay 500_000 <> limitRetries 10)
(const (killConnection backend))

noEvent <- assertNoEvent ws
noEvent `shouldMatch` WebSocketDied

----------------------------------------------------------------------
-- helpers
Expand Down Expand Up @@ -422,15 +456,24 @@ assertEvent ws expectations = do
addFailureContext ("event:\n" <> pretty)
$ expectations e

assertNoEvent :: (HasCallStack) => EventWebSocket -> App ()
data NoEvent = NoEvent | WebSocketDied

instance ToJSON NoEvent where
toJSON NoEvent = toJSON "no-event"
toJSON WebSocketDied = toJSON "web-socket-died"

assertNoEvent :: (HasCallStack) => EventWebSocket -> App NoEvent
assertNoEvent ws = do
timeout 1_000_000 (readChan ws.events) >>= \case
Nothing -> pure ()
Just (Left _) -> pure ()
Nothing -> pure NoEvent
Just (Left _) -> pure WebSocketDied
Just (Right e) -> do
eventJSON <- prettyJSON e
assertFailure $ "Did not expect event: \n" <> eventJSON

assertNoEvent_ :: (HasCallStack) => EventWebSocket -> App ()
assertNoEvent_ = void . assertNoEvent

consumeAllEvents :: EventWebSocket -> App ()
consumeAllEvents ws = do
timeout 1_000_000 (readChan ws.events) >>= \case
Expand All @@ -442,3 +485,25 @@ consumeAllEvents ws = do
Just (Right e) -> do
ackEvent ws e
consumeAllEvents ws

killConnection :: (HasCallStack) => BackendResource -> App ()
killConnection backend = do
rc <- asks (.rabbitMQConfig)
let opts =
RabbitMqAdminOpts
{ host = rc.host,
port = 0,
adminPort = fromIntegral rc.adminPort,
vHost = Text.pack backend.berVHost,
tls = Just $ RabbitMqTlsOpts Nothing True
}
servantClient <- liftIO $ mkRabbitMqAdminClientEnv opts
name <- do
connections <- liftIO $ listConnectionsByVHost servantClient opts.vHost
connection <-
assertOne
[ c | c <- connections, c.userProvidedName == Just (Text.pack "pool 0")
]
pure connection.name

void $ liftIO $ deleteConnection servantClient name
30 changes: 29 additions & 1 deletion libs/extended/src/Network/RabbitMqAdmin.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- | Perhaps this module should be a separate package and published to hackage.
module Network.RabbitMqAdmin where

import Data.Aeson
import Data.Aeson as Aeson
import Imports
import Servant
import Servant.Client
Expand Down Expand Up @@ -31,6 +31,19 @@ data AdminAPI route = AdminAPI
:> "queues"
:> Capture "vhost" VHost
:> Capture "queue" QueueName
:> DeleteNoContent,
listConnectionsByVHost ::
route
:- "api"
:> "vhosts"
:> Capture "vhost" Text
:> "connections"
:> Get '[JSON] [Connection],
deleteConnection ::
route
:- "api"
:> "connections"
:> Capture "name" Text
:> DeleteNoContent
}
deriving (Generic)
Expand All @@ -43,13 +56,28 @@ data AuthenticatedAPI route = AuthenticatedAPI
}
deriving (Generic)

jsonOptions :: Aeson.Options
jsonOptions = defaultOptions {fieldLabelModifier = camelTo2 '_'}

data Queue = Queue {name :: Text, vhost :: Text}
deriving (Show, Eq, Generic)

instance FromJSON Queue

instance ToJSON Queue

data Connection = Connection
{ userProvidedName :: Maybe Text,
name :: Text
}
deriving (Eq, Show, Generic)

instance FromJSON Connection where
parseJSON = genericParseJSON jsonOptions

instance ToJSON Connection where
toJSON = genericToJSON jsonOptions

adminClient :: BasicAuthData -> AdminAPI (AsClientT ClientM)
adminClient ba = fromServant $ clientWithAuth.api ba
where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ mockApi :: MockRabbitMqAdmin -> AdminAPI (AsServerT Servant.Handler)
mockApi mockAdmin =
AdminAPI
{ listQueuesByVHost = mockListQueuesByVHost mockAdmin,
deleteQueue = mockListDeleteQueue mockAdmin
deleteQueue = mockListDeleteQueue mockAdmin,
listConnectionsByVHost = mockListConnectionsByVHost mockAdmin,
deleteConnection = mockDeleteConnection mockAdmin
}

mockListQueuesByVHost :: MockRabbitMqAdmin -> Text -> Servant.Handler [Queue]
Expand All @@ -362,6 +364,12 @@ mockListDeleteQueue :: MockRabbitMqAdmin -> Text -> Text -> Servant.Handler NoCo
mockListDeleteQueue _ _ _ = do
pure NoContent

mockListConnectionsByVHost :: MockRabbitMqAdmin -> Text -> Servant.Handler [Connection]
mockListConnectionsByVHost _ _ = pure []

mockDeleteConnection :: MockRabbitMqAdmin -> Text -> Servant.Handler NoContent
mockDeleteConnection _ _ = pure NoContent

mockRabbitMqAdminApp :: MockRabbitMqAdmin -> Application
mockRabbitMqAdminApp mockAdmin = genericServe (mockApi mockAdmin)

Expand Down
24 changes: 19 additions & 5 deletions services/cannon/src/Cannon/RabbitMq.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import Control.Retry
import Data.ByteString.Conversion
import Data.List.Extra
import Data.Map qualified as Map
import Data.Text qualified as T
import Data.Timeout
import Imports hiding (threadDelay)
import Network.AMQP qualified as Q
Expand Down Expand Up @@ -59,7 +60,8 @@ data RabbitMqPool key = RabbitMqPool
data RabbitMqPoolOptions = RabbitMqPoolOptions
{ maxConnections :: Int,
maxChannels :: Int,
endpoint :: AmqpEndpoint
endpoint :: AmqpEndpoint,
retryEnabled :: Bool
}

createRabbitMqPool :: (Ord key) => RabbitMqPoolOptions -> Logger -> Codensity IO (RabbitMqPool key)
Expand Down Expand Up @@ -163,6 +165,9 @@ createConnection pool = mask_ $ do
void . async $ do
v <- race (takeMVar closedVar) (readMVar pool.deadVar)
when (isRight v) $
-- close connection and ignore exceptions
-- close connection and ignore exceptions

pcapriotti marked this conversation as resolved.
Show resolved Hide resolved
-- close connection and ignore exceptions
catch @SomeException (Q.closeConnection conn) $
\_ -> pure ()
Expand All @@ -176,6 +181,11 @@ createConnection pool = mask_ $ do

openConnection :: RabbitMqPool key -> IO Q.Connection
openConnection pool = do
-- This might not be the correct connection ID that will eventually be
-- assigned to this connection, since there are potential races with other
-- connections being opened at the same time. However, this is only used to
-- name the connection, and we only rely on names for tests, so it is fine.
connId <- readTVarIO pool.nextId
(username, password) <- readCredsFromEnv
recovering
rabbitMqRetryPolicy
Expand All @@ -199,7 +209,9 @@ openConnection pool = do
],
Q.coVHost = pool.opts.endpoint.vHost,
Q.coAuth = [Q.plain username password],
Q.coTLSSettings = fmap Q.TLSCustom mTlsSettings
Q.coTLSSettings = fmap Q.TLSCustom mTlsSettings,
-- the name is used by tests to identify pool connections
Q.coName = Just ("pool " <> T.pack (show connId))
}
)

Expand Down Expand Up @@ -233,11 +245,11 @@ createChannel pool queue key = do
(_, Just (Q.ConnectionClosedException {})) -> do
Log.info pool.logger $
Log.msg (Log.val "RabbitMQ connection was closed unexpectedly")
pure True
pure pool.opts.retryEnabled
_ -> do
unless (fromException e == Just AsyncCancelled) $
logException pool.logger "RabbitMQ channel closed" e
pure True
pure pool.opts.retryEnabled
putMVar closedVar retry

let manageChannel = do
Expand All @@ -258,7 +270,9 @@ createChannel pool queue key = do
putMVar inner chan
void $ liftIO $ Q.consumeMsgs chan queue Q.Ack $ \(message, envelope) -> do
putMVar msgVar (Just (message, envelope))
takeMVar closedVar
retry <- takeMVar closedVar
void $ takeMVar inner
pure retry

when retry manageChannel

Expand Down
3 changes: 2 additions & 1 deletion services/cannon/src/Cannon/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ mkEnv external o cs l d conns p g t endpoint = do
RabbitMqPoolOptions
{ endpoint = endpoint,
maxConnections = o ^. rabbitMqMaxConnections,
maxChannels = o ^. rabbitMqMaxChannels
maxChannels = o ^. rabbitMqMaxChannels,
retryEnabled = False
}
pool <- createRabbitMqPool poolOpts l
let wsEnv =
Expand Down
Loading