Skip to content

Commit

Permalink
removing unliftio
Browse files Browse the repository at this point in the history
  • Loading branch information
kazu-yamamoto committed Aug 27, 2024
1 parent 87e24e4 commit ad2e457
Show file tree
Hide file tree
Showing 45 changed files with 97 additions and 101 deletions.
6 changes: 3 additions & 3 deletions Network/QUIC/Client/Reader.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ module Network.QUIC.Client.Reader (
clientSocket,
) where

import Control.Concurrent
import qualified Control.Exception as E
import Data.List (intersect)
import Network.Socket (Socket, close, getSocketName)
import qualified Network.Socket.ByteString as NSB
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Connection
import Network.QUIC.Connector
Expand All @@ -34,7 +34,7 @@ readerClient s0 conn = handleLogUnit logAction $ do
loop
where
wait = do
bound <- E.handleAny (\_ -> return False) $ do
bound <- E.handle (\(E.SomeException _) -> return False) $ do
_ <- getSocketName s0
return True
unless bound $ do
Expand Down
8 changes: 4 additions & 4 deletions Network/QUIC/Client/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ module Network.QUIC.Client.Run (
migrate,
) where

import Control.Concurrent
import Control.Concurrent.Async
import qualified Control.Exception as E
import qualified Network.Socket as NS
import UnliftIO.Async
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Client.Reader
import Network.QUIC.Closer
Expand Down Expand Up @@ -92,7 +92,7 @@ runClient conf client0 isICVN verInfo = do
case er of
Left () -> E.throwIO MustNotReached
Right r -> return r
ex <- E.trySyncOrAsync runThreads
ex <- E.try runThreads
sendFinal conn
closure conn ldcc ex
where
Expand Down
4 changes: 2 additions & 2 deletions Network/QUIC/Closer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

module Network.QUIC.Closer (closure) where

import Control.Concurrent
import qualified Control.Exception as E
import Foreign.Marshal.Alloc
import Foreign.Ptr
import qualified Network.Socket as NS
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Config
import Network.QUIC.Connection
Expand Down
4 changes: 2 additions & 2 deletions Network/QUIC/Connection/Crypto.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ module Network.QUIC.Connection.Crypto (
setCurrentKeyPhase,
) where

import Control.Concurrent.STM
import Network.TLS.QUIC
import UnliftIO.STM

import Network.QUIC.Connection.Misc
import Network.QUIC.Connection.Types
Expand Down Expand Up @@ -63,7 +63,7 @@ putOffCrypto Connection{..} lvl rpkt =
waitEncryptionLevel :: Connection -> EncryptionLevel -> IO ()
waitEncryptionLevel Connection{..} lvl = atomically $ do
l <- readTVar $ encryptionLevel connState
checkSTM (l >= lvl)
check (l >= lvl)

----------------------------------------------------------------

Expand Down
6 changes: 3 additions & 3 deletions Network/QUIC/Connection/Migration.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ module Network.QUIC.Connection.Migration (
validatePath,
) where

import Control.Concurrent.STM
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Map.Strict as Map
import UnliftIO.STM

import Network.QUIC.Connection.Misc
import Network.QUIC.Connection.Queue
Expand Down Expand Up @@ -123,7 +123,7 @@ waitPeerCID conn@Connection{..} = do
let ref = peerCIDDB
db <- readTVar ref
mncid <- pickPeerCID conn
checkSTM $ isJust mncid
check $ isJust mncid
let u = usedCIDInfo db
setPeerCID conn (fromJust mncid) True
return u
Expand Down Expand Up @@ -313,7 +313,7 @@ isPathValidating Connection{..} = do
waitResponse :: Connection -> IO ()
waitResponse Connection{..} = atomically $ do
state <- readTVar migrationState
checkSTM (state == RecvResponse)
check (state == RecvResponse)
writeTVar migrationState NonMigration

checkResponse :: Connection -> PathData -> IO ()
Expand Down
4 changes: 2 additions & 2 deletions Network/QUIC/Connection/Misc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ module Network.QUIC.Connection.Misc (
abortConnection,
) where

import Control.Concurrent
import qualified Control.Exception as E
import Network.Socket (Socket)
import System.Mem.Weak
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Connection.Queue
import Network.QUIC.Connection.Timeout
Expand Down
2 changes: 1 addition & 1 deletion Network/QUIC/Connection/Queue.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Network.QUIC.Connection.Queue where

import UnliftIO.STM
import Control.Concurrent.STM

import Network.QUIC.Connection.Types
import Network.QUIC.Imports
Expand Down
2 changes: 1 addition & 1 deletion Network/QUIC/Connection/Role.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ module Network.QUIC.Connection.Role (
getCertificateChain,
) where

import Control.Concurrent
import qualified Crypto.Token as CT
import Data.X509 (CertificateChain)
import UnliftIO.Concurrent

import Network.QUIC.Connection.Misc
import Network.QUIC.Connection.Types
Expand Down
10 changes: 5 additions & 5 deletions Network/QUIC/Connection/State.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ module Network.QUIC.Connection.State (
checkAntiAmplificationFree,
) where

import Control.Concurrent.STM
import Network.Control
import UnliftIO.STM

import Network.QUIC.Connection.Types
import Network.QUIC.Connector
Expand Down Expand Up @@ -62,20 +62,20 @@ isConnection1RTTReady Connection{..} = atomically $ do
wait0RTTReady :: Connection -> IO ()
wait0RTTReady Connection{..} = atomically $ do
cs <- readTVar $ connectionState connState
checkSTM (cs >= ReadyFor0RTT)
check (cs >= ReadyFor0RTT)

-- | Waiting until 1-RTT data can be sent.
wait1RTTReady :: Connection -> IO ()
wait1RTTReady Connection{..} = atomically $ do
cs <- readTVar $ connectionState connState
checkSTM (cs >= ReadyFor1RTT)
check (cs >= ReadyFor1RTT)

-- | For clients, waiting until HANDSHAKE_DONE is received.
-- For servers, waiting until a TLS stack reports that the handshake is complete.
waitEstablished :: Connection -> IO ()
waitEstablished Connection{..} = atomically $ do
cs <- readTVar $ connectionState connState
checkSTM (cs >= Established)
check (cs >= Established)

----------------------------------------------------------------

Expand Down Expand Up @@ -134,7 +134,7 @@ waitAntiAmplificationFree conn@Connection{..} siz = do
ok <- checkAntiAmplificationFree conn siz
unless ok $ do
beforeAntiAmp connLDCC
atomically (checkAntiAmplificationFreeSTM conn siz >>= checkSTM)
atomically (checkAntiAmplificationFreeSTM conn siz >>= check)

-- setLossDetectionTimer is called eventually.

Expand Down
14 changes: 7 additions & 7 deletions Network/QUIC/Connection/Stream.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module Network.QUIC.Connection.Stream (
checkStreamIdRoom,
) where

import UnliftIO.STM
import Control.Concurrent.STM

import Network.QUIC.Connection.Misc
import Network.QUIC.Connection.Types
Expand Down Expand Up @@ -45,7 +45,7 @@ get tvar = atomically $ do
conc@Concurrency{..} <- readTVar tvar
let streamType = currentStream .&. 0b11
StreamIdBase base = maxStreams
checkSTM (currentStream < base * 4 + streamType)
check (currentStream < base * 4 + streamType)
let currentStream' = currentStream + 4
writeTVar tvar conc{currentStream = currentStream'}
return currentStream
Expand All @@ -68,15 +68,15 @@ updatePeerStreamId conn sid = do
|| (isServer conn && isClientInitiatedBidirectional sid)
)
$ do
atomicModifyIORef'' (peerStreamId conn) check
atomicModifyIORef'' (peerStreamId conn) checkConc
when
( (isClient conn && isServerInitiatedUnidirectional sid)
|| (isServer conn && isClientInitiatedUnidirectional sid)
)
$ do
atomicModifyIORef'' (peerUniStreamId conn) check
atomicModifyIORef'' (peerUniStreamId conn) checkConc
where
check conc@Concurrency{..}
checkConc conc@Concurrency{..}
| currentStream < sid = conc{currentStream = sid}
| otherwise = conc

Expand Down Expand Up @@ -106,9 +106,9 @@ checkStreamIdRoom conn dir = do
let ref
| dir == Bidirectional = peerStreamId conn
| otherwise = peerUniStreamId conn
atomicModifyIORef' ref check
atomicModifyIORef' ref checkConc
where
check conc@Concurrency{..} =
checkConc conc@Concurrency{..} =
let StreamIdBase base = maxStreams
initialStreams = initialMaxStreamsBidi $ getMyParameters conn
cbase = currentStream !>>. 2
Expand Down
8 changes: 4 additions & 4 deletions Network/QUIC/Connection/Timeout.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ module Network.QUIC.Connection.Timeout (
delay,
) where

import Control.Concurrent
import qualified Control.Exception as E
import Network.QUIC.Event
import qualified System.Timeout as ST
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Connection.Types
import Network.QUIC.Connector
Expand All @@ -25,7 +25,7 @@ fire conn (Microseconds microseconds) action = do
where
action' = do
alive <- getAlive conn
when alive action `E.catchSyncOrAsync` ignore
when alive action `E.catch` ignore

cfire :: Connection -> Microseconds -> TimeoutCallback -> IO (IO ())
cfire conn (Microseconds microseconds) action = do
Expand All @@ -36,7 +36,7 @@ cfire conn (Microseconds microseconds) action = do
where
action' = do
alive <- getAlive conn
when alive action `E.catchSyncOrAsync` ignore
when alive action `E.catch` ignore

delay :: Microseconds -> IO ()
delay (Microseconds microseconds) = threadDelay microseconds
4 changes: 2 additions & 2 deletions Network/QUIC/Connection/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

module Network.QUIC.Connection.Types where

import Control.Concurrent
import Control.Concurrent.STM
import qualified Crypto.Token as CT
import Data.Array.IO
import Data.ByteString.Internal
Expand All @@ -18,8 +20,6 @@ import Foreign.Ptr (nullPtr)
import Network.Control (Rate, RxFlow, TxFlow, newRate, newRxFlow, newTxFlow)
import Network.Socket (Cmsg, SockAddr, Socket)
import Network.TLS.QUIC
import UnliftIO.Concurrent
import UnliftIO.STM

import Network.QUIC.Config
import Network.QUIC.Connector
Expand Down
2 changes: 1 addition & 1 deletion Network/QUIC/Connector.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module Network.QUIC.Connector where

import Control.Concurrent.STM
import Data.IORef
import Network.QUIC.Types
import UnliftIO.STM

class Connector a where
getRole :: a -> Role
Expand Down
4 changes: 2 additions & 2 deletions Network/QUIC/Crypto/Keys.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ module Network.QUIC.Crypto.Keys (
headerProtectionKey,
) where

import qualified Control.Exception as E
import Network.TLS hiding (Version)
import Network.TLS.Extra.Cipher
import Network.TLS.QUIC
import qualified UnliftIO.Exception as E

import Network.QUIC.Crypto.Types
import Network.QUIC.Imports
Expand All @@ -35,7 +35,7 @@ initialSalt Version1 =
"\x38\x76\x2c\xf7\xf5\x59\x34\xb3\x4d\x17\x9a\xe6\xa4\xc8\x0c\xad\xcc\xbb\x7f\x0a"
initialSalt Version2 =
"\x0d\xed\xe3\xde\xf7\x00\xa6\xdb\x81\x93\x81\xbe\x6e\x26\x9d\xcb\xf9\xbd\x2e\xd9"
initialSalt (Version v) = E.impureThrow $ VersionIsUnknown v
initialSalt (Version v) = E.throw $ VersionIsUnknown v

initialSecrets :: Version -> CID -> TrafficSecrets InitialSecret
initialSecrets v c = (clientInitialSecret v c, serverInitialSecret v c)
Expand Down
6 changes: 3 additions & 3 deletions Network/QUIC/Exception.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ module Network.QUIC.Exception (
handleLogUnit,
) where

import qualified Control.Exception as E
import qualified GHC.IO.Exception as E
import qualified System.IO.Error as E
import qualified UnliftIO.Exception as E

import Network.QUIC.Logger

-- Catch all exceptions including asynchronous ones.
handleLogUnit :: DebugLogger -> IO () -> IO ()
handleLogUnit logAction action = action `E.catchSyncOrAsync` handler
handleLogUnit logAction action = action `E.catch` handler
where
handler :: E.SomeException -> IO ()
handler se = case E.fromException se of
Expand All @@ -25,7 +25,7 @@ handleLogUnit logAction action = action `E.catchSyncOrAsync` handler

-- Log and throw an exception
handleLogT :: DebugLogger -> IO a -> IO a
handleLogT logAction action = action `E.catchAny` handler
handleLogT logAction action = action `E.catch` handler
where
handler (E.SomeException e) = do
logAction $ bhow e
Expand Down
2 changes: 1 addition & 1 deletion Network/QUIC/Handshake.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

module Network.QUIC.Handshake where

import qualified Control.Exception as E
import Data.List (intersect)
import qualified Network.TLS as TLS
import Network.TLS.QUIC
import qualified UnliftIO.Exception as E

import Network.QUIC.Config
import Network.QUIC.Connection
Expand Down
6 changes: 3 additions & 3 deletions Network/QUIC/IO.hs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module Network.QUIC.IO where

import Control.Concurrent.STM
import qualified Control.Exception as E
import qualified Data.ByteString as BS
import Network.Control
import qualified UnliftIO.Exception as E
import UnliftIO.STM

import Network.QUIC.Connection
import Network.QUIC.Connector
Expand Down Expand Up @@ -110,7 +110,7 @@ checkBlocked s len wait = atomically $ do
connWindow = txWindowSize connFlow
minFlow = min strmWindow connWindow
n = min len minFlow
when wait $ checkSTM (n > 0)
when wait $ check (n > 0)
if n > 0
then return $ Right n
else do
Expand Down
2 changes: 1 addition & 1 deletion Network/QUIC/Packet/Decode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ module Network.QUIC.Packet.Decode (
decodeStatelessResetToken,
) where

import qualified Control.Exception as E
import qualified Data.ByteString as BS
import qualified Data.ByteString.Short as Short
import qualified UnliftIO.Exception as E

import Network.QUIC.Imports
import Network.QUIC.Packet.Header
Expand Down
2 changes: 1 addition & 1 deletion Network/QUIC/Receiver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ module Network.QUIC.Receiver (
receiver,
) where

import qualified Control.Exception as E
import qualified Data.ByteString as BS
import Network.Control
import Network.TLS (AlertDescription (..))
import qualified UnliftIO.Exception as E

import Network.QUIC.Config
import Network.QUIC.Connection
Expand Down
Loading

0 comments on commit ad2e457

Please sign in to comment.