diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index caf67c6..088d349 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -16,12 +16,19 @@ -- @ module Database.Esqueleto ( -- * @esqueleto@'s Language - Esqueleto( where_, orderBy, asc, desc, sub, (^.), val - , isNothing, just, nothing, not_, (==.), (>=.) + Esqueleto( where_, on, orderBy, asc, desc, sub, (^.), (?.) + , val, isNothing, just, nothing, not_, (==.), (>=.) , (>.), (<=.), (<.), (!=.), (&&.), (||.) , (+.), (-.), (/.), (*.) ) , from , OrderBy + -- ** Joins + , InnerJoin(..) + , CrossJoin(..) + , LeftOuterJoin(..) + , RightOuterJoin(..) + , FullOuterJoin(..) + , OnClauseWithoutMatchingJoinException(..) -- * SQL backend , SqlQuery diff --git a/src/Database/Esqueleto/Internal/Language.hs b/src/Database/Esqueleto/Internal/Language.hs index 991bea4..44bdabd 100644 --- a/src/Database/Esqueleto/Internal/Language.hs +++ b/src/Database/Esqueleto/Internal/Language.hs @@ -1,11 +1,22 @@ -{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, FunctionalDependencies, TypeFamilies, EmptyDataDecls #-} +{-# LANGUAGE FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, FunctionalDependencies, TypeFamilies, EmptyDataDecls, UndecidableInstances, DeriveDataTypeable #-} module Database.Esqueleto.Internal.Language ( Esqueleto(..) , from + , InnerJoin(..) + , CrossJoin(..) + , LeftOuterJoin(..) + , RightOuterJoin(..) + , FullOuterJoin(..) + , JoinKind(..) + , IsJoinKind(..) + , OnClauseWithoutMatchingJoinException(..) + , PreprocessedFrom , OrderBy ) where import Control.Applicative (Applicative(..), (<$>)) +import Control.Exception (Exception) +import Data.Typeable (Typeable) import Database.Persist.GenericSql import Database.Persist.Store @@ -13,14 +24,94 @@ import Database.Persist.Store -- | Finally tagless representation of @esqueleto@'s EDSL. class (Functor query, Applicative query, Monad query) => Esqueleto query expr backend | query -> expr backend, expr -> query backend where - -- | (Internal) Single entity version of 'from'. - fromSingle :: ( PersistEntity val - , PersistEntityBackend val ~ backend) - => query (expr (Entity val)) + -- | (Internal) Start a 'from' query with an entity. 'from' + -- does two kinds of magic using 'fromStart', 'fromJoin' and + -- 'fromFinish': + -- + -- 1. The simple but tedious magic of allowing tuples to be + -- used. + -- + -- 2. The more advanced magic of creating @JOIN@s. The + -- @JOIN@ is processed from right to left. The rightmost + -- entity of the @JOIN@ is created with 'fromStart'. Each + -- @JOIN@ step is then translated into a call to 'fromJoin'. + -- In the end, 'fromFinish' is called to materialize the + -- @JOIN@. + fromStart + :: ( PersistEntity a + , PersistEntityBackend a ~ backend ) + => query (expr (PreprocessedFrom (expr (Entity a)))) + -- | (Internal) Same as 'fromStart', but entity may be missing. + fromStartMaybe + :: ( PersistEntity a + , PersistEntityBackend a ~ backend ) + => query (expr (PreprocessedFrom (expr (Maybe (Entity a))))) + -- | (Internal) Do a @JOIN@. + fromJoin + :: ( PersistEntity a + , PersistEntityBackend a ~ backend + , IsJoinKind join ) + => expr (PreprocessedFrom b) + -> query (expr (PreprocessedFrom (join (expr (Entity a)) b))) + -- | (Internal) Finish a @JOIN@. + fromFinish + :: expr (PreprocessedFrom a) + -> query a -- | @WHERE@ clause: restrict the query's result. where_ :: expr (Single Bool) -> query () + -- | @ON@ clause: restrict the a @JOIN@'s result. The @ON@ + -- clause will be applied to the /last/ @JOIN@ that does not + -- have an @ON@ clause yet. If there are no @JOIN@s without + -- @ON@ clauses (either because you didn't do any @JOIN@, or + -- because all @JOIN@s already have their own @ON@ clauses), a + -- runtime exception 'OnClauseWithoutMatchingJoinException' is + -- thrown. @ON@ clauses are optional when doing @JOIN@s. + -- + -- On the simple case of doing just one @JOIN@, for example + -- + -- @ + -- select $ + -- from $ \(foo `InnerJoin` bar) -> do + -- on (foo ^. FooId ==. bar ^. BarFooId) + -- ... + -- @ + -- + -- there's no ambiguity and the rules above just mean that + -- you're allowed to call 'on' only once (as in SQL). If you + -- have many joins, then the 'on's are applied on the /reverse/ + -- order that the @JOIN@s appear. For example: + -- + -- @ + -- select $ + -- from $ \(foo `InnerJoin` bar `InnerJoin` baz) -> do + -- on (baz ^. BazId ==. bar ^. BarBazId) + -- on (foo ^. FooId ==. bar ^. BarFooId) + -- ... + -- @ + -- + -- The order is /reversed/ in order to improve composability. + -- For example, consider @query1@ and @query2@ below: + -- + -- @ + -- let query1 = + -- from $ \(foo `InnerJoin` bar) -> do + -- on (foo ^. FooId ==. bar ^. BarFooId) + -- + -- query2 = + -- from $ \(mbaz `LeftOuterJoin` quux) -> do + -- return (mbaz ?. BazName, quux) + -- + -- test1 = (,) <$> query1 <*> query2 + -- test2 = flip (,) <$> query2 <*> query1 + -- @ + -- + -- If the order was *not* reversed, then @test2@ would be + -- broken: @query1@'s 'on' would refer to @query2@'s + -- 'LeftOuterJoin'. + on :: expr (Single Bool) -> query () + -- | @ORDER BY@ clause. See also 'asc' and 'desc'. orderBy :: [expr OrderBy] -> query () @@ -37,6 +128,10 @@ class (Functor query, Applicative query, Monad query) => (^.) :: (PersistEntity val, PersistField typ) => expr (Entity val) -> EntityField val typ -> expr (Single typ) + -- | Project a field of an entity that may be null. + (?.) :: (PersistEntity val, PersistField typ) => + expr (Maybe (Entity val)) -> EntityField val typ -> expr (Single (Maybe typ)) + -- | Lift a constant value from Haskell-land to the query. val :: PersistField typ => typ -> expr (Single typ) @@ -75,6 +170,83 @@ infixl 6 +., -. infix 4 ==., >=., >., <=., <., !=. infixr 3 &&. infixr 2 ||. +infixr 2 `InnerJoin`, `CrossJoin`, `LeftOuterJoin`, `RightOuterJoin`, `FullOuterJoin` + + +-- | Data type that represents an @INNER JOIN@ (see 'LeftOuterJoin' for an example). +data InnerJoin a b = a `InnerJoin` b + +-- | Data type that represents an @CROSS JOIN@ (see 'LeftOuterJoin' for an example). +data CrossJoin a b = a `CrossJoin` b + +-- | Data type that represents an @LEFT OUTER JOIN@. For example, +-- +-- @ +-- select $ +-- from $ \(person `LeftOuterJoin` pet) -> +-- ... +-- @ +-- +-- is translated into +-- +-- @ +-- SELECT ... +-- FROM Person AS TB LEFT OUTER JOIN Pet AS TA +-- ... +-- @ +data LeftOuterJoin a b = a `LeftOuterJoin` b + +-- | Data type that represents an @RIGHT OUTER JOIN@ (see 'LeftOuterJoin' for an example). +data RightOuterJoin a b = a `RightOuterJoin` b + +-- | Data type that represents an @FULL OUTER JOIN@ (see 'LeftOuterJoin' for an example). +data FullOuterJoin a b = a `FullOuterJoin` b + + +-- | (Internal) A kind of @JOIN@. +data JoinKind = + InnerJoinKind -- ^ @INNER JOIN@ + | CrossJoinKind -- ^ @CROSS JOIN@ + | LeftOuterJoinKind -- ^ @LEFT OUTER JOIN@ + | RightOuterJoinKind -- ^ @RIGHT OUTER JOIN@ + | FullOuterJoinKind -- ^ @FULL OUTER JOIN@ + + +-- | (Internal) Functions that operate on types (that should be) +-- of kind 'JoinKind'. +class IsJoinKind join where + -- | (Internal) @smartJoin a b@ is a @JOIN@ of the correct kind. + smartJoin :: a -> b -> join a b + -- | (Internal) Reify a @JoinKind@ from a @JOIN@. This + -- function is non-strict. + reifyJoinKind :: join a b -> JoinKind +instance IsJoinKind InnerJoin where + smartJoin a b = a `InnerJoin` b + reifyJoinKind _ = InnerJoinKind +instance IsJoinKind CrossJoin where + smartJoin a b = a `CrossJoin` b + reifyJoinKind _ = CrossJoinKind +instance IsJoinKind LeftOuterJoin where + smartJoin a b = a `LeftOuterJoin` b + reifyJoinKind _ = LeftOuterJoinKind +instance IsJoinKind RightOuterJoin where + smartJoin a b = a `RightOuterJoin` b + reifyJoinKind _ = RightOuterJoinKind +instance IsJoinKind FullOuterJoin where + smartJoin a b = a `FullOuterJoin` b + reifyJoinKind _ = FullOuterJoinKind + + +-- | Exception thrown whenever 'on' is used to create an @ON@ +-- clause but no matching @JOIN@ is found. +data OnClauseWithoutMatchingJoinException = + OnClauseWithoutMatchingJoinException String + deriving (Eq, Ord, Show, Typeable) +instance Exception OnClauseWithoutMatchingJoinException where + + +-- | (Internal) Phantom type used to process 'from' (see 'fromStart'). +data PreprocessedFrom a -- | Phantom type used by 'orderBy', 'asc' and 'desc'. @@ -98,14 +270,45 @@ from :: From query expr backend a => (a -> query b) -> query b from = (from_ >>=) +-- | (Internal) Class that implements the tuple 'from' magic (see +-- 'fromStart'). class Esqueleto query expr backend => From query expr backend a where from_ :: query a instance ( Esqueleto query expr backend - , PersistEntity val - , PersistEntityBackend val ~ backend + , FromPreprocess query expr backend (expr (Entity val)) ) => From query expr backend (expr (Entity val)) where - from_ = fromSingle + from_ = fromPreprocess >>= fromFinish + +instance ( Esqueleto query expr backend + , FromPreprocess query expr backend (expr (Maybe (Entity val))) + ) => From query expr backend (expr (Maybe (Entity val))) where + from_ = fromPreprocess >>= fromFinish + +instance ( Esqueleto query expr backend + , FromPreprocess query expr backend (InnerJoin (expr (Entity val)) b) + ) => From query expr backend (InnerJoin (expr (Entity val)) b) where + from_ = fromPreprocess >>= fromFinish + +instance ( Esqueleto query expr backend + , FromPreprocess query expr backend (CrossJoin (expr (Entity val)) b) + ) => From query expr backend (CrossJoin (expr (Entity val)) b) where + from_ = fromPreprocess >>= fromFinish + +instance ( Esqueleto query expr backend + , FromPreprocess query expr backend (LeftOuterJoin (expr (Entity val)) b) + ) => From query expr backend (LeftOuterJoin (expr (Entity val)) b) where + from_ = fromPreprocess >>= fromFinish + +instance ( Esqueleto query expr backend + , FromPreprocess query expr backend (RightOuterJoin (expr (Entity val)) b) + ) => From query expr backend (RightOuterJoin (expr (Entity val)) b) where + from_ = fromPreprocess >>= fromFinish + +instance ( Esqueleto query expr backend + , FromPreprocess query expr backend (FullOuterJoin (expr (Entity val)) b) + ) => From query expr backend (FullOuterJoin (expr (Entity val)) b) where + from_ = fromPreprocess >>= fromFinish instance ( From query expr backend a , From query expr backend b @@ -162,3 +365,30 @@ instance ( From query expr backend a , From query expr backend h ) => From query expr backend (a, b, c, d, e, f, g, h) where from_ = (,,,,,,,) <$> from_ <*> from_ <*> from_ <*> from_ <*> from_ <*> from_ <*> from_ <*> from_ + + + +-- | (Internal) Class that implements the @JOIN@ 'from' magic +-- (see 'fromStart'). +class Esqueleto query expr backend => FromPreprocess query expr backend a where + fromPreprocess :: query (expr (PreprocessedFrom a)) + +instance ( Esqueleto query expr backend + , PersistEntity val + , PersistEntityBackend val ~ backend + ) => FromPreprocess query expr backend (expr (Entity val)) where + fromPreprocess = fromStart + +instance ( Esqueleto query expr backend + , PersistEntity val + , PersistEntityBackend val ~ backend + ) => FromPreprocess query expr backend (expr (Maybe (Entity val))) where + fromPreprocess = fromStartMaybe + +instance ( Esqueleto query expr backend + , PersistEntity val + , PersistEntityBackend val ~ backend + , IsJoinKind join + , FromPreprocess query expr backend b + ) => FromPreprocess query expr backend (join (expr (Entity val)) b) where + fromPreprocess = fromPreprocess >>= fromJoin diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index e901b32..0230ab5 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -9,7 +9,7 @@ module Database.Esqueleto.Internal.Sql import Control.Applicative (Applicative(..), (<$>)) import Control.Arrow ((***), first) -import Control.Exception (throwIO) +import Control.Exception (throw, throwIO) import Control.Monad (ap) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Logger (MonadLogger) @@ -62,7 +62,27 @@ instance Monoid SideData where -- | A part of a @FROM@ clause. -data FromClause = From Ident EntityDef +data FromClause = + FromStart Ident EntityDef + | FromJoin Ident EntityDef JoinKind FromClause (Maybe (SqlExpr (Single Bool))) + | OnClause (SqlExpr (Single Bool)) + + +-- | Collect 'OnClause's on 'FromJoin's. Returns the first +-- unmatched 'OnClause's data on error. Returns a list without +-- 'OnClauses' on success. +collectOnClauses :: [FromClause] -> Either (SqlExpr (Single Bool)) [FromClause] +collectOnClauses = go [] + where + go [] (f@(FromStart _ _):fs) = fmap (f:) (go [] fs) -- fast path + go acc (OnClause expr :fs) = findMatching acc expr >>= flip go fs + go acc (f:fs) = go (f:acc) fs + go acc [] = return $ reverse acc + + findMatching (FromJoin i e k f Nothing : acc) expr = + return (FromJoin i e k f (Just expr) : acc) + findMatching (f : acc) expr = (f:) <$> findMatching acc expr + findMatching [] expr = Left expr -- | A complete @WHERE@ clause. @@ -98,25 +118,50 @@ idents _ = -- | An expression on the SQL backend. data SqlExpr a where EEntity :: Ident -> SqlExpr (Entity val) + EMaybe :: SqlExpr a -> SqlExpr (Maybe a) ERaw :: (Escape -> (TLB.Builder, [PersistValue])) -> SqlExpr (Single a) EOrderBy :: OrderByType -> SqlExpr (Single a) -> SqlExpr OrderBy + EPreprocessedFrom :: a -> FromClause -> SqlExpr (PreprocessedFrom a) data OrderByType = ASC | DESC type Escape = DBName -> TLB.Builder + instance Esqueleto SqlQuery SqlExpr SqlPersist where - fromSingle = Q $ do + fromStart = Q $ do ident <- S.supply - let from_ = From ident $ entityDef (getVal ret) - ret = EEntity ident - getVal :: SqlExpr (Entity val) -> val - getVal = error "SqlQuery/getVal: never here" + let ret = EEntity ident + from_ = FromStart ident $ entityDef (getVal ret) + return (EPreprocessedFrom ret from_) + + fromStartMaybe = maybelize <$> fromStart + where + maybelize :: SqlExpr (PreprocessedFrom (SqlExpr (Entity a))) + -> SqlExpr (PreprocessedFrom (SqlExpr (Maybe (Entity a)))) + maybelize (EPreprocessedFrom ret from_) = EPreprocessedFrom (EMaybe ret) from_ + maybelize _ = error "Esqueleto/Sql/fromStartMaybe: never here (see GHC #6124)" + + fromJoin (EPreprocessedFrom rhsRet rhsFrom) = Q $ do + lhsIdent <- S.supply + let lhsRet = EEntity lhsIdent + ret = smartJoin lhsRet rhsRet + from_ = FromJoin lhsIdent (entityDef $ getVal lhsRet) -- LHS + (reifyJoinKind ret) -- JOIN + rhsFrom -- RHS + Nothing -- ON + return (EPreprocessedFrom ret from_) + fromJoin _ = error "Esqueleto/Sql/fromJoin: never here (see GHC #6124)" + + fromFinish (EPreprocessedFrom ret from_) = Q $ do W.tell mempty { sdFromClause = [from_] } return ret + fromFinish _ = error "Esqueleto/Sql/fromFinish: never here (see GHC #6124)" where_ expr = Q $ W.tell mempty { sdWhereClause = Where expr } + on expr = Q $ W.tell mempty { sdFromClause = [OnClause expr] } + orderBy exprs = Q $ W.tell mempty { sdOrderByClause = exprs } asc = EOrderBy ASC desc = EOrderBy DESC @@ -127,6 +172,13 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where where name esc = esc . fieldDB . persistFieldDef _ ^. _ = error "Esqueleto/Sql/(^.): never here (see GHC #6124)" + EMaybe r ?. field = maybelize (r ^. field) + where + maybelize :: SqlExpr (Single a) -> SqlExpr (Single (Maybe a)) + maybelize (ERaw f) = ERaw f + maybelize _ = error "Esqueleto/Sql/(?.): never here 1 (see GHC #6124)" + _ ?. _ = error "Esqueleto/Sql/(?.): never here 2 (see GHC #6124)" + val = ERaw . const . (,) "?" . return . toPersistValue isNothing (ERaw f) = ERaw $ first ((<> " IS NULL") . parens) . f @@ -153,6 +205,9 @@ instance Esqueleto SqlQuery SqlExpr SqlPersist where (*.) = binop " * " +getVal :: SqlExpr (Entity val) -> val +getVal = error "SqlQuery/getVal: never here" + fromDBName :: Connection -> DBName -> TLB.Builder fromDBName conn = TLB.fromText . escapeName conn @@ -237,9 +292,35 @@ makeSelect esc ret = first ("SELECT " <>) (sqlSelectCols esc ret) makeFrom :: Escape -> [FromClause] -> (TLB.Builder, [PersistValue]) makeFrom _ [] = mempty -makeFrom esc fs = ("\nFROM " <> uncommas (map mk fs), mempty) +makeFrom esc fs = ret where - mk (From (I i) def) = esc (entityDB def) <> (" AS " <> i) + ret = case collectOnClauses fs of + Left expr -> throw $ mkExc expr + Right fs' -> first ("\nFROM " <>) $ uncommas' (map mk fs') + + mk (FromStart (I i) def) = base i def + mk (FromJoin (I i) def kind rest monClause) = + mconcat [ base i def + , (fromKind kind, mempty) + , mk rest + , maybe mempty makeOnClause monClause ] + mk (OnClause _) = error "Esqueleto/Sql/makeFrom: never here (is collectOnClauses working?)" + + base i def = (esc (entityDB def) <> (" AS " <> i), mempty) + + fromKind InnerJoinKind = " INNER JOIN " + fromKind CrossJoinKind = " CROSS JOIN " + fromKind LeftOuterJoinKind = " LEFT OUTER JOIN " + fromKind RightOuterJoinKind = " RIGHT OUTER JOIN " + fromKind FullOuterJoinKind = " FULL OUTER JOIN " + + makeOnClause (ERaw f) = first (" ON " <>) (f esc) + makeOnClause _ = error "Esqueleto/Sql/makeFrom/makeOnClause: never here (see GHC #6124)" + + mkExc (ERaw f) = + OnClauseWithoutMatchingJoinException $ + TL.unpack $ TLB.toLazyText $ fst (f esc) + mkExc _ = OnClauseWithoutMatchingJoinException "???" makeWhere :: Escape -> WhereClause -> (TLB.Builder, [PersistValue]) @@ -309,6 +390,17 @@ instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where getEntityVal :: SqlExpr (Entity a) -> a getEntityVal = error "Esqueleto/Sql/getEntityVal" +instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entity a)) where + sqlSelectCols escape (EMaybe ent) = sqlSelectCols escape ent + sqlSelectCols _ _ = error "Esqueleto/Sql/sqlSelectCols[Maybe Entity]: never here (see GHC #6124)" + sqlSelectColCount = sqlSelectColCount . fromEMaybe + where + fromEMaybe :: SqlExpr (Maybe e) -> SqlExpr e + fromEMaybe = error "Esqueleto/Sql/sqlSelectColCount[Maybe Entity]/fromEMaybe" + sqlSelectProcessRow cols + | all (== PersistNull) cols = return Nothing + | otherwise = Just <$> sqlSelectProcessRow cols + instance PersistField a => SqlSelect (SqlExpr (Single a)) (Single a) where sqlSelectCols esc (ERaw f) = let (b, vals) = f esc in (parens b, vals) diff --git a/test/Test.hs b/test/Test.hs index 5564c61..fab2227 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -107,6 +107,61 @@ main = do , (Single (personName p2), Single (personName p1)) , (Single (personName p2), Single (personName p2)) ] + it "works with a LEFT OUTER JOIN" $ + run $ do + p1e <- insert' p1 + p2e <- insert' p2 + p3e <- insert' p3 + p4e <- insert' p4 + b12e <- insert' $ BlogPost "b" (entityKey p1e) + b11e <- insert' $ BlogPost "a" (entityKey p1e) + b31e <- insert' $ BlogPost "c" (entityKey p3e) + ret <- select $ + from $ \(p `LeftOuterJoin` mb) -> do + on (just (p ^. PersonId) ==. mb ?. BlogPostAuthorId) + orderBy [ asc (p ^. PersonName), asc (mb ?. BlogPostTitle) ] + return (p, mb) + liftIO $ ret `shouldBe` [ (p1e, Just b11e) + , (p1e, Just b12e) + , (p4e, Nothing) + , (p3e, Just b31e) + , (p2e, Nothing) ] + + it "throws an error for using on without joins" $ + run (do + p1e <- insert' p1 + p2e <- insert' p2 + p3e <- insert' p3 + p4e <- insert' p4 + b12e <- insert' $ BlogPost "b" (entityKey p1e) + b11e <- insert' $ BlogPost "a" (entityKey p1e) + b31e <- insert' $ BlogPost "c" (entityKey p3e) + ret <- select $ + from $ \(p, mb) -> do + on (just (p ^. PersonId) ==. mb ?. BlogPostAuthorId) + orderBy [ asc (p ^. PersonName), asc (mb ?. BlogPostTitle) ] + return (p, mb) + return () + ) `shouldThrow` (\(OnClauseWithoutMatchingJoinException _) -> True) + + it "throws an error for using too many ons" $ + run (do + p1e <- insert' p1 + p2e <- insert' p2 + p3e <- insert' p3 + p4e <- insert' p4 + b12e <- insert' $ BlogPost "b" (entityKey p1e) + b11e <- insert' $ BlogPost "a" (entityKey p1e) + b31e <- insert' $ BlogPost "c" (entityKey p3e) + ret <- select $ + from $ \(p `FullOuterJoin` mb) -> do + on (just (p ^. PersonId) ==. mb ?. BlogPostAuthorId) + on (just (p ^. PersonId) ==. mb ?. BlogPostAuthorId) + orderBy [ asc (p ^. PersonName), asc (mb ?. BlogPostTitle) ] + return (p, mb) + return () + ) `shouldThrow` (\(OnClauseWithoutMatchingJoinException _) -> True) + describe "select/where_" $ do it "works for a simple example with (==.)" $ run $ do