diff --git a/example/login.hs b/example/login.hs index 10dbccb..332e330 100644 --- a/example/login.hs +++ b/example/login.hs @@ -31,7 +31,7 @@ import qualified System.IO as IO -- base data Conf = Conf { host :: String - , port :: PortNumber + , port :: Int , dn :: Dn , password :: Password , base :: Dn diff --git a/src/Ldap/Client.hs b/src/Ldap/Client.hs index 50493de..3bc7169 100644 --- a/src/Ldap/Client.hs +++ b/src/Ldap/Client.hs @@ -2,6 +2,8 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} + -- | This module is intended to be imported qualified -- -- @ @@ -9,11 +11,17 @@ -- @ module Ldap.Client ( with + , with' + , runsIn + , runsInEither + , open + , close , Host(..) , defaultTlsSettings , insecureTlsSettings , PortNumber , Ldap + , LdapH , LdapError(..) , ResponseError(..) , Type.ResultCode(..) @@ -66,8 +74,9 @@ import qualified Control.Concurrent.Async as Async import Control.Concurrent.STM (atomically, throwSTM) import Control.Concurrent.STM.TMVar (putTMVar) import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue) -import Control.Exception (Exception, Handler(..), bracket, throwIO, catch, catches) +import Control.Exception (Exception, bracket, throwIO, SomeException, fromException, throw, Handler(..)) import Control.Monad (forever) +import Data.Void (Void) import qualified Data.ASN1.BinaryEncoding as Asn1 import qualified Data.ASN1.Encoding as Asn1 import qualified Data.ASN1.Error as Asn1 @@ -114,50 +123,99 @@ import Ldap.Client.Extended (Oid(..), extended, noticeOfDisconnectionO {-# ANN module ("HLint: ignore Use first" :: String) #-} -newLdap :: IO Ldap -newLdap = Ldap - <$> newTQueueIO - -- | Various failures that can happen when working with LDAP. -data LdapError = - IOError !IOError -- ^ Network failure. +data LdapError + = IOError !IOError -- ^ Network failure. | ParseError !Asn1.ASN1Error -- ^ Invalid ASN.1 data received from the server. | ResponseError !ResponseError -- ^ An LDAP operation failed. | DisconnectError !Disconnect -- ^ Notice of Disconnection has been received. deriving (Show, Eq) -newtype WrappedIOError = WrappedIOError IOError - deriving (Show, Eq, Typeable) - -instance Exception WrappedIOError +instance Exception LdapError data Disconnect = Disconnect !Type.ResultCode !Dn !Text deriving (Show, Eq, Typeable) instance Exception Disconnect +newtype LdapH = LdapH Ldap + +-- | Provide a 'LdapH' to a function needing an 'Ldap' handle. +runsIn :: (Ldap -> IO a) + -> LdapH + -> IO a +runsIn act (LdapH ldap) = do + actor <- Async.async (act ldap) + r <- Async.waitEitherCatch (workers ldap) actor + case r of + Left (Right _a) -> error "Unreachable" + Left (Left e) -> throwIO =<< catchesHandler workerErr e + Right (Right r') -> pure r' + Right (Left e) -> throwIO =<< catchesHandler respErr e + +-- | Provide a 'LdapH' to a function needing an 'Ldap' handle +runsInEither :: (Ldap -> IO a) + -> LdapH + -> IO (Either LdapError a) +runsInEither act (LdapH ldap) = do + actor <- Async.async (act ldap) + r <- Async.waitEitherCatch (workers ldap) actor + case r of + Left (Right _a) -> error "Unreachable" + Left (Left e) -> do Left <$> catchesHandler workerErr e + Right (Right r') -> pure (Right r') + Right (Left e) -> do Left <$> catchesHandler respErr e + + +workerErr :: [Handler LdapError] +workerErr = [ Handler (\(ex :: IOError) -> pure (IOError ex)) + , Handler (\(ex :: Asn1.ASN1Error) -> pure (ParseError ex)) + , Handler (\(ex :: Disconnect) -> pure (DisconnectError ex)) + ] + +respErr :: [Handler LdapError] +respErr = [ Handler (\(ex :: ResponseError) -> pure (ResponseError ex)) + ] + +catchesHandler :: [Handler a] -> SomeException -> IO a +catchesHandler handlers e = foldr tryHandler (throw e) handlers + where tryHandler (Handler handler) res + = case fromException e of + Just e' -> handler e' + Nothing -> res + -- | The entrypoint into LDAP. --- --- It catches all LDAP-related exceptions. +with' :: Host -> PortNumber -> (Ldap -> IO a) -> IO a +with' host port act = bracket (open host port) close (runsIn act) + with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a) -with host port f = do +with host port act = bracket (open host port) close (runsInEither act) + +-- | Creates an LDAP handle. This action is useful for creating your own resource +-- management, such as with 'resource-pool'. The handle must be manually closed +-- with 'close'. +open :: Host -> PortNumber -> IO (LdapH) +open host port = do context <- Conn.initConnectionContext - bracket (Conn.connectTo context params) Conn.connectionClose (\conn -> - bracket newLdap unbindAsync (\l -> do - inq <- newTQueueIO - outq <- newTQueueIO - as <- traverse Async.async - [ input inq conn - , output outq conn - , dispatch l inq outq - , f l - ] - fmap (Right . snd) (Async.waitAnyCancel as))) - `catches` - [ Handler (\(WrappedIOError e) -> return (Left (IOError e))) - , Handler (return . Left . ParseError) - , Handler (return . Left . ResponseError) - ] + conn <- Conn.connectTo context params + reqQ <- newTQueueIO + inQ <- newTQueueIO + outQ <- newTQueueIO + + -- The input worker that reads data off the network. + (inW :: Async.Async Void) <- Async.async (input inQ conn) + + -- The output worker that sends data onto the network. + (outW :: Async.Async Void) <- Async.async (output outQ conn) + + -- The dispatch worker that sends data between the three queues. + (dispW :: Async.Async Void) <- Async.async (dispatch reqQ inQ outQ) + + -- We use this to propagate exceptions between the workers. The `workers` Async is just a tool to + -- exchange exceptions between the entire worker group and another thread. + workers <- Async.async (snd <$> Async.waitAnyCancel [inW, outW, dispW]) + + pure (LdapH (Ldap reqQ workers conn)) where params = Conn.ConnectionParams { Conn.connectionHostname = @@ -172,6 +230,14 @@ with host port f = do , Conn.connectionUseSocks = Nothing } +-- | Closes an LDAP connection. +-- This is to be used in together with 'open'. +close :: LdapH -> IO () +close (LdapH ldap) = do + unbindAsync ldap + Conn.connectionClose (conn ldap) + Async.cancel (workers ldap) + defaultTlsSettings :: Conn.TLSSettings defaultTlsSettings = Conn.TLSSettingsSimple { Conn.settingDisableCertificateValidation = False @@ -186,84 +252,85 @@ insecureTlsSettings = Conn.TLSSettingsSimple , Conn.settingUseServerName = False } +-- | Reads Asn1 BER encoded chunks off a connection into a TQueue. input :: FromAsn1 a => TQueue a -> Connection -> IO b -input inq conn = wrap . flip fix [] $ \loop chunks -> do - chunk <- Conn.connectionGet conn 8192 - case ByteString.length chunk of - 0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing) - _ -> do - let chunks' = chunk : chunks - case Asn1.decodeASN1 Asn1.BER (ByteString.Lazy.fromChunks (reverse chunks')) of - Left Asn1.ParsingPartial - -> loop chunks' - Left e -> throwIO e - Right asn1 -> do - flip fix asn1 $ \loop' asn1' -> - case parseAsn1 asn1' of - Nothing -> return () - Just (asn1'', a) -> do - atomically (writeTQueue inq a) - loop' asn1'' - loop [] +input inq conn = loop [] + where + loop chunks = do + chunk <- Conn.connectionGet conn 8192 + case ByteString.length chunk of + 0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing) + _ -> do + let chunks' = chunk : chunks + case Asn1.decodeASN1 Asn1.BER (ByteString.Lazy.fromChunks (reverse chunks')) of + Left Asn1.ParsingPartial + -> loop chunks' + Left e -> throwIO e + Right asn1 -> do + flip fix asn1 $ \loop' asn1' -> + case parseAsn1 asn1' of + Nothing -> return () + Just (asn1'', a) -> do + atomically (writeTQueue inq a) + loop' asn1'' + loop [] +-- | Transmits Asn1 DER encoded data from a TQueue into a Connection. output :: ToAsn1 a => TQueue a -> Connection -> IO b -output out conn = wrap . forever $ do +output out conn = forever $ do msg <- atomically (readTQueue out) Conn.connectionPut conn (encode (toAsn1 msg)) where encode x = Asn1.encodeASN1' Asn1.DER (appEndo x []) dispatch - :: Ldap + :: TQueue ClientMessage -> TQueue (Type.LdapMessage Type.ProtocolServerOp) -> TQueue (Type.LdapMessage Request) -> IO a -dispatch Ldap { client } inq outq = - flip fix (Map.empty, 1) $ \loop (!req, !counter) -> - loop =<< atomically (asum - [ do New new var <- readTQueue client - writeTQueue outq (Type.LdapMessage (Type.Id counter) new Nothing) - return (Map.insert (Type.Id counter) ([], var) req, counter + 1) - , do Type.LdapMessage mid op _ - <- readTQueue inq - res <- case op of - Type.BindResponse {} -> done mid op req - Type.SearchResultEntry {} -> saveUp mid op req - Type.SearchResultReference {} -> return req - Type.SearchResultDone {} -> done mid op req - Type.ModifyResponse {} -> done mid op req - Type.AddResponse {} -> done mid op req - Type.DeleteResponse {} -> done mid op req - Type.ModifyDnResponse {} -> done mid op req - Type.CompareResponse {} -> done mid op req - Type.ExtendedResponse {} -> probablyDisconnect mid op req - Type.IntermediateResponse {} -> saveUp mid op req - return (res, counter) - ]) - where - saveUp mid op res = - return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res) +dispatch reqq inq outq = loop (Map.empty, 1) + where + saveUp mid op res = return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res) - done mid op req = - case Map.lookup mid req of - Nothing -> return req - Just (stack, var) -> do - putTMVar var (op :| stack) - return (Map.delete mid req) + loop (!req, !counter) = + loop =<< atomically (asum + [ do New new var <- readTQueue reqq + writeTQueue outq (Type.LdapMessage (Type.Id counter) new Nothing) + return (Map.insert (Type.Id counter) ([], var) req, counter + 1) + , do Type.LdapMessage mid op _ + <- readTQueue inq + res <- case op of + Type.BindResponse {} -> done mid op req + Type.SearchResultEntry {} -> saveUp mid op req + Type.SearchResultReference {} -> return req + Type.SearchResultDone {} -> done mid op req + Type.ModifyResponse {} -> done mid op req + Type.AddResponse {} -> done mid op req + Type.DeleteResponse {} -> done mid op req + Type.ModifyDnResponse {} -> done mid op req + Type.CompareResponse {} -> done mid op req + Type.ExtendedResponse {} -> probablyDisconnect mid op req + Type.IntermediateResponse {} -> saveUp mid op req + return (res, counter) + ]) - probablyDisconnect (Type.Id 0) - (Type.ExtendedResponse - (Type.LdapResult code - (Type.LdapDn (Type.LdapString dn)) - (Type.LdapString reason) - _) - moid _) - req = - case moid of - Just (Type.LdapOid oid) - | Oid oid == noticeOfDisconnectionOid -> throwSTM (Disconnect code (Dn dn) reason) - _ -> return req - probablyDisconnect mid op req = done mid op req + done mid op req = + case Map.lookup mid req of + Nothing -> return req + Just (stack, var) -> do + putTMVar var (op :| stack) + return (Map.delete mid req) -wrap :: IO a -> IO a -wrap m = m `catch` (throwIO . WrappedIOError) + probablyDisconnect (Type.Id 0) + (Type.ExtendedResponse + (Type.LdapResult code + (Type.LdapDn (Type.LdapString dn)) + (Type.LdapString reason) + _) + moid _) + req = + case moid of + Just (Type.LdapOid oid) + | Oid oid == noticeOfDisconnectionOid -> throwSTM (Disconnect code (Dn dn) reason) + _ -> return req + probablyDisconnect mid op req = done mid op req diff --git a/src/Ldap/Client/Add.hs b/src/Ldap/Client/Add.hs index 158439a..efe7710 100644 --- a/src/Ldap/Client/Add.hs +++ b/src/Ldap/Client/Add.hs @@ -31,7 +31,7 @@ import Ldap.Client.Internal -- | Perform the Add operation synchronously. Raises 'ResponseError' on failures. add :: Ldap -> Dn -> AttrList NonEmpty -> IO () add l dn as = - raise =<< addEither l dn as + eitherToIO =<< addEither l dn as -- | Perform the Add operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures. diff --git a/src/Ldap/Client/Bind.hs b/src/Ldap/Client/Bind.hs index 07abf16..ddcf021 100644 --- a/src/Ldap/Client/Bind.hs +++ b/src/Ldap/Client/Bind.hs @@ -42,7 +42,7 @@ newtype Password = Password ByteString -- | Perform the Bind operation synchronously. Raises 'ResponseError' on failures. bind :: Ldap -> Dn -> Password -> IO () bind l username password = - raise =<< bindEither l username password + eitherToIO =<< bindEither l username password -- | Perform the Bind operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures. @@ -82,7 +82,7 @@ bindResult req res = Left (ResponseInvalid req res) -- | Perform a SASL EXTERNAL Bind operation synchronously. Raises 'ResponseError' on failures. externalBind :: Ldap -> Dn -> Maybe Text -> IO () externalBind l username mCredentials = - raise =<< externalBindEither l username mCredentials + eitherToIO =<< externalBindEither l username mCredentials -- | Perform a SASL EXTERNAL Bind operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures. diff --git a/src/Ldap/Client/Compare.hs b/src/Ldap/Client/Compare.hs index 9d54fac..5507fc4 100644 --- a/src/Ldap/Client/Compare.hs +++ b/src/Ldap/Client/Compare.hs @@ -33,7 +33,7 @@ import qualified Ldap.Asn1.Type as Type -- | Perform the Compare operation synchronously. Raises 'ResponseError' on failures. compare :: Ldap -> Dn -> Attr -> AttrValue -> IO Bool compare l dn k v = - raise =<< compareEither l dn k v + eitherToIO =<< compareEither l dn k v -- | Perform the Compare operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures. diff --git a/src/Ldap/Client/Delete.hs b/src/Ldap/Client/Delete.hs index c877ce7..a0476c6 100644 --- a/src/Ldap/Client/Delete.hs +++ b/src/Ldap/Client/Delete.hs @@ -31,7 +31,7 @@ import Ldap.Client.Internal -- | Perform the Delete operation synchronously. Raises 'ResponseError' on failures. delete :: Ldap -> Dn -> IO () delete l dn = - raise =<< deleteEither l dn + eitherToIO =<< deleteEither l dn -- | Perform the Delete operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures. diff --git a/src/Ldap/Client/Extended.hs b/src/Ldap/Client/Extended.hs index d5bcabb..96674d3 100644 --- a/src/Ldap/Client/Extended.hs +++ b/src/Ldap/Client/Extended.hs @@ -54,7 +54,7 @@ instance IsString Oid where -- | Perform the Extended operation synchronously. Raises 'ResponseError' on failures. extended :: Ldap -> Oid -> Maybe ByteString -> IO () extended l oid mv = - raise =<< extendedEither l oid mv + eitherToIO =<< extendedEither l oid mv -- | Perform the Extended operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures. @@ -92,7 +92,7 @@ extendedResult req res = Left (ResponseInvalid req res) -- | An example of @Extended Operation@, cf. 'extended'. startTls :: Ldap -> IO () startTls = - raise <=< startTlsEither + eitherToIO <=< startTlsEither -- | An example of @Extended Operation@, cf. 'extendedEither'. startTlsEither :: Ldap -> IO (Either ResponseError ()) diff --git a/src/Ldap/Client/Internal.hs b/src/Ldap/Client/Internal.hs index 3d298df..16189f3 100644 --- a/src/Ldap/Client/Internal.hs +++ b/src/Ldap/Client/Internal.hs @@ -16,7 +16,7 @@ module Ldap.Client.Internal , Response , ResponseError(..) , Request - , raise + , eitherToIO , sendRequest , Dn(..) , Attr(..) @@ -27,6 +27,7 @@ module Ldap.Client.Internal , unbindAsyncSTM ) where +import qualified Control.Concurrent.Async as Async (Async) import Control.Concurrent.STM (STM, atomically) import Control.Concurrent.STM.TMVar (TMVar, newEmptyTMVar, readTMVar) import Control.Concurrent.STM.TQueue (TQueue, writeTQueue) @@ -41,7 +42,8 @@ import Network.Socket (PortNumber) #else import Network (PortNumber) #endif -import Network.Connection (TLSSettings) +import Network.Connection (TLSSettings, Connection) +import Data.Void (Void) import qualified Ldap.Asn1.Type as Type @@ -52,10 +54,12 @@ data Host = | Tls String TLSSettings -- ^ LDAP over TLS. deriving (Show) --- | A token. All functions that interact with the Directory require one. -newtype Ldap = Ldap - { client :: TQueue ClientMessage - } deriving (Eq) +-- | An LDAP connection handle +data Ldap = Ldap + { reqQ :: !(TQueue ClientMessage) -- ^ Request queue for client messages to be send. + , workers :: !(Async.Async Void) -- ^ Workers group for communicating with the server. + , conn :: !Connection -- ^ Network connection to the server. + } data ClientMessage = New !Request !(TMVar (NonEmpty Type.ProtocolServerOp)) type Request = Type.ProtocolClientOp @@ -116,11 +120,10 @@ sendRequest l p msg = return (Async (fmap p (readTMVar var))) writeRequest :: Ldap -> TMVar Response -> Request -> STM () -writeRequest Ldap { client } var msg = writeTQueue client (New msg var) - -raise :: Exception e => Either e a -> IO a -raise = either throwIO return +writeRequest Ldap { reqQ } var msg = writeTQueue reqQ (New msg var) +eitherToIO :: Exception e => Either e a -> IO a +eitherToIO = either throwIO pure -- | Terminate the connection to the Directory. -- diff --git a/src/Ldap/Client/Modify.hs b/src/Ldap/Client/Modify.hs index e5520d9..fa666a3 100644 --- a/src/Ldap/Client/Modify.hs +++ b/src/Ldap/Client/Modify.hs @@ -48,7 +48,7 @@ data Operation = -- | Perform the Modify operation synchronously. Raises 'ResponseError' on failures. modify :: Ldap -> Dn -> [Operation] -> IO () modify l dn as = - raise =<< modifyEither l dn as + eitherToIO =<< modifyEither l dn as -- | Perform the Modify operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures. @@ -98,7 +98,7 @@ newtype RelativeDn = RelativeDn Text -- | Perform the Modify DN operation synchronously. Raises 'ResponseError' on failures. modifyDn :: Ldap -> Dn -> RelativeDn -> Bool -> Maybe Dn -> IO () modifyDn l dn rdn del new = - raise =<< modifyDnEither l dn rdn del new + eitherToIO =<< modifyDnEither l dn rdn del new -- | Perform the Modify DN operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures. diff --git a/src/Ldap/Client/Search.hs b/src/Ldap/Client/Search.hs index fcc21dc..ca4d742 100644 --- a/src/Ldap/Client/Search.hs +++ b/src/Ldap/Client/Search.hs @@ -52,7 +52,7 @@ import Ldap.Client.Internal -- | Perform the Search operation synchronously. Raises 'ResponseError' on failures. search :: Ldap -> Dn -> Mod Search -> Filter -> [Attr] -> IO [SearchEntry] search l base opts flt attributes = - raise =<< searchEither l base opts flt attributes + eitherToIO =<< searchEither l base opts flt attributes -- | Perform the Search operation synchronously. Returns @Left e@ where -- @e@ is a 'ResponseError' on failures.