From a38646dee7e77e826cc218d45a2818a86959cf23 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Sat, 23 Nov 2024 08:27:19 +0100 Subject: [PATCH] Fix leak in H2 manager See `ManagedThreads`. Closes #154. --- Network/HTTP2/H2/Manager.hs | 50 ++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/Network/HTTP2/H2/Manager.hs b/Network/HTTP2/H2/Manager.hs index 017665c7..2c29341e 100644 --- a/Network/HTTP2/H2/Manager.hs +++ b/Network/HTTP2/H2/Manager.hs @@ -19,7 +19,7 @@ import Control.Concurrent.STM import Control.Exception import qualified Control.Exception as E import Data.Foldable -import Data.Map (Map) +import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import qualified System.TimeManager as T @@ -28,9 +28,14 @@ import Imports ---------------------------------------------------------------- -- | Manager to manage the thread and the timer. -data Manager = Manager T.Manager (TVar ManagedThreads) +data Manager = Manager T.Manager ManagedThreads -type ManagedThreads = Map ThreadId TimeoutHandle +-- | The set of managed threads +-- +-- This is a newtype to ensure that this is always updated strictly. +newtype ManagedThreads = WrapManagedThreads + { unwrapManagedThreads :: TVar (Map ThreadId TimeoutHandle) + } ---------------------------------------------------------------- @@ -49,7 +54,7 @@ cancelTimeout ThreadWithoutTimeout = return () -- by 'setAction'. This allows that the action can include -- the manager itself. start :: T.Manager -> IO Manager -start timmgr = Manager timmgr <$> newTVarIO Map.empty +start timmgr = Manager timmgr <$> newManagedThreads ---------------------------------------------------------------- @@ -70,10 +75,7 @@ stopAfter :: Manager -> IO a -> (Maybe SomeException -> IO ()) -> IO a stopAfter (Manager _timmgr var) action cleanup = do mask $ \unmask -> do ma <- try $ unmask action - m <- atomically $ do - m0 <- readTVar var - writeTVar var Map.empty - return m0 + m <- atomically $ modifyManagedThreads var (\ts -> (Map.empty, ts)) forM_ (Map.elems m) cancelTimeout let er = either Just (const Nothing) ma forM_ (Map.keys m) $ \tid -> @@ -102,17 +104,17 @@ forkManagedUnmask (Manager _timmgr var) label io = void $ mask_ $ forkIOWithUnmask $ \unmask -> E.handle ignore $ do labelMe label tid <- myThreadId - atomically $ modifyTVar var $ Map.insert tid ThreadWithoutTimeout + atomically $ modifyManagedThreads_ var $ Map.insert tid ThreadWithoutTimeout -- We catch the exception and do not rethrow it: we don't want the -- exception printed to stderr. io unmask `catch` ignore - atomically $ modifyTVar var $ Map.delete tid + atomically $ modifyManagedThreads_ var $ Map.delete tid where ignore (E.SomeException _) = return () waitCounter0 :: Manager -> IO () waitCounter0 (Manager _timmgr var) = atomically $ do - m <- readTVar var + m <- getManagedThreads var check (Map.size m == 0) ---------------------------------------------------------------- @@ -122,5 +124,29 @@ withTimeout (Manager timmgr var) action = T.withHandleKillThread timmgr (return ()) $ \th -> do tid <- myThreadId -- overriding ThreadWithoutTimeout - atomically $ modifyTVar var $ Map.insert tid $ ThreadWithTimeout th + atomically $ modifyManagedThreads_ var $ Map.insert tid $ ThreadWithTimeout th action th + +---------------------------------------------------------------- + +newManagedThreads :: IO ManagedThreads +newManagedThreads = WrapManagedThreads <$> newTVarIO Map.empty + +getManagedThreads :: ManagedThreads -> STM (Map ThreadId TimeoutHandle) +getManagedThreads = readTVar . unwrapManagedThreads + +modifyManagedThreads + :: ManagedThreads + -> (Map ThreadId TimeoutHandle -> (Map ThreadId TimeoutHandle, a)) + -> STM a +modifyManagedThreads (WrapManagedThreads var) f = do + threads <- readTVar var + let (threads', result) = f threads + writeTVar var $! threads' -- strict update + return result + +modifyManagedThreads_ + :: ManagedThreads + -> (Map ThreadId TimeoutHandle -> Map ThreadId TimeoutHandle) + -> STM () +modifyManagedThreads_ var f = modifyManagedThreads var (\ts -> (f ts, ()))