React on Unsolicited Notifications

Also improves the behavior when an `IOException` is encountered:
only exceptions directly related to LDAP are trapped by `with`.
This commit is contained in:
Matvey Aksenov 2015-04-04 16:42:28 +00:00
parent 9ab5760b8e
commit aea85536cf
5 changed files with 67 additions and 10 deletions

View File

@ -11,7 +11,8 @@ This library implements (the parts of) [RFC 4511][rfc4511]
:--------------------------- |:-----------:|:-----------: :--------------------------- |:-----------:|:-----------:
Bind Operation | 4.2 | ✔ Bind Operation | 4.2 | ✔
Unbind Operation | 4.3 | ✔ Unbind Operation | 4.3 | ✔
Notice of Disconnection | 4.4.1 | ✘ Unsolicited Notification | 4.4 | ✔
Notice of Disconnection | 4.4.1 | ✔
Search Operation | 4.5 | ✔\* Search Operation | 4.5 | ✔\*
Modify Operation | 4.6 | ✔ Modify Operation | 4.6 | ✔
Add Operation | 4.7 | ✔ Add Operation | 4.7 | ✔

View File

@ -60,6 +60,7 @@ test-suite spec
main-is: main-is:
Spec.hs Spec.hs
other-modules: other-modules:
Ldap.ClientSpec
Ldap.Client.AddSpec Ldap.Client.AddSpec
Ldap.Client.BindSpec Ldap.Client.BindSpec
Ldap.Client.CompareSpec Ldap.Client.CompareSpec

View File

@ -336,10 +336,10 @@ instance FromAsn1 ProtocolServerOp where
Asn1.Start (Asn1.Container Asn1.Application 24) <- next Asn1.Start (Asn1.Container Asn1.Application 24) <- next
res <- fromAsn1 res <- fromAsn1
name <- optional $ do name <- optional $ do
Asn1.Other Asn1.Context 0 s <- next Asn1.Other Asn1.Context 10 s <- next
return s return s
value <- optional $ do value <- optional $ do
Asn1.Other Asn1.Context 1 s <- next Asn1.Other Asn1.Context 11 s <- next
return s return s
Asn1.End (Asn1.Container Asn1.Application 24) <- next Asn1.End (Asn1.Container Asn1.Application 24) <- next
return (ExtendedResponse res (fmap LdapOid name) value) return (ExtendedResponse res (fmap LdapOid name) value)

View File

@ -1,5 +1,6 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NamedFieldPuns #-}
module Ldap.Client module Ldap.Client
( Host(..) ( Host(..)
@ -55,14 +56,15 @@ module Ldap.Client
import Control.Applicative ((<$>)) import Control.Applicative ((<$>))
#endif #endif
import qualified Control.Concurrent.Async as Async import qualified Control.Concurrent.Async as Async
import Control.Concurrent.STM (atomically) import Control.Concurrent.STM (atomically, throwSTM)
import Control.Concurrent.STM.TMVar (putTMVar) import Control.Concurrent.STM.TMVar (putTMVar)
import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue) import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue)
import Control.Exception (Handler(..), bracket, throwIO, catches) import Control.Exception (Exception, Handler(..), bracket, throwIO, catch, catches)
import Control.Monad (forever) import Control.Monad (forever)
import qualified Data.ASN1.BinaryEncoding as Asn1 import qualified Data.ASN1.BinaryEncoding as Asn1
import qualified Data.ASN1.Encoding as Asn1 import qualified Data.ASN1.Encoding as Asn1
import qualified Data.ASN1.Error as Asn1 import qualified Data.ASN1.Error as Asn1
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as ByteString.Lazy import qualified Data.ByteString.Lazy as ByteString.Lazy
import Data.Foldable (asum) import Data.Foldable (asum)
@ -70,10 +72,13 @@ import Data.Function (fix)
import Data.List.NonEmpty (NonEmpty((:|))) import Data.List.NonEmpty (NonEmpty((:|)))
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import Data.Monoid (Endo(appEndo)) import Data.Monoid (Endo(appEndo))
import Data.String (fromString)
import Data.Text (Text)
import Data.Typeable (Typeable)
import Network.Connection (Connection) import Network.Connection (Connection)
import qualified Network.Connection as Conn import qualified Network.Connection as Conn
import qualified System.IO.Error as IO
import Prelude hiding (compare) import Prelude hiding (compare)
import qualified System.IO.Error as IO
import Ldap.Asn1.ToAsn1 (ToAsn1(toAsn1)) import Ldap.Asn1.ToAsn1 (ToAsn1(toAsn1))
import Ldap.Asn1.FromAsn1 (FromAsn1, parseAsn1) import Ldap.Asn1.FromAsn1 (FromAsn1, parseAsn1)
@ -109,8 +114,19 @@ data LdapError =
IOError IOError IOError IOError
| ParseError Asn1.ASN1Error | ParseError Asn1.ASN1Error
| ResponseError ResponseError | ResponseError ResponseError
| DisconnectError Disconnect
deriving (Show, Eq) deriving (Show, Eq)
newtype WrappedIOError = WrappedIOError IOError
deriving (Show, Eq, Typeable)
instance Exception WrappedIOError
data Disconnect = Disconnect Type.ResultCode Dn Text
deriving (Show, Eq, Typeable)
instance Exception Disconnect
-- | The entrypoint into LDAP. -- | The entrypoint into LDAP.
with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a) with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a)
with host port f = do with host port f = do
@ -125,7 +141,7 @@ with host port f = do
Async.withAsync (f l) $ \u -> Async.withAsync (f l) $ \u ->
fmap (Right . snd) (Async.waitAnyCancel [i, o, d, u]))) fmap (Right . snd) (Async.waitAnyCancel [i, o, d, u])))
`catches` `catches`
[ Handler (return . Left . IOError) [ Handler (\(WrappedIOError e) -> return (Left (IOError e)))
, Handler (return . Left . ParseError) , Handler (return . Left . ParseError)
, Handler (return . Left . ResponseError) , Handler (return . Left . ResponseError)
] ]
@ -154,7 +170,7 @@ with host port f = do
} }
input :: FromAsn1 a => TQueue a -> Connection -> IO b input :: FromAsn1 a => TQueue a -> Connection -> IO b
input inq conn = flip fix [] $ \loop chunks -> do input inq conn = wrap . flip fix [] $ \loop chunks -> do
chunk <- Conn.connectionGet conn 8192 chunk <- Conn.connectionGet conn 8192
case ByteString.length chunk of case ByteString.length chunk of
0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing) 0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing)
@ -174,7 +190,7 @@ input inq conn = flip fix [] $ \loop chunks -> do
loop [] loop []
output :: ToAsn1 a => TQueue a -> Connection -> IO b output :: ToAsn1 a => TQueue a -> Connection -> IO b
output out conn = forever $ do output out conn = wrap . forever $ do
msg <- atomically (readTQueue out) msg <- atomically (readTQueue out)
Conn.connectionPut conn (encode (toAsn1 msg)) Conn.connectionPut conn (encode (toAsn1 msg))
where where
@ -203,16 +219,37 @@ dispatch Ldap { client } inq outq =
Type.DeleteResponse {} -> done mid op req Type.DeleteResponse {} -> done mid op req
Type.ModifyDnResponse {} -> done mid op req Type.ModifyDnResponse {} -> done mid op req
Type.CompareResponse {} -> done mid op req Type.CompareResponse {} -> done mid op req
Type.ExtendedResponse {} -> done mid op req Type.ExtendedResponse {} -> probablyDisconnect mid op req
Type.IntermediateResponse {} -> saveUp mid op req Type.IntermediateResponse {} -> saveUp mid op req
return (res, counter) return (res, counter)
]) ])
where where
saveUp mid op res = saveUp mid op res =
return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res) return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res)
done mid op req = done mid op req =
case Map.lookup mid req of case Map.lookup mid req of
Nothing -> return req Nothing -> return req
Just (stack, var) -> do Just (stack, var) -> do
putTMVar var (op :| stack) putTMVar var (op :| stack)
return (Map.delete mid req) return (Map.delete mid req)
probablyDisconnect (Type.Id 0)
(Type.ExtendedResponse
(Type.LdapResult code
(Type.LdapDn (Type.LdapString dn))
(Type.LdapString reason)
_)
moid _)
req =
case moid of
Just (Type.LdapOid oid)
| oid == noticeOfDisconnection -> throwSTM (Disconnect code (Dn dn) reason)
_ -> return req
probablyDisconnect mid op req = done mid op req
noticeOfDisconnection :: ByteString
noticeOfDisconnection = fromString "1.3.6.1.4.1.1466.20036"
wrap :: IO a -> IO a
wrap m = m `catch` (throwIO . WrappedIOError)

18
test/Ldap/ClientSpec.hs Normal file
View File

@ -0,0 +1,18 @@
module Ldap.ClientSpec (spec) where
import Control.Exception (IOException, throwIO)
import Test.Hspec
import SpecHelper (locally)
spec :: Spec
spec =
context "exceptions" $
it "propagates unrelated IOExceptions through" $
locally (\_ -> throwIO unrelated)
`shouldThrow`
(== unrelated)
unrelated :: IOException
unrelated = userError "unrelated"