React on Unsolicited Notifications

Also improves the behavior when an `IOException` is encountered:
only exceptions directly related to LDAP are trapped by `with`.
This commit is contained in:
Matvey Aksenov 2015-04-04 16:42:28 +00:00
parent 9ab5760b8e
commit aea85536cf
5 changed files with 67 additions and 10 deletions

View File

@ -11,7 +11,8 @@ This library implements (the parts of) [RFC 4511][rfc4511]
:--------------------------- |:-----------:|:-----------:
Bind Operation | 4.2 | ✔
Unbind Operation | 4.3 | ✔
Notice of Disconnection | 4.4.1 | ✘
Unsolicited Notification | 4.4 | ✔
Notice of Disconnection | 4.4.1 | ✔
Search Operation | 4.5 | ✔\*
Modify Operation | 4.6 | ✔
Add Operation | 4.7 | ✔

View File

@ -60,6 +60,7 @@ test-suite spec
main-is:
Spec.hs
other-modules:
Ldap.ClientSpec
Ldap.Client.AddSpec
Ldap.Client.BindSpec
Ldap.Client.CompareSpec

View File

@ -336,10 +336,10 @@ instance FromAsn1 ProtocolServerOp where
Asn1.Start (Asn1.Container Asn1.Application 24) <- next
res <- fromAsn1
name <- optional $ do
Asn1.Other Asn1.Context 0 s <- next
Asn1.Other Asn1.Context 10 s <- next
return s
value <- optional $ do
Asn1.Other Asn1.Context 1 s <- next
Asn1.Other Asn1.Context 11 s <- next
return s
Asn1.End (Asn1.Container Asn1.Application 24) <- next
return (ExtendedResponse res (fmap LdapOid name) value)

View File

@ -1,5 +1,6 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
module Ldap.Client
( Host(..)
@ -55,14 +56,15 @@ module Ldap.Client
import Control.Applicative ((<$>))
#endif
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM (atomically, throwSTM)
import Control.Concurrent.STM.TMVar (putTMVar)
import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue)
import Control.Exception (Handler(..), bracket, throwIO, catches)
import Control.Exception (Exception, Handler(..), bracket, throwIO, catch, catches)
import Control.Monad (forever)
import qualified Data.ASN1.BinaryEncoding as Asn1
import qualified Data.ASN1.Encoding as Asn1
import qualified Data.ASN1.Error as Asn1
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as ByteString.Lazy
import Data.Foldable (asum)
@ -70,10 +72,13 @@ import Data.Function (fix)
import Data.List.NonEmpty (NonEmpty((:|)))
import qualified Data.Map.Strict as Map
import Data.Monoid (Endo(appEndo))
import Data.String (fromString)
import Data.Text (Text)
import Data.Typeable (Typeable)
import Network.Connection (Connection)
import qualified Network.Connection as Conn
import qualified System.IO.Error as IO
import Prelude hiding (compare)
import qualified System.IO.Error as IO
import Ldap.Asn1.ToAsn1 (ToAsn1(toAsn1))
import Ldap.Asn1.FromAsn1 (FromAsn1, parseAsn1)
@ -109,8 +114,19 @@ data LdapError =
IOError IOError
| ParseError Asn1.ASN1Error
| ResponseError ResponseError
| DisconnectError Disconnect
deriving (Show, Eq)
newtype WrappedIOError = WrappedIOError IOError
deriving (Show, Eq, Typeable)
instance Exception WrappedIOError
data Disconnect = Disconnect Type.ResultCode Dn Text
deriving (Show, Eq, Typeable)
instance Exception Disconnect
-- | The entrypoint into LDAP.
with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a)
with host port f = do
@ -125,7 +141,7 @@ with host port f = do
Async.withAsync (f l) $ \u ->
fmap (Right . snd) (Async.waitAnyCancel [i, o, d, u])))
`catches`
[ Handler (return . Left . IOError)
[ Handler (\(WrappedIOError e) -> return (Left (IOError e)))
, Handler (return . Left . ParseError)
, Handler (return . Left . ResponseError)
]
@ -154,7 +170,7 @@ with host port f = do
}
input :: FromAsn1 a => TQueue a -> Connection -> IO b
input inq conn = flip fix [] $ \loop chunks -> do
input inq conn = wrap . flip fix [] $ \loop chunks -> do
chunk <- Conn.connectionGet conn 8192
case ByteString.length chunk of
0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing)
@ -174,7 +190,7 @@ input inq conn = flip fix [] $ \loop chunks -> do
loop []
output :: ToAsn1 a => TQueue a -> Connection -> IO b
output out conn = forever $ do
output out conn = wrap . forever $ do
msg <- atomically (readTQueue out)
Conn.connectionPut conn (encode (toAsn1 msg))
where
@ -203,16 +219,37 @@ dispatch Ldap { client } inq outq =
Type.DeleteResponse {} -> done mid op req
Type.ModifyDnResponse {} -> done mid op req
Type.CompareResponse {} -> done mid op req
Type.ExtendedResponse {} -> done mid op req
Type.ExtendedResponse {} -> probablyDisconnect mid op req
Type.IntermediateResponse {} -> saveUp mid op req
return (res, counter)
])
where
saveUp mid op res =
return (Map.adjust (\(stack, var) -> (op : stack, var)) mid res)
done mid op req =
case Map.lookup mid req of
Nothing -> return req
Just (stack, var) -> do
putTMVar var (op :| stack)
return (Map.delete mid req)
probablyDisconnect (Type.Id 0)
(Type.ExtendedResponse
(Type.LdapResult code
(Type.LdapDn (Type.LdapString dn))
(Type.LdapString reason)
_)
moid _)
req =
case moid of
Just (Type.LdapOid oid)
| oid == noticeOfDisconnection -> throwSTM (Disconnect code (Dn dn) reason)
_ -> return req
probablyDisconnect mid op req = done mid op req
noticeOfDisconnection :: ByteString
noticeOfDisconnection = fromString "1.3.6.1.4.1.1466.20036"
wrap :: IO a -> IO a
wrap m = m `catch` (throwIO . WrappedIOError)

18
test/Ldap/ClientSpec.hs Normal file
View File

@ -0,0 +1,18 @@
module Ldap.ClientSpec (spec) where
import Control.Exception (IOException, throwIO)
import Test.Hspec
import SpecHelper (locally)
spec :: Spec
spec =
context "exceptions" $
it "propagates unrelated IOExceptions through" $
locally (\_ -> throwIO unrelated)
`shouldThrow`
(== unrelated)
unrelated :: IOException
unrelated = userError "unrelated"