500 lines
16 KiB
Haskell
500 lines
16 KiB
Haskell
{-# LANGUAGE BangPatterns #-}
|
|
{-# LANGUAGE DeriveDataTypeable #-}
|
|
{-# LANGUAGE NamedFieldPuns #-}
|
|
module Ldap.Client
|
|
( Host(..)
|
|
, PortNumber
|
|
, Ldap
|
|
, LdapError(..)
|
|
, Type.ResultCode(..)
|
|
, Async
|
|
, with
|
|
-- * Bind Request
|
|
, Dn(..)
|
|
, Password(..)
|
|
, BindError(..)
|
|
, bind
|
|
, bindEither
|
|
, bindAsync
|
|
, bindAsyncSTM
|
|
-- * Search Request
|
|
, Type.Scope(..)
|
|
, Attr(..)
|
|
, SearchEntry(..)
|
|
, SearchError(..)
|
|
, search
|
|
, searchEither
|
|
, searchAsync
|
|
, searchAsyncSTM
|
|
, Search
|
|
, defaultSearch
|
|
, scope
|
|
, size
|
|
, time
|
|
, typesOnly
|
|
, derefAliases
|
|
, Filter(..)
|
|
-- * Unbind Request
|
|
, unbindAsync
|
|
, unbindAsyncSTM
|
|
-- * Add Request
|
|
, AttrList
|
|
, AddError(..)
|
|
, add
|
|
, addEither
|
|
, addAsync
|
|
, addAsyncSTM
|
|
-- * Waiting for Request Completion
|
|
, wait
|
|
, waitSTM
|
|
) where
|
|
|
|
import qualified Control.Concurrent.Async as Async
|
|
import Control.Concurrent.STM (STM, atomically)
|
|
import Control.Concurrent.STM.TMVar (TMVar, newEmptyTMVar, putTMVar, readTMVar)
|
|
import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue)
|
|
import Control.Exception (Exception, Handler(..), bracket, throwIO, catches)
|
|
import Control.Monad (forever, void)
|
|
import qualified Data.ASN1.BinaryEncoding as Asn1
|
|
import qualified Data.ASN1.Encoding as Asn1
|
|
import qualified Data.ASN1.Error as Asn1
|
|
import Data.ByteString (ByteString)
|
|
import qualified Data.ByteString as ByteString
|
|
import qualified Data.ByteString.Lazy as ByteString.Lazy
|
|
import Data.Foldable (traverse_, asum)
|
|
import Data.Function (fix)
|
|
import Data.Int (Int32)
|
|
import Data.List.NonEmpty (NonEmpty((:|)))
|
|
import qualified Data.List.NonEmpty as NonEmpty
|
|
import qualified Data.Map.Strict as Map
|
|
import Data.Maybe (mapMaybe)
|
|
import Data.Monoid (Endo(appEndo))
|
|
import Data.Set (Set)
|
|
import qualified Data.Set as Set
|
|
import Data.Text (Text)
|
|
import Data.Typeable (Typeable)
|
|
import Network.Connection (Connection)
|
|
import qualified Network.Connection as Conn
|
|
import Network (PortNumber)
|
|
import qualified System.IO.Error as IO
|
|
|
|
import Ldap.Asn1.ToAsn1 (ToAsn1(toAsn1))
|
|
import Ldap.Asn1.FromAsn1 (FromAsn1, parseAsn1)
|
|
import qualified Ldap.Asn1.Type as Type
|
|
|
|
|
|
data Host =
|
|
Plain String
|
|
| Secure String
|
|
deriving (Show, Eq, Ord)
|
|
|
|
data Ldap = Ldap
|
|
{ client :: TQueue ClientMessage
|
|
} deriving (Eq)
|
|
|
|
data ClientMessage = New Request (TMVar (NonEmpty Type.ProtocolServerOp))
|
|
type Request = Type.ProtocolClientOp
|
|
type InMessage = Type.ProtocolServerOp
|
|
type Response = NonEmpty InMessage
|
|
|
|
newLdap :: IO Ldap
|
|
newLdap = Ldap
|
|
<$> newTQueueIO
|
|
|
|
data LdapError =
|
|
IOError IOError
|
|
| ParseError Asn1.ASN1Error
|
|
| BindError BindError
|
|
| SearchError SearchError
|
|
deriving (Show, Eq)
|
|
|
|
-- | The entrypoint into LDAP.
|
|
with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a)
|
|
with host port f = do
|
|
context <- Conn.initConnectionContext
|
|
bracket (Conn.connectTo context params) Conn.connectionClose (\conn ->
|
|
bracket newLdap unbindAsync (\l -> do
|
|
inq <- newTQueueIO
|
|
outq <- newTQueueIO
|
|
Async.withAsync (input inq conn) $ \i ->
|
|
Async.withAsync (output outq conn) $ \o ->
|
|
Async.withAsync (dispatch l inq outq) $ \d ->
|
|
Async.withAsync (f l) $ \u ->
|
|
fmap (Right . snd) (Async.waitAnyCancel [i, o, d, u])))
|
|
`catches`
|
|
[ Handler (return . Left . IOError)
|
|
, Handler (return . Left . ParseError)
|
|
, Handler (return . Left . BindError)
|
|
, Handler (return . Left . SearchError)
|
|
]
|
|
where
|
|
params = Conn.ConnectionParams
|
|
{ Conn.connectionHostname =
|
|
case host of
|
|
Plain h -> h
|
|
Secure h -> h
|
|
, Conn.connectionPort = port
|
|
, Conn.connectionUseSecure =
|
|
case host of
|
|
Plain _ -> Nothing
|
|
Secure _ -> Just Conn.TLSSettingsSimple
|
|
{ Conn.settingDisableCertificateValidation = False
|
|
, Conn.settingDisableSession = False
|
|
, Conn.settingUseServerName = False
|
|
}
|
|
, Conn.connectionUseSocks = Nothing
|
|
}
|
|
|
|
input :: FromAsn1 a => TQueue a -> Connection -> IO b
|
|
input inq conn = flip fix [] $ \loop chunks -> do
|
|
chunk <- Conn.connectionGet conn 8192
|
|
case ByteString.length chunk of
|
|
0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing)
|
|
_ -> do
|
|
let chunks' = chunk : chunks
|
|
case Asn1.decodeASN1 Asn1.DER (ByteString.Lazy.fromChunks (reverse chunks')) of
|
|
Left Asn1.ParsingPartial
|
|
-> loop chunks'
|
|
Left e -> throwIO e
|
|
Right asn1 -> do
|
|
flip fix asn1 $ \loop' asn1' ->
|
|
case parseAsn1 asn1' of
|
|
Nothing -> return ()
|
|
Just (asn1'', a) -> do
|
|
atomically (writeTQueue inq a)
|
|
loop' asn1''
|
|
loop []
|
|
|
|
output :: ToAsn1 a => TQueue a -> Connection -> IO b
|
|
output out conn = forever $ do
|
|
msg <- atomically (readTQueue out)
|
|
Conn.connectionPut conn (encode (toAsn1 msg))
|
|
where
|
|
encode x = Asn1.encodeASN1' Asn1.DER (appEndo x [])
|
|
|
|
dispatch :: Ldap -> TQueue (Type.LdapMessage InMessage) -> TQueue (Type.LdapMessage Request) -> IO a
|
|
dispatch Ldap { client } inq outq =
|
|
flip fix (Map.empty, Map.empty, 1) $ \loop (!got, !results, !counter) -> do
|
|
loop =<< atomically (asum
|
|
[ do New new var <- readTQueue client
|
|
writeTQueue outq (Type.LdapMessage (Type.Id counter) new Nothing)
|
|
return (got, Map.insert (Type.Id counter) var results, counter + 1)
|
|
, do Type.LdapMessage mid op _ <- readTQueue inq
|
|
case op of
|
|
Type.BindResponse {} -> do
|
|
traverse_ (\var -> putTMVar var (op :| [])) (Map.lookup mid results)
|
|
return (Map.delete mid got, Map.delete mid results, counter)
|
|
Type.SearchResultEntry {} -> do
|
|
return (Map.insertWith (++) mid [op] got, results, counter)
|
|
Type.SearchResultReference {} -> do
|
|
return (got, results, counter)
|
|
Type.SearchResultDone {} -> do
|
|
let stack = Map.findWithDefault [] mid got
|
|
traverse_ (\var -> putTMVar var (op :| stack)) (Map.lookup mid results)
|
|
return (Map.delete mid got, Map.delete mid results, counter)
|
|
Type.AddResponse {} -> do
|
|
traverse_ (\var -> putTMVar var (op :| [])) (Map.lookup mid results)
|
|
return (Map.delete mid got, Map.delete mid results, counter)
|
|
])
|
|
|
|
|
|
data Async e a = Async (STM (Either e a))
|
|
|
|
instance Functor (Async e) where
|
|
fmap f (Async stm) = Async (fmap (fmap f) stm)
|
|
|
|
|
|
newtype Dn = Dn Text
|
|
deriving (Show, Eq)
|
|
newtype Password = Password ByteString
|
|
deriving (Show, Eq)
|
|
|
|
data BindError =
|
|
BindInvalidResponse Response
|
|
| BindErrorCode Type.ResultCode
|
|
deriving (Show, Eq, Typeable)
|
|
|
|
instance Exception BindError
|
|
|
|
-- | Throws 'BindError' on failure. Don't worry, the nearest 'with'
|
|
-- will catch it, so it won't destroy your program.
|
|
bind :: Ldap -> Dn -> Password -> IO ()
|
|
bind l username password =
|
|
raise =<< bindEither l username password
|
|
|
|
bindEither :: Ldap -> Dn -> Password -> IO (Either BindError ())
|
|
bindEither l username password =
|
|
wait =<< bindAsync l username password
|
|
|
|
bindAsync :: Ldap -> Dn -> Password -> IO (Async BindError ())
|
|
bindAsync l username password =
|
|
atomically (bindAsyncSTM l username password)
|
|
|
|
bindAsyncSTM :: Ldap -> Dn -> Password -> STM (Async BindError ())
|
|
bindAsyncSTM l username password =
|
|
sendRequest l bindResult (bindRequest username password)
|
|
|
|
bindRequest :: Dn -> Password -> Request
|
|
bindRequest (Dn username) (Password password) =
|
|
Type.BindRequest ldapVersion
|
|
(Type.LdapDn (Type.LdapString username))
|
|
(Type.Simple password)
|
|
where
|
|
ldapVersion = 3
|
|
|
|
bindResult :: Response -> Either BindError ()
|
|
bindResult (Type.BindResponse (Type.LdapResult code _ _ _) _ :| [])
|
|
| Type.Success <- code = Right ()
|
|
| otherwise = Left (BindErrorCode code)
|
|
bindResult res = Left (BindInvalidResponse res)
|
|
|
|
|
|
data SearchError =
|
|
SearchInvalidResponse Response
|
|
| SearchErrorCode Type.ResultCode
|
|
deriving (Show, Eq, Typeable)
|
|
|
|
instance Exception SearchError
|
|
|
|
search
|
|
:: Ldap
|
|
-> Dn
|
|
-> Mod Search
|
|
-> Filter
|
|
-> [Attr]
|
|
-> IO [SearchEntry]
|
|
search l base opts flt attributes =
|
|
raise =<< searchEither l base opts flt attributes
|
|
|
|
searchEither
|
|
:: Ldap
|
|
-> Dn
|
|
-> Mod Search
|
|
-> Filter
|
|
-> [Attr]
|
|
-> IO (Either SearchError [SearchEntry])
|
|
searchEither l base opts flt attributes =
|
|
wait =<< searchAsync l base opts flt attributes
|
|
|
|
searchAsync
|
|
:: Ldap
|
|
-> Dn
|
|
-> Mod Search
|
|
-> Filter
|
|
-> [Attr]
|
|
-> IO (Async SearchError [SearchEntry])
|
|
searchAsync l base opts flt attributes =
|
|
atomically (searchAsyncSTM l base opts flt attributes)
|
|
|
|
searchAsyncSTM
|
|
:: Ldap
|
|
-> Dn
|
|
-> Mod Search
|
|
-> Filter
|
|
-> [Attr]
|
|
-> STM (Async SearchError [SearchEntry])
|
|
searchAsyncSTM l base opts flt attributes =
|
|
sendRequest l searchResult (searchRequest base opts flt attributes)
|
|
|
|
searchResult :: Response -> Either SearchError [SearchEntry]
|
|
searchResult (Type.SearchResultDone (Type.LdapResult code _ _ _) :| xs)
|
|
| Type.Success <- code = Right (mapMaybe g xs)
|
|
| Type.AdminLimitExceeded <- code = Right (mapMaybe g xs)
|
|
| Type.SizeLimitExceeded <- code = Right (mapMaybe g xs)
|
|
| otherwise = Left (SearchErrorCode code)
|
|
where
|
|
g (Type.SearchResultEntry (Type.LdapDn (Type.LdapString dn))
|
|
(Type.PartialAttributeList ys)) =
|
|
Just (SearchEntry (Dn dn) (map h ys))
|
|
g _ = Nothing
|
|
h (Type.PartialAttribute (Type.AttributeDescription (Type.LdapString x))
|
|
y) = (Attr x, Set.map j y)
|
|
j (Type.AttributeValue x) = x
|
|
searchResult res = Left (SearchInvalidResponse res)
|
|
|
|
searchRequest :: Dn -> Mod Search -> Filter -> [Attr] -> Request
|
|
searchRequest (Dn base) (Mod m) flt attributes =
|
|
Type.SearchRequest (Type.LdapDn (Type.LdapString base))
|
|
_scope
|
|
_derefAliases
|
|
_size
|
|
_time
|
|
_typesOnly
|
|
(fromFilter flt)
|
|
(Type.AttributeSelection (map (Type.LdapString . unAttr) attributes))
|
|
where
|
|
Search { _scope, _derefAliases, _size, _time, _typesOnly } =
|
|
m defaultSearch
|
|
fromFilter (Not x) = Type.Not (fromFilter x)
|
|
fromFilter (And xs) = Type.And (fmap fromFilter xs)
|
|
fromFilter (Or xs) = Type.Or (fmap fromFilter xs)
|
|
fromFilter (Present (Attr x)) =
|
|
Type.Present (Type.AttributeDescription (Type.LdapString x))
|
|
fromFilter (Attr x := y) =
|
|
Type.EqualityMatch
|
|
(Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x))
|
|
(Type.AssertionValue y))
|
|
fromFilter (Attr x :>= y) =
|
|
Type.GreaterOrEqual
|
|
(Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x))
|
|
(Type.AssertionValue y))
|
|
fromFilter (Attr x :<= y) =
|
|
Type.LessOrEqual
|
|
(Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x))
|
|
(Type.AssertionValue y))
|
|
fromFilter (Attr x :~= y) =
|
|
Type.ApproxMatch
|
|
(Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x))
|
|
(Type.AssertionValue y))
|
|
fromFilter (Attr x :=* (mi, xs, mf)) =
|
|
Type.Substrings
|
|
(Type.SubstringFilter (Type.AttributeDescription (Type.LdapString x))
|
|
(NonEmpty.fromList (concat
|
|
[ maybe [] (\i -> [Type.Initial (Type.AssertionValue i)]) mi
|
|
, fmap (Type.Any . Type.AssertionValue) xs
|
|
, maybe [] (\f -> [Type.Final (Type.AssertionValue f)]) mf
|
|
])))
|
|
fromFilter ((mx, mr, b) ::= y) =
|
|
Type.ExtensibleMatch
|
|
(Type.MatchingRuleAssertion (fmap (\(Attr r) -> Type.MatchingRuleId (Type.LdapString r)) mr)
|
|
(fmap (\(Attr x) -> Type.AttributeDescription (Type.LdapString x)) mx)
|
|
(Type.AssertionValue y)
|
|
b)
|
|
|
|
data Search = Search
|
|
{ _scope :: Type.Scope
|
|
, _derefAliases :: Type.DerefAliases
|
|
, _size :: Int32
|
|
, _time :: Int32
|
|
, _typesOnly :: Bool
|
|
} deriving (Show, Eq)
|
|
|
|
defaultSearch :: Search
|
|
defaultSearch = Search
|
|
{ _scope = Type.BaseObject
|
|
, _size = 0
|
|
, _time = 0
|
|
, _typesOnly = False
|
|
, _derefAliases = Type.NeverDerefAliases
|
|
}
|
|
|
|
scope :: Type.Scope -> Mod Search
|
|
scope x = Mod (\y -> y { _scope = x })
|
|
|
|
size :: Int32 -> Mod Search
|
|
size x = Mod (\y -> y { _size = x })
|
|
|
|
time :: Int32 -> Mod Search
|
|
time x = Mod (\y -> y { _time = x })
|
|
|
|
typesOnly :: Bool -> Mod Search
|
|
typesOnly x = Mod (\y -> y { _typesOnly = x })
|
|
|
|
derefAliases :: Type.DerefAliases -> Mod Search
|
|
derefAliases x = Mod (\y -> y { _derefAliases = x })
|
|
|
|
newtype Mod a = Mod (a -> a)
|
|
|
|
instance Monoid (Mod a) where
|
|
mempty = Mod id
|
|
Mod f `mappend` Mod g = Mod (g . f)
|
|
|
|
data Filter =
|
|
Not Filter
|
|
| And (NonEmpty Filter)
|
|
| Or (NonEmpty Filter)
|
|
| Present Attr
|
|
| Attr := ByteString
|
|
| Attr :>= ByteString
|
|
| Attr :<= ByteString
|
|
| Attr :~= ByteString
|
|
| Attr :=* (Maybe ByteString, [ByteString], Maybe ByteString)
|
|
| (Maybe Attr, Maybe Attr, Bool) ::= ByteString
|
|
|
|
newtype Attr = Attr Text
|
|
deriving (Show, Eq)
|
|
|
|
-- 'Attr' unwrapper. This is a separate function not to turn 'Attr''s
|
|
-- 'Show' instance into complete and utter shit.
|
|
unAttr :: Attr -> Text
|
|
unAttr (Attr a) = a
|
|
|
|
data SearchEntry = SearchEntry Dn [(Attr, Set ByteString)]
|
|
deriving (Show, Eq)
|
|
|
|
|
|
-- | Note that 'unbindAsync' does not return an 'Async',
|
|
-- because LDAP server never responds to @UnbindRequest@s, hence
|
|
-- a call to 'wait' on a hypothetical 'Async' would have resulted
|
|
-- in an exception anyway.
|
|
unbindAsync :: Ldap -> IO ()
|
|
unbindAsync =
|
|
atomically . unbindAsyncSTM
|
|
|
|
-- | Note that 'unbindAsyncSTM' does not return an 'Async',
|
|
-- because LDAP server never responds to @UnbindRequest@s, hence
|
|
-- a call to 'wait' on a hypothetical 'Async' would have resulted
|
|
-- in an exception anyway.
|
|
unbindAsyncSTM :: Ldap -> STM ()
|
|
unbindAsyncSTM l =
|
|
void (sendRequest l die Type.UnbindRequest)
|
|
where
|
|
die = error "Ldap.Client: do not wait for the response to UnbindRequest"
|
|
|
|
|
|
type AttrList f = [(Attr, f ByteString)]
|
|
|
|
data AddError =
|
|
AddInvalidResponse Response
|
|
| AddErrorCode Type.ResultCode
|
|
deriving (Show, Eq, Typeable)
|
|
|
|
instance Exception AddError
|
|
|
|
add :: Ldap -> Dn -> AttrList NonEmpty -> IO ()
|
|
add l dn as =
|
|
raise =<< addEither l dn as
|
|
|
|
addEither :: Ldap -> Dn -> AttrList NonEmpty -> IO (Either AddError ())
|
|
addEither l dn as =
|
|
wait =<< addAsync l dn as
|
|
|
|
addAsync :: Ldap -> Dn -> AttrList NonEmpty -> IO (Async AddError ())
|
|
addAsync l dn as =
|
|
atomically (addAsyncSTM l dn as)
|
|
|
|
addAsyncSTM :: Ldap -> Dn -> AttrList NonEmpty -> STM (Async AddError ())
|
|
addAsyncSTM l (Dn dn) as =
|
|
sendRequest l addResult
|
|
(Type.AddRequest (Type.LdapDn (Type.LdapString dn))
|
|
(Type.AttributeList (map f as)))
|
|
where
|
|
f (Attr x, xs) = Type.Attribute (Type.AttributeDescription (Type.LdapString x))
|
|
(fmap Type.AttributeValue xs)
|
|
|
|
addResult :: NonEmpty Type.ProtocolServerOp -> Either AddError ()
|
|
addResult (Type.AddResponse (Type.LdapResult code _ _ _) :| [])
|
|
| Type.Success <- code = Right ()
|
|
| otherwise = Left (AddErrorCode code)
|
|
addResult res = Left (AddInvalidResponse res)
|
|
|
|
|
|
wait :: Async e a -> IO (Either e a)
|
|
wait = atomically . waitSTM
|
|
|
|
waitSTM :: Async e a -> STM (Either e a)
|
|
waitSTM (Async stm) = stm
|
|
|
|
|
|
sendRequest :: Ldap -> (Response -> Either e a) -> Request -> STM (Async e a)
|
|
sendRequest l p msg =
|
|
do var <- newEmptyTMVar
|
|
writeRequest l var msg
|
|
return (Async (fmap p (readTMVar var)))
|
|
|
|
writeRequest :: Ldap -> TMVar Response -> Request -> STM ()
|
|
writeRequest Ldap { client } var msg = writeTQueue client (New msg var)
|
|
|
|
raise :: Exception e => Either e a -> IO a
|
|
raise = either throwIO return
|