diff --git a/README.markdown b/README.markdown index 7722db4..28788f5 100644 --- a/README.markdown +++ b/README.markdown @@ -11,7 +11,8 @@ This library implements (the parts of) [RFC 4511][rfc4511] :--------------------------- |:-----------:|:-----------: Bind Operation | 4.2 | ✔ Unbind Operation | 4.3 | ✔ -Notice of Disconnection | 4.4.1 | ✘ +Unsolicited Notification | 4.4 | ✔ +Notice of Disconnection | 4.4.1 | ✔ Search Operation | 4.5 | ✔\* Modify Operation | 4.6 | ✔ Add Operation | 4.7 | ✔ diff --git a/ldap-client.cabal b/ldap-client.cabal index eeac1cd..c9b3e1a 100644 --- a/ldap-client.cabal +++ b/ldap-client.cabal @@ -60,6 +60,7 @@ test-suite spec main-is: Spec.hs other-modules: + Ldap.ClientSpec Ldap.Client.AddSpec Ldap.Client.BindSpec Ldap.Client.CompareSpec diff --git a/src/Ldap/Asn1/FromAsn1.hs b/src/Ldap/Asn1/FromAsn1.hs index 3c2c154..b0ed155 100644 --- a/src/Ldap/Asn1/FromAsn1.hs +++ b/src/Ldap/Asn1/FromAsn1.hs @@ -336,10 +336,10 @@ instance FromAsn1 ProtocolServerOp where Asn1.Start (Asn1.Container Asn1.Application 24) <- next res <- fromAsn1 name <- optional $ do - Asn1.Other Asn1.Context 0 s <- next + Asn1.Other Asn1.Context 10 s <- next return s value <- optional $ do - Asn1.Other Asn1.Context 1 s <- next + Asn1.Other Asn1.Context 11 s <- next return s Asn1.End (Asn1.Container Asn1.Application 24) <- next return (ExtendedResponse res (fmap LdapOid name) value) diff --git a/src/Ldap/Client.hs b/src/Ldap/Client.hs index dfc0b65..5f409ee 100644 --- a/src/Ldap/Client.hs +++ b/src/Ldap/Client.hs @@ -1,5 +1,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE NamedFieldPuns #-} module Ldap.Client ( Host(..) @@ -55,14 +56,15 @@ module Ldap.Client import Control.Applicative ((<$>)) #endif import qualified Control.Concurrent.Async as Async -import Control.Concurrent.STM (atomically) +import Control.Concurrent.STM (atomically, throwSTM) import Control.Concurrent.STM.TMVar (putTMVar) import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue) -import Control.Exception (Handler(..), bracket, throwIO, catches) +import Control.Exception (Exception, Handler(..), bracket, throwIO, catch, catches) import Control.Monad (forever) import qualified Data.ASN1.BinaryEncoding as Asn1 import qualified Data.ASN1.Encoding as Asn1 import qualified Data.ASN1.Error as Asn1 +import Data.ByteString (ByteString) import qualified Data.ByteString as ByteString import qualified Data.ByteString.Lazy as ByteString.Lazy import Data.Foldable (asum) @@ -70,10 +72,13 @@ import Data.Function (fix) import Data.List.NonEmpty (NonEmpty((:|))) import qualified Data.Map.Strict as Map import Data.Monoid (Endo(appEndo)) +import Data.String (fromString) +import Data.Text (Text) +import Data.Typeable (Typeable) import Network.Connection (Connection) import qualified Network.Connection as Conn -import qualified System.IO.Error as IO import Prelude hiding (compare) +import qualified System.IO.Error as IO import Ldap.Asn1.ToAsn1 (ToAsn1(toAsn1)) import Ldap.Asn1.FromAsn1 (FromAsn1, parseAsn1) @@ -109,8 +114,19 @@ data LdapError = IOError IOError | ParseError Asn1.ASN1Error | ResponseError ResponseError + | DisconnectError Disconnect deriving (Show, Eq) +newtype WrappedIOError = WrappedIOError IOError + deriving (Show, Eq, Typeable) + +instance Exception WrappedIOError + +data Disconnect = Disconnect Type.ResultCode Dn Text + deriving (Show, Eq, Typeable) + +instance Exception Disconnect + -- | The entrypoint into LDAP. with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a) with host port f = do @@ -125,7 +141,7 @@ with host port f = do Async.withAsync (f l) $ \u -> fmap (Right . snd) (Async.waitAnyCancel [i, o, d, u]))) `catches` - [ Handler (return . Left . IOError) + [ Handler (\(WrappedIOError e) -> return (Left (IOError e))) , Handler (return . Left . ParseError) , Handler (return . Left . ResponseError) ] @@ -154,7 +170,7 @@ with host port f = do } input :: FromAsn1 a => TQueue a -> Connection -> IO b -input inq conn = flip fix [] $ \loop chunks -> do +input inq conn = wrap . flip fix [] $ \loop chunks -> do chunk <- Conn.connectionGet conn 8192 case ByteString.length chunk of 0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing) @@ -174,7 +190,7 @@ input inq conn = flip fix [] $ \loop chunks -> do loop [] output :: ToAsn1 a => TQueue a -> Connection -> IO b -output out conn = forever $ do +output out conn = wrap . forever $ do msg <- atomically (readTQueue out) Conn.connectionPut conn (encode (toAsn1 msg)) where @@ -203,16 +219,37 @@ dispatch Ldap { client } inq outq = Type.DeleteResponse {} -> done mid op req Type.ModifyDnResponse {} -> done mid op req Type.CompareResponse {} -> done mid op req - Type.ExtendedResponse {} -> done mid op req + Type.ExtendedResponse {} -> probablyDisconnect mid op req Type.IntermediateResponse {} -> saveUp mid op req return (res, counter) ]) where saveUp mid op res = return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res) + done mid op req = case Map.lookup mid req of Nothing -> return req Just (stack, var) -> do putTMVar var (op :| stack) return (Map.delete mid req) + + probablyDisconnect (Type.Id 0) + (Type.ExtendedResponse + (Type.LdapResult code + (Type.LdapDn (Type.LdapString dn)) + (Type.LdapString reason) + _) + moid _) + req = + case moid of + Just (Type.LdapOid oid) + | oid == noticeOfDisconnection -> throwSTM (Disconnect code (Dn dn) reason) + _ -> return req + probablyDisconnect mid op req = done mid op req + + noticeOfDisconnection :: ByteString + noticeOfDisconnection = fromString "1.3.6.1.4.1.1466.20036" + +wrap :: IO a -> IO a +wrap m = m `catch` (throwIO . WrappedIOError) diff --git a/test/Ldap/ClientSpec.hs b/test/Ldap/ClientSpec.hs new file mode 100644 index 0000000..418501e --- /dev/null +++ b/test/Ldap/ClientSpec.hs @@ -0,0 +1,18 @@ +module Ldap.ClientSpec (spec) where + +import Control.Exception (IOException, throwIO) +import Test.Hspec + +import SpecHelper (locally) + + +spec :: Spec +spec = + context "exceptions" $ + it "propagates unrelated ‘IOException’s through" $ + locally (\_ -> throwIO unrelated) + `shouldThrow` + (== unrelated) + +unrelated :: IOException +unrelated = userError "unrelated"