ldap-client/src/Ldap/Client.hs
Matvey Aksenov 4b8863c1e4 Support Add
2015-04-01 20:10:10 +00:00

500 lines
16 KiB
Haskell

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
module Ldap.Client
( Host(..)
, PortNumber
, Ldap
, LdapError(..)
, Type.ResultCode(..)
, Async
, with
-- * Bind Request
, Dn(..)
, Password(..)
, BindError(..)
, bind
, bindEither
, bindAsync
, bindAsyncSTM
-- * Search Request
, Type.Scope(..)
, Attr(..)
, SearchEntry(..)
, SearchError(..)
, search
, searchEither
, searchAsync
, searchAsyncSTM
, Search
, defaultSearch
, scope
, size
, time
, typesOnly
, derefAliases
, Filter(..)
-- * Unbind Request
, unbindAsync
, unbindAsyncSTM
-- * Add Request
, AttrList
, AddError(..)
, add
, addEither
, addAsync
, addAsyncSTM
-- * Waiting for Request Completion
, wait
, waitSTM
) where
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.STM (STM, atomically)
import Control.Concurrent.STM.TMVar (TMVar, newEmptyTMVar, putTMVar, readTMVar)
import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue)
import Control.Exception (Exception, Handler(..), bracket, throwIO, catches)
import Control.Monad (forever, void)
import qualified Data.ASN1.BinaryEncoding as Asn1
import qualified Data.ASN1.Encoding as Asn1
import qualified Data.ASN1.Error as Asn1
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as ByteString.Lazy
import Data.Foldable (traverse_, asum)
import Data.Function (fix)
import Data.Int (Int32)
import Data.List.NonEmpty (NonEmpty((:|)))
import qualified Data.List.NonEmpty as NonEmpty
import qualified Data.Map.Strict as Map
import Data.Maybe (mapMaybe)
import Data.Monoid (Endo(appEndo))
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Text (Text)
import Data.Typeable (Typeable)
import Network.Connection (Connection)
import qualified Network.Connection as Conn
import Network (PortNumber)
import qualified System.IO.Error as IO
import Ldap.Asn1.ToAsn1 (ToAsn1(toAsn1))
import Ldap.Asn1.FromAsn1 (FromAsn1, parseAsn1)
import qualified Ldap.Asn1.Type as Type
data Host =
Plain String
| Secure String
deriving (Show, Eq, Ord)
data Ldap = Ldap
{ client :: TQueue ClientMessage
} deriving (Eq)
data ClientMessage = New Request (TMVar (NonEmpty Type.ProtocolServerOp))
type Request = Type.ProtocolClientOp
type InMessage = Type.ProtocolServerOp
type Response = NonEmpty InMessage
newLdap :: IO Ldap
newLdap = Ldap
<$> newTQueueIO
data LdapError =
IOError IOError
| ParseError Asn1.ASN1Error
| BindError BindError
| SearchError SearchError
deriving (Show, Eq)
-- | The entrypoint into LDAP.
with :: Host -> PortNumber -> (Ldap -> IO a) -> IO (Either LdapError a)
with host port f = do
context <- Conn.initConnectionContext
bracket (Conn.connectTo context params) Conn.connectionClose (\conn ->
bracket newLdap unbindAsync (\l -> do
inq <- newTQueueIO
outq <- newTQueueIO
Async.withAsync (input inq conn) $ \i ->
Async.withAsync (output outq conn) $ \o ->
Async.withAsync (dispatch l inq outq) $ \d ->
Async.withAsync (f l) $ \u ->
fmap (Right . snd) (Async.waitAnyCancel [i, o, d, u])))
`catches`
[ Handler (return . Left . IOError)
, Handler (return . Left . ParseError)
, Handler (return . Left . BindError)
, Handler (return . Left . SearchError)
]
where
params = Conn.ConnectionParams
{ Conn.connectionHostname =
case host of
Plain h -> h
Secure h -> h
, Conn.connectionPort = port
, Conn.connectionUseSecure =
case host of
Plain _ -> Nothing
Secure _ -> Just Conn.TLSSettingsSimple
{ Conn.settingDisableCertificateValidation = False
, Conn.settingDisableSession = False
, Conn.settingUseServerName = False
}
, Conn.connectionUseSocks = Nothing
}
input :: FromAsn1 a => TQueue a -> Connection -> IO b
input inq conn = flip fix [] $ \loop chunks -> do
chunk <- Conn.connectionGet conn 8192
case ByteString.length chunk of
0 -> throwIO (IO.mkIOError IO.eofErrorType "Ldap.Client.input" Nothing Nothing)
_ -> do
let chunks' = chunk : chunks
case Asn1.decodeASN1 Asn1.DER (ByteString.Lazy.fromChunks (reverse chunks')) of
Left Asn1.ParsingPartial
-> loop chunks'
Left e -> throwIO e
Right asn1 -> do
flip fix asn1 $ \loop' asn1' ->
case parseAsn1 asn1' of
Nothing -> return ()
Just (asn1'', a) -> do
atomically (writeTQueue inq a)
loop' asn1''
loop []
output :: ToAsn1 a => TQueue a -> Connection -> IO b
output out conn = forever $ do
msg <- atomically (readTQueue out)
Conn.connectionPut conn (encode (toAsn1 msg))
where
encode x = Asn1.encodeASN1' Asn1.DER (appEndo x [])
dispatch :: Ldap -> TQueue (Type.LdapMessage InMessage) -> TQueue (Type.LdapMessage Request) -> IO a
dispatch Ldap { client } inq outq =
flip fix (Map.empty, Map.empty, 1) $ \loop (!got, !results, !counter) -> do
loop =<< atomically (asum
[ do New new var <- readTQueue client
writeTQueue outq (Type.LdapMessage (Type.Id counter) new Nothing)
return (got, Map.insert (Type.Id counter) var results, counter + 1)
, do Type.LdapMessage mid op _ <- readTQueue inq
case op of
Type.BindResponse {} -> do
traverse_ (\var -> putTMVar var (op :| [])) (Map.lookup mid results)
return (Map.delete mid got, Map.delete mid results, counter)
Type.SearchResultEntry {} -> do
return (Map.insertWith (++) mid [op] got, results, counter)
Type.SearchResultReference {} -> do
return (got, results, counter)
Type.SearchResultDone {} -> do
let stack = Map.findWithDefault [] mid got
traverse_ (\var -> putTMVar var (op :| stack)) (Map.lookup mid results)
return (Map.delete mid got, Map.delete mid results, counter)
Type.AddResponse {} -> do
traverse_ (\var -> putTMVar var (op :| [])) (Map.lookup mid results)
return (Map.delete mid got, Map.delete mid results, counter)
])
data Async e a = Async (STM (Either e a))
instance Functor (Async e) where
fmap f (Async stm) = Async (fmap (fmap f) stm)
newtype Dn = Dn Text
deriving (Show, Eq)
newtype Password = Password ByteString
deriving (Show, Eq)
data BindError =
BindInvalidResponse Response
| BindErrorCode Type.ResultCode
deriving (Show, Eq, Typeable)
instance Exception BindError
-- | Throws 'BindError' on failure. Don't worry, the nearest 'with'
-- will catch it, so it won't destroy your program.
bind :: Ldap -> Dn -> Password -> IO ()
bind l username password =
raise =<< bindEither l username password
bindEither :: Ldap -> Dn -> Password -> IO (Either BindError ())
bindEither l username password =
wait =<< bindAsync l username password
bindAsync :: Ldap -> Dn -> Password -> IO (Async BindError ())
bindAsync l username password =
atomically (bindAsyncSTM l username password)
bindAsyncSTM :: Ldap -> Dn -> Password -> STM (Async BindError ())
bindAsyncSTM l username password =
sendRequest l bindResult (bindRequest username password)
bindRequest :: Dn -> Password -> Request
bindRequest (Dn username) (Password password) =
Type.BindRequest ldapVersion
(Type.LdapDn (Type.LdapString username))
(Type.Simple password)
where
ldapVersion = 3
bindResult :: Response -> Either BindError ()
bindResult (Type.BindResponse (Type.LdapResult code _ _ _) _ :| [])
| Type.Success <- code = Right ()
| otherwise = Left (BindErrorCode code)
bindResult res = Left (BindInvalidResponse res)
data SearchError =
SearchInvalidResponse Response
| SearchErrorCode Type.ResultCode
deriving (Show, Eq, Typeable)
instance Exception SearchError
search
:: Ldap
-> Dn
-> Mod Search
-> Filter
-> [Attr]
-> IO [SearchEntry]
search l base opts flt attributes =
raise =<< searchEither l base opts flt attributes
searchEither
:: Ldap
-> Dn
-> Mod Search
-> Filter
-> [Attr]
-> IO (Either SearchError [SearchEntry])
searchEither l base opts flt attributes =
wait =<< searchAsync l base opts flt attributes
searchAsync
:: Ldap
-> Dn
-> Mod Search
-> Filter
-> [Attr]
-> IO (Async SearchError [SearchEntry])
searchAsync l base opts flt attributes =
atomically (searchAsyncSTM l base opts flt attributes)
searchAsyncSTM
:: Ldap
-> Dn
-> Mod Search
-> Filter
-> [Attr]
-> STM (Async SearchError [SearchEntry])
searchAsyncSTM l base opts flt attributes =
sendRequest l searchResult (searchRequest base opts flt attributes)
searchResult :: Response -> Either SearchError [SearchEntry]
searchResult (Type.SearchResultDone (Type.LdapResult code _ _ _) :| xs)
| Type.Success <- code = Right (mapMaybe g xs)
| Type.AdminLimitExceeded <- code = Right (mapMaybe g xs)
| Type.SizeLimitExceeded <- code = Right (mapMaybe g xs)
| otherwise = Left (SearchErrorCode code)
where
g (Type.SearchResultEntry (Type.LdapDn (Type.LdapString dn))
(Type.PartialAttributeList ys)) =
Just (SearchEntry (Dn dn) (map h ys))
g _ = Nothing
h (Type.PartialAttribute (Type.AttributeDescription (Type.LdapString x))
y) = (Attr x, Set.map j y)
j (Type.AttributeValue x) = x
searchResult res = Left (SearchInvalidResponse res)
searchRequest :: Dn -> Mod Search -> Filter -> [Attr] -> Request
searchRequest (Dn base) (Mod m) flt attributes =
Type.SearchRequest (Type.LdapDn (Type.LdapString base))
_scope
_derefAliases
_size
_time
_typesOnly
(fromFilter flt)
(Type.AttributeSelection (map (Type.LdapString . unAttr) attributes))
where
Search { _scope, _derefAliases, _size, _time, _typesOnly } =
m defaultSearch
fromFilter (Not x) = Type.Not (fromFilter x)
fromFilter (And xs) = Type.And (fmap fromFilter xs)
fromFilter (Or xs) = Type.Or (fmap fromFilter xs)
fromFilter (Present (Attr x)) =
Type.Present (Type.AttributeDescription (Type.LdapString x))
fromFilter (Attr x := y) =
Type.EqualityMatch
(Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x))
(Type.AssertionValue y))
fromFilter (Attr x :>= y) =
Type.GreaterOrEqual
(Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x))
(Type.AssertionValue y))
fromFilter (Attr x :<= y) =
Type.LessOrEqual
(Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x))
(Type.AssertionValue y))
fromFilter (Attr x :~= y) =
Type.ApproxMatch
(Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x))
(Type.AssertionValue y))
fromFilter (Attr x :=* (mi, xs, mf)) =
Type.Substrings
(Type.SubstringFilter (Type.AttributeDescription (Type.LdapString x))
(NonEmpty.fromList (concat
[ maybe [] (\i -> [Type.Initial (Type.AssertionValue i)]) mi
, fmap (Type.Any . Type.AssertionValue) xs
, maybe [] (\f -> [Type.Final (Type.AssertionValue f)]) mf
])))
fromFilter ((mx, mr, b) ::= y) =
Type.ExtensibleMatch
(Type.MatchingRuleAssertion (fmap (\(Attr r) -> Type.MatchingRuleId (Type.LdapString r)) mr)
(fmap (\(Attr x) -> Type.AttributeDescription (Type.LdapString x)) mx)
(Type.AssertionValue y)
b)
data Search = Search
{ _scope :: Type.Scope
, _derefAliases :: Type.DerefAliases
, _size :: Int32
, _time :: Int32
, _typesOnly :: Bool
} deriving (Show, Eq)
defaultSearch :: Search
defaultSearch = Search
{ _scope = Type.BaseObject
, _size = 0
, _time = 0
, _typesOnly = False
, _derefAliases = Type.NeverDerefAliases
}
scope :: Type.Scope -> Mod Search
scope x = Mod (\y -> y { _scope = x })
size :: Int32 -> Mod Search
size x = Mod (\y -> y { _size = x })
time :: Int32 -> Mod Search
time x = Mod (\y -> y { _time = x })
typesOnly :: Bool -> Mod Search
typesOnly x = Mod (\y -> y { _typesOnly = x })
derefAliases :: Type.DerefAliases -> Mod Search
derefAliases x = Mod (\y -> y { _derefAliases = x })
newtype Mod a = Mod (a -> a)
instance Monoid (Mod a) where
mempty = Mod id
Mod f `mappend` Mod g = Mod (g . f)
data Filter =
Not Filter
| And (NonEmpty Filter)
| Or (NonEmpty Filter)
| Present Attr
| Attr := ByteString
| Attr :>= ByteString
| Attr :<= ByteString
| Attr :~= ByteString
| Attr :=* (Maybe ByteString, [ByteString], Maybe ByteString)
| (Maybe Attr, Maybe Attr, Bool) ::= ByteString
newtype Attr = Attr Text
deriving (Show, Eq)
-- 'Attr' unwrapper. This is a separate function not to turn 'Attr''s
-- 'Show' instance into complete and utter shit.
unAttr :: Attr -> Text
unAttr (Attr a) = a
data SearchEntry = SearchEntry Dn [(Attr, Set ByteString)]
deriving (Show, Eq)
-- | Note that 'unbindAsync' does not return an 'Async',
-- because LDAP server never responds to @UnbindRequest@s, hence
-- a call to 'wait' on a hypothetical 'Async' would have resulted
-- in an exception anyway.
unbindAsync :: Ldap -> IO ()
unbindAsync =
atomically . unbindAsyncSTM
-- | Note that 'unbindAsyncSTM' does not return an 'Async',
-- because LDAP server never responds to @UnbindRequest@s, hence
-- a call to 'wait' on a hypothetical 'Async' would have resulted
-- in an exception anyway.
unbindAsyncSTM :: Ldap -> STM ()
unbindAsyncSTM l =
void (sendRequest l die Type.UnbindRequest)
where
die = error "Ldap.Client: do not wait for the response to UnbindRequest"
type AttrList f = [(Attr, f ByteString)]
data AddError =
AddInvalidResponse Response
| AddErrorCode Type.ResultCode
deriving (Show, Eq, Typeable)
instance Exception AddError
add :: Ldap -> Dn -> AttrList NonEmpty -> IO ()
add l dn as =
raise =<< addEither l dn as
addEither :: Ldap -> Dn -> AttrList NonEmpty -> IO (Either AddError ())
addEither l dn as =
wait =<< addAsync l dn as
addAsync :: Ldap -> Dn -> AttrList NonEmpty -> IO (Async AddError ())
addAsync l dn as =
atomically (addAsyncSTM l dn as)
addAsyncSTM :: Ldap -> Dn -> AttrList NonEmpty -> STM (Async AddError ())
addAsyncSTM l (Dn dn) as =
sendRequest l addResult
(Type.AddRequest (Type.LdapDn (Type.LdapString dn))
(Type.AttributeList (map f as)))
where
f (Attr x, xs) = Type.Attribute (Type.AttributeDescription (Type.LdapString x))
(fmap Type.AttributeValue xs)
addResult :: NonEmpty Type.ProtocolServerOp -> Either AddError ()
addResult (Type.AddResponse (Type.LdapResult code _ _ _) :| [])
| Type.Success <- code = Right ()
| otherwise = Left (AddErrorCode code)
addResult res = Left (AddInvalidResponse res)
wait :: Async e a -> IO (Either e a)
wait = atomically . waitSTM
waitSTM :: Async e a -> STM (Either e a)
waitSTM (Async stm) = stm
sendRequest :: Ldap -> (Response -> Either e a) -> Request -> STM (Async e a)
sendRequest l p msg =
do var <- newEmptyTMVar
writeRequest l var msg
return (Async (fmap p (readTMVar var)))
writeRequest :: Ldap -> TMVar Response -> Request -> STM ()
writeRequest Ldap { client } var msg = writeTQueue client (New msg var)
raise :: Exception e => Either e a -> IO a
raise = either throwIO return