diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index e082071..5136825 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -10,7 +10,7 @@ module Database.Esqueleto.Internal.Sql import Control.Applicative (Applicative(..), (<$>)) import Control.Arrow ((***), first) import Control.Exception (throw, throwIO) -import Control.Monad (ap) +import Control.Monad (ap, MonadPlus(..)) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Logger (MonadLogger) import Control.Monad.Trans.Resource (MonadResourceBase) @@ -79,10 +79,18 @@ collectOnClauses = go [] go acc (f:fs) = go (f:acc) fs go acc [] = return $ reverse acc - findMatching (FromJoin l k r Nothing : acc) expr = - return (FromJoin l k r (Just expr) : acc) - findMatching (f : acc) expr = (f:) <$> findMatching acc expr - findMatching [] expr = Left expr + findMatching (f : acc) expr = + case tryMatch expr f of + Just f' -> return (f' : acc) + Nothing -> (f:) <$> findMatching acc expr + findMatching [] expr = Left expr + + tryMatch expr (FromJoin l k r Nothing) = + return (FromJoin l k r (Just expr)) + tryMatch expr (FromJoin l k r j@(Just _)) = + ((\r' -> FromJoin l k r' j) <$> tryMatch expr r) `mplus` + ((\l' -> FromJoin l' k r j) <$> tryMatch expr l) + tryMatch _ _ = mzero -- | A complete @WHERE@ clause.