diff --git a/src/Control/Concurrent/Async/Lifted/Safe/Utils.hs b/src/Control/Concurrent/Async/Lifted/Safe/Utils.hs index f7f395b64..27dc86127 100644 --- a/src/Control/Concurrent/Async/Lifted/Safe/Utils.hs +++ b/src/Control/Concurrent/Async/Lifted/Safe/Utils.hs @@ -1,15 +1,17 @@ module Control.Concurrent.Async.Lifted.Safe.Utils - ( allocateLinkedAsync + ( allocateAsync, allocateLinkedAsync ) where import ClassyPrelude hiding (cancel) +import Control.Lens import Control.Concurrent.Async.Lifted.Safe import Control.Monad.Trans.Resource -allocateLinkedAsync :: forall m a. - MonadResource m - => IO a -> m (Async a) -allocateLinkedAsync act = allocate (async act) cancel >>= (\(_k, a) -> a <$ link a) +allocateLinkedAsync, allocateAsync :: forall m a. + MonadResource m + => IO a -> m (Async a) +allocateAsync = fmap (view _2) . flip allocate cancel . async +allocateLinkedAsync = uncurry (<$) . (id &&& link) <=< allocateAsync diff --git a/src/Ldap/Client/Pool.hs b/src/Ldap/Client/Pool.hs index 6682d7c98..4c1d6fdfa 100644 --- a/src/Ldap/Client/Pool.hs +++ b/src/Ldap/Client/Pool.hs @@ -10,6 +10,8 @@ module Ldap.Client.Pool import ClassyPrelude +import Control.Lens + import Ldap.Client (Ldap, LdapError) import qualified Ldap.Client as Ldap @@ -22,11 +24,17 @@ import Data.Dynamic import System.Timeout.Lifted +import Control.Concurrent.Async.Lifted.Safe +import Control.Concurrent.Async.Lifted.Safe.Utils +import Control.Monad.Trans.Resource (MonadResource) +import qualified Control.Monad.Trans.Resource as Resource + type LdapPool = Pool LdapExecutor data LdapExecutor = LdapExecutor { ldapExec :: forall a. Typeable a => (Ldap -> IO a) -> IO (Either LdapPoolError a) , ldapDestroy :: TMVar () + , ldapAsync :: Async () } instance Exception LdapError @@ -41,7 +49,7 @@ withLdap :: (MonadBaseControl IO m, MonadIO m, Typeable a) => LdapPool -> (Ldap withLdap pool act = withResource pool $ \LdapExecutor{..} -> liftIO $ ldapExec act -createLdapPool :: ( MonadLoggerIO m, MonadIO m ) +createLdapPool :: ( MonadLoggerIO m, MonadResource m ) => Ldap.Host -> Ldap.PortNumber -> Int -- ^ Stripes @@ -53,15 +61,15 @@ createLdapPool host port stripes timeoutConn (round . (* 1e6) -> timeoutAct) lim logFunc <- askLoggerIO let - mkExecutor :: IO LdapExecutor - mkExecutor = do - ldapDestroy <- newEmptyTMVarIO - ldapAct <- newEmptyTMVarIO + mkExecutor :: Resource.InternalState -> IO LdapExecutor + mkExecutor rSt = Resource.runInternalState ?? rSt $ do + ldapDestroy <- liftIO newEmptyTMVarIO + ldapAct <- liftIO newEmptyTMVarIO let ldapExec :: forall a. Typeable a => (Ldap -> IO a) -> IO (Either LdapPoolError a) ldapExec act = do - ldapAnswer <- newEmptyTMVarIO :: IO (TMVar (Either SomeException Dynamic)) + ldapAnswer <- liftIO newEmptyTMVarIO :: IO (TMVar (Either SomeException Dynamic)) atomically $ putTMVar ldapAct (fmap toDyn . act, ldapAnswer) either throwIO (return . Right . flip fromDyn (error "Could not cast dynamic")) =<< atomically (takeTMVar ldapAnswer) `catches` @@ -91,10 +99,10 @@ createLdapPool host port stripes timeoutConn (round . (* 1e6) -> timeoutAct) lim ] go Nothing ldap - withTimeout $ do - setup <- newEmptyTMVarIO + ldapAsync <- withTimeout $ do + setup <- liftIO newEmptyTMVarIO - void . fork . flip runLoggingT logFunc $ do + ldapAsync <- allocateAsync . flip runLoggingT logFunc $ do $logInfoS "LdapExecutor" "Starting" res <- liftIO . Ldap.with host port $ flip runLoggingT logFunc . go (Just setup) case res of @@ -105,11 +113,16 @@ createLdapPool host port stripes timeoutConn (round . (* 1e6) -> timeoutAct) lim maybe (return ()) throwM =<< atomically (takeTMVar setup) + return ldapAsync + return LdapExecutor{..} delExecutor :: LdapExecutor -> IO () - delExecutor LdapExecutor{..} = atomically . void $ tryPutTMVar ldapDestroy () - liftIO $ createPool mkExecutor delExecutor stripes timeoutConn limit + delExecutor LdapExecutor{..} = do + atomically . void $ tryPutTMVar ldapDestroy () + wait ldapAsync + rSt <- view _2 <$> Resource.allocate Resource.createInternalState Resource.closeInternalState + liftIO $ createPool (mkExecutor rSt) delExecutor stripes timeoutConn limit where withTimeout :: forall m a. (MonadBaseControl IO m, MonadThrow m) => m a -> m a withTimeout = maybe (throwM LdapPoolTimeout) return <=< timeout timeoutAct