From e56c2b41c9146a4c62f077da0100fa8da4750587 Mon Sep 17 00:00:00 2001 From: Matvey Aksenov Date: Wed, 1 Apr 2015 21:42:14 +0000 Subject: [PATCH] Shuffle things around --- example/login.hs | 23 +- src/Ldap/Client.hs | 397 +++------------------------------ src/Ldap/Client/Bind.hs | 77 +++++++ src/Ldap/Client/Internal.hs | 172 ++++++++++++++ src/Ldap/Client/Search.hs | 197 ++++++++++++++++ test/Ldap/Client/BindSpec.hs | 31 +++ test/Ldap/Client/SearchSpec.hs | 134 +++++++++++ test/Ldap/ClientSpec.hs | 172 ++------------ test/SpecHelper.hs | 81 ++++++- 9 files changed, 739 insertions(+), 545 deletions(-) create mode 100644 src/Ldap/Client/Bind.hs create mode 100644 src/Ldap/Client/Internal.hs create mode 100644 src/Ldap/Client/Search.hs create mode 100644 test/Ldap/Client/BindSpec.hs create mode 100644 test/Ldap/Client/SearchSpec.hs diff --git a/example/login.hs b/example/login.hs index 36dadb7..7cbdf29 100644 --- a/example/login.hs +++ b/example/login.hs @@ -23,23 +23,18 @@ import Data.Text (Text) -- text import qualified Data.Text.Encoding as Text -- text import qualified Data.Text.IO as Text -- text import Env -- envparse -import qualified Ldap.Client as Ldap -- ldap-client -import Ldap.Client -- ldap-client - ( LdapError - , Scope(..) - , Filter(..) - , Attr(..) - ) +import Ldap.Client as Ldap -- ldap-client +import qualified Ldap.Client.Bind as Ldap -- ldap-client import System.Exit (die) -- base import qualified System.IO as IO -- base data Conf = Conf { host :: String - , port :: Ldap.PortNumber - , dn :: Ldap.Dn - , password :: Ldap.Password - , base :: Ldap.Dn + , port :: PortNumber + , dn :: Dn + , password :: Password + , base :: Dn } deriving (Show, Eq) getConf :: IO Conf @@ -65,20 +60,20 @@ login conf = fix $ \loop -> do uid <- prompt "Username: " us <- Ldap.search l (base conf) - (Ldap.scope WholeSubtree <> Ldap.typesOnly True) + (scope WholeSubtree <> typesOnly True) (And [ Attr "objectClass" := "Person" , Attr "uid" := Text.encodeUtf8 uid ]) [] case us of - Ldap.SearchEntry udn _ : _ -> + SearchEntry udn _ : _ -> fix $ \loop' -> do pwd <- bracket_ hideOutput showOutput (do pwd <- prompt ("Password for ‘" <> uid <> "’: ") Text.putStr "\n" return pwd) - res <- Ldap.bindEither l udn (Ldap.Password (Text.encodeUtf8 pwd)) + res <- Ldap.bindEither l udn (Password (Text.encodeUtf8 pwd)) case res of Left _ -> do again <- question "Invalid password. Try again? [y/n] " when again loop' diff --git a/src/Ldap/Client.hs b/src/Ldap/Client.hs index 1105beb..77fcea6 100644 --- a/src/Ldap/Client.hs +++ b/src/Ldap/Client.hs @@ -14,29 +14,19 @@ module Ldap.Client , Password(..) , BindError(..) , bind - , bindEither - , bindAsync - , bindAsyncSTM -- * Search Request - , Type.Scope(..) , Attr(..) - , SearchEntry(..) , SearchError(..) , search - , searchEither - , searchAsync - , searchAsyncSTM , Search - , defaultSearch , scope + , Type.Scope(..) , size , time , typesOnly , derefAliases , Filter(..) - -- * Unbind Request - , unbindAsync - , unbindAsyncSTM + , SearchEntry(..) -- * Add Request , AttrList , AddError(..) @@ -56,53 +46,44 @@ module Ldap.Client ) where import qualified Control.Concurrent.Async as Async -import Control.Concurrent.STM (STM, atomically) -import Control.Concurrent.STM.TMVar (TMVar, newEmptyTMVar, putTMVar, readTMVar) +import Control.Concurrent.STM (atomically) +import Control.Concurrent.STM.TMVar (putTMVar) import Control.Concurrent.STM.TQueue (TQueue, newTQueueIO, writeTQueue, readTQueue) -import Control.Exception (Exception, Handler(..), bracket, throwIO, catches) -import Control.Monad (forever, void) +import Control.Exception (Handler(..), bracket, throwIO, catches) +import Control.Monad (forever) import qualified Data.ASN1.BinaryEncoding as Asn1 import qualified Data.ASN1.Encoding as Asn1 import qualified Data.ASN1.Error as Asn1 -import Data.ByteString (ByteString) import qualified Data.ByteString as ByteString import qualified Data.ByteString.Lazy as ByteString.Lazy import Data.Foldable (traverse_, asum) import Data.Function (fix) -import Data.Int (Int32) import Data.List.NonEmpty (NonEmpty((:|))) -import qualified Data.List.NonEmpty as NonEmpty import qualified Data.Map.Strict as Map -import Data.Maybe (mapMaybe) import Data.Monoid (Endo(appEndo)) -import Data.Set (Set) -import qualified Data.Set as Set -import Data.Text (Text) -import Data.Typeable (Typeable) import Network.Connection (Connection) import qualified Network.Connection as Conn -import Network (PortNumber) import qualified System.IO.Error as IO import Ldap.Asn1.ToAsn1 (ToAsn1(toAsn1)) import Ldap.Asn1.FromAsn1 (FromAsn1, parseAsn1) import qualified Ldap.Asn1.Type as Type +import Ldap.Client.Internal +import Ldap.Client.Bind (BindError(..), bind, unbindAsync) +import Ldap.Client.Search + ( SearchError(..) + , search + , Search + , scope + , size + , time + , typesOnly + , derefAliases + , Filter(..) + , SearchEntry(..) + ) -data Host = - Plain String - | Secure String - deriving (Show, Eq, Ord) - -data Ldap = Ldap - { client :: TQueue ClientMessage - } deriving (Eq) - -data ClientMessage = New Request (TMVar (NonEmpty Type.ProtocolServerOp)) -type Request = Type.ProtocolClientOp -type InMessage = Type.ProtocolServerOp -type Response = NonEmpty InMessage - newLdap :: IO Ldap newLdap = Ldap <$> newTQueueIO @@ -182,7 +163,11 @@ output out conn = forever $ do where encode x = Asn1.encodeASN1' Asn1.DER (appEndo x []) -dispatch :: Ldap -> TQueue (Type.LdapMessage InMessage) -> TQueue (Type.LdapMessage Request) -> IO a +dispatch + :: Ldap + -> TQueue (Type.LdapMessage Type.ProtocolServerOp) + -> TQueue (Type.LdapMessage Request) + -> IO a dispatch Ldap { client } inq outq = flip fix (Map.empty, Map.empty, 1) $ \loop (!got, !results, !counter) -> do loop =<< atomically (asum @@ -209,335 +194,3 @@ dispatch Ldap { client } inq outq = traverse_ (\var -> putTMVar var (op :| [])) (Map.lookup mid results) return (Map.delete mid got, Map.delete mid results, counter) ]) - - -data Async e a = Async (STM (Either e a)) - -instance Functor (Async e) where - fmap f (Async stm) = Async (fmap (fmap f) stm) - - -newtype Dn = Dn Text - deriving (Show, Eq) -newtype Password = Password ByteString - deriving (Show, Eq) - -data BindError = - BindInvalidResponse Response - | BindErrorCode Type.ResultCode - deriving (Show, Eq, Typeable) - -instance Exception BindError - --- | Throws 'BindError' on failure. Don't worry, the nearest 'with' --- will catch it, so it won't destroy your program. -bind :: Ldap -> Dn -> Password -> IO () -bind l username password = - raise =<< bindEither l username password - -bindEither :: Ldap -> Dn -> Password -> IO (Either BindError ()) -bindEither l username password = - wait =<< bindAsync l username password - -bindAsync :: Ldap -> Dn -> Password -> IO (Async BindError ()) -bindAsync l username password = - atomically (bindAsyncSTM l username password) - -bindAsyncSTM :: Ldap -> Dn -> Password -> STM (Async BindError ()) -bindAsyncSTM l username password = - sendRequest l bindResult (bindRequest username password) - -bindRequest :: Dn -> Password -> Request -bindRequest (Dn username) (Password password) = - Type.BindRequest ldapVersion - (Type.LdapDn (Type.LdapString username)) - (Type.Simple password) - where - ldapVersion = 3 - -bindResult :: Response -> Either BindError () -bindResult (Type.BindResponse (Type.LdapResult code _ _ _) _ :| []) - | Type.Success <- code = Right () - | otherwise = Left (BindErrorCode code) -bindResult res = Left (BindInvalidResponse res) - - -data SearchError = - SearchInvalidResponse Response - | SearchErrorCode Type.ResultCode - deriving (Show, Eq, Typeable) - -instance Exception SearchError - -search - :: Ldap - -> Dn - -> Mod Search - -> Filter - -> [Attr] - -> IO [SearchEntry] -search l base opts flt attributes = - raise =<< searchEither l base opts flt attributes - -searchEither - :: Ldap - -> Dn - -> Mod Search - -> Filter - -> [Attr] - -> IO (Either SearchError [SearchEntry]) -searchEither l base opts flt attributes = - wait =<< searchAsync l base opts flt attributes - -searchAsync - :: Ldap - -> Dn - -> Mod Search - -> Filter - -> [Attr] - -> IO (Async SearchError [SearchEntry]) -searchAsync l base opts flt attributes = - atomically (searchAsyncSTM l base opts flt attributes) - -searchAsyncSTM - :: Ldap - -> Dn - -> Mod Search - -> Filter - -> [Attr] - -> STM (Async SearchError [SearchEntry]) -searchAsyncSTM l base opts flt attributes = - sendRequest l searchResult (searchRequest base opts flt attributes) - -searchResult :: Response -> Either SearchError [SearchEntry] -searchResult (Type.SearchResultDone (Type.LdapResult code _ _ _) :| xs) - | Type.Success <- code = Right (mapMaybe g xs) - | Type.AdminLimitExceeded <- code = Right (mapMaybe g xs) - | Type.SizeLimitExceeded <- code = Right (mapMaybe g xs) - | otherwise = Left (SearchErrorCode code) - where - g (Type.SearchResultEntry (Type.LdapDn (Type.LdapString dn)) - (Type.PartialAttributeList ys)) = - Just (SearchEntry (Dn dn) (map h ys)) - g _ = Nothing - h (Type.PartialAttribute (Type.AttributeDescription (Type.LdapString x)) - y) = (Attr x, Set.map j y) - j (Type.AttributeValue x) = x -searchResult res = Left (SearchInvalidResponse res) - -searchRequest :: Dn -> Mod Search -> Filter -> [Attr] -> Request -searchRequest (Dn base) (Mod m) flt attributes = - Type.SearchRequest (Type.LdapDn (Type.LdapString base)) - _scope - _derefAliases - _size - _time - _typesOnly - (fromFilter flt) - (Type.AttributeSelection (map (Type.LdapString . unAttr) attributes)) - where - Search { _scope, _derefAliases, _size, _time, _typesOnly } = - m defaultSearch - fromFilter (Not x) = Type.Not (fromFilter x) - fromFilter (And xs) = Type.And (fmap fromFilter xs) - fromFilter (Or xs) = Type.Or (fmap fromFilter xs) - fromFilter (Present (Attr x)) = - Type.Present (Type.AttributeDescription (Type.LdapString x)) - fromFilter (Attr x := y) = - Type.EqualityMatch - (Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x)) - (Type.AssertionValue y)) - fromFilter (Attr x :>= y) = - Type.GreaterOrEqual - (Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x)) - (Type.AssertionValue y)) - fromFilter (Attr x :<= y) = - Type.LessOrEqual - (Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x)) - (Type.AssertionValue y)) - fromFilter (Attr x :~= y) = - Type.ApproxMatch - (Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x)) - (Type.AssertionValue y)) - fromFilter (Attr x :=* (mi, xs, mf)) = - Type.Substrings - (Type.SubstringFilter (Type.AttributeDescription (Type.LdapString x)) - (NonEmpty.fromList (concat - [ maybe [] (\i -> [Type.Initial (Type.AssertionValue i)]) mi - , fmap (Type.Any . Type.AssertionValue) xs - , maybe [] (\f -> [Type.Final (Type.AssertionValue f)]) mf - ]))) - fromFilter ((mx, mr, b) ::= y) = - Type.ExtensibleMatch - (Type.MatchingRuleAssertion (fmap (\(Attr r) -> Type.MatchingRuleId (Type.LdapString r)) mr) - (fmap (\(Attr x) -> Type.AttributeDescription (Type.LdapString x)) mx) - (Type.AssertionValue y) - b) - -data Search = Search - { _scope :: Type.Scope - , _derefAliases :: Type.DerefAliases - , _size :: Int32 - , _time :: Int32 - , _typesOnly :: Bool - } deriving (Show, Eq) - -defaultSearch :: Search -defaultSearch = Search - { _scope = Type.BaseObject - , _size = 0 - , _time = 0 - , _typesOnly = False - , _derefAliases = Type.NeverDerefAliases - } - -scope :: Type.Scope -> Mod Search -scope x = Mod (\y -> y { _scope = x }) - -size :: Int32 -> Mod Search -size x = Mod (\y -> y { _size = x }) - -time :: Int32 -> Mod Search -time x = Mod (\y -> y { _time = x }) - -typesOnly :: Bool -> Mod Search -typesOnly x = Mod (\y -> y { _typesOnly = x }) - -derefAliases :: Type.DerefAliases -> Mod Search -derefAliases x = Mod (\y -> y { _derefAliases = x }) - -newtype Mod a = Mod (a -> a) - -instance Monoid (Mod a) where - mempty = Mod id - Mod f `mappend` Mod g = Mod (g . f) - -data Filter = - Not Filter - | And (NonEmpty Filter) - | Or (NonEmpty Filter) - | Present Attr - | Attr := ByteString - | Attr :>= ByteString - | Attr :<= ByteString - | Attr :~= ByteString - | Attr :=* (Maybe ByteString, [ByteString], Maybe ByteString) - | (Maybe Attr, Maybe Attr, Bool) ::= ByteString - -newtype Attr = Attr Text - deriving (Show, Eq) - --- 'Attr' unwrapper. This is a separate function not to turn 'Attr''s --- 'Show' instance into complete and utter shit. -unAttr :: Attr -> Text -unAttr (Attr a) = a - -data SearchEntry = SearchEntry Dn [(Attr, Set ByteString)] - deriving (Show, Eq) - - --- | Note that 'unbindAsync' does not return an 'Async', --- because LDAP server never responds to @UnbindRequest@s, hence --- a call to 'wait' on a hypothetical 'Async' would have resulted --- in an exception anyway. -unbindAsync :: Ldap -> IO () -unbindAsync = - atomically . unbindAsyncSTM - --- | Note that 'unbindAsyncSTM' does not return an 'Async', --- because LDAP server never responds to @UnbindRequest@s, hence --- a call to 'wait' on a hypothetical 'Async' would have resulted --- in an exception anyway. -unbindAsyncSTM :: Ldap -> STM () -unbindAsyncSTM l = - void (sendRequest l die Type.UnbindRequest) - where - die = error "Ldap.Client: do not wait for the response to UnbindRequest" - - -type AttrList f = [(Attr, f ByteString)] - -data AddError = - AddInvalidResponse Response - | AddErrorCode Type.ResultCode - deriving (Show, Eq, Typeable) - -instance Exception AddError - -add :: Ldap -> Dn -> AttrList NonEmpty -> IO () -add l dn as = - raise =<< addEither l dn as - -addEither :: Ldap -> Dn -> AttrList NonEmpty -> IO (Either AddError ()) -addEither l dn as = - wait =<< addAsync l dn as - -addAsync :: Ldap -> Dn -> AttrList NonEmpty -> IO (Async AddError ()) -addAsync l dn as = - atomically (addAsyncSTM l dn as) - -addAsyncSTM :: Ldap -> Dn -> AttrList NonEmpty -> STM (Async AddError ()) -addAsyncSTM l (Dn dn) as = - sendRequest l addResult - (Type.AddRequest (Type.LdapDn (Type.LdapString dn)) - (Type.AttributeList (map f as))) - where - f (Attr x, xs) = Type.Attribute (Type.AttributeDescription (Type.LdapString x)) - (fmap Type.AttributeValue xs) - -addResult :: Response -> Either AddError () -addResult (Type.AddResponse (Type.LdapResult code _ _ _) :| []) - | Type.Success <- code = Right () - | otherwise = Left (AddErrorCode code) -addResult res = Left (AddInvalidResponse res) - - -data DeleteError = - DeleteInvalidResponse Response - | DeleteErrorCode Type.ResultCode - deriving (Show, Eq, Typeable) - -instance Exception DeleteError - -delete :: Ldap -> Dn -> IO () -delete l dn = - raise =<< deleteEither l dn - -deleteEither :: Ldap -> Dn -> IO (Either DeleteError ()) -deleteEither l dn = - wait =<< deleteAsync l dn - -deleteAsync :: Ldap -> Dn -> IO (Async DeleteError ()) -deleteAsync l dn = - atomically (deleteAsyncSTM l dn) - -deleteAsyncSTM :: Ldap -> Dn -> STM (Async DeleteError ()) -deleteAsyncSTM l (Dn dn) = - sendRequest l deleteResult - (Type.DeleteRequest (Type.LdapDn (Type.LdapString dn))) - -deleteResult :: Response -> Either DeleteError () -deleteResult (Type.DeleteResponse (Type.LdapResult code _ _ _) :| []) - | Type.Success <- code = Right () - | otherwise = Left (DeleteErrorCode code) -deleteResult res = Left (DeleteInvalidResponse res) - - -wait :: Async e a -> IO (Either e a) -wait = atomically . waitSTM - -waitSTM :: Async e a -> STM (Either e a) -waitSTM (Async stm) = stm - - -sendRequest :: Ldap -> (Response -> Either e a) -> Request -> STM (Async e a) -sendRequest l p msg = - do var <- newEmptyTMVar - writeRequest l var msg - return (Async (fmap p (readTMVar var))) - -writeRequest :: Ldap -> TMVar Response -> Request -> STM () -writeRequest Ldap { client } var msg = writeTQueue client (New msg var) - -raise :: Exception e => Either e a -> IO a -raise = either throwIO return diff --git a/src/Ldap/Client/Bind.hs b/src/Ldap/Client/Bind.hs new file mode 100644 index 0000000..38550a5 --- /dev/null +++ b/src/Ldap/Client/Bind.hs @@ -0,0 +1,77 @@ +module Ldap.Client.Bind + ( BindError(..) + , bind + , bindEither + , bindAsync + , bindAsyncSTM + , unbindAsync + , unbindAsyncSTM + ) where + +import Control.Exception (Exception) +import Control.Monad (void) +import Control.Monad.STM (STM, atomically) +import Data.List.NonEmpty (NonEmpty((:|))) +import Data.Typeable (Typeable) + +import Ldap.Client.Internal +import qualified Ldap.Asn1.Type as Type + + +data BindError = + BindInvalidResponse Response + | BindErrorCode Type.ResultCode + deriving (Show, Eq, Typeable) + +instance Exception BindError + +-- | Throws 'BindError' on failure. Don't worry, the nearest 'with' +-- will catch it, so it won't destroy your program. +bind :: Ldap -> Dn -> Password -> IO () +bind l username password = + raise =<< bindEither l username password + +bindEither :: Ldap -> Dn -> Password -> IO (Either BindError ()) +bindEither l username password = + wait =<< bindAsync l username password + +bindAsync :: Ldap -> Dn -> Password -> IO (Async BindError ()) +bindAsync l username password = + atomically (bindAsyncSTM l username password) + +bindAsyncSTM :: Ldap -> Dn -> Password -> STM (Async BindError ()) +bindAsyncSTM l username password = + sendRequest l bindResult (bindRequest username password) + +bindRequest :: Dn -> Password -> Request +bindRequest (Dn username) (Password password) = + Type.BindRequest ldapVersion + (Type.LdapDn (Type.LdapString username)) + (Type.Simple password) + where + ldapVersion = 3 + +bindResult :: Response -> Either BindError () +bindResult (Type.BindResponse (Type.LdapResult code _ _ _) _ :| []) + | Type.Success <- code = Right () + | otherwise = Left (BindErrorCode code) +bindResult res = Left (BindInvalidResponse res) + + +-- | Note that 'unbindAsync' does not return an 'Async', +-- because LDAP server never responds to @UnbindRequest@s, hence +-- a call to 'wait' on a hypothetical 'Async' would have resulted +-- in an exception anyway. +unbindAsync :: Ldap -> IO () +unbindAsync = + atomically . unbindAsyncSTM + +-- | Note that 'unbindAsyncSTM' does not return an 'Async', +-- because LDAP server never responds to @UnbindRequest@s, hence +-- a call to 'wait' on a hypothetical 'Async' would have resulted +-- in an exception anyway. +unbindAsyncSTM :: Ldap -> STM () +unbindAsyncSTM l = + void (sendRequest l die Type.UnbindRequest) + where + die = error "Ldap.Client: do not wait for the response to UnbindRequest" diff --git a/src/Ldap/Client/Internal.hs b/src/Ldap/Client/Internal.hs new file mode 100644 index 0000000..ea87d52 --- /dev/null +++ b/src/Ldap/Client/Internal.hs @@ -0,0 +1,172 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE NamedFieldPuns #-} +module Ldap.Client.Internal + ( Host(..) + , PortNumber + , Ldap(..) + , ClientMessage(..) + , Type.ResultCode(..) + , Async + -- * Add Request + , AttrList + , AddError(..) + , add + , addEither + , addAsync + , addAsyncSTM + -- * Delete Request + , DeleteError(..) + , delete + , deleteEither + , deleteAsync + , deleteAsyncSTM + -- * Waiting for Request Completion + , wait + , waitSTM + -- * Misc + , Response + , Request + , raise + , sendRequest + , Dn(..) + , Password(..) + , Attr(..) + , unAttr + ) where + +import Control.Concurrent.STM (STM, atomically) +import Control.Concurrent.STM.TMVar (TMVar, newEmptyTMVar, readTMVar) +import Control.Concurrent.STM.TQueue (TQueue, writeTQueue) +import Control.Exception (Exception, throwIO) +import Data.ByteString (ByteString) +import Data.List.NonEmpty (NonEmpty((:|))) +import Data.Text (Text) +import Data.Typeable (Typeable) +import Network (PortNumber) + +import qualified Ldap.Asn1.Type as Type + + +data Host = + Plain String + | Secure String + deriving (Show, Eq, Ord) + +data Ldap = Ldap + { client :: TQueue ClientMessage + } deriving (Eq) + +data ClientMessage = New Request (TMVar (NonEmpty Type.ProtocolServerOp)) +type Request = Type.ProtocolClientOp +type InMessage = Type.ProtocolServerOp +type Response = NonEmpty InMessage + +data Async e a = Async (STM (Either e a)) + +instance Functor (Async e) where + fmap f (Async stm) = Async (fmap (fmap f) stm) + + +newtype Dn = Dn Text + deriving (Show, Eq) +newtype Password = Password ByteString + deriving (Show, Eq) + + + + +newtype Attr = Attr Text + deriving (Show, Eq) + +type AttrList f = [(Attr, f ByteString)] + +data AddError = + AddInvalidResponse Response + | AddErrorCode Type.ResultCode + deriving (Show, Eq, Typeable) + +instance Exception AddError + +add :: Ldap -> Dn -> AttrList NonEmpty -> IO () +add l dn as = + raise =<< addEither l dn as + +addEither :: Ldap -> Dn -> AttrList NonEmpty -> IO (Either AddError ()) +addEither l dn as = + wait =<< addAsync l dn as + +addAsync :: Ldap -> Dn -> AttrList NonEmpty -> IO (Async AddError ()) +addAsync l dn as = + atomically (addAsyncSTM l dn as) + +addAsyncSTM :: Ldap -> Dn -> AttrList NonEmpty -> STM (Async AddError ()) +addAsyncSTM l (Dn dn) as = + sendRequest l addResult + (Type.AddRequest (Type.LdapDn (Type.LdapString dn)) + (Type.AttributeList (map f as))) + where + f (Attr x, xs) = Type.Attribute (Type.AttributeDescription (Type.LdapString x)) + (fmap Type.AttributeValue xs) + +addResult :: Response -> Either AddError () +addResult (Type.AddResponse (Type.LdapResult code _ _ _) :| []) + | Type.Success <- code = Right () + | otherwise = Left (AddErrorCode code) +addResult res = Left (AddInvalidResponse res) + +-- 'Attr' unwrapper. This is a separate function not to turn 'Attr''s +-- 'Show' instance into complete and utter shit. +unAttr :: Attr -> Text +unAttr (Attr a) = a + + +data DeleteError = + DeleteInvalidResponse Response + | DeleteErrorCode Type.ResultCode + deriving (Show, Eq, Typeable) + +instance Exception DeleteError + +delete :: Ldap -> Dn -> IO () +delete l dn = + raise =<< deleteEither l dn + +deleteEither :: Ldap -> Dn -> IO (Either DeleteError ()) +deleteEither l dn = + wait =<< deleteAsync l dn + +deleteAsync :: Ldap -> Dn -> IO (Async DeleteError ()) +deleteAsync l dn = + atomically (deleteAsyncSTM l dn) + +deleteAsyncSTM :: Ldap -> Dn -> STM (Async DeleteError ()) +deleteAsyncSTM l (Dn dn) = + sendRequest l deleteResult + (Type.DeleteRequest (Type.LdapDn (Type.LdapString dn))) + +deleteResult :: Response -> Either DeleteError () +deleteResult (Type.DeleteResponse (Type.LdapResult code _ _ _) :| []) + | Type.Success <- code = Right () + | otherwise = Left (DeleteErrorCode code) +deleteResult res = Left (DeleteInvalidResponse res) + + +wait :: Async e a -> IO (Either e a) +wait = atomically . waitSTM + +waitSTM :: Async e a -> STM (Either e a) +waitSTM (Async stm) = stm + + +sendRequest :: Ldap -> (Response -> Either e a) -> Request -> STM (Async e a) +sendRequest l p msg = + do var <- newEmptyTMVar + writeRequest l var msg + return (Async (fmap p (readTMVar var))) + +writeRequest :: Ldap -> TMVar Response -> Request -> STM () +writeRequest Ldap { client } var msg = writeTQueue client (New msg var) + +raise :: Exception e => Either e a -> IO a +raise = either throwIO return diff --git a/src/Ldap/Client/Search.hs b/src/Ldap/Client/Search.hs new file mode 100644 index 0000000..cbfbe0f --- /dev/null +++ b/src/Ldap/Client/Search.hs @@ -0,0 +1,197 @@ +{-# LANGUAGE NamedFieldPuns #-} +module Ldap.Client.Search + ( SearchError(..) + , search + , searchEither + , searchAsync + , searchAsyncSTM + , Search + , Type.Scope(..) + , scope + , size + , time + , typesOnly + , derefAliases + , Filter(..) + , SearchEntry(..) + ) where + +import Control.Exception (Exception) +import Control.Monad.STM (STM, atomically) +import Data.ByteString (ByteString) +import Data.Int (Int32) +import Data.List.NonEmpty (NonEmpty((:|))) +import qualified Data.List.NonEmpty as NonEmpty +import Data.Maybe (mapMaybe) +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Typeable (Typeable) + +import qualified Ldap.Asn1.Type as Type +import Ldap.Client.Internal + + +data SearchError = + SearchInvalidResponse Response + | SearchErrorCode Type.ResultCode + deriving (Show, Eq, Typeable) + +instance Exception SearchError + +search + :: Ldap + -> Dn + -> Mod Search + -> Filter + -> [Attr] + -> IO [SearchEntry] +search l base opts flt attributes = + raise =<< searchEither l base opts flt attributes + +searchEither + :: Ldap + -> Dn + -> Mod Search + -> Filter + -> [Attr] + -> IO (Either SearchError [SearchEntry]) +searchEither l base opts flt attributes = + wait =<< searchAsync l base opts flt attributes + +searchAsync + :: Ldap + -> Dn + -> Mod Search + -> Filter + -> [Attr] + -> IO (Async SearchError [SearchEntry]) +searchAsync l base opts flt attributes = + atomically (searchAsyncSTM l base opts flt attributes) + +searchAsyncSTM + :: Ldap + -> Dn + -> Mod Search + -> Filter + -> [Attr] + -> STM (Async SearchError [SearchEntry]) +searchAsyncSTM l base opts flt attributes = + sendRequest l searchResult (searchRequest base opts flt attributes) + +searchResult :: Response -> Either SearchError [SearchEntry] +searchResult (Type.SearchResultDone (Type.LdapResult code _ _ _) :| xs) + | Type.Success <- code = Right (mapMaybe g xs) + | Type.AdminLimitExceeded <- code = Right (mapMaybe g xs) + | Type.SizeLimitExceeded <- code = Right (mapMaybe g xs) + | otherwise = Left (SearchErrorCode code) + where + g (Type.SearchResultEntry (Type.LdapDn (Type.LdapString dn)) + (Type.PartialAttributeList ys)) = + Just (SearchEntry (Dn dn) (map h ys)) + g _ = Nothing + h (Type.PartialAttribute (Type.AttributeDescription (Type.LdapString x)) + y) = (Attr x, Set.map j y) + j (Type.AttributeValue x) = x +searchResult res = Left (SearchInvalidResponse res) + +searchRequest :: Dn -> Mod Search -> Filter -> [Attr] -> Request +searchRequest (Dn base) (Mod m) flt attributes = + Type.SearchRequest (Type.LdapDn (Type.LdapString base)) + _scope + _derefAliases + _size + _time + _typesOnly + (fromFilter flt) + (Type.AttributeSelection (map (Type.LdapString . unAttr) attributes)) + where + Search { _scope, _derefAliases, _size, _time, _typesOnly } = + m defaultSearch + fromFilter (Not x) = Type.Not (fromFilter x) + fromFilter (And xs) = Type.And (fmap fromFilter xs) + fromFilter (Or xs) = Type.Or (fmap fromFilter xs) + fromFilter (Present (Attr x)) = + Type.Present (Type.AttributeDescription (Type.LdapString x)) + fromFilter (Attr x := y) = + Type.EqualityMatch + (Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x)) + (Type.AssertionValue y)) + fromFilter (Attr x :>= y) = + Type.GreaterOrEqual + (Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x)) + (Type.AssertionValue y)) + fromFilter (Attr x :<= y) = + Type.LessOrEqual + (Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x)) + (Type.AssertionValue y)) + fromFilter (Attr x :~= y) = + Type.ApproxMatch + (Type.AttributeValueAssertion (Type.AttributeDescription (Type.LdapString x)) + (Type.AssertionValue y)) + fromFilter (Attr x :=* (mi, xs, mf)) = + Type.Substrings + (Type.SubstringFilter (Type.AttributeDescription (Type.LdapString x)) + (NonEmpty.fromList (concat + [ maybe [] (\i -> [Type.Initial (Type.AssertionValue i)]) mi + , fmap (Type.Any . Type.AssertionValue) xs + , maybe [] (\f -> [Type.Final (Type.AssertionValue f)]) mf + ]))) + fromFilter ((mx, mr, b) ::= y) = + Type.ExtensibleMatch + (Type.MatchingRuleAssertion (fmap (\(Attr r) -> Type.MatchingRuleId (Type.LdapString r)) mr) + (fmap (\(Attr x) -> Type.AttributeDescription (Type.LdapString x)) mx) + (Type.AssertionValue y) + b) + +data Search = Search + { _scope :: Type.Scope + , _derefAliases :: Type.DerefAliases + , _size :: Int32 + , _time :: Int32 + , _typesOnly :: Bool + } deriving (Show, Eq) + +defaultSearch :: Search +defaultSearch = Search + { _scope = Type.BaseObject + , _size = 0 + , _time = 0 + , _typesOnly = False + , _derefAliases = Type.NeverDerefAliases + } + +scope :: Type.Scope -> Mod Search +scope x = Mod (\y -> y { _scope = x }) + +size :: Int32 -> Mod Search +size x = Mod (\y -> y { _size = x }) + +time :: Int32 -> Mod Search +time x = Mod (\y -> y { _time = x }) + +typesOnly :: Bool -> Mod Search +typesOnly x = Mod (\y -> y { _typesOnly = x }) + +derefAliases :: Type.DerefAliases -> Mod Search +derefAliases x = Mod (\y -> y { _derefAliases = x }) + +newtype Mod a = Mod (a -> a) + +instance Monoid (Mod a) where + mempty = Mod id + Mod f `mappend` Mod g = Mod (g . f) + +data Filter = + Not Filter + | And (NonEmpty Filter) + | Or (NonEmpty Filter) + | Present Attr + | Attr := ByteString + | Attr :>= ByteString + | Attr :<= ByteString + | Attr :~= ByteString + | Attr :=* (Maybe ByteString, [ByteString], Maybe ByteString) + | (Maybe Attr, Maybe Attr, Bool) ::= ByteString + +data SearchEntry = SearchEntry Dn (AttrList Set) + deriving (Show, Eq) diff --git a/test/Ldap/Client/BindSpec.hs b/test/Ldap/Client/BindSpec.hs new file mode 100644 index 0000000..5e9eb84 --- /dev/null +++ b/test/Ldap/Client/BindSpec.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE OverloadedStrings #-} +module Ldap.Client.BindSpec (spec) where + +import Test.Hspec +import Ldap.Client as Ldap + +import SpecHelper (locally) + + +spec :: Spec +spec = do + it "binds as admin" $ do + res <- locally $ \l -> do + Ldap.bind l (Dn "cn=admin") (Password "secret") + res `shouldBe` Right () + + it "tries to bind as admin with the wrong password, unsuccessfully" $ do + res <- locally $ \l -> do + Ldap.bind l (Dn "cn=admin") (Password "public") + res `shouldBe` Left (Ldap.BindError (Ldap.BindErrorCode Ldap.InvalidCredentials)) + + it "binds as pikachu" $ do + res <- locally $ \l -> do + Ldap.bind l (Dn "cn=admin") (Password "secret") + Ldap.SearchEntry udn _ : [] + <- Ldap.search l (Dn "o=localhost") + (scope WholeSubtree) + (Attr "cn" := "pikachu") + [] + Ldap.bind l udn (Password "i-choose-you") + res `shouldBe` Right () diff --git a/test/Ldap/Client/SearchSpec.hs b/test/Ldap/Client/SearchSpec.hs new file mode 100644 index 0000000..4fadee6 --- /dev/null +++ b/test/Ldap/Client/SearchSpec.hs @@ -0,0 +1,134 @@ +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +module Ldap.Client.SearchSpec (spec) where + +import Data.Monoid ((<>)) +import Test.Hspec +import Ldap.Client as Ldap + +import SpecHelper + ( locally + , dns + , bulbasaur + , ivysaur + , venusaur + , charmander + , charmeleon + , charizard + , squirtle + , wartortle + , blastoise + , caterpie + , metapod + , butterfree + , pikachu + ) + + +spec :: Spec +spec = do + let go l f = Ldap.search l (Dn "o=localhost") + (Ldap.scope WholeSubtree <> Ldap.typesOnly True) + f + [] + + it "cannot search as ‘pikachu’" $ do + res <- locally $ \l -> do + Ldap.bind l pikachu (Password "i-choose-you") + go l (Present (Attr "password")) + res `shouldBe` Left (Ldap.SearchError (Ldap.SearchErrorCode Ldap.InsufficientAccessRights)) + + it "‘present’ filter" $ do + res <- locally $ \l -> do + res <- go l (Present (Attr "password")) + dns res `shouldBe` [pikachu] + res `shouldBe` Right () + + it "‘equality match’ filter" $ do + res <- locally $ \l -> do + res <- go l (Attr "type" := "flying") + dns res `shouldMatchList` + [ butterfree + , charizard + ] + res `shouldBe` Right () + + it "‘and’ filter" $ do + res <- locally $ \l -> do + res <- go l (And [ Attr "type" := "fire" + , Attr "evolution" := "1" + ]) + dns res `shouldBe` [charmeleon] + res `shouldBe` Right () + + it "‘or’ filter" $ do + res <- locally $ \l -> do + res <- go l (Or [ Attr "type" := "fire" + , Attr "evolution" := "1" + ]) + dns res `shouldMatchList` + [ ivysaur + , charizard + , charmeleon + , charmander + , wartortle + , metapod + ] + res `shouldBe` Right () + + it "‘ge’ filter" $ do + res <- locally $ \l -> do + res <- go l (Attr "evolution" :>= "2") + dns res `shouldMatchList` + [ venusaur + , charizard + , blastoise + , butterfree + ] + res `shouldBe` Right () + + it "‘le’ filter" $ do + res <- locally $ \l -> do + res <- go l (Attr "evolution" :<= "0") + dns res `shouldMatchList` + [ bulbasaur + , charmander + , squirtle + , caterpie + , pikachu + ] + res `shouldBe` Right () + + it "‘not’ filter" $ do + res <- locally $ \l -> do + res <- go l (Not (Or [ Attr "type" := "fire" + , Attr "evolution" :>= "1" + ])) + dns res `shouldMatchList` + [ bulbasaur + , squirtle + , caterpie + , pikachu + ] + res `shouldBe` Right () + + it "‘substrings’ filter" $ do + res <- locally $ \l -> do + x <- go l (Attr "cn" :=* (Just "char", [], Nothing)) + dns x `shouldMatchList` + [ charmander + , charmeleon + , charizard + ] + y <- go l (Attr "cn" :=* (Nothing, [], Just "saur")) + dns y `shouldMatchList` + [ bulbasaur + , ivysaur + , venusaur + ] + z <- go l (Attr "cn" :=* (Nothing, ["a", "o"], Just "e")) + dns z `shouldMatchList` + [ blastoise + , wartortle + ] + res `shouldBe` Right () diff --git a/test/Ldap/ClientSpec.hs b/test/Ldap/ClientSpec.hs index 121468a..e1500da 100644 --- a/test/Ldap/ClientSpec.hs +++ b/test/Ldap/ClientSpec.hs @@ -5,143 +5,24 @@ module Ldap.ClientSpec (spec) where import Data.Monoid ((<>)) import Test.Hspec -import Ldap.Client (Dn(..), Password(..), Filter(..), Scope(..), Attr(..)) +import Ldap.Client (Dn(..), Filter(..), Scope(..), Attr(..)) import qualified Ldap.Client as Ldap -import SpecHelper (port) +import SpecHelper + ( locally + , dns + , pikachu + , vulpix + , oddish + ) spec :: Spec spec = do - - let locally = Ldap.with localhost port - search l f = Ldap.search l (Dn "o=localhost") - (Ldap.scope WholeSubtree <> Ldap.typesOnly True) - f - [] - - context "bind" $ do - - it "binds as admin" $ do - res <- locally $ \l -> do - Ldap.bind l (Dn "cn=admin") (Password "secret") - res `shouldBe` Right () - - it "tries to bind as admin with the wrong password, unsuccessfully" $ do - res <- locally $ \l -> do - Ldap.bind l (Dn "cn=admin") (Password "public") - res `shouldBe` Left (Ldap.BindError (Ldap.BindErrorCode Ldap.InvalidCredentials)) - - it "binds as pikachu" $ do - res <- locally $ \l -> do - Ldap.bind l (Dn "cn=admin") (Password "secret") - Ldap.SearchEntry udn _ : [] - <- search l (Attr "cn" := "pikachu") - Ldap.bind l udn (Password "i-choose-you") - res `shouldBe` Right () - - context "search" $ do - - it "cannot search as ‘pikachu’" $ do - res <- locally $ \l -> do - Ldap.bind l pikachu (Password "i-choose-you") - search l (Present (Attr "password")) - res `shouldBe` Left (Ldap.SearchError (Ldap.SearchErrorCode Ldap.InsufficientAccessRights)) - - it "‘present’ filter" $ do - res <- locally $ \l -> do - res <- search l (Present (Attr "password")) - dns res `shouldBe` [pikachu] - res `shouldBe` Right () - - it "‘equality match’ filter" $ do - res <- locally $ \l -> do - res <- search l (Attr "type" := "flying") - dns res `shouldMatchList` - [ butterfree - , charizard - ] - res `shouldBe` Right () - - it "‘and’ filter" $ do - res <- locally $ \l -> do - res <- search l (And [ Attr "type" := "fire" - , Attr "evolution" := "1" - ]) - dns res `shouldBe` [charmeleon] - res `shouldBe` Right () - - it "‘or’ filter" $ do - res <- locally $ \l -> do - res <- search l (Or [ Attr "type" := "fire" - , Attr "evolution" := "1" - ]) - dns res `shouldMatchList` - [ ivysaur - , charizard - , charmeleon - , charmander - , wartortle - , metapod - ] - res `shouldBe` Right () - - it "‘ge’ filter" $ do - res <- locally $ \l -> do - res <- search l (Attr "evolution" :>= "2") - dns res `shouldMatchList` - [ venusaur - , charizard - , blastoise - , butterfree - ] - res `shouldBe` Right () - - it "‘le’ filter" $ do - res <- locally $ \l -> do - res <- search l (Attr "evolution" :<= "0") - dns res `shouldMatchList` - [ bulbasaur - , charmander - , squirtle - , caterpie - , pikachu - ] - res `shouldBe` Right () - - it "‘not’ filter" $ do - res <- locally $ \l -> do - res <- search l (Not (Or [ Attr "type" := "fire" - , Attr "evolution" :>= "1" - ])) - dns res `shouldMatchList` - [ bulbasaur - , squirtle - , caterpie - , pikachu - ] - res `shouldBe` Right () - - it "‘substrings’ filter" $ do - res <- locally $ \l -> do - x <- search l (Attr "cn" :=* (Just "char", [], Nothing)) - dns x `shouldMatchList` - [ charmander - , charmeleon - , charizard - ] - y <- search l (Attr "cn" :=* (Nothing, [], Just "saur")) - dns y `shouldMatchList` - [ bulbasaur - , ivysaur - , venusaur - ] - z <- search l (Attr "cn" :=* (Nothing, ["a", "o"], Just "e")) - dns z `shouldMatchList` - [ blastoise - , wartortle - ] - res `shouldBe` Right () + let go l f = Ldap.search l (Dn "o=localhost") + (Ldap.scope WholeSubtree <> Ldap.typesOnly True) + f + [] context "add" $ do @@ -152,7 +33,7 @@ spec = do , (Attr "evolution", ["0"]) , (Attr "type", ["fire"]) ] - res <- search l (Attr "cn" := "vulpix") + res <- go l (Attr "cn" := "vulpix") dns res `shouldBe` [vulpix] res `shouldBe` Right () @@ -161,7 +42,7 @@ spec = do it "deletes an entry" $ do res <- locally $ \l -> do Ldap.delete l pikachu - res <- search l (Attr "cn" := "pikachu") + res <- go l (Attr "cn" := "pikachu") dns res `shouldBe` [] res `shouldBe` Right () @@ -169,28 +50,3 @@ spec = do res <- locally $ \l -> do Ldap.delete l oddish res `shouldBe` Left (Ldap.DeleteError (Ldap.DeleteErrorCode Ldap.NoSuchObject)) - - where - bulbasaur = Dn "cn=bulbasaur,o=localhost" - ivysaur = Dn "cn=ivysaur,o=localhost" - venusaur = Dn "cn=venusaur,o=localhost" - charmander = Dn "cn=charmander,o=localhost" - charmeleon = Dn "cn=charmeleon,o=localhost" - charizard = Dn "cn=charizard,o=localhost" - squirtle = Dn "cn=squirtle,o=localhost" - wartortle = Dn "cn=wartortle,o=localhost" - blastoise = Dn "cn=blastoise,o=localhost" - caterpie = Dn "cn=caterpie,o=localhost" - metapod = Dn "cn=metapod,o=localhost" - butterfree = Dn "cn=butterfree,o=localhost" - pikachu = Dn "cn=pikachu,o=localhost" - vulpix = Dn "cn=vulpix,o=localhost" - oddish = Dn "cn=oddish,o=localhost" - -localhost :: Ldap.Host -localhost = Ldap.Plain "localhost" - -dns :: [Ldap.SearchEntry] -> [Dn] -dns (Ldap.SearchEntry dn _ : es) = dn : dns es -dns [] = [] -dns _ = error "?" diff --git a/test/SpecHelper.hs b/test/SpecHelper.hs index 5c061ba..b9440c9 100644 --- a/test/SpecHelper.hs +++ b/test/SpecHelper.hs @@ -1,4 +1,83 @@ -module SpecHelper (port) where +{-# LANGUAGE OverloadedStrings #-} +module SpecHelper + ( locally + , port + , dns + -- * Users + , bulbasaur + , ivysaur + , venusaur + , charmander + , charmeleon + , charizard + , squirtle + , wartortle + , blastoise + , caterpie + , metapod + , butterfree + , pikachu + , vulpix + , oddish + ) where + +import Ldap.Client as Ldap + + +locally :: (Ldap -> IO a) -> IO (Either LdapError a) +locally = Ldap.with localhost port + +localhost :: Host +localhost = Plain "localhost" port :: Num a => a port = 24620 + +dns :: [SearchEntry] -> [Dn] +dns (SearchEntry dn _ : es) = dn : dns es +dns [] = [] + +bulbasaur :: Dn +bulbasaur = Dn "cn=bulbasaur,o=localhost" + +ivysaur :: Dn +ivysaur = Dn "cn=ivysaur,o=localhost" + +venusaur :: Dn +venusaur = Dn "cn=venusaur,o=localhost" + +charmander :: Dn +charmander = Dn "cn=charmander,o=localhost" + +charmeleon :: Dn +charmeleon = Dn "cn=charmeleon,o=localhost" + +charizard :: Dn +charizard = Dn "cn=charizard,o=localhost" + +squirtle :: Dn +squirtle = Dn "cn=squirtle,o=localhost" + +wartortle :: Dn +wartortle = Dn "cn=wartortle,o=localhost" + +blastoise :: Dn +blastoise = Dn "cn=blastoise,o=localhost" + +caterpie :: Dn +caterpie = Dn "cn=caterpie,o=localhost" + +metapod :: Dn +metapod = Dn "cn=metapod,o=localhost" + +butterfree :: Dn +butterfree = Dn "cn=butterfree,o=localhost" + +pikachu :: Dn +pikachu = Dn "cn=pikachu,o=localhost" + +vulpix :: Dn +vulpix = Dn "cn=vulpix,o=localhost" + +oddish :: Dn +oddish = Dn "cn=oddish,o=localhost"